MMDetection-kitty

Published by onesixx on

1. sixx_make_cfg.py

ROOT_DIR = '/home/oschung_skcc/git'

import os
import os.path as osp
WORK_DIR = osp.dirname( osp.dirname(osp.realpath(__file__)) )

from mmcv import Config
cfg = Config.fromfile(osp.join(ROOT_DIR, "mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py"))
# print(cfg.pretty_text)
config_file     = osp.join(WORK_DIR, 'configs/faster_rcnn_r50_fpn_1x_tidy.py')
checkpoint_file = osp.join(WORK_DIR, 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')

from mmdet.apis import set_random_seed
# dataset에 대한 환경 파라미터 수정. 

cfg.dataset_type = 'KittyTinyDataset'       # 'CocoDataset' 
cfg.data_root = WORK_DIR                    # 'data/coco/'

###---sixx---:: DATA 
# train, val, test dataset에 대한 
# type, WORK_DIR, ann_file, img_prefix 환경 파라미터 수정. 
cfg.data.train.type = 'KittyTinyDataset'             # 'CocoDataset' 
cfg.data.train.data_root = WORK_DIR
cfg.data.train.ann_file = 'data/train.txt'                # 'data/coco/annotations/instances_train2017.json'
cfg.data.train.img_prefix = 'data/image_2'  # 'data/coco/train2017/'

cfg.data.val.type = 'KittyTinyDataset'
cfg.data.val.data_root = WORK_DIR
cfg.data.val.ann_file = 'data/valid.txt'
cfg.data.val.img_prefix = 'data/image_2'

cfg.data.test.type = 'KittyTinyDataset'
cfg.data.test.data_root = WORK_DIR
cfg.data.test.ann_file = 'data/valid.txt'
cfg.data.test.img_prefix = 'data/image_2'

###---sixx--::: train_pipeline

###---sixx--::: test_pipeline

###---sixx---:: MODEL 
cfg.model.roi_head.bbox_head.num_classes = 4         #  class의 80 갯수 수정. 

cfg.load_from = checkpoint_file                      # pretrained 모델 (경로확인)
cfg.work_dir = osp.join(WORK_DIR, 'tutorial_exps')   # 학습 weight 파일로 로그를 저장하기 위한 디렉토리 설정. 

# schedule 이나 default_runtime의 설정값 수정 
# 학습율 변경 환경 파라미터 설정. 
cfg.optimizer.lr = 0.02 / 8          # 0.02
cfg.lr_config.warmup = None          # linear
cfg.log_config.interval = 10
# config 수행 시마다 policy값이 없어지는 bug로 인하여 설정. 
cfg.lr_config.policy = 'step'
# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'        # bbox
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12         # 1
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12  #1

# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# print(f'Config:\
{cfg.pretty_text}')

###############################################################################
with open (config_file, 'w') as f:
    print(cfg.pretty_text, file=f)

import datetime
print(f"---created custom config file by sixx on {datetime.datetime.now()}")
print(f"{cfg.dataset_type} :: {osp.relpath(config_file)}")

2. sixx_middle_dataset

RROOT_DIR = '/home/oschung_skcc/git'
import os
import os.path as osp
WORK_DIR = os.path.dirname( os.path.dirname(os.path.realpath(__file__)) )  # 'mymm/kitty_tiny'

DATA_DIR = osp.join(WORK_DIR, 'data')
IMG_PREFIX = 'image_2' 
ANN_PREFIX = 'label_2'

import sys
sys.path.append('/home/oschung_skcc/git/mmdetection') # ( os.path.dirname(os.path.abspath(__file__)) )
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom  import CustomDataset

import mmcv
import cv2
import numpy as np
import pandas as pd

# annotation xml 파일 파싱해서 bbox정보 추출
def get_bboxes_from_xml(annPath):
    annLines = mmcv.list_from_file(annPath)
    content  = [line.strip().split(' ') for line in annLines]
    bbox_names = [x[0] for x in content]
    bboxes = [ [float(info) for info in x[4:8]] for x in content]

    return bbox_names, bboxes

# imgFileNm = '000006'
# ann_file  = '/home/oschung_skcc/git/mymm/kitty_tiny/data/label_2/000006.txt'
@DATASETS.register_module(force=True)
class KittyTinyDataset(CustomDataset):
    CLASSES = ('Car', 'Truck', 'Pedestrian', 'Cyclist')
    
    def load_annotations(self, ann_file):
        # mmdetection 프레임웍이 Config에서 ann_file(path)인자를 찾아 파라미터로 사용.
        cat2label = {k:i for i, k in enumerate(self.CLASSES)}    
        annFileNm_list = mmcv.list_from_file(self.ann_file)

        data_info = []
        for imgFileNm in annFileNm_list:
            if imgFileNm is None:
                continue

            ### IMAGE metadata 
            imgBaseNm = str(imgFileNm)+'.jpeg'
            imgPath = osp.join(DATA_DIR, IMG_PREFIX, imgBaseNm)
            image = cv2.imread(imgPath)
            height, width = image.shape[:2]
            img_metaData = {
                'filename': imgBaseNm,  
                'width':  width, 
                'height': height 
            }

            ### Annotation metadata   
            annBaseNm = str(imgFileNm)+'.txt'
            annPath = osp.join(DATA_DIR, ANN_PREFIX, annBaseNm)
            if not osp.exists(annPath):
                continue
            elif os.stat(annPath).st_size==0:
                continue           

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            bbox_names, bboxes = get_bboxes_from_xml(annPath)
            for bboxNm, bbox in zip(bbox_names, bboxes):
                if bboxNm in cat2label:
                    gt_bboxes.append(bbox)
                    gt_labels.append(cat2label[bboxNm])
                else: 
                    gt_bboxes_ignore.append(bbox)
                    gt_labels_ignore.append(-1)
            ann_metaData = {
                'bboxes':        np.array(gt_bboxes,        dtype=np.float32).reshape(-1, 4),
                'labels':        np.array(gt_labels,        dtype=np.compat.long),
                'bboxes_ignore': np.array(gt_bboxes_ignore, dtype=np.float32).reshape(-1, 4),
                'labels_ignore': np.array(gt_labels_ignore, dtype=np.compat.long)
            }
            img_metaData.update(ann=ann_metaData)

            data_info.append(img_metaData)
        print(data_info[0])
        return data_info

import datetime
print(f"---registerd custom dataset by sixx on {datetime.datetime.now()}")

3. sixx_train.py

  • build_dataset()
  • build_detector() 아직 weight가 램덤하게 or pretrained된 가중치가 들어가 있는 상태
  • => train_detector weight값 수정됨
$ python sixx_train.py  configs/faster_rcnn_r50_fpn_1x_tidy.py
ROOT_DIR = '/home/oschung_skcc/git'

import os
import os.path as osp
WORK_DIR = os.path.dirname(os.path.realpath(__file__)) #'mymm/kitty_tiny'

# /home/oschung_skcc/git/mmdetection/my/kitty_tiny
# $ python sixx_train.py configs/faster_rcnn_r50_fpn_1x_tidy.py 
import tools.sixx_middle_dataset

# -----------------------------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import argparse
import time
import warnings

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash

# -----------------------------------------------------------------------------
from mmdet import __version__

from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models   import build_detector

from mmdet.utils    import collect_env, get_root_logger, setup_multi_processes

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config',        help='train config file path')
    parser.add_argument('--work-dir',    help='the dir to save logs and models')
    parser.add_argument('--resume-from', help='the checkpoint file to resume from')
    parser.add_argument('--auto-resume', action='store_true',
        help='resume from the latest checkpoint automatically')
    parser.add_argument('--no-validate', action='store_true',
        help='whether not to evaluate the checkpoint during training')
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument('--gpus',type=int,
        help='(Deprecated, please use --gpu-id) number of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument('--gpu-ids', type=int, nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument('--gpu-id', type=int, default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed training)')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument('--deterministic', action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument('--options', nargs='+', action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file (deprecate), '
        'change to --cfg-options instead.')
    parser.add_argument('--cfg-options', nargs='+', action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'],  default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.cfg_options:
        raise ValueError(
            '--options and --cfg-options cannot be both '
            'specified, --options is deprecated in favor of --cfg-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --cfg-options')
        args.cfg_options = args.options

    return args 

def main():
    args = parse_args()
    ### --- cfg ================================================================ 
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # set multi-process settings
    setup_multi_processes(cfg)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.auto_resume = args.auto_resume
    
    if args.gpus is not None:
        cfg.gpu_ids = range(1)
        warnings.warn('`--gpus` is deprecated because we only support '
                      'single GPU mode in non-distributed training. '
                      'Use `gpus=1` now.')
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed training. Use the first GPU '
                      'in `gpu_ids` now.')
    if args.gpus is None and args.gpu_ids is None:
        cfg.gpu_ids = [args.gpu_id]

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file  = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger    = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()

    env_info_dict = collect_env()
    env_info = '\
'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    
    #dash_line = '-'*60+'\
'
    #logger.info('Environment info:\
'+dash_line + env_info +'\
'+dash_line)
    
    meta['env_info'] = env_info
    meta['config']   = cfg.pretty_text
    
    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    #logger.info(f'Config:\
{cfg.pretty_text}')

    # set random seeds
    seed = init_random_seed(args.seed)
    logger.info(f'Set random seed to {seed}, '
                f'deterministic: {args.deterministic}')   
    set_random_seed(seed, deterministic=args.deterministic)
    cfg.seed = seed
    meta['seed'] = seed
    meta['exp_name'] = osp.basename(args.config)
    ### --- datasets ===========================================================
    # train용 Dataset 생성. 
    datasets = [build_dataset(cfg.data.train)]

    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__ + get_git_hash()[:7],
            CLASSES=datasets[0].CLASSES)
    ### --- model ==============================================================
    model = build_detector( 
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg')
    )
    model.init_weights()
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES

    train_detector( model, datasets, cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)

if __name__ == '__main__':
    main()

4. sixx_inference.py

ROOT_DIR = '/home/oschung_skcc/git'

import os
import os.path as osp
# from re import I     # Reqular expression operations, IgnoreCase
WORK_DIR = os.path.dirname(os.path.realpath(__file__))

DATA_DIR = osp.join(WORK_DIR, 'data')
IMG_PREFIX = 'image_2' 
ANN_PREFIX = 'label_2'

import cv2
from matplotlib import pyplot as plt 
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import numpy as np

def get_detected_img(model, imgPath, score_threshold=0.3, is_print=True):
    img_array = cv2.imread(imgPath)
    # plt.imshow(draw_img)
    bbox_color = (  0,255,   0)   # Green
    text_color = (  0,  0, 255)   # Blur

    results = inference_detector(model, img_array)
    for result_ind, result in enumerate(results):
        if len(result)==0:
            continue
        result_filtered = result[ np.where(result[:, 4] > score_threshold)]
        
        for i in range(len(result_filtered)):
            # 좌상단 좌
            left = int(result_filtered[i, 0])
            top  = int(result_filtered[i, 1])
            # 우하단 좌표
            right  = int(result_filtered[i, 2])
            bottom = int(result_filtered[i, 3])
            cv2.rectangle(img_array, (left, top), (right, bottom), color=bbox_color, thickness=2)
            # Class Caption 
            caption = f"{labels_to_names_seq[result_ind]}: {result_filtered[i, 4]}"
            cv2.putText(img_array, caption, (int(left), int(top - 7)), cv2.FONT_HERSHEY_SIMPLEX, 0.37, text_color, 1)
            
            if is_print:
                print(caption)
        return img_array

config_file     = osp.join(WORK_DIR, 'configs/faster_rcnn_r50_fpn_1x_tidy.py')
checkpoint_file = osp.join(WORK_DIR, 'tutorial_exps/latest.pth')

model = init_detector(config_file, checkpoint_file)
imgPath = osp.join(DATA_DIR, IMG_PREFIX,'000068.jpeg')

CLASSES = ('Car', 'Truck', 'Pedestrian', 'Cyclist')
labels_to_names_seq = {i:k for i, k in enumerate(CLASSES)}
draw_img = get_detected_img(model, imgPath, score_threshold=0.3, is_print=True)  
plt.figure(figsize=(4,4))#(15,10))
plt.imshow(draw_img)
Categories: vision

onesixx

Blog Owner

Subscribe
Notify of
guest

0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x