10.1.3.5. 注册机制

注册机制是辅助构建 config 的重要模块,也是HAT的重要组成部分。

本小节通过自定义模块的例子,为您说明如何在注册机制下在增加新的模块并在 config 中正常使用。

10.1.3.5.1. 自定义模块

backbone 为例,这里展示一下如何开发以 mobilenet 为例的新模块。

10.1.3.5.1.1. 定义一个新的backbone(如MobileNet):

新建一个新文件:hat/models/backbones/mobilenet.py

import torch.nn as nn
from hat.registry import OBJECT_REGISTRY

__all__ = ["MobileNet"]

@OBJECT_REGISTRY.register
class MobileNet(nn.Module):
    def __init__(self, args1, args2):
        pass
    def forward(self, x):
        pass

10.1.3.5.1.2. 导入新定义的模块

可以在 hat/models/backbones/__init__.py 中增加导入模块的行。


from .mobilenet import MobileNet

10.1.3.5.1.3. config中使用新的backbone


model = dict(
    ...
    backbone=dict(
        type="MobileNet",
        args1=xxx,
        args2=xxx,
    )
    ...
)

以此类推,其他任何可注册的模块,都可以使用这种方法来完成开发和使用。