PyTorch Lightning 是面向专业AI研究人员和机器学习工程师的深度学习框架,该项目旨在不牺牲大规模性能的情况下获得最大的开发灵活性。框架详见:Pytorch Lighting
在本文档中,主要包含以下方面的内容:
- PyTorch Lightning的核心概念
- 模板的文件说明
- 使用模板的注意事项
- 一个以MNIST手写体数字识别的例子,参考PLighting.yaml提供的包版本配置环境,可直接运行main.py,生成结果。方便你理解和调试。
该部分用于定义模型的训练、验证、测试步骤、optimizer、lr_scheduler。
Pytorch-Lighting的核心设计**是“自给自足”。在定义自己的模型时,需要继承pytorch_lightning.LightningModule类,并在定义模型的过程中同时实现如何训练、如何测试、优化器定义等内容。 具体的,这些内容通常由以下几个类函数实现:
- def forward : 定义模型的前向传播过程
- def configure_loss : 定义训练模型的损失函数
- def training_step : 定义每个batch的训练步骤,在该函数中主要返回训练Loss,用于更新模型参数
- def on_train_epoch_end : 定义模型在训练时,一个epoch结束时进行的操作
- def validation_step : 定义每个batch的验证步骤,包括输入batch数据,并得到验证loss等
- def on_validation_epoch_end :定义模型验证时,一个epoch结束时的操作
- def test_step : 定义每个batch的测试步骤,基本和validation_step类似
- def on_test_epoch_end : 定义模型测试时,一个epoch结束时的操作。如计算一个epoch的总体准确率等操作
- def predict_step : 用于定义模型预测时的操作,通常用于模型推理阶段
- def configure_optimizers : 定义训练过程中,更新模型所使用的优化器。如SGD,Adam,AdamW等。此外,在该函数中还可以定义用于动态调整学习率的lr_scheduler,如StepLR, MultiStepLR, CosineAnnealingLR等
该部分定义用于训练、验证和测试的数据集和对应的DataLoader。通常由三个类函数构成:
- def setup : 通常用于初始化Dataset。或者定义模型在训练(fit)和测试(test)阶段所使用的不同数据集。
- def train_dataloader : 利用返回setup函数中定义的训练、验证和测试数据集,定义模型的训练dataloader
- def val_dataloader : 利用返回setup函数中定义的训练、验证和测试数据集,定义模型的验证dataloader
- def test_dataloader : 利用返回setup函数中定义的训练、验证和测试数据集,定义模型的测试dataloader
本项目提供了一个易用于大型项目、容易迁移、易于复用的模板。 基于该模板,我们需要做的,就是像填空一样,填模板中的这些函数。从而,只需要将精力放在定义模型结构和数据集上, 而无需定义优化器,避免繁杂的中间处理流程。 自定义pytorch lighting模板文件结构如下:
root-
|-data
|-__init__.py
|-data_interface.py
|-xxxstandard_data1.py
|-xxxstandard_data2.py
|-...
|-example_Minist_data
|-Logs
|-model
|-__init__.py
|-model_interface.py
|-xxxstandard_net1.py
|-xxxstandard_net2.py
|-...
|-utils
|-__init__.py
|-xxxutils1.py
|-xxxutils2.py
|-...
|-main.py
-
模板文件说明:其中data包用于为模型提供训练、验证和测试数据集。model包用提供自定义模型。utils包用于提供模型评估指标、或常用函数等通用模块。Logs用于存储模型训练的日志。 main.py用于模型训练所需的callback函数、实例化数据和模型接口,控制超参数。example_Minist_data文件夹提供了示例MNIST手写数字识别数据集。 同时,在data包下实现了其对应的训练、验证和测试Dataset,分别为example_traindata.py、 example_evaldata.py和example_testdata.py。在定义自己的数据集时,可参考它们进行实现。 在model包下提供了自定义模型的基本格式文件standard_net.py和一个用于MNIST手写数字识别的示例模型example_net.py。 在定义自己的模型时,可参考它们实现。
-
有关模板更加详细的信息如下: 如果对每个模型直接上pl.LightningModule,对于已有项目、别人的代码等的转换将相当耗时。另外,这样的话,你需要给每个模型都加上一些相似的代码, 如training_step,validation_step。显然,这并不是我们想要的,如果真的这样做,不但不易于维护,反而可能会更加杂乱。 同理,如果把每个数据集类都直接转换成pl的DataModule,也会面临相似的问题。基于这样的考量,我建议使用上述架构:
- 主目录下只放一个main.py文件。
- data和model两个文件夹中放入__init__.py文件,做成包。这样方便导入。两个init文件分别是:
- from .data_interface import DInterface
- from .model_interface import MInterface
- utils文件夹中也加入__init__.py文件,做成包。在该包中,可以定义好常用的函数。如评估指标函数。
- 在data_interface中建立一个class DInterface(pl.LightningDataModule):用作所有数据集文件的接口。init()函数中import相应Dataset类,setup()进行实例化,并老老实实加入所需要的的train_dataloader, val_dataloader, test_dataloader函数。 这些函数往往都是相似的,可以用几个输入args控制不同的部分。
- 同理,在model_interface中建立class MInterface(pl.LightningModule):类,作为模型的中间接口。init()函数中import相应模型类,然后老老实实加入configure_optimizers, training_step, validation_step等函数,用一个接口类控制所有模型。不同部分使用输入参数控制。
- main.py函数只负责:
- 定义parser,添加parse项
- 选好需要的callback函数们
- 实例化MInterface, DInterface, Trainer
- 在使用该模板定义自己实验中需要的数据集时,仅需要参考standard_data.py的样式定义多个数据集,并在main.py的train_dataset、eval_datasets和test_datasets参数中指定对应的数据集。 通常,我们在一个数据集上训练模型,并需要在多个验证集或测试集上评估模型性能。因此,本模板考虑了这一点,在验证集和测试集中,能以列表的形式指定多个验证集和测试集,模型会自动在多个数据集上执行验证,并将结果保存到logs文件中。
- 在model或data包中,一个model或Dataset需在单独的.py文件中定义。同时,模型和Dataset的命名规则必须与模板文件的相同。 即.py文件的命名和定义的模型名称或Dataset的名称对应,如standard_net.py与class StandardData(data.Dataset)。 因为,所使用的model和Dataset需要通过main.py中的超参数指定,并通过data_interface.pydata_interface.py和model_interface.pymodel_interface.py中的load_data_module、load_model和instancialize函数实例化。
- 用于控制dataloader的num_workers、batch_size等参数可以直接在main.py中直接添加或修改,并通过main(args)中的args传递到data_interface.py中的DataLoader中。对应的,模型结构的超参数、 优化器、损失函数、动态调整学习率的lr_scheduler的参数也通过args传递到model_interface.py的pl.LightningModule中。 这种方式无需再重复修改pl.LightningModule中的其它代码,仅需要我们定义自己的模型和数据集。
- 当model_interface.py不包含你所需的损失函数、optimizer、lr_scheduler时,你可以在configure_loss、configure_optimizers函数中自行添加。
- main.py中定义了常用的callbacks函数,如
TorchMetrics 最初是作为 PyTorch Lightning 的一部分而创建的。 TorchMetrics 是 100+ PyTorch 指标实现和易于使用的 API 的集合,用于创建自定义指标。 虽然 TorchMetrics 是为与原生 PyTorch 一起使用而构建的,但将 TorchMetrics 与 Lightning 结合使用可提供额外的好处:
- 模块化度量标准在 LightningModule 中正确定义后,会自动放置在正确的设备上。这意味着您的数据将始终与度量值放置在同一设备上。无需再调用!.to(device)
- 原生的支持使用 LightningModule 内的 self.log 在 Lightning 中记录metric。
- metric 的.reset()方法的度量在一个epoch结束后自动被调用
如果你想使用一个还不被支持的指标,你可以使用TorchMetrics的API来实现你自己的自定义指标,只需子类化torchmetrics.Metric并实现以下方法:
- init():每个状态变量都应该使用self.add_state(…)调用。
- update():任何需要更新内部度量状态的代码。
- compute():从度量值的状态计算一个最终值。