1 引言

各位朋友大家好,欢迎来到月来客栈。经过前面6篇文章的介绍,对于Transformer相信大家都应该理解得差不多了。不过要想做到灵活运用Transformer结构,那就还得多看看其它场景下的运用。在接下来的这篇文章中,笔者将会以一个含有70余万条的对联数据集为例,来搭建一个基于Transformer结构的对联生成模型。同时,这也是介绍Transformer结构系列文章的最后一篇。

图 1. 基于Transformer的对联生成模型

如图1所示便是一个基于Transformer结构的对联生成模型。可以看出,其实它与前面介绍的基于Transformer结构的翻译模型没有特别的变换,唯一不同的可能就是在对联生成模型中解码器和编码器共用了同一个词表(因为两者都是中文)。

2 数据预处理

2.1 语料介绍

按老规矩,在正式介绍模型搭建之前我们还是先来看看数据集都长什么样。本次所使用到的数据集是一个网上公开的对联数据集,在github中搜索"couplet-dataset”就能找到,其一共包含有770491条训练样本,4000条测试样本。和翻译数据集类似,对联数据集也包含上下两句:

如上所示便是3条样本,分别存放在in.txtout.txt这两个文件中。可以看出,原始数据已经做了分字这步操作,所有后续我们只需要进行简单的split操作即可。

2.2 数据集构建

总体上来说对联生成模型的数据集构建过程和翻译模型的数据集构建过程基本上没有太大差别,主要步骤同样也是:①构建字典;②将文本中的每一个词(字)转换为Token序列;③对不同长度的样本序列按照某个标准进行padding处理;④构建DataLoader类。

第1步:定义tokenize

由于原始数据每个字已经被空格隔开了,所以这里tokenizer的定义只需要进行split操作即可,代码如下:

可以看到,其实也非常简单。例如对于如下文本来说

其tokenize后的结果为:

第2步:建立词表

在介绍完tokenize的实现方法后,我们就可以正式通过torchtext.vocab中的Vocab方法来构建词典了,代码如下:

在上述代码中,第3行代码用来指定特殊的字符;第5-10行分别用来遍历in.txt文件和out.txt文件中的每一个样本(每行一个)并进行tokenize和计数,其中对于counter.update进行介绍可以参考[1];第8行则是返回最后得到词典。值得注意的是,由于在对联生成这一场景中编码器和解码器共用的是一个词表,所以这里同时对in.txtout.txt文件进行了遍历。

在完成上述过程后,我们将得到一个Vocab类的实例化对象,即:

此时,我们就需要定义一个类,并在类的初始化过程中根据训练语料完成字典的构建,代码如下:

第3步:转换为Token序列

在得到构建的字典后,便可以通过如下函数来将训练集和测试集转换成Token序列:

在上述代码中,第11-4行分别用来将原始序列上联和目标序列下联转换为对应词表中的Token形式。在处理完成后,就会得到类似如下的结果:

其中左边的一列就是原始序列上联的Token形式,右边一列就是目标序列下联的Token形式,每一行构成一个样本。

第4步:padding处理

从上面的输出结果可以看到,无论是对于原始序列来说还是目标序列来说,在不同的样本中其对应长度都不尽相同。但是在将数据输入到相应模型时却需要保持同样的长度,因此在这里我们就需要对Token序列化后的样本进行padding处理。同时需要注意的是,一般在这种生成模型中,模型在训练过程中只需要保证同一个batch中所有的原始序列等长,所有的目标序列等长即可,也就是说不需要在整个数据集中所有样本都保证等长。

因此,在实际处理过程中无论是原始序列还是目标序列都会以每个batch中最长的样本为标准对其它样本进行padding,具体代码如下:

在上述代码中,第6-7行用来在目标序列的首尾加上特定的起止符;第9-10行则是分别对一个batch中的原始序列和目标序列以各自当中最长的样本为标准进行padding(这里的pad_sequence导入自torch.nn.utils.rnn)。

第5步:构造mask向量

在处理完成前面几个步骤后,进一步需要根据src_inputtgt_input来构造相关的mask向量,具体代码如下:

在上述代码中,第1-4行是用来生成一个形状为[sz,sz]的注意力掩码矩阵,用于在解码过程中掩盖当前position之后的position;第6-17行用来返回Transformer中各种情况下的mask矩阵,其中src_mask在这里并没有作用。

第6步:构造DataLoade与使用示例

经过前面5步的操作,整个数据集的构建就算是已经基本完成了,只需要再构造一个DataLoader迭代器即可,代码如下:

在上述代码中,第2-3行便是分别用来将训练集和测试集转换为Token序列;第4-7行则是分别构造2个DataLoader,其中generate_batch将作为一个参数传入来对每个batch的样本进行处理。在完成类LoadCoupletDataset所有的编码过程后,便可以通过如下形式进行使用:

在介绍完数据集构建的整个过程后,下面就开始正式进入到翻译模型的构建中。如果对于这部分不是特别理解的话,建议先看这篇文章[2]中的数据处理流程图进行理解。

3 基于Transformer的对联生成模型

3.1 网络结构

总体来说,基于Transformer的对联生成模型的网络结构其实就是图1所展示的所有部分,只是在前面介绍Transformer网络结构[3]时笔者并没有把Embedding部分的实现给加进去。这是因为对于不同的文本生成模型,其Embedding部分会不一样(例如在本场景中编码器和解码器共用一个TokenEmbedding即可,而在翻译模型中就需要两个),所以将两者进行了拆分。同时,待模型训练完成后,在inference过程中Encoder只需要执行一次,所以在此过程中也需要单独使用Transformer中的Encoder和Decoder。

首先,我们需要定义一个名为CoupletModel的类,其前向传播过程代码如下所示:

在上述代码中,第7-12行便是用来定义一个Transformer结构;第13-15分别用来定义Positional Embedding、Token Embedding和最后的分类器(需要注意的是这里是共用同一个Token Embedding);第28-38行便是用来执行整个前向传播过程,其中Transformer的整个前向传播过程在前一篇[3]文章中已经介绍过,在这里就不再赘述。

在定义完logits的前向传播过后,便可以通过如下形式进行使用:

接着,我们需要再定义一个EncoderDecoder在inference中进行使用,代码如下:

在上述代码中,第1-5行用于在inference时对输入序列进行编码并得到memory(只需要执行一次);第7-11行用于根据memory和当前解码时刻的输入对输出进行预测,需要循环执行多次,这部分内容详见模型预测部分。

3.2 模型训练

在定义完成整个对联生成模型的网络结构后下面就可以开始训练模型了。由于这部分代码较长,所以下面笔者依旧以分块的形式进行介绍:

第1步:载入数据集

首先我们可以根据前面的介绍,通过类LoadCoupletDataset来载入数据集,其中config中定义了模型所涉及到的所有配置参数。

第2步:定义模型并初始化权重

在载入数据后,便可以定义模型CoupletModel,并根据相关参数对其进行实例化;同时,可以对整个模型中的所有参数进行一个初始化操作。

第3步:定义损失学习率与优化器

在上述代码中,第1行是定义交叉熵损失函数,并同时指定需要忽略的索引ignore_index。因为根据tgt_output可知,有些位置上的标签值其实是Padding后的结果,因此在计算损失的时候需要将这些位置给忽略掉。第2行代码则是论文[4]中所提出来的动态学习率计算过程,其计算公式为:

具体实现代码为:

通过CustomSchedule,就能够在训练过程中动态的调整学习率。学习率随step增加而变换的结果如图2所示:

图 2. 动态学习率变化过程图

从图2可以看出,在前warm_up个step中,学习率是线性增长的,在这之后便是非线性下降,直至收敛与0.0004。

第4步:开始训练

在上述代码中,第5-9行是用来得到模型各个部分的输入;第10-18行是计算模型整个前向传播的过程;第20-24行则是执行损失计算与反向传播;第26-28则是将每个step更新后的学习率送入到模型中并进行参数更新;第30行是用来计算模型预测的准确率,具体过程将在后续文章中进行介绍。以下便是模型训练过程中的输出:

3.3 模型预测

在介绍完模型的训练过程后接下来就来看模型的预测部分。生成模型的预测部分不像普通的分类任务只需要将网络最后的输出做argmax操作即可,生成模型在预测过程中往往需要按时刻一步步进行来进行。因此,下面我们这里定义一个couplet函数来执行这一过程,具体代码如下:

在上述代码中,第5行是将待翻译的源序列进行序列化操作;第7-10行则是通过函数greedy_decode函数来对输入进行解码;第11行则是将最后解码后的结果由Token序列在转换成实际的目标语言。同时,greedy_decode函数的实现如下:

在上述代码中,第3行是将源序列输入到Transformer的编码器中进行编码并得到Memory;第4-5行是初始化解码阶段输入的第1个时刻的,在这里也就是' ';第6-18行则是整个循环解码过程,在下一个时刻为EOS_IDX或者达到最大长度后停止;第8-9行是根据当前解码器输入的长度生成注意力掩码矩阵tgt_mask;第10行是根据memory以及当前时刻的输入对当前时刻的输出进行解码;第12-14行则是分类得到当前时刻的解码输出结果;第15行则是将当前时刻的解码输出结果头当前时刻之前所有的输入进行拼接,以此再对下一个时刻的输出进行预测。

最后,我们只需要调用如下函数便可以完成对原始输入上联的下联生成任务:

在上述代码中,第6-15行是定义网络结构,以及恢复本地保存的网络权重;第16行则是开始执行下联生成任务;第19-26行为生成示例,其输出结果为:

以上完整代码可参见[5]。

4 总结

在这篇文章中,笔者首先介绍了对联生成模型的整个数据预处理过程;接着笔者介绍了基于Transformer结构的对联模型的整体构成,然后循序渐进地带着各位读者来实现了整个模型,包括基础结构的搭建、模型训练的详细实现、动态学习率的调整实现等;最后介绍了如何来实现模型在实际预测过程中的处理流程等,包括源输入序列的构建、解码时刻输入序列的构建等。

本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎分享至一位你的朋友!若有任何疑问与建议,请添加笔者微信或加群进行交流。青山不改,绿水长流,我们月来客栈见!

引用

[1] 你还在手动构造词表?试试torch.vocab

[2] This post is all you need(⑤基于Transformer的翻译模型)

[3] This post is all you need(④Transformer的实现过程)

[4] Attention is all you need

[5] https://github.com/moon-hotel/TransformerCouplet