1 引言

各位朋友大家好,欢迎来到月来客栈。经过前面一系列文章的介绍,相信大家对于Transformer的原理应该有了一个比较清晰的认识。不过要想做到灵活运用Transformer结构,那就还得再看看其它情况下的运用场景。在接下来的这篇文章中,笔者将会以AG_News数据集为例,来搭建一个基于Transformer结构的文本分类模型。

图 1. Transformer文本分类网络结构图

如图1所示便是一个基于Transformer结构的文本分类模型。不过准确的说应该只是一个基于Transformer中Encoder的文本分类模型。这是因为在文本分类任务中并没有解码这一过程,所以我们只需要将Encoder编码得到的向量输入到分类器中进行分类即可。同时需要注意的是,Encoder部分最后输出张量的形状为[batch_size,d_model,src_len](图1中Encoder输出的src_len为7),我们需要根据相应策略来进行下一步的处理,具体见后文。

2 数据预处理

2.1 语料介绍

在正式介绍模型之前,我们还是先来看看后续所要用到的AG_News数据集。AG_News新闻主题分类数据集是通过从原始语料库中选择 4 个最大的类构建的。每个类包含 30000 个训练样本和 1900 个测试样本。训练样本总数为 120000条,测试总数为 7600条。AG_News原始数据大概长这样:

上述一共包含有3个样本,每1行为1个样本。同时,所有样本均使用逗号作为分隔符,一共包含有 3 列,分别对应类标(1 到 4)、标题和新闻描述。在本篇文中中,我们暂时只使用新闻描述作为输入(当然也可以用title作为输入来进行分类)。

对于该数据的载入,你可以使用Pytorch中的方法来下载并使用[1]:

也可以自己下载原始数据来进行处理。在这篇文章中,为了延续使用与熟悉上一篇文章[2]中介绍的预处理代码,所以这里我们暂不使用Pytorch内置的代码。

2.2 数据集构建

由于分类模型数据集的构建过程并不复杂,所以这里笔者就只是简单的介绍一下即可,详细内容可参考文章[3] [4]中的内容。

第1步:定义tokenize

如果是对类似英文这样的语料进行处理,那就是直接按空格切分即可。但是需要注意的是要把其中的逗号、句号等也给分割出来。因此,这部分代码可以根据如下方式进行实现:

可以看到,其实也非常简单。例如对于如下文本来说

其tokenize后的结果为:

第2步:定义字符串清理

从2.1节中的示例语料中可以看到, 原始语料中有很多奇奇怪怪的字符,因此还需要对其稍微做一点处理。例如①只保留、数字、以及常用标点;②全部转换为小写字母;③把缩写还原等等。当然,你也可以自己再添加其它处理方式。具体代码如下:

第3步:建立词表

在介绍完tokenize和字符串清理的实现方法后,我们就可以正式通过torchtext.vocab中的Vocab方法来构建词典了,代码如下:

在上述代码中,第3行代码用来指定特殊的字符;第5-8行代码用来遍历文件中的每一个样本(每行一个)并进行tokenize和计数,其中对于counter.update进行介绍可以参考[3];第9行则是返回最后得到词典。

在完成上述过程后,我们将得到一个Vocab类的实例化对象:

接下来,我们就需要定义一个类,并在类的初始化过程中根据训练语料完成字典的构建,代码如下:

第4步:转换为Token序列

在得到构建的字典后,便可以通过如下函数来将训练集和测试集转换成Token序列:

在上述代码中,第11-4行分别用来将原始输入序列转换为对应词表中的Token形式。在处理完成后,就会得到类似如下的结果:

第5步:padding处理

由于对于不同的样本来说其对应的序列长度通常来说都是不同的,但是在将数据输入到相应模型时却需要保持同样的长度。因此在这里我们就需要对Token序列化后的样本进行padding处理,具体代码如下:

在上述代码中,max_len 表示 最大句子长度,默认为None,即在每个batch中以最长样本的长度对其它样本进行padding;当然同样也可以指定max_len的值为整个数据集中最长样本的长度进行padding处理。padding处理后的结果类似如下:

末尾的1即是padding的部分。

在定义完pad_sequence这个函数后,我们便可以通过它来对每个batch中的数据集进行padding处理:

第6步:构造DataLoade与使用示例

经过前面5步的操作,整个数据集的构建就算是已经基本完成了,只需要再构造一个DataLoader迭代器即可,代码如下:

在上述代码中,第2-5行便是分别用来将训练集和测试集转换为Token序列;第6-9行则是分别构造2个DataLoader,其中generate_batch将作为一个参数传入来对每个batch的样本进行处理。在完成类LoadSentenceClassificationDataset所有的编码过程后,便可以通过如下形式进行使用:

最后,由于Encoder只会在padding部分有mask操作,所以每个样本的key_padding_mask向量我们在训练部分再生成即可。下面,我们正式进入到文本分类模型部分的介绍。

3 基于Transformer的文本分类模型

3.1 网络结构

总体来说,基于Transformer的翻译模型的网络结构其实就是图1所展示的所有部分,当然你还可以使用多个Encoder进行堆叠。最后,只需要将Encdder的输出喂入到一个softmax分类器即可完成分类任务。不过这里有两个细节的地方需要大家注意:

①根据前一篇文章[5]的介绍可知,Encoder在编码结束后输出的形状为[src_len,batch_size,embed_dim](这里的src_len也可以理解为LSTM中time step的概念)。因此,在构造最后分类器的输入时就可以有多种不同的形式,例如只取最后一个位置上的向量、或者是取所有位置向量的平均(求和)等都可以。后面笔者也会将这3种方式都实现供大家参考。

②由于每个样本长度给不相同,因此在对样本进行padding的时候就有两种方式。一般来说在大多数模型中多需要保持所有的样本具有相同的长度,不过由于这里我们使用的是自注意力的编码机制,因此只需要保持同一个batch中的样本长度一致即可。不过后面笔者对这两种方式都进行了实现,只需要通过max_sen_len这个参数来控制即可。

首先,我们需要定义一个名为ClassificationModel的类,其前向传播过程代码如下所示:

在上述代码中,第10-17行用来定义Transformer中的Encoder;第18-20行用来定义一个分类器。整个网络的前向传播过程如下:

在上述代码中,第7-11行用来执行编码器的前向传播过程;第13-18行便是用来选择以何种方式来选择分类器的输入,经笔者实验后发现取各个位置的平均值效果最好;第20-21行便是将经过分类器后的输出进行返回。

在定义完logits的前向传播过后,便可以通过如下形式进行使用:

3.2 模型训练

在定义完成整个分类模型的网络结构后下面就可以开始训练模型了。由于这部分代码较长,所以下面笔者依旧以分块的形式进行介绍:

第1步:载入数据集

首先我们可以根据前面的介绍,通过类LoadSentenceClassificationDataset来载入数据集,其中config中定义了模型所涉及到的所有配置参数。同时,可以通过max_sen_len参数来控制padding时保持所有样本一样还是仅在每个batch内部一样。

第2步:定义模型并初始化权重

在载入数据后,便可以定义一个翻译模型ClassificationModel,并根据相关参数对其进行实例化;同时,可以对整个模型中的所有参数进行一个初始化操作。

第3步:定义损失学习率与优化器

在上述代码中,第1行是定义交叉熵损失函数;第2行代码则是论文中所提出来的动态学习率计算过程,其计算公式为:

具体实现代码为:

通过CustomSchedule,就能够在训练过程中动态的调整学习率。学习率随step增加而变换的结果如图2所示:

图 2. 动态学习率变化过程图

从图2可以看出,在前warm_up个step中,学习率是线性增长的,在这之后便是非线性下降,直至收敛与0.0004。

第4步:开始训练

在上述代码中,第7行代码用来生成每个样本对应的padding mask向量;第15-16行是将每个step更新后的学习率送入到模型中。以下便是模型训练过程中的输出:

以上完整代码可参见[7]。

4 总结

在这篇文章中,笔者首先介绍了文本分类模型的数据预处理过程,然后再一步步地通过编码实现了整个数据集的构造过程;接着笔者介绍了基于Transformer结构的文本分类模型的整体构成,然后循序渐进地带着各位读者来实现了整个分类模型,包括基础结构的搭建、模型训练的详细实现、动态学习率的调整实现等;最后介绍了如何来实现模型在实际预测过程中的处理流程等。在下一篇文章中,笔者将会介绍如何基于Transformer结构来搭建一个对联模型(本质上和翻译模型一样),这同时也是介绍Transformer内容的最后一篇文章。

本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎分享至一位你的朋友!若有任何疑问与建议,请添加笔者微信'nulls8'或加群进行交流。青山不改,绿水长流,我们月来客栈见!

引用

[1] https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html

[2] This post is all you need(基于Transformer的翻译模型)

[3] 你还在手动构造词表?试试torchtext.vocab

[4] 文本数据如何快速从零构建成DataLoader

[5] This post is all you need(Transformer的实现过程)

[7] https://github.com/moon-hotel/TransformerClassification