mmdet 4: Customize Models
https://comlini8-8.tistory.com/86
TUTORIAL 4: CUSTOMIZE MODELS
MMDet은 기본적으로 Model을 5가지 요소로 구분한다.
Backbone | 피처맵을 추출하기 위한 FCN 네트워크 | ex. ResNet, MobileNet |
neck | backbone과 head 사이를 연결하는 요소 | ex. FPN, PAFPN |
Head | 구체적인 태스크를 위한 요소 | ex. bbox prediction, mask prediction |
roi extractor | 피처맵으로부터 RoI 특징을 추출하는 부분 | ex. RoI Align |
loss | loss를 계산하기 위한 head의 구성 요소 | ex. FocalLoss, L1Loss, GHMLoss |
Develop new components
Add a new backbone
제공되지 않는 새로운 모델을 추가하는 방법을 알아보자.
이와 유사한 방법으로 Neck, Head, Loss도 추가할 수 있다. (문서 참고)
Here we show how to develop new components with an example of MobileNet.
1. Define a new backbone (e.g. MobileNet)
새로운 파일을 만든다.
import torch.nn as nn from ..builder import BACKBONES @BACKBONES.register_module() class MobileNet(nn.Module): \t def __init__(self, arg1, arg2): \tpass def forward(self, x): # should return a tuple \tpass
2. Import the module
다음과 같은 줄을 추가한다. mmdet/models/backbones/__init__.py
from .mobilenet import MobileNet
or alternatively add
custom_imports = dict( imports=['mmdet.models.backbones.mobilenet'], allow_failed_imports=False)
3. Use the backbone in your config file
model = dict( ... backbone=dict( type='MobileNet', arg1=xxx, arg2=xxx), ...
+ model config의 기본 구성은 다음과 같다. backbone, neck 등 각각의 부분은 위와 같이 여러 아규먼트를 갖고, 이 아규먼트 값을 데이터, 모델링에 맞게 변경해주면 된다.
# 1 stage model model = dict( \ttype = 'RetinaNet', pretrained='torchvision://resnet50', backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, style='pytorch'), neck= dict(), bbox_head = dict(), train_cfg = dict(), test_cfg = dict() )
Add new necks
1. Define a neck (e.g. PAFPN)
Create a new file mmdet/models/necks/pafpn.py
.
from ..builder import NECKS @NECKS.register_module() class PAFPN(nn.Module): def __init__(self, in_channels, out_channels, num_outs, start_level=0, end_level=-1, add_extra_convs=False): pass def forward(self, inputs): # implementation is ignored pass
2. Import the module
You can either add the following line to mmdet/models/necks/__init__.py
,
from .pafpn import PAFPN
or alternatively add
custom_imports = dict( imports=['mmdet.models.necks.pafpn.py'], allow_failed_imports=False)
to the config file and avoid modifying the original code.
3. Modify the config file
neck=dict( type='PAFPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5)
Add new heads
Here we show how to develop a new head with the example of Double Head R-CNN as the following.
First, add a new bbox head in mmdet/models/roi_heads/bbox_heads/double_bbox_head.py
. Double Head R-CNN implements a new bbox head for object detection. To implement a bbox head, basically we need to implement three functions of the new module as the following.
from mmdet.models.builder import HEADS from .bbox_head import BBoxHead @HEADS.register_module() class DoubleConvFCBBoxHead(BBoxHead): r"""Bbox head used in Double-Head R-CNN /-> cls /-> shared convs -> \-> reg roi features /-> cls \-> shared fc -> \-> reg """ # noqa: W605 def __init__(self, num_convs=0, num_fcs=0, conv_out_channels=1024, fc_out_channels=1024, conv_cfg=None, norm_cfg=dict(type='BN'), **kwargs): kwargs.setdefault('with_avg_pool', True) super(DoubleConvFCBBoxHead, self).__init__(**kwargs) def forward(self, x_cls, x_reg):
Second, implement a new RoI Head if it is necessary. We plan to inherit the new DoubleHeadRoIHead
from StandardRoIHead
. We can find that a StandardRoIHead
already implements the following functions.
import torch from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler from ..builder import HEADS, build_head, build_roi_extractor from .base_roi_head import BaseRoIHead from .test_mixins import BBoxTestMixin, MaskTestMixin @HEADS.register_module() class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): """Simplest base roi head including one bbox head and one mask head. """ def init_assigner_sampler(self): def init_bbox_head(self, bbox_roi_extractor, bbox_head): def init_mask_head(self, mask_roi_extractor, mask_head): def forward_dummy(self, x, proposals): def forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None): def _bbox_forward(self, x, rois): def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, img_metas): def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, img_metas): def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None): def simple_test(self, x, proposal_list, img_metas, proposals=None, rescale=False): """Test without augmentation."""
Double Head’s modification is mainly in the bbox_forward logic, and it inherits other logics from the StandardRoIHead
. In the mmdet/models/roi_heads/double_roi_head.py
, we implement the new RoI Head as the following:
from ..builder import HEADS from .standard_roi_head import StandardRoIHead @HEADS.register_module() class DoubleHeadRoIHead(StandardRoIHead): """RoI head for Double Head RCNN https://arxiv.org/abs/1904.06493 """ def __init__(self, reg_roi_scale_factor, **kwargs): super(DoubleHeadRoIHead, self).__init__(**kwargs) self.reg_roi_scale_factor = reg_roi_scale_factor def _bbox_forward(self, x, rois): bbox_cls_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_reg_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois, roi_scale_factor=self.reg_roi_scale_factor) if self.with_shared_head: bbox_cls_feats = self.shared_head(bbox_cls_feats) bbox_reg_feats = self.shared_head(bbox_reg_feats) cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_cls_feats) return bbox_results
Last, the users need to add the module in mmdet/models/bbox_heads/__init__.py
and mmdet/models/roi_heads/__init__.py
thus the corresponding registry could find and load them.
Alternatively, the users can add
custom_imports=dict( imports=['mmdet.models.roi_heads.double_roi_head', 'mmdet.models.bbox_heads.double_bbox_head'])
to the config file and achieve the same goal.
The config file of Double Head R-CNN is as the following
_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' model = dict( roi_head=dict( type='DoubleHeadRoIHead', reg_roi_scale_factor=1.3, bbox_head=dict( _delete_=True, type='DoubleConvFCBBoxHead', num_convs=4, num_fcs=2, in_channels=256, conv_out_channels=1024, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))
Since MMDetection 2.0, the config system supports to inherit configs such that the users can focus on the modification. The Double Head R-CNN mainly uses a new DoubleHeadRoIHead and a new DoubleConvFCBBoxHead
, the arguments are set according to the __init__
function of each module.
Add new loss
Assume you want to add a new loss as MyLoss
, for bounding box regression. To add a new loss function, the users need implement it in mmdet/models/losses/my_loss.py
. The decorator weighted_loss
enable the loss to be weighted for each element.
import torch import torch.nn as nn from ..builder import LOSSES from .utils import weighted_loss @weighted_loss def my_loss(pred, target): assert pred.size() == target.size() and target.numel() > 0 loss = torch.abs(pred - target) return loss @LOSSES.register_module() class MyLoss(nn.Module): def __init__(self, reduction='mean', loss_weight=1.0): super(MyLoss, self).__init__() self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_bbox = self.loss_weight * my_loss( pred, target, weight, reduction=reduction, avg_factor=avg_factor) return loss_bbox
Then the users need to add it in the mmdet/models/losses/__init__.py
.
from .my_loss import MyLoss, my_loss
Alternatively, you can add
custom_imports=dict( imports=['mmdet.models.losses.my_loss'])
to the config file and achieve the same goal.
To use it, modify the loss_xxx
field. Since MyLoss is for regression, you need to modify the loss_bbox
field in the head.
loss_bbox=dict(type='MyLoss', loss_weight=1.0))