config 文件介绍

使用 HAT 算法包训练模型通常只需使用一条命令就可以了,即:python3 tools/train.py --step TRAINING_STEPS --config /PATH/TO/CONFIG,其中 /PATH/TO/CONFIG 就是模型训练对应的 config 文件,它负责定义了模型结构、数据集加载、以及整套的训练流程。
这篇教程通过介绍 config 文件中一些固定的全局的关键字,以及它们是如何配置的,让用户对 config 文件的内容以及作用有个大致的了解。

全局关键字

training_step: 模型训练的各个阶段,包括 float

device_ids: 模型训练使用的 gpu 列表。

cudnn_benchmark: 是否打开cudnn benchmark。通常默认为 True

seed: 是否设置随机数种子。通常默认为 None

log_rank_zero_only: 简化多卡训练时的日志打印,只在第0卡上输出日志。通常默认为 True

model: 参与 training 过程中的模型结构。type 表示模型的类型,如 ClassifierSegmentorRetinaNet等等,分别对应分类、分割、检测中的某一类模型。它会在使用过程中被 build 成具体的类,余下的参数都是用于初始化这个类。

test_model: 参与 test 过程的模型结构,主要用于模型编译。和 model 相比,大多数情况下只需要把损失函数以及后处理部分设置为 None 即可。

test_inputs: test 过程的模拟输入。不用关心具体的数值,只要保证格式满足输入要求即可。

data_loader: 训练阶段的数据集加载流程。它的 type 是一个具体的类 torch.utils.data.DataLoader ,余下的参数都是用于初始化这个类。相关参数的含义也可以参考 pytorch 官网提供的接口文档。这里 dataset 表示读取某个具体的数据集,例如ImageNetMSCOCOVOC等等,它的 transforms 表示在数据读取过程中添加的数据增强操作。

val_data_loader: 验证模型性能阶段的数据集加载流程。和 data_loader 不同的地方在于 data_path 不同,以及去掉了 transforms 的过程和 sample 的过程。

batch_processor: 模型在训练过程中每个迭代 step 进行的操作,包括前向计算、梯度回传、参数更新等等。如果包含 batch_transforms 参数,表示一些数据增强的操作是在 gpu 上进行的,这可以大大加快训练速度。

val_batch_processor: 模型在验证过程中每个迭代 step 进行的操作,只包含前向计算。

metric_updater: 模型训练过程中更新指标的方法,这个指标是用来验证训练的模型性能是否在提升。它通常是和 float_trainer 下面的 train_metrics 配合着使用。train_metrics 是具体的指标形式,metric_updater 只是提供一种更新方法。

val_metric_updater: 训练出来的模型在验证性能的过程中更新指标的方法,这个指标用来验证最终训练出来的模型性能到底如何。它通常是和 float_trainer 下面的 val_metrics 配合着使用,和 metric_updater 同理。

float_trainer: 浮点模型训练流程的配置。type 类型为 distributed_data_parallel_trainer 表示支持分布式训练,其余的参数分别定义了模型、数据集加载、优化器、训练 epoch 长度等等。其中 callbacks 表示训练过程中进行的一系列操作,比如模型保存、学习率更新、精度验证等等。

float_solver: 直接被 tools/train.py 文件调用的变量,核心主要是 float_trainer

之所以称这些变量为全局关键字,是因为几乎每个 config 文件中都定义了以上这些变量,且对应的功能基本一致。因此通过对这篇文档的学习,用户可以大致理解任意一个 config 文件实现的功能。

如何配置

这里主要介绍数据类型为 dict 的全局关键字的配置。

数据类型为 dict 的全局关键字可以为两种,包含 type的,例如 modeldata_loaderfloat_trainer等,和不包含 type的,例如 float_solverstep2solver等。

它们的区别在于包含 type 的全局关键字本质可以看作是一个 class,它的 type 值可以是一个 string 变量,也可以是一个具体的 class,如果是 string, 在程序运行中同样会被 build 成一个相应的 class。这个 dict 中除掉 type 之外的其它 keys 的值都用于初始化这个 class。和全局关键字属性类似,这些 keys 的值可以是一个数值,也可以是一个包含 type 变量的 dict,例如 data_loader 中的 dataset 属性,以及这个 dataset 下面的 transforms 属性。

对于没有 type 变量的全局关键字来说,它就是一个普通类型的 dict 变量,代码在运行过程中会通过其 keys 获取对应的 values

.. note::

  所有已经提供的config配置,可以保证正常运行和复现精度。如果因为环境配置和训练时间等原因,需要修改配置的话,那么相对的训练策略可能也需要更改。直接修改config中的个别配置有时候并不能得到想要的结果。