解决方案是围绕 TextVectorization
对象定义一个包装器,并将自定义标准化器作为一个方法使用。此外,我们需要在保存配置到pickle文件时排除可调用对象。以下是修复后的代码:
@keras.utils.register_keras_serializable(package='custom_layers', name='TextVectorizer')
class TextVectorizer(layers.Layer):
'''英文-西班牙语文本向量化器'''
def __init__(self, max_tokens=None, output_mode='int', output_sequence_length=None, standardize='lower_and_strip_punctuation', vocabulary=None, config=None):
super().__init__()
if config:
self.vectorization = layers.TextVectorization.from_config(config)
else:
self.max_tokens = max_tokens
self.output_mode = output_mode
self.output_sequence_length = output_sequence_length
self.vocabulary = vocabulary
if standardize != 'lower_and_strip_punctuation':
self.vectorization = layers.TextVectorization(max_tokens=self.max_tokens,
output_mode=self.output_mode,
output_sequence_length=self.output_sequence_length,
vocabulary=self.vocabulary,
standardize=self.standardize)
else:
self.vectorization = layers.TextVectorization(max_tokens=self.max_tokens,
output_mode=self.output_mode,
output_sequence_length=self.output_sequence_length,
vocabulary=self.vocabulary)
def standardize(self, input_string, preserve=['[', ']'], add=['¿']) -> str:
strip_chars = string.punctuation
for item in add:
strip_chars += item
for item in preserve:
strip_chars = strip_chars.replace(item, '')
lowercase = tf.strings.lower(input_string)
output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')
return output
def __call__(self, *args, **kwargs):
return self.vectorization(*args, **kwargs)
def get_config(self):
config = self.vectorization.get_config()
return {k: v for k, v in config.items() if not callable(v)}
@classmethod
def from_config(cls, config):
return cls(config=config)
def set_weights(self, weights):
self.vectorization.set_weights(weights)
def adapt(self, dataset):
self.vectorization.adapt(dataset)
def get_vocabulary(self):
return self.vectorization.get_vocabulary()
在训练阶段适应并保存权重:
vocab_size = 15000
sequence_length = 20
source_vectorization = TextVectorizer(max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length)
target_vectorization = TextVectorizer(max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length + 1,
standardize='spanish')
# 假设train_pairs已定义
train_english_texts = [pair[0] for pair in train_pairs]
train_spanish_texts = [pair[1] for pair in train_pairs]
source_vectorization.adapt(train_english_texts)
target_vectorization.adapt(train_spanish_texts)
with open('ckpts/english_vectorization.pkl', 'wb') as f:
pickle.dump({'config': source_vectorization.get_config(), 'weights': source_vectorization.get_weights()}, f)
with open('ckpts/spanish_vectorization.pkl', 'wb') as f:
pickle.dump({'config': target_vectorization.get_config(), 'weights': target_vectorization.get_weights()}, f)
在推理阶段加载并使用它们:
with open('ckpts/english_vectorization.pkl', 'rb') as f:
vectorization_data = pickle.load(f)
source_vectorization = TextVectorizer.from_config(vectorization_data['config'])
source_vectorization.set_weights(vectorization_data['weights'])
with open('ckpts/spanish_vectorization.pkl', 'rb') as f:
vectorization_data = pickle.load(f)
target_vectorization = TextVectorizer.from_config(vectorization_data['config'])
target_vectorization.set_weights(vectorization_data['weights'])