1 引言
各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。
不知道大家在训练模型的时候有没有遇到类似这样的问题:①反复调试各种超参数组合,但是由于没有进行有效的记录到最后不知道那种参数组合的结果更好。②即使是做了相应的统计与记录,但是由于超参数太多,或者时不时的又会引入新的超参数,导致最后的记录结果很乱。③每次改动都要手动记录特别麻烦,而且记录的要素可能还会包括评估结果、epoch数量,甚至最后还要分析损失值以及某些权重参数等等。④多个模型在调用过程中希望输出相应的处理信息。
那有没有什么比较好的方法来保存这些信息呢?当然有,那就是将整个训练过程中的相关信息都打印成日志保存到本地,然后再按需取相关部分的信息进行分析。那怎么快速高效的在工程的各个模块中将相关信息打印到同一个日志文件中呢?
2 日志记录
在正式介绍如何高效地在各个模型里将相关信息打印到同一个日志文件之前,掌柜先来简单介绍一下如何管理模型中的各个参数。
2.1 参数管理
在模型的实现过程中,通常来说我们都会定义一个Config
类并用类的成员变量来保存各个模型参数的值。当然,如果你需要从本地文件中读取参数(例如bert的config.json文件),那么只需要将读取后的结果赋值到Config
类的各个成员变量即可。
如下所示便是我们定义的一个ModelConfig
类,里面的各个成员变量就是模型中的各个超参数。
x
1 class ModelConfig:
2 def __init__(self):
3 self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4 self.dataset_dir = os.path.join(self.project_dir, 'data')
5 self.pretrained_model_dir = os.path.join(self.project_dir, "pretrained_model")
6 self.vocab_path = os.path.join(self.pretrained_model_dir, 'vocab.txt')
7 self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
8 self.batch_size = 64
9 self.max_sen_len = None
10 self.num_labels = 15
11 self.epochs = 10
在上述代码中,第3行用来获取当前工程的目录,这个需要根据自己的实际情况和ModelConfig
这个类所在的目录层级来获取;其中os.path.abspath(__file__)
表示取当前ModelConfig
所在文件的绝对路径,os.path.dirname()
函数用来取对应的目录路径。第4-6行代码分别用来指定数据集、预训练模型和词表所在的目录。这里之所以要把这些目录或者路径都定义出来是为了方便后续的使用,当然也可以根据自己的实际情况来进行处理。
如图1所示便是上面代码示例中的目录结构,其中ModelConfig
这个类在Tasks
文件夹中的某个文件中。
在完成上述步骤后,我们便可以通过实例化ModelConfig
这个类,然后将实例化后的对象作为参数传入到相关模块中进行使用即可,示例如下:
xxxxxxxxxx
31if __name__ == '__main__':
2 model_config = ModelConfig()
3 # train(model_config)
此时便可以通过以访问类成员变量model_config.max_sen_len
来访问各个模型参数;同时在向各个模块传递参数时也不需要像之前那样写一大串,而只需要传入model_config
即可。
2.2 定义日志初始化函数
在打印日志的过程中主要使用到的是logging
这个Python包,如果没有的话通过pip install logging
命令安装即可。同时,为了满足相同的训练信息在保存到日志的同时也能同时输出到控制端等功能,下面我们需要基于logging
再改进一下,代码如下:
x1import logging
2def logger_init(log_file_name='monitor',
3 log_level=logging.DEBUG,
4 log_dir='./logs/',
5 only_file=False):
6 # 指定路径
7 if not os.path.exists(log_dir):
8 os.makedirs(log_dir)
9
10 log_path = os.path.join(log_dir, log_file_name + '_' + str(datetime.now())[:10] + '.txt')
11 formatter = '[%(asctime)s] - %(levelname)s: %(message)s'
12 if only_file:
13 logging.basicConfig(filename=log_path,
14 level=log_level,
15 format=formatter,
16 datefmt='%Y-%d-%m %H:%M:%S')
17 else:
18 logging.basicConfig(level=log_level,
19 format=formatter,
20 datefmt='%Y-%d-%m %H:%M:%S',
21 handlers=[logging.FileHandler(log_path),
22 logging.StreamHandler(sys.stdout)]
23 )
在上述代码中,第7-8行用来创建logs
文件夹用于存放日志;第10行用来构造一个日志保存路径,文件名还包含有日期;第12-16行是用于只将日志文件输出到文件;第18-23行则是同时将日志输出到文件和终端。最后,logs
文件中将会生成一个类似名为monitor_2021-10-14.txt
的日志文件。
2.3 打印日志示例
在完成上述两步工作后,我们便可以在任意模块或者文件中使用logging
来进行日志记录。
首先在classA.py
中新建了一个名为classA
的类,代码如下:
xxxxxxxxxx
61import logging
2
3class classA(object):
4 def __init__(self):
5 logging.info(f"我在{__name__}中!")
6 logging.debug(f"我在{__name__}中,这是一条debug信息!")
接着在classB.py
中新建了一个名为classB
的类,代码如下:
x
1import logging
2
3class classB(object):
4 def __init__(self):
5 logging.info(f"我在{__name__}中!")
6 logging.debug(f"我在{__name__}中,这是一条debug信息!")
最后在main.py
中调用这两个类,并输出相应的日志信息,代码如下:
xxxxxxxxxx
101 from classA import classA
2 from classB import classB
3 from log_helper import logger_init
4 import logging
5
6 if __name__ == '__main__':
7 logger_init('nulls', log_level=logging.INFO, log_dir='./logs')
8 a = classA()
9 b = classB()
10 logging.info(f"我在{__name__}中!")
在运行完上述代码后,日志文件和终端里就会输出如下所示的日志信息:
xxxxxxxxxx
31[2021-11-07 07:54:23] - INFO: 我在classA中!
2[2021-11-07 07:54:23] - INFO: 我在classB中!
3[2021-11-07 07:54:23] - INFO: 我在__main__中!
可以发现,classA
和classB
这两个模块中的日志信息都有被打印出来,而这也就满足了我们跨模块日志打印的需求。但是可以发现,logging.debug
这样的信息并没有打印出来,其原因就在于通过logger_init()
函数初始化时指定的日志输出等级为logging.INFO
,这就意味着不会输出Debug的信息。当然,只需要将log_level
指定为logging.DEBUG
即可输出所有信息。
xxxxxxxxxx
71logger_init('nulls', log_level=logging.DEBUG, log_dir='./logs')
2
3[2021-11-07 08:00:12] - INFO: 我在classA中!
4[2021-11-07 08:00:12] - DEBUG: 我在classA中,这是一条debug信息!
5[2021-11-07 08:00:12] - INFO: 我在classB中!
6[2021-11-07 08:00:12] - DEBUG: 我在classB中,这是一条debug信息!
7[2021-11-07 08:00:12] - INFO: 我在__main__中!
2.4 打印模型参数
在介绍完日志的打印输出方法后,进一步只需要在上面的ModelConfig
类定义中加入如下几行代码便可以在模型训练时打印相关的模型信息:
xxxxxxxxxx
81self.logs_save_dir = os.path.join(self.project_dir, 'logs')
2logger_init('nulls', log_level=logging.INFO,
3 log_dir='./logs',
4 only_file=False)
5
6logging.info("\n\n\n\n\n######## <----------------------->")
7for key, value in self.__dict__.items():
8 logging.info(f"######## {key} = {value}")
此时,我们便可以通过类似logging.info()
的方式来进行相关信息的打印和写入日志。第7-8行便是用来将类ModelConfig
中所有的成员变量写入到日志中,也就达到了将所有参数保存下来的目的。
例如现在我们定义了一个函数来训练某个模型,其部分代码如下所示[1]:
xxxxxxxxxx
151 def train(config):
2 classification_model = BertForSequenceClassification(config,
3 config.num_labels,
4 config.pretrained_model_dir)
5 model_save_path = os.path.join(config.model_save_dir, 'model.pt')
6 if os.path.exists(model_save_path):
7 loaded_paras = torch.load(model_save_path)
8 classification_model.load_state_dict(loaded_paras)
9 logging.info("## 成功载入已有模型,进行追加训练......")
10 logging.debug("## 成功载入已有模型,进行追加训练......")
11# ......
12
13if __name__ == '__main__':
14 model_config = ModelConfig()
15 train(model_config)
从上述代码中可以看到,可以通过以第9-10行的方式来打印输出相关的训练信息;同时,在BertForSequenceClassification
类中我们依旧可以通过logging.info()
来对其中的相关信息进行打印输出。第9行和第10行的区别在于,当log_level=logging.INFO
时,只有第9行的信息会被打印出来同时写到日志文件中,当log_level=logging.DEBUG
时,两行的信息都会被打印及写入到日志中。
最后,在控制台和日志文件中便会输出类似如下的信息:
xxxxxxxxxx
121######## <----------------------->
2[2021-11-07 08:10:12] - INFO: ######## project_dir = ~/BertWithPretrained
3[2021-11-07 08:10:12] - INFO: ######## dataset_dir = ~/BertWithPretrained/data
4[2021-11-07 08:10:12] - INFO: ######## pretrained_model_dir = ~/BertWithPretrained/pretrained_model
5[2021-11-07 08:10:13] - INFO: ######## vocab_path = ~/BertWithPretrained/pretrained_model/vocab.txt
6[2021-11-07 08:10:13] - INFO: ######## logs_save_dir = ~/BertWithPretrained/logs
7[2021-11-07 08:10:13] - INFO: ######## split_sep = _!_
8[2021-11-07 08:10:13] - INFO: ######## batch_size = 64
9[2021-11-07 08:10:13] - INFO: ######## max_sen_len = None
10[2021-11-07 08:10:13] - INFO: ######## num_labels = 15
11[2021-11-07 08:10:13] - INFO: ######## epochs = 10
12[2021-11-07 08:10:14] - INFO: ## 成功载入已有模型,进行追加训练......
3 总结
在这篇文章中,掌柜首先介绍了如何管理模型参数;然后介绍了如何基于logging
来定义一个初始化函数;最后详细展示了如何来使用logging
在各个模块中将相关信息打印到同一个日志文件中。在实际使用过程中,我们只需要在需要输出信息的地方通过函数logging.info()
来进行打印,然后再主函数运行的地方调用logger_init()
函数来初始化即可完成日志信息的输出和打印。
本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎分享至一位你的朋友!若有任何疑问与建议,请添加掌柜微信nulls8或加群进行交流。青山不改,绿水长流,我们月来客栈见!