1 引言

各位朋友大家好,欢迎来到月来客栈。今天要和大家介绍的内容是如何在Pytorch框架中对模型进行保存和载入、以及模型的迁移和再训练。一般来说,最常见的场景就是模型完成训练后的推断过程。一个网络模型在完成训练后通常都需要对新样本进行预测,此时就只需要构建模型的前向传播过程,然后载入已训练好的参数初始化网络即可。

第2个场景就是模型的再训练过程。一个模型在一批数据上训练完成之后需要将其保存到本地,并且可能过了一段时间后又收集到了一批新的数据,因此这个时候就需要将之前的模型载入进行在新数据上进行增量训练(或者是在整个数据上进行全量训练)。

第3个应用场景就是模型的迁移学习。这个时候就是将别人已经训练好的预模型拿过来,作为你自己网络模型参数的一部分进行初始化。例如:你自己在Bert模型的基础上加了几个全连接层来做分类任务,那么你就需要将原始BERT模型中的参数载入并以此来初始化你的网络中的Bert部分的权重参数。

在接下来的这篇文章中,笔者就以上述3个场景为例来介绍如何利用Pytorch框架来完成上述过程。

2 模型的保存与复用

在Pytorch中,我们可以通过torch.save()torch.load()来完成上述场景中的主要步骤。下面,笔者将以之前介绍的LeNet5网络模型为例来分别进行介绍。不过在这之前,我们先来看看Pytorch中模型参数的保存形式。

2.1 查看网络模型参数

(1)查看参数

首先定义好LeNet5的网络模型结构,如下代码所示:

在定义好LeNet5这个网络结构的类之后,只要我们完成了这个类的实例化操作,那么网络中对应的权重参数也都完成了初始化的工作,即有了一个初始值。同时,我们可以通过如下方式来访问:

其输出的结果为:

可以发现,网络模型中的参数model.state_dict()其实是以字典的形式(实质上是collections模块中的OrderedDict)保存下来的:

(2)自定义参数前缀

同时,这里值得注意的地方有两点:①参数名中的fcconv前缀是根据你在上面定义nn.Sequential()时的名字所确定的;②参数名中的数字表示每个Sequential()中网络层所在的位置。例如将网络结构定义成如下形式:

那么其参数名则为:

理解了这一点对于后续我们去解析和载入一些预训练模型很有帮助。

除此之外,对于中的优化器等,其同样有对应的state_dict()方法来获取对于的参数,例如:

在介绍完模型参数的查看方法后,就可以进入到模型复用阶段的内容介绍了。

2.2 载入模型进行推断

(1) 模型保存

在Pytorch中,对于模型的保存来说是非常简单的,通常来说通过如下两行代码便可以实现:

在指定保存的模型名称时Pytorch官方建议的后缀为.pt或者.pth(当然也不是强制的)。最后,只需要在合适的地方加入第2行代码即可完成模型的保存。

同时,如果想要在训练过程中保存某个条件下的最优模型,那么应该通过如下方式:

而不是:

因为后者best_model_state得到只是model.state_dict()的引用,它依旧会随着训练过程而发生改变。

(2)复用模型进行推断

在推断过程中,首先需要完成网络的初始化,然后再载入已有的模型参数来覆盖网络中的权重参数即可,示例代码如下:

在上述代码中,4-7行便是用来载入本地模型参数,并用其覆盖网络模型中原有的参数。这样,便可以进行后续的推断工作:

2.3 载入模型进行训练

在介绍完模型的保存与复用之后,对于网络的追加训练就很简单了。最简便的一种方式就是在训练过程中只保存网络权重,然后在后续进行追加训练时只载入网络权重参数初始化网络进行训练即可,示例如下(完整代码参见[2]):

这样,便完成了模型的追加训练:

除此之外,你也可以在保存参数的时候,将优化器参数、损失值等一同保存下来,然后在恢复模型的时候连同其它参数一起恢复,示例如下:

载入方式如下:

2.4 载入模型进行迁移

(1)定义新模型

到目前为止,对于前面两种应用场景的介绍就算完成了,可以发现总体上并不复杂。但是对于第3中场景的应用来说就会略微复杂一点。

假设现在有一个LeNet6网络模型,它是在LeNet5的基础最后多加了一个全连接层,其定义如下:

接下来,我们需要将在LeNet5上训练得到的权重参数迁移到LeNet6网络中去。从上面LeNet6的定义可以发现,此时尽管只是多加了一个全连接层,但是倒数第2层参数的维度也发生了变换。因此,对于LeNet6来说只能复用LeNet5网络前面4层的权重参数。

(2)查看模型参数

在拿到一个模型参数后,首先我们可以将其载入,然查看相关参数的信息:

同时,对于LeNet6网络的参数信息为:

在理清楚了新旧模型的参数后,下面就可以将LeNet5中我们需要的参数给取出来,然后再换到LeNet6的网络中。

(3)模型迁移

虽然本地载入的模型参数(上面的loaded_paras)和模型初始化后的参数(上面的model.state_dict())都是一个字典的形式,但是我们并不能够直接改变model.state_dict()中的权重参数。这里需要先构造一个state_dict然后通过model.load_state_dict()方法来重新初始化网络中的参数。

同时,在这个过程中我们需要筛选掉本地模型中不可复用的部分,具体代码如下:

在上述代码中,第2行的作用是先拷贝网络中(LeNet6)原有的参数;第6-9行则是用本地的模型参数(LeNet5)中可以复用的替换掉LeNet6中的对应部分,其中第7行就是判断可用的条件。同时需要注意的是在不同的情况下筛选的方式可能不一样,因此具体情况需要具体分析,但是整体逻辑是一样的。

最后,我们只需要在模型训练之前调用该函数,然后重新初始化LeNet6中的部分权重参数即可[2]:

训练结果如下:

可以发现,在大约100个batch之后,模型的准确率就提升上来了。

3 总结

在本篇文章中,笔者首先介绍了模型复用的几种典型场景;然后介绍了如何查看Pytorch模型中的相关参数信息;接着介绍了如何载入模型、如何进行追加训练以及进行模型的迁移学习等。

本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎分享至一位你的朋友!若有任何疑问与建议,请添加笔者微信nulls8或加群进行交流。青山不改,绿水长流,我们月来客栈见

引用

[1] SAVING AND LOADING MODELS https://pytorch.org/tutorials/beginner/saving_loading_models.html

[2] 示例代码:https://github.com/moon-hotel/DeepLearningWithMe