1 引言
各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。
不知道各位客官在行走江湖的过程中有没有遇到类似这样的问题:由于数据集过于庞大或者说数据结构很复杂,导致每次都需要花费很长的时间来等待数据集的预处理过程。例如掌柜最近在研究SQuAD任务时就发现每次数据预处理都需要等待很长的时间。虽然在这期间掌柜也想过将预处理过后的结果给缓存下来(之前也是这么做的),当下次载入数据集时先进行判断,如果本地存在缓存则直接载入缓存;但是想想这次写了下次换个场景又需要重写这些代码,觉得麻烦又给放弃了。
不过由于最终还是没能忍受等待时间太长,于是思考了一下写了一个通用的缓存方法,这样在任何地方只需要调用该函数便可以实现上述目的。并且为了使得调用方便以及代码简介,掌柜还特地将它实现为了Python中的修饰器。下面,掌柜首先就带着大家简单地了解一下Python中修饰器的作用及用法。
2 修饰器简介
关于什么是修饰器(或装饰器Decorator)掌柜这里就不从Python语法上来做详细的解释了。简单一句话,修饰器的作用的就是在正式执行某个功能函数之前,预先执行你想要执行的某些操作。下面,我们直接从用法的层面来逐步了解Python中的修饰器,因为这样带着目的去学习能够更快的入门。
2.1 修饰器用例
首先来看这样一个场景,假如你已经定义了很多功能函数,但是你现在想在日志文件中同时也打印出当前主程序正在调用哪个功能函数的信息。例如:
xxxxxxxxxx
51def func1(str="moon hotel"):
2 print(str)
3
4def func2(str="月来客栈"):
5 print(str)
要实习这样一个功能,最直接的做法就是在原始的函数里面加上一句函数的输出信息:
xxxxxxxxxx
71def func1(str="moon hotel"):
2 print(f"正在函数 {sys._getframe().f_code.co_name}() 里面!")
3 print(str)
4
5def func2(str="月来客栈"):
6 print(f"正在函数 {sys._getframe().f_code.co_name}() 里面!")
7 print(str)
这样我们在调用func1
和func2
这两个函数时就能够分别输出对应的信息:
xxxxxxxxxx
41正在函数 func1() 里面!
2moon hotel
3正在函数 func2() 里面!
4月来客栈
虽然说上面这个做法稍微有点麻烦,但似乎还能接受。不过又过了一会儿,你还想在此基础上打印出进入每个函数时的具体时间该怎么办呢?继续像刚刚那样再加一行代码?万一有100个函数怎么操作?
面对这样一个问题,Python中的修饰器便可以出来大展身手了。在使用修饰器之前,需要先定义一个完成该功能的函数,如下:
xxxxxxxxxx
61def get_info(func):
2 def wrapper(*args, **kwargs):
3 print(f"正在函数 {func.__name__}() 里面!")
4 print(f"当前时间是 {datetime.now()}")
5 return func(*args, **kwargs)
6 return wrapper
可以看到,get_info
似乎就像是定义了一个嵌套的函数一样。因此,我们可以通过函数调用的方式来使用get_info()
方法:
xxxxxxxxxx
61if __name__ == '__main__':
2 get_info(func1)(str="nulls8")
3#
4正在函数 func1() 里面!
5当前时间是 2021-12-11 21:23:56.721417
6nulls8
这样,对于后续需要增加的任何操作,只需要在函数get_info
中加入即可而不需要在调用的地方进行改动。不过这样调用稍微有点麻烦,需要到每个调用该函数的地方修改函数传入的方式。所以,我们还可以通过一个更加简洁的方式来进行调用,那就是直接在该函数定义的地方将它作为修饰器使用:
xxxxxxxxxx
71
2def func1(str="moon hotel"):
3 print(str)
4
5
6def func2(str="月来客栈"):
7 print(str)
这样,我们直接通过调用对应的功能函数就能够实现输出该函数名和时间的信息:
xxxxxxxxxx
101if __name__ == '__main__':
2 func1()
3 func2()
4
5正在函数 func1() 里面!
6当前时间是 2021-12-11 21:45:00.748882
7moon hotel
8正在函数 func2() 里面!
9当前时间是 2021-12-11 21:45:00.748908
10月来客栈
介绍到这里,相信各位客官对于Python中修饰器的定义与使用已经有了一个基本的了解。下面掌柜再来总结一下修饰器的使用方法,简称套路。
2.2 修饰器格式
通过上面的示例介绍可以发现,定义修饰器函数的大致格式如下:
xxxxxxxxxx
51def Decorator(func):
2 def wrapper(*args, **kwargs):
3 print(f"在这里执行你需要预先执行的代码语句")
4 return func(*args, **kwargs)
5 return wrapper
在上述代码中,Decorator
为修饰器的名称;func
为使用该修饰器的函数;*args, **kwargs
则为使用该修饰器的函数的相关参数。同时,由于通过@
符号来将Decorator
作为修饰器调用本质上只是一种快速简洁的方式,所以@Decorator
还等价于Decorator(func)(*args, **kwargs)
这样的调用方式。因此,通过后者我们还能够更加清晰的认识到整个修饰器的工作流程。
虽然上面掌柜给出了一个定义修饰器的大致格式,但是在理解了整个修饰器的工作流程后,我们还可以根据自己的需要灵活的做出相应的修改。例如需要再统计每个函数的运行时间则可以修改为:
xxxxxxxxxx
91def get_info(func):
2 def wrapper(*args, **kwargs):
3 print(f"正在函数 {func.__name__}() 里面!")
4 start_time = time.time()
5 func(*args, **kwargs)
6 end_time = time.time()
7 print(f"一共耗时{(end_time - start_time):.3f}s")
8
9 return wrapper
总结起来就是,通过@
符号来将get_info
作为修饰器调用本质上只是一种快速简洁的函数调用方式,因此对于get_info
函数内部的整个处理流程依旧等价于普通的函数定义流程。
3 缓存预处理结果
在介绍完修饰器的基本原理及用法之后再来看如何缓存数据预处理结果就变得十分容易了。总结起来就是在正式载入数据集之前首先判断本地是否存在缓存,如果存在则直接载入缓存,如果不存在则再调用函数进行数据预处理并进行缓存。
3.1 定义数据载入类
熟悉掌柜的客官都知道,对于数据预处理部分掌柜一般都喜欢将其定义为一个类,并在各个成员函数内实现相应的处理逻辑。并且通常来说,这个类至少会包含3个方法:__init__
、data_process
和load_train_test_data
,其中__init__
用来初始化类中的相关参数(如batch_size
、max_len
、数据集路径等等;data_process
用来对数据集进行预处理返回预处理后的结果;load_train_test_data
用来构造最后模型训练时的DataLoader
。
如下所示便是一个简单的数据载入类(实战示例可参考文章):
xxxxxxxxxx
201class LoadData(object):
2 def __init__(self):
3 self.x = torch.randn((10, 5))
4 self.y = torch.randint(2, [10])
5 self.max_len = 5
6 self.batch_size = 2
7
8
9 def data_process(self, file_dir='./', postfix=f'cache'):
10 print("正在进行预处理数据")
11 data = {"x": self.x, "y": self.y}
12 return data
13
14 def load_train_test_data(self, file_dir='./'):
15 postfix = f'cache_{self.max_len}_{self.batch_size}'
16 data = self.data_process(file_dir=file_dir, postfix=postfix)
17 x, y = data['x'], data['y']
18 data_iter = TensorDataset(x, y)
19 data_iter = DataLoader(data_iter, batch_size=self.batch_size)
20 return data_iter
在上述代码中,第8行data_process
方法返回的便是预处理后的结果;第15行代码则是定义的一个缓存文件名的后缀,因为在一些场景中可能会存在可调节的参数(例如NLP处理时对于句子的长度等等),因此对于不同参数对应的缓存应该加以区分。最后,我们只需要将data_process
处理后的结果进行缓存即可。
当然,虽然这里是以构造PyTorch中的DataLoader为例进行的代码示例,但是同样的处理逻辑一样可以运用到机器学习中。
3.2 定义缓存修饰器
如下所示便是根据我们实际的需要所定义的一个预处理数据缓存的修饰器。在经过第1节内容的介绍之后,掌柜相信各位客官应该很容易的就能看懂这些代码:
xxxxxxxxxx
161def cache(func):
2 def wrapper(*args, **kwargs):
3 file_dir = kwargs['file_dir']
4 postfix = kwargs['postfix']
5 data_path = os.path.join(file_dir, f"{postfix}.pt")
6 if not os.path.exists(data_path):
7 print(f"缓存文件 {data_path} 不存在,重新处理并缓存!")
8 data = func(*args, **kwargs)
9 with open(data_path, 'wb') as f:
10 torch.save(data, f)
11 else:
12 print(f"缓存文件 {data_path} 存在,直接载入缓存文件!")
13 with open(data_path, 'rb') as f:
14 data = torch.load(f)
15 return data
16 return wrapper
在上述代码中,第3-4行用来通过命名参数获得对应的路径以及缓存文件名的后缀(这就意味着在调用data_process
函数时必须以命名参数的形式进行,即data_process(file_dir=file_dir, postfix=postfix)
;第6-10行则是当缓存文件不存在时,则调用data_process
进行数据处理,并缓存处理有的结果;第11-14行则是当对应的缓存文件存在时,直接从本地进行载入;第15行则是返回对应的预处理结果。
在这里可以发现,对于上述缓存修饰器你还可以几乎不加修改的运用到任何一个场景中,而只需要将data_process
预处理后的结果构造成一个字典即可。
最后,当第1次通过load_train_test_data
载入数据集时会看到类似如下的结果:
xxxxxxxxxx
131if __name__ == '__main__':
2 data_loader = LoadData()
3 data_iter = data_loader.load_train_test_data()
4 for x, y in data_iter:
5 print(x.shape)
6 print(y)
7
8缓存文件 ./cache_5.pt 不存在,重新处理并缓存!
9正在进行预处理数据
10torch.Size([5, 5])
11tensor([0, 1, 1, 1, 1])
12torch.Size([5, 5])
13tensor([0, 0, 0, 1, 0])
当后续再次载入同一份数据预处理结果时,则会看到类似如下的结果:
xxxxxxxxxx
51缓存文件 ./cache_5.pt 存在,直接载入缓存文件!
2torch.Size([5, 5])
3tensor([0, 1, 1, 1, 1])
4torch.Size([5, 5])
5tensor([0, 0, 0, 1, 0])
从输出提示来看,当第2次载入同一份数据预处理文件时,会直接从缓存中载入而并不需要再次运行data_process
函数中的处理逻辑。
到此,对于如何利用Python修饰器来便捷缓存数据预处理结果的内容就介绍完了。
4 总结
在这篇文章中,掌柜首先从使用示例的角度来介绍了Python修饰器的用法及工作原理,即其本质上只是Python中所支持的一种快速简洁的函数调用方式;然后介绍了在机器学习中数据预处理时的一种可借鉴流程;最后介绍了如何实现一个可通用数据预处理缓存修饰器。有了这个修饰器的加持,相信各位客官在建模过程中一定能够极大的节约数据预处理的等待时间。
本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎点赞转发分享!若有任何疑问与建议,请添加掌柜微信nulls8(备注来源)或加群进行交流。青山不改,绿水长流,我们月来客栈见!