MMDetection-kitty
Published by onesixx on
1. sixx_make_cfg.py
ROOT_DIR = '/home/oschung_skcc/git'
import os
import os.path as osp
WORK_DIR = osp.dirname( osp.dirname(osp.realpath(__file__)) )
from mmcv import Config
cfg = Config.fromfile(osp.join(ROOT_DIR, "mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py"))
# print(cfg.pretty_text)
config_file = osp.join(WORK_DIR, 'configs/faster_rcnn_r50_fpn_1x_tidy.py')
checkpoint_file = osp.join(WORK_DIR, 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
from mmdet.apis import set_random_seed
# dataset에 대한 환경 파라미터 수정.
cfg.dataset_type = 'KittyTinyDataset' # 'CocoDataset'
cfg.data_root = WORK_DIR # 'data/coco/'
###---sixx---:: DATA
# train, val, test dataset에 대한
# type, WORK_DIR, ann_file, img_prefix 환경 파라미터 수정.
cfg.data.train.type = 'KittyTinyDataset' # 'CocoDataset'
cfg.data.train.data_root = WORK_DIR
cfg.data.train.ann_file = 'data/train.txt' # 'data/coco/annotations/instances_train2017.json'
cfg.data.train.img_prefix = 'data/image_2' # 'data/coco/train2017/'
cfg.data.val.type = 'KittyTinyDataset'
cfg.data.val.data_root = WORK_DIR
cfg.data.val.ann_file = 'data/valid.txt'
cfg.data.val.img_prefix = 'data/image_2'
cfg.data.test.type = 'KittyTinyDataset'
cfg.data.test.data_root = WORK_DIR
cfg.data.test.ann_file = 'data/valid.txt'
cfg.data.test.img_prefix = 'data/image_2'
###---sixx--::: train_pipeline
###---sixx--::: test_pipeline
###---sixx---:: MODEL
cfg.model.roi_head.bbox_head.num_classes = 4 # class의 80 갯수 수정.
cfg.load_from = checkpoint_file # pretrained 모델 (경로확인)
cfg.work_dir = osp.join(WORK_DIR, 'tutorial_exps') # 학습 weight 파일로 로그를 저장하기 위한 디렉토리 설정.
# schedule 이나 default_runtime의 설정값 수정
# 학습율 변경 환경 파라미터 설정.
cfg.optimizer.lr = 0.02 / 8 # 0.02
cfg.lr_config.warmup = None # linear
cfg.log_config.interval = 10
# config 수행 시마다 policy값이 없어지는 bug로 인하여 설정.
cfg.lr_config.policy = 'step'
# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP' # bbox
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12 # 1
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12 #1
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
# print(f'Config:\
{cfg.pretty_text}')
###############################################################################
with open (config_file, 'w') as f:
print(cfg.pretty_text, file=f)
import datetime
print(f"---created custom config file by sixx on {datetime.datetime.now()}")
print(f"{cfg.dataset_type} :: {osp.relpath(config_file)}")
2. sixx_middle_dataset
RROOT_DIR = '/home/oschung_skcc/git'
import os
import os.path as osp
WORK_DIR = os.path.dirname( os.path.dirname(os.path.realpath(__file__)) ) # 'mymm/kitty_tiny'
DATA_DIR = osp.join(WORK_DIR, 'data')
IMG_PREFIX = 'image_2'
ANN_PREFIX = 'label_2'
import sys
sys.path.append('/home/oschung_skcc/git/mmdetection') # ( os.path.dirname(os.path.abspath(__file__)) )
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
import mmcv
import cv2
import numpy as np
import pandas as pd
# annotation xml 파일 파싱해서 bbox정보 추출
def get_bboxes_from_xml(annPath):
annLines = mmcv.list_from_file(annPath)
content = [line.strip().split(' ') for line in annLines]
bbox_names = [x[0] for x in content]
bboxes = [ [float(info) for info in x[4:8]] for x in content]
return bbox_names, bboxes
# imgFileNm = '000006'
# ann_file = '/home/oschung_skcc/git/mymm/kitty_tiny/data/label_2/000006.txt'
@DATASETS.register_module(force=True)
class KittyTinyDataset(CustomDataset):
CLASSES = ('Car', 'Truck', 'Pedestrian', 'Cyclist')
def load_annotations(self, ann_file):
# mmdetection 프레임웍이 Config에서 ann_file(path)인자를 찾아 파라미터로 사용.
cat2label = {k:i for i, k in enumerate(self.CLASSES)}
annFileNm_list = mmcv.list_from_file(self.ann_file)
data_info = []
for imgFileNm in annFileNm_list:
if imgFileNm is None:
continue
### IMAGE metadata
imgBaseNm = str(imgFileNm)+'.jpeg'
imgPath = osp.join(DATA_DIR, IMG_PREFIX, imgBaseNm)
image = cv2.imread(imgPath)
height, width = image.shape[:2]
img_metaData = {
'filename': imgBaseNm,
'width': width,
'height': height
}
### Annotation metadata
annBaseNm = str(imgFileNm)+'.txt'
annPath = osp.join(DATA_DIR, ANN_PREFIX, annBaseNm)
if not osp.exists(annPath):
continue
elif os.stat(annPath).st_size==0:
continue
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_labels_ignore = []
bbox_names, bboxes = get_bboxes_from_xml(annPath)
for bboxNm, bbox in zip(bbox_names, bboxes):
if bboxNm in cat2label:
gt_bboxes.append(bbox)
gt_labels.append(cat2label[bboxNm])
else:
gt_bboxes_ignore.append(bbox)
gt_labels_ignore.append(-1)
ann_metaData = {
'bboxes': np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
'labels': np.array(gt_labels, dtype=np.c