社区QAT

QAT的基本原理是在模型中插入量化节点,使得模型在训练中可以感知到量化误差,减少量化损失的精度。社区QAT有eager模式和fx graph模式,eager模式需要用户自己编写fuse_model和set_qconfig方法,上手难度大。而fx graph模式不需要编写这两个方法,使用起来更加方便,所以HAT里的社区QAT功能是基于fx graph模式开发的。这篇教程以fcos-efficientnetb0为例,告诉大家如何使用HAT算法包训练一个社区QAT检测模型。同理,fcos-efficientnetb2,fcos-efficientnetb3也可以用同样的方法训练QAT模型。默认用户已经熟悉如何使用HAT训练模型,如果有QAT以外的问题,请参考其他与模型训练相关的文档。

训练流程

整个训练流程分为以下几步:

  1. 训练float模型:这一步请参考其他训练浮点模型的说明文档,这里不再赘述。

  2. 训练calibration模型:QAT插入伪量化节点的初始scale并不可靠,而calibration的主要作用是在训练开始前调整这些scale。虽然calibration不是必须的,但绝大多数情况下,calibration对QAT精度都是有一定提升的,有益无害。

  3. 训练qat模型。

社区qat的入口为 tools/train.py,提供 float -> calibration -> qat 三个阶段的功能。此入口对应上述训练流程的运行命令为:

python3 tools/train.py --step float --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py
python3 tools/train.py --step calibration --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py
python3 tools/train.py --step qat --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py

注意:calibration与qat模型相比于float模型多了更多的参数,如果发生显存不足的情况,需要将配置文件中的batch_size_per_gpu改小一些。

配置文件

configs/detection/fcos/fcos_efficientnetb0_mscoco.py 提供了一个配置文件样例。使用社区qat需要在配置中添加 calibration_trainercalibration_solverqat_trainerqat_solver。 添加的方式与浮点的trainer和solver一致,重点解释独有的字段 custom_config_dictcustom_config_dict 中包含五部分内容 modules, qconfig_dictprepare_custom_config_dictquantize_inputquantize_outputmodules 用于指定需要prepare的模块,如果不指定,则默认会prepare整个模型,每个需要prepare的模块都可以设置自己的 custom_config_dictmodules 不能与其余四项同时出现,因为如果指定了 modules ,那么当前模型不需要整体prepare。qconfig_dictprepare_custom_config_dict 与pytorch中的prepare接口参数完全一致。

calibration配置

以fcos-efficientnetb0为例, 做calibration时,需要prepare backbone , neck , head 三部分。neck的输入不需要量化,因为neck的输入直接来自于backbone,而backbone的输出已经量化过了,head同理也不需要量化输入。head的输出不需要量化,因为芯片支持高精度输出。qconfig_dict 使用horizon_nn提供的calibration qconfig。prepare_custom_config_dict 使用horizon_nn提供的HorizonPrepareCustomConfigDict。这些配置已为用户处理了一些特殊情况,也将部分算子的量化方式与BPU对齐,如无特殊需求,不需要用户修改。

custom_config_dict={
    "modules": {
        "backbone": {
            "qconfig_dict": calibration_qconfig_dict,
            "prepare_custom_config_dict": HorizonPrepareCustomConfigDict,
        },
        "neck": {
            "qconfig_dict": calibration_qconfig_dict,
            "prepare_custom_config_dict": HorizonPrepareCustomConfigDict,
            "quantize_input": False,
        },
        "head": {
            "qconfig_dict": calibration_qconfig_dict,
            "prepare_custom_config_dict": HorizonPrepareCustomConfigDict,
            "quantize_input": False,
            "quantize_output": False,
        },
    }
},

配置custom_config_dict的高阶技巧

如果对模型进行合理的模块划分,绝大部分情况下是可以通过简单的配置完成QAT流程的,如果需要更复杂的功能,参考下方内容对 qconfig_dictprepare_custom_config_dict 进行更高级的配置。

qconfig_dict 支持对象类型,模块名正则,模块名,全局四种配置方式,优先级从高到低。用户可以通过这种方式指定哪个模块用哪个qconfig。

qconfig_dict = {
    "": qconfig,
    "object_type": [
        (torch.nn.Conv2d, qconfig),
        (torch.nn.functional.add, qconfig),
        ...,
    ],
    "module_name": [
        (foo.bar, qconfig)
        ...,
    ],
    "module_name_regex": [
        (foo.bar.conv[0-9]+, qconfig)
        ...,
    ],
}

用户可以参考horizon_nn提供的HorizonPrepareCustomConfigDict添加自己需要的配置,具体来说,prepare_custom_config_dict 用法如下,

prepare_custom_config_dict = {
    # 会被fx trace,但不会进入的module
    "standalone_module_name": [],
    "standalone_module_class": [],

    # 不会被fx trace的module
    "non_traceable_module_name": [],
    "non_traceable_module_class": [],

    # 额外的fuse方法
    "additional_fuser_method_mapping": {
        (torch.nn.Conv2d, torch.nn.BatchNorm2d) fuse_conv_bn
    },

    # 额外的模块替换
    "additional_qat_module_mapping": {
        torch.nn.intrinsic.ConvBn2d torch.nn.qat.ConvBn2d
    },

    # 额外的fuse pattern
    "additional_fusion_pattern": {
        (torch.nn.BatchNorm2d, torch.nn.Conv2d) ConvReluFusionhandler
    },

    # 额外的quantize pattern
    "additional_quant_pattern": {
        torch.nn.Conv2d: ConvReluQuantizeHandler,
        (torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler,
    }

    # 没有在forward中使用的属性会被丢掉,使用这一项保留必要的属性
    "preserved_attributes": ["preserved_attr"],
}

qat配置

在做qat时,需要固定calibration得到的activation fake quantize,这里``qconfig_dict`` 使用horizon_nn提供的HorizonCalibratedQConfig。其余配置与calibration完全一致。如果未使用calibration,则``qconfig_dict`` 使用horizon_nn提供的HorizonQConfig。

custom_config_dict={
    "modules": {
        "backbone": {
            "qconfig_dict": qat_qconfig_dict,
            "prepare_custom_config_dict": HorizonPrepareCustomConfigDict,
        },
        "neck": {
            "qconfig_dict": qat_qconfig_dict,
            "prepare_custom_config_dict": HorizonPrepareCustomConfigDict,
            "quantize_input": False,
        },
        "head": {
            "qconfig_dict": qat_qconfig_dict,
            "prepare_custom_config_dict": HorizonPrepareCustomConfigDict,
            "quantize_input": False,
            "quantize_output": False,
        },
    }
},

模型结构

fx graph模式的社区qat不需要编写fuse_model和set_qconfig,但由于fx自身的局限性,用户在编写模型的forward方法时,需要注意与fx的适配。HAT中的fcos已为用户处理了模型与pytorch fx的适配问题,但如果用户有自定义模型的需求,那么需要在编写模型的时候注意这些与fx有关的适配问题,并且在prepare完模型之后检查伪量化节点是否已经正确插入。

编写模型注意事项

  • 避免在forward中编写不运行在training状态的逻辑。社区的prepare_qat方法需要模型处于training状态,如果此时进行trace,那么生成的graph module会丢失这一部分与training无关的逻辑。当然,如果确认这一部分逻辑丢掉不会对evaluation等其他阶段产生影响,不修改也没有关系。

def forward(self, inputs):
    if not self.training:
        # 这一部分逻辑不会存在于prepare之后的graph module中
        ...
  • fx不支持动态控制流,避免在forward中使用与动态输入有关的if、for、assert… 社区关于这个问题的讨论见这里。事实上,绝大部分与控制流相关的动态输入都并非真正的动态,比如:height, width … 这些都可以以成员变量的形式预先存储在模型中。如果实在无法避免与动态输入相关的控制流,可以将这部分逻辑写为一个函数,使用wrap方法装饰起来,用法详见这里

def forward(self, inputs):
    assert len(inputs) == len(self.in_channels)

# assert中用到了inputs,会报如下错误:
 File pytorch/torch/fx/proxy.py, line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError symbolically traced variables cannot be used as inputs to control flow
  • 不支持trace部分python内置方法,比如:len。可以使用wrap()修饰不需要被trace的方法,社区关于这个问题的讨论见这里

# 不用wrap的话,fx会尝试trace内置的len方法
torch.fx.wrap('len')

def forward(self, inputs):
    length = len(self.inputs)

# 报如下错误:
 File pytorch/torch/fx/proxy.py, line 161, in __len__
    raise RuntimeError('len' is not supported in symbolic tracing by default. If you want
RuntimeError 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
  • 参数共享问题。以一个简单的例子说明:假设我们有一个conv被两个BN共享,prepare QAT模型时,会先做fuse操作,在匹配到conv+bn的模式之后,conv和第一个BN会被标记为matched,第二个bn将无法被conv+bn的pattern匹配,模型变为convBN和一个单独的BN,所以第二个BN的分支会出现一些问题,导致模型预测完全错误。pytorch在fuse这里的代码中留了todo,后续版本应该会解决这个问题。在当前版本,为了避免这个问题,推荐重点检查参数共享中有无处理完一个分支后影响另一个分支的情况。以上述例子来说,这类参数共享问题的解决方法有两种:1. 编写fuse pattern,并通过custom_config_dict传入,详细做法见社区QAT源码。2. 将QAT模型的conv拆开,这样两个BN都有各自对应conv。实际使用中更推荐第二种做法,只需要在代码中加上少量deepcopy即可解决问题。

模型结构检查

为避免踩坑未知错误,训练模型之前最好进行如下检查:

  • 打印模型,检查量化节点的位置是否正确。

  • 查看输出的量化节点是否已经禁用。如果没有禁用,需要将对应模块的``quantize_output`` 置为false。

# 输出量化节点的fake_quant_enabled应为0
(conv_centerness_4_activation_post_process_0): FixedActivationFakeQuantize(
    fake_quant_enabled=tensor([0], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)
  • 检查非必要模块是否插入了量化节点。loss,post_process等与inference无关或没有量化需求的模块不需要插入量化节点,否则会引入不必要的误差。

# 这些模块均没有插入量化节点
(loss_cls): FocalLoss()
(loss_reg): GIoULoss()
(loss_centerness): CrossEntropyLossV2()
(post_process): FCOSDecoder()

调参技巧

  • batch size 尽可能大一些,最好打满显存。

  • 确保输入数据分布合理,均值为0,最好是均匀分布,其次为高斯分布,避免长尾分布。

  • weight decay一般设置为5e-5,可根据实际实验情况调整。

  • learning rate一般从0.001左右开始设置,可根据实际实验情况调整。一般可以搭配StepLrUpdater做1-2次scale=0.1的decay。

  • learning rate的最小值最好不要小于1e-6

  • 不推荐使用warmup。warmup初期学习率过小,对QAT几乎没有加成,甚至会降低QAT精度。

  • epoch长度不固定,一般选为float epoch大小的十分之一到二分之一不等。

  • 最好的QAT模型一般在第一个epoch结果的基础上提升不超过3个点,如果第一个epoch的指标较低,那么基本可以断定最后模型的结果不会很好。

  • 减弱data augmentation有时效果较好。

  • 多选用不同epoch的浮点模型做QAT,有时并非最好的浮点模型就能训出最好的QAT模型。最好的浮点模型往往处于过拟合的边缘,此时进行QAT不一定最好。

  • 调大averaging_constant,这样可以增加当前观察值对scale的影响,在scale初始化不靠谱的情况下,这种方法往往非常有效。

  • calibration对QAT效果的提升非常大,一般是有益无害的。

  • 如果单次训练的batch size较小,固定住BN的均值和方差可能取得意想不到的效果。

转换流程

得到QAT模型之后,需要对模型进行转换和编译才能上板,我们提供了配套的工具,入口为 tools/convert.py。运行如下命令即可转换fcos qat模型。

python tools/convert.py --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py --qat_model_path tmp_models/fcos_efficentnetb0_mscoco/qat-checkpoint-best.pth.tar --output_dir converted_model --march bernoulli2

该工具需要指定模型的配置文件,需要转换的qat模型路径,定点模型的输出路径以及芯片架构。如果使用的是XJ3芯片,march选择bernoulli2,J5芯片的march选择bayes。

转换成功后可以在输出路径下看到QAT模型导出的ONNX模型,转换后的ONNX模型,编译之后的bin以及perf文件。