config 文件介绍¶
使用 HAT
算法包训练模型通常只需使用一条命令就可以了,即:python3 tools/train.py --step TRAINING_STEPS --config /PATH/TO/CONFIG
,其中 /PATH/TO/CONFIG
就是模型训练对应的 config
文件,它负责定义了模型结构、数据集加载、以及整套的训练流程。
这篇教程通过介绍 config
文件中一些固定的全局的关键字,以及它们是如何配置的,让用户对 config
文件的内容以及作用有个大致的了解。
全局关键字¶
training_step
: 模型训练的各个阶段,包括 float
。
device_ids
: 模型训练使用的 gpu
列表。
cudnn_benchmark
: 是否打开cudnn benchmark。通常默认为 True
。
seed
: 是否设置随机数种子。通常默认为 None
。
log_rank_zero_only
: 简化多卡训练时的日志打印,只在第0卡上输出日志。通常默认为 True
。
model
: 参与 training
过程中的模型结构。type
表示模型的类型,如 Classifier
、Segmentor
、RetinaNet
等等,分别对应分类、分割、检测中的某一类模型。它会在使用过程中被 build
成具体的类,余下的参数都是用于初始化这个类。
test_model
: 参与 test
过程的模型结构,主要用于模型编译。和 model
相比,大多数情况下只需要把损失函数以及后处理部分设置为 None
即可。
test_inputs
: test
过程的模拟输入。不用关心具体的数值,只要保证格式满足输入要求即可。
data_loader
: 训练阶段的数据集加载流程。它的 type
是一个具体的类 torch.utils.data.DataLoader
,余下的参数都是用于初始化这个类。相关参数的含义也可以参考 pytorch
官网提供的接口文档。这里 dataset
表示读取某个具体的数据集,例如ImageNet
、MSCOCO
、VOC
等等,它的 transforms
表示在数据读取过程中添加的数据增强操作。
val_data_loader
: 验证模型性能阶段的数据集加载流程。和 data_loader
不同的地方在于 data_path
不同,以及去掉了 transforms
的过程和 sample
的过程。
batch_processor
: 模型在训练过程中每个迭代 step
进行的操作,包括前向计算、梯度回传、参数更新等等。如果包含 batch_transforms
参数,表示一些数据增强的操作是在 gpu
上进行的,这可以大大加快训练速度。
val_batch_processor
: 模型在验证过程中每个迭代 step
进行的操作,只包含前向计算。
metric_updater
: 模型训练过程中更新指标的方法,这个指标是用来验证训练的模型性能是否在提升。它通常是和 float_trainer
下面的 train_metrics
配合着使用。train_metrics
是具体的指标形式,metric_updater
只是提供一种更新方法。
val_metric_updater
: 训练出来的模型在验证性能的过程中更新指标的方法,这个指标用来验证最终训练出来的模型性能到底如何。它通常是和 float_trainer
下面的 val_metrics
配合着使用,和 metric_updater
同理。
float_trainer
: 浮点模型训练流程的配置。type
类型为 distributed_data_parallel_trainer
表示支持分布式训练,其余的参数分别定义了模型、数据集加载、优化器、训练 epoch
长度等等。其中 callbacks
表示训练过程中进行的一系列操作,比如模型保存、学习率更新、精度验证等等。
float_solver
: 直接被 tools/train.py
文件调用的变量,核心主要是 float_trainer
。
之所以称这些变量为全局关键字,是因为几乎每个 config
文件中都定义了以上这些变量,且对应的功能基本一致。因此通过对这篇文档的学习,用户可以大致理解任意一个 config
文件实现的功能。
如何配置¶
这里主要介绍数据类型为 dict
的全局关键字的配置。
数据类型为 dict
的全局关键字可以为两种,包含 type
的,例如 model
、 data_loader
、 float_trainer
等,和不包含 type
的,例如
float_solver
、step2solver
等。
它们的区别在于包含 type
的全局关键字本质可以看作是一个 class
,它的 type
值可以是一个 string
变量,也可以是一个具体的
class
,如果是 string
, 在程序运行中同样会被 build
成一个相应的 class
。这个 dict
中除掉 type
之外的其它 keys
的值都用于初始化这个 class
。和全局关键字属性类似,这些
keys
的值可以是一个数值,也可以是一个包含 type
变量的 dict
,例如 data_loader
中的 dataset
属性,以及这个 dataset
下面的 transforms
属性。
对于没有 type
变量的全局关键字来说,它就是一个普通类型的 dict
变量,代码在运行过程中会通过其 keys
获取对应的 values
。
.. note::
所有已经提供的config配置,可以保证正常运行和复现精度。如果因为环境配置和训练时间等原因,需要修改配置的话,那么相对的训练策略可能也需要更改。直接修改config中的个别配置有时候并不能得到想要的结果。