Commit 696e7be3 by Ting PAN

Intial repository

0 parents
Showing with 4864 additions and 0 deletions
[flake8]
max-line-length = 120
ignore = E741, # ambiguous variable name
F403, # ‘from module import *’ used; unable to detect undefined names
F405, # name may be undefined, or defined from star imports: module
F811, # redefinition of unused name from line N
F821, # undefined name
W503, # line break before binary operator
W504 # line break after binary operator
# module imported but unused
per-file-ignores = __init__.py: F401
# Compiled Object files
*.slo
*.lo
*.o
*.cuo
# Compiled Dynamic libraries
*.so
*.dll
*.dylib
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Compiled python
*.pyc
__pycache__
# Compiled MATLAB
*.mex*
# IPython notebook checkpoints
.ipynb_checkpoints
# Editor temporaries
*.swp
*~
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Eclipse Project settings
*.*project
.settings
# QtCreator files
*.user
# VSCode files
.vscode
# IDEA files
.idea
# OSX dir files
.DS_Store
# Android files
.gradle
*.iml
local.properties
Copyright (c) 2017, SeetaTech, Co.,Ltd. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Benchmark and Model Zoo
## Introduction
### Pretrained Models
Refer to [Pretrained Models](data/pretrained) for details.
## Baselines
### Faster R-CNN
Refer to [Faster R-CNN](configs/faster_rcnn) for details.
### Mask R-CNN
Refer to [Mask R-CNN](configs/mask_rcnn) for details.
### Pascal VOC
Refer to [Pascal VOC](configs/pascal_voc) for details.
# SeetaDet
SeetaDet is a platform implementing popular object detection algorithms.
This platform works with [**SeetaDragon**](https://dragon.seetatech.com), and uses the [**PyTorch**](https://dragon.seetatech.com/api/python/#pytorch) style.
<img src="https://dragon.seetatech.com/download/seetadet/assets/banner.png"/>
## Installation
Install from PyPI:
```bash
pip install seeta-det
```
Or, clone this repository to local disk and install:
```bash
cd seetadet && pip install .
```
You can also install from the remote repository:
```bash
pip install git+ssh://git@github.com/seetaresearch/seetadet.git
```
If you prefer to develop locally, build but not install to ***site-packages***:
```bash
cd seetadet && python setup.py build
```
## Quick Start
### Train a detection model
```bash
cd tools
python train.py --cfg <MODEL_YAML>
```
We have provided the default YAML examples into [configs](configs).
### Test a detection model
```bash
cd tools
python test.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
### Export a detection model to ONNX
```bash
cd tools
python export.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
### Serve a detection model
```bash
cd tools
python serve.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
## Benchmark and Model Zoo
Results and models are available in the [Model Zoo](MODEL_ZOO.md).
## License
[BSD 2-Clause license](LICENSE)
# Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks
## Introduction
```
@article{Ren_2017,
title={Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
publisher={Institute of Electrical and Electronics Engineers (IEEE)},
author={Ren, Shaoqing and He, Kaiming and Girshick, Ross and Sun, Jian},
year={2017},
month={Jun},
}
```
## COCO Object Detection Baselines
| Model | Lr sched | Infer time (fps) | box AP | Download |
| :---: | :------: | :--------------: | :----: | :-----: |
| [R50-FPN](coco_faster_rcnn_R_50_FPN_1x.yml) | 1x | 37.04 | 37.7 | [model](https://dragon.seetatech.com/download/seetadet/faster_rcnn/coco_faster_rcnn_R_50_FPN_1x/model_7abb52ab.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/faster_rcnn/coco_faster_rcnn_R_50_FPN_1x/logs.json) |
| [R50-FPN](coco_faster_rcnn_R_50_FPN_3x.yml) | 3x | 37.04 | 39.8 | [model](https://dragon.seetatech.com/download/seetadet/faster_rcnn/coco_faster_rcnn_R_50_FPN_3x/model_04e548ca.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/faster_rcnn/coco_faster_rcnn_R_50_FPN_3x/logs.json) |
NUM_GPUS: 8
MODEL:
TYPE: 'faster_rcnn'
PRECISION: 'float16'
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
BACKBONE:
TYPE: 'resnet50_v1a.fpn'
FPN:
MIN_LEVEL: 2
MAX_LEVEL: 6
ANCHOR_GENERATOR:
STRIDES: [4, 8, 16, 32, 64]
SOLVER:
BASE_LR: 0.02
DECAY_STEPS: [60000, 80000]
MAX_STEPS: 90000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'coco_faster_rcnn_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-A_in1k_cls120e.pkl'
DATASET: '../data/datasets/coco_train2017'
IMS_PER_BATCH: 2
SCALES: [640, 672, 704, 736, 768, 800]
MAX_SIZE: 1333
TEST:
DATASET: '../data/datasets/coco_val2017'
JSON_DATASET: '../data/datasets/coco_instances_val2017.json'
EVALUATOR: 'coco'
IMS_PER_BATCH: 1
SCALES: [800]
MAX_SIZE: 1333
NUM_GPUS: 8
MODEL:
TYPE: 'faster_rcnn'
PRECISION: 'float16'
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
BACKBONE:
TYPE: 'resnet50_v1a.fpn'
FPN:
MIN_LEVEL: 2
MAX_LEVEL: 6
ANCHOR_GENERATOR:
STRIDES: [4, 8, 16, 32, 64]
SOLVER:
BASE_LR: 0.02
DECAY_STEPS: [210000, 250000]
MAX_STEPS: 270000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'coco_faster_rcnn_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-A_in1k_cls120e.pkl'
DATASET: '../data/datasets/coco_train2017'
IMS_PER_BATCH: 2
SCALES: [640, 672, 704, 736, 768, 800]
MAX_SIZE: 1333
TEST:
DATASET: '../data/datasets/coco_val2017'
JSON_DATASET: '../data/datasets/coco_instances_val2017.json'
EVALUATOR: 'coco'
IMS_PER_BATCH: 1
SCALES: [800]
MAX_SIZE: 1333
# Mask R-CNN
## Introduction
```
@article{He_2017,
title={Mask R-CNN},
journal={2017 IEEE International Conference on Computer Vision (ICCV)},
publisher={IEEE},
author={He, Kaiming and Gkioxari, Georgia and Dollar, Piotr and Girshick, Ross},
year={2017},
month={Oct}
}
```
## COCO Instance Segmentation Baselines
| Model | Lr sched | Infer time (fps) | box AP | mask AP | Download |
| :---: | :------: | :---------------: | :----: | :-----: | :------: |
| [R50-FPN](coco_mask_rcnn_R_50_FPN_1x.yml) | 1x | 30.30 | 38.3 | 34.9 | [model](https://dragon.seetatech.com/download/seetadet/mask_rcnn/coco_mask_rcnn_R_50_FPN_1x/model_b27317db.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/mask_rcnn/coco_mask_rcnn_R_50_FPN_1x/logs.json) |
| [R50-FPN](coco_mask_rcnn_R_50_FPN_3x.yml) | 3x | 30.30 | 40.7 | 36.8 | [model](https://dragon.seetatech.com/download/seetadet/mask_rcnn/coco_mask_rcnn_R_50_FPN_3x/model_6f7e3878.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/mask_rcnn/coco_mask_rcnn_R_50_FPN_3x/logs.json) |
NUM_GPUS: 8
MODEL:
TYPE: 'mask_rcnn'
PRECISION: 'float16'
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
BACKBONE:
TYPE: 'resnet50_v1a.fpn'
FPN:
MIN_LEVEL: 2
MAX_LEVEL: 6
ANCHOR_GENERATOR:
STRIDES: [4, 8, 16, 32, 64]
SOLVER:
BASE_LR: 0.02
DECAY_STEPS: [60000, 80000]
MAX_STEPS: 90000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'coco_mask_rcnn_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-A_in1k_cls120e.pkl'
DATASET: '../data/datasets/coco_train2017'
LOADER: 'mask_train'
IMS_PER_BATCH: 2
SCALES: [640, 672, 704, 736, 768, 800]
MAX_SIZE: 1333
TEST:
DATASET: '../data/datasets/coco_val2017'
JSON_DATASET: '../data/datasets/coco_instances_val2017.json'
EVALUATOR: 'coco'
IMS_PER_BATCH: 1
SCALES: [800]
MAX_SIZE: 1333
NUM_GPUS: 8
MODEL:
TYPE: 'mask_rcnn'
PRECISION: 'float16'
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
BACKBONE:
TYPE: 'resnet50_v1a.fpn'
FPN:
MIN_LEVEL: 2
MAX_LEVEL: 6
ANCHOR_GENERATOR:
STRIDES: [4, 8, 16, 32, 64]
SOLVER:
BASE_LR: 0.02
DECAY_STEPS: [210000, 250000]
MAX_STEPS: 270000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'coco_mask_rcnn_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-A_in1k_cls120e.pkl'
DATASET: '../data/datasets/coco_train2017'
LOADER: 'mask_train'
IMS_PER_BATCH: 2
SCALES: [640, 672, 704, 736, 768, 800]
MAX_SIZE: 1333
TEST:
DATASET: '../data/datasets/coco_val2017'
JSON_DATASET: '../data/datasets/coco_instances_val2017.json'
EVALUATOR: 'coco'
IMS_PER_BATCH: 1
SCALES: [800]
MAX_SIZE: 1333
# Pascal VOC
## Introduction
```latex
@Article{Everingham10,
author = "Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.",
title = "The Pascal Visual Object Classes (VOC) Challenge",
journal = "International Journal of Computer Vision",
volume = "88",
year = "2010",
number = "2",
month = jun,
pages = "303--338",
}
```
## Object Detection Baselines
### Faster R-CNN
| Model | Lr sched | Infer time (fps) | box AP | Download |
| :---: | :------: | :--------------: | :----: | :------: |
| [R50-FPN](voc_faster_rcnn_R_50_FPN_15e.yml) | 15e | 47.62 | 82.1 | [model](https://dragon.seetatech.com/download/seetadet/pascal_voc/voc_faster_rcnn_R_50_FPN_15e/model_3dcb03f9.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/pascal_voc/voc_faster_rcnn_R_50_FPN_15e/logs.json) |
### RetinaNet
| Model | Lr sched | Infer time (fps) | box AP | Download |
| :---: | :------: | :--------------: | :----: | :------: |
| [R50-FPN](voc_retinanet_R_50_FPN_120e.yml) | 120 | 58.82 | 82.4 | [model](https://dragon.seetatech.com/download/seetadet/pascal_voc/voc_retinanet_R_50_FPN_120e/model_1ae4cd3d.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/pascal_voc/voc_retinanet_R_50_FPN_120e/logs.json) |
### SSD
| Model | Lr sched | Infer time (fps) | box AP | Download |
| :---: | :------: | :--------------: | :----: | :------: |
| [VGG16-SSD300](voc_ssd300_VGG_16_120e.yml) | 120 | 125 | 77.8 | [model](https://dragon.seetatech.com/download/seetadet/pascal_voc/voc_ssd300_VGG_16_120e/model_3417d961.pkl) &#124; [log](https://dragon.seetatech.com/download/seetadet/pascal_voc/voc_ssd300_VGG_16_120e/logs.json) |
NUM_GPUS: 2
MODEL:
TYPE: 'faster_rcnn'
PRECISION: 'float16'
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
BACKBONE:
TYPE: 'resnet50.fpn'
FPN:
MIN_LEVEL: 2
MAX_LEVEL: 6
ANCHOR_GENERATOR:
STRIDES: [4, 8, 16, 32, 64]
FAST_RCNN:
BBOX_REG_LOSS_TYPE: 'smooth_l1'
SOLVER:
BASE_LR: 0.002
DECAY_STEPS: [80000, 100000]
MAX_STEPS: 120000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'voc_cascade_rcnn_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50_in1k_cls90e.pkl'
DATASET: '../data/datasets/voc_trainval0712'
USE_DIFF: True
IMS_PER_BATCH: 2
SCALES: [480, 512, 544, 576, 608, 640]
MAX_SIZE: 1000
TEST:
DATASET: '../data/datasets/voc_test2007'
JSON_DATASET: '../data/datasets/voc_test2007.json'
EVALUATOR: 'voc2007'
IMS_PER_BATCH: 1
SCALES: [640]
MAX_SIZE: 1000
NMS_THRESH: 0.45
NUM_GPUS: 1
MODEL:
TYPE: 'retinanet'
PRECISION: 'float32'
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
BACKBONE:
TYPE: 'resnet50.fpn'
SOLVER:
BASE_LR: 0.01
WARM_UP_STEPS: 3000
DECAY_STEPS: [80000, 100000]
MAX_STEPS: 120000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'voc_retinanet_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50_in1k_cls90e.pkl'
DATASET: '../data/datasets/voc_trainval0712'
USE_DIFF: True
IMS_PER_BATCH: 16
SCALES: [512]
SCALES_RANGE: [0.1, 2.0]
MAX_SIZE: 512
CROP_SIZE: 512
COLOR_JITTER: 0.5
TEST:
DATASET: '../data/datasets/voc_test2007'
JSON_DATASET: '../data/datasets/voc_test2007.json'
EVALUATOR: 'voc2007'
IMS_PER_BATCH: 1
SCALES: [512]
MAX_SIZE: 512
CROP_SIZE: 512
NMS_THRESH: 0.45
NUM_GPUS: 1
MODEL:
TYPE: 'ssd'
PRECISION: 'float16'
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
BACKBONE:
TYPE: 'vgg16_fcn.ssd300'
NORM: ''
FREEZE_AT: 0
COARSEST_STRIDE: 300
FPN:
ACTIVATION: 'ReLU'
ANCHOR_GENERATOR:
STRIDES: [8, 16, 32, 64, 100, 300]
SIZES: [[30, 60], [60, 110],[110, 162],
[162, 213], [213, 264], [264, 315]]
ASPECT_RATIOS: [[1, 2, 0.5],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5],
[1, 2, 0.5]]
SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.0005
DECAY_STEPS: [80000, 100000]
MAX_STEPS: 120000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'voc_ssd300_VGG_16'
TRAIN:
WEIGHTS: '../data/pretrained/VGG-16-FCN_in1k.pkl'
DATASET: '../data/datasets/voc_trainval0712'
LOADER: 'ssd_train'
USE_DIFF: True
IMS_PER_BATCH: 16
SCALES: [300]
SCALES_RANGE: [0.25, 1.0]
COLOR_JITTER: 0.5
TEST:
DATASET: '../data/datasets/voc_test2007'
JSON_DATASET: '../data/datasets/voc_test2007.json'
EVALUATOR: 'voc2007'
IMS_PER_BATCH: 8
SCALES: [300]
NMS_THRESH: 0.45
SCORE_THRESH: 0.01
NUM_GPUS: 1
MODEL:
TYPE: 'ssd'
PRECISION: 'float16'
CLASSES: ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
BACKBONE:
TYPE: 'vgg16_fcn.ssd512'
NORM: ''
FREEZE_AT: 0
COARSEST_STRIDE: 512
FPN:
ACTIVATION: 'ReLU'
ANCHOR_GENERATOR:
STRIDES: [8, 16, 32, 64, 128, 256, 512]
SIZES: [[35.84, 76.8],
[76.8, 153.6],
[153.6, 230.4],
[230.4, 307.2],
[307.2, 384.0],
[384.0, 460.8],
[460.8, 537.6]]
ASPECT_RATIOS: [[1, 2, 0.5],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5, 3, 0.33],
[1, 2, 0.5],
[1, 2, 0.5]]
SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.0005
DECAY_STEPS: [80000, 100000]
MAX_STEPS: 120000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'voc_ssd512_VGG_16'
AUG:
COLOR_JITTER: 0.5
TRAIN:
WEIGHTS: '../data/pretrained/VGG-16-FCN_in1k.pkl'
DATASET: '../data/datasets/voc_trainval0712'
IMS_PER_BATCH: 16
SCALES: [512]
SCALES_RANGE: [0.25, 1.0]
LOADER: 'ssd_train'
TEST:
DATASET: '../data/datasets/voc_test2007'
JSON_DATASET: '../data/datasets/voc_test2007.json'
EVALUATOR: 'voc2007'
IMS_PER_BATCH: 1
SCALES: [512]
NMS_THRESH: 0.45
SCORE_THRESH: 0.01
# Focal Loss for Dense Object Detection
## Introduction
```
@inproceedings{lin2017focal,
title={Focal loss for dense object detection},
author={Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr},
booktitle={Proceedings of the IEEE international conference on computer vision},
year={2017}
}
```
## COCO Object Detection Baselines
| Model | Lr sched | Infer time (s/im) | box AP | Download |
| :---: | :------: | :---------------: | :----: | :------: |
| [R-50-FPN-800](coco_retinanet_R-50-FPN_800_1x.yml) | 1x | 0.051 | 37.4 | [model](https://dragon.seetatech.com/download/models/seetadet/retinanet/coco_retinanet_R-50-FPN_800_1x/model_final.pkl) |
| [R-50-FPN-800](coco_retinanet_R-50-FPN_800_2x.yml) | 2x | 0.051 | 39.1 | [model](https://dragon.seetatech.com/download/models/seetadet/retinanet/coco_retinanet_R-50-FPN_800_2x/model_final.pkl) |
## Pascal VOC Object Detection Baselines
| Model | Lr sched | Infer time (s/im) | AP@0.5 | Download |
| :---: | :------: | :---------------: | :----: | :------: |
| [R-50-FPN-512](voc_retinanet_R-50-FPN_512_120e.yml) | 120e | 0.017 | 83.0 | [model](https://dragon.seetatech.com/download/models/seetadet/retinanet/voc_retinanet_R-50-FPN_512/model_final.pkl) |
| [R-50-FPN-512](voc_retinanet_R-50-FPN_640_120e.yml) | 120e | 0.017 | 83.0 | [model](https://dragon.seetatech.com/download/models/seetadet/retinanet/voc_retinanet_R-50-FPN_512/model_final.pkl) |
NUM_GPUS: 8
MODEL:
TYPE: 'retinanet'
PRECISION: 'float16'
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
BACKBONE:
TYPE: 'resnet50_v1a.fpn'
SOLVER:
BASE_LR: 0.01
DECAY_STEPS: [60000, 80000]
MAX_STEPS: 90000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'coco_retinanet_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-A_in1k_cls120e.pkl'
DATASET: '../data/datasets/coco_train2017'
IMS_PER_BATCH: 2
SCALES: [640, 672, 704, 736, 768, 800]
MAX_SIZE: 1333
TEST:
DATASET: '../data/datasets/coco_val2017'
JSON_DATASET: '../data/datasets/coco_instances_val2017.json'
EVALUATOR: 'coco'
IMS_PER_BATCH: 1
SCALES: [800]
MAX_SIZE: 1333
NUM_GPUS: 8
MODEL:
TYPE: 'retinanet'
PRECISION: 'float16'
CLASSES: ['__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
BACKBONE:
TYPE: 'resnet50_v1a.fpn'
SOLVER:
BASE_LR: 0.01
DECAY_STEPS: [210000, 250000]
MAX_STEPS: 270000
SNAPSHOT_EVERY: 5000
SNAPSHOT_PREFIX: 'coco_retinanet_R_50_FPN'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-A_in1k_cls120e.pkl'
DATASET: '../data/datasets/coco_train2017'
IMS_PER_BATCH: 2
SCALES: [640, 672, 704, 736, 768, 800]
MAX_SIZE: 1333
TEST:
DATASET: '../data/datasets/coco_val2017'
JSON_DATASET: '../data/datasets/coco_instances_val2017.json'
EVALUATOR: 'coco'
IMS_PER_BATCH: 1
SCALES: [800]
MAX_SIZE: 1333
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_OPERATORS_MASK_OP_H_
#define DRAGON_EXTENSION_OPERATORS_MASK_OP_H_
#include <dragon/core/operator.h>
namespace dragon {
template <class Context>
class PasteMaskOp final : public Operator<Context> {
public:
PasteMaskOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
mask_threshold_(OP_SINGLE_ARG(float, "mask_threshold", 0.5f)) {
INITIALIZE_OP_REPEATED_ARG(int64_t, sizes);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::TypesBase<float>>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
float mask_threshold_;
DECLARE_OP_REPEATED_ARG(int64_t, sizes);
};
DEFINE_OP_REPEATED_ARG(int64_t, PasteMaskOp, sizes);
} // namespace dragon
#endif // DRAGON_EXTENSION_OPERATORS_MASK_OP_H_
#include <dragon/core/workspace.h>
#include "../operators/mask_op.h"
#include "../utils/detection.h"
namespace dragon {
template <class Context>
template <typename T>
void PasteMaskOp<Context>::DoRunWithType() {
auto &X_masks = Input(0), &X_boxes = Input(1), *Y = Output(0);
vector<int64_t> Y_dims({X_masks.dim(0)});
int num_sizes;
sizes(0, &num_sizes);
for (int i = 0; i < num_sizes; ++i) {
Y_dims.push_back(sizes(i));
}
if (num_sizes == 2) {
detection::PasteMask(
Y_dims[0], // N
Y_dims[1], // H
Y_dims[2], // W
X_masks.dim(1), // mask_h
X_masks.dim(2), // mask_w
mask_threshold_,
X_masks.template data<T, Context>(),
X_boxes.template data<float, Context>(),
Y->Reshape(Y_dims)->template mutable_data<uint8_t, Context>(),
ctx());
} else {
LOG(FATAL) << "PasteMask" << num_sizes << "d is not supported.";
}
}
DEPLOY_CPU_OPERATOR(PasteMask);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(PasteMask);
#endif
#ifdef USE_MPS
DEPLOY_MPS_OPERATOR(PasteMask, PasteMask);
#endif
OPERATOR_SCHEMA(PasteMask).NumInputs(2).NumOutputs(1);
NO_GRADIENT(PasteMask);
} // namespace dragon
#include "../operators/nms_op.h"
#include "../utils/detection.h"
namespace dragon {
template <class Context>
template <typename T>
void NonMaxSuppressionOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
CHECK(X.ndim() == 2 && X.dim(1) == 5)
<< "\nThe dimensions of boxes should be (num_boxes, 5).";
detection::ApplyNMS(
X.dim(0),
X.dim(0),
0,
iou_threshold_,
X.template mutable_data<T, Context>(),
out_indices_,
ctx());
Y->template CopyFrom<int64_t>(out_indices_);
}
DEPLOY_CPU_OPERATOR(NonMaxSuppression);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(NonMaxSuppression);
#endif
#ifdef USE_MPS
DEPLOY_MPS_OPERATOR(NonMaxSuppression, NonMaxSuppression);
#endif
OPERATOR_SCHEMA(NonMaxSuppression).NumInputs(1).NumOutputs(1);
NO_GRADIENT(NonMaxSuppression);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_OPERATORS_NMS_OP_H_
#define DRAGON_EXTENSION_OPERATORS_NMS_OP_H_
#include <dragon/core/operator.h>
namespace dragon {
template <class Context>
class NonMaxSuppressionOp final : public Operator<Context> {
public:
NonMaxSuppressionOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
iou_threshold_(OP_SINGLE_ARG(float, "iou_threshold", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::TypesBase<float>>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
float iou_threshold_;
vector<int64_t> out_indices_;
};
} // namespace dragon
#endif // DRAGON_EXTENSION_OPERATORS_NMS_OP_H_
#include "../operators/retinanet_decoder_op.h"
#include "../utils/detection.h"
namespace dragon {
template <class Context>
template <typename T>
void RetinaNetDecoderOp<Context>::DoRunWithType() {
auto N = Input(SCORES).dim(0);
auto AxK = Input(SCORES).dim(1);
auto C = Input(SCORES).dim(2);
auto AxKxC = AxK * C;
auto A = int64_t(ratios_.size() * scales_.size());
auto num_lvls = int64_t(strides_.size());
// Generate anchors.
CHECK_EQ(Input(GRID_INFO).dim(0), num_lvls);
cell_anchors_.resize(strides_.size());
vector<detection::GridArgs<int64_t>> grid_args(strides_.size());
for (int i = 0; i < strides_.size(); ++i) {
grid_args[i].stride = strides_[i];
auto& anchors = cell_anchors_[i];
if (int64_t(anchors.size()) == A * 4) continue;
anchors.resize(A * 4);
detection::GenerateAnchors(
strides_[i],
int64_t(ratios_.size()),
int64_t(scales_.size()),
ratios_.data(),
scales_.data(),
anchors.data());
}
// Set grid arguments.
auto* grid_info = Input(GRID_INFO).template data<int64_t, CPUContext>();
detection::SetGridArgs(AxK, A, grid_info, grid_args);
// Decode detections.
auto* scores = Input(SCORES).template data<T, Context>();
auto* deltas = Input(DELTAS).template data<T, CPUContext>();
auto* im_info = Input(IM_INFO).template data<float, CPUContext>();
auto* Y = Output(0)->Reshape({N * num_lvls * pre_nms_topk_, 7});
auto* dets = Y->template mutable_data<float, CPUContext>();
int64_t size_dets = 0;
for (int batch_ind = 0; batch_ind < N; ++batch_ind) {
detection::ImageArgs<int64_t> im_args(im_info + batch_ind * 4);
im_args.batch_ind = batch_ind;
for (int lvl_ind = 0; lvl_ind < num_lvls; ++lvl_ind) {
detection::SelectTopK(
grid_args[lvl_ind].size * C,
pre_nms_topk_,
score_thresh_,
scores + batch_ind * AxKxC + grid_args[lvl_ind].offset * C,
scores_,
indices_,
ctx());
auto* offset_dets = dets + size_dets * 7;
auto num_dets = int64_t(indices_.size());
size_dets += num_dets;
detection::GetAnchors(
num_dets,
A, // num_cell_anchors
C, // num_classes
grid_args[lvl_ind],
cell_anchors_[lvl_ind].data(),
indices_.data(),
offset_dets);
detection::DecodeDetections(
num_dets,
AxK, // num_anchors
C, // num_classes
im_args,
grid_args[lvl_ind],
scores_.data(),
deltas + batch_ind * Input(DELTAS).stride(0),
indices_.data(),
offset_dets);
}
}
// Shrink to the correct dimensions.
Y->Reshape({size_dets, 7});
}
DEPLOY_CPU_OPERATOR(RetinaNetDecoder);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(RetinaNetDecoder);
#endif
#ifdef USE_MPS
REGISTER_MPS_OPERATOR(RetinaNetDecoder, RetinaNetDecoderOp<CPUContext>);
#endif
OPERATOR_SCHEMA(RetinaNetDecoder).NumInputs(4).NumOutputs(1);
NO_GRADIENT(RetinaNetDecoder);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_OPERATORS_RETINANET_DECODER_OP_H_
#define DRAGON_EXTENSION_OPERATORS_RETINANET_DECODER_OP_H_
#include <dragon/core/operator.h>
namespace dragon {
template <class Context>
class RetinaNetDecoderOp final : public Operator<Context> {
public:
RetinaNetDecoderOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
strides_(OP_REPEATED_ARG(int64_t, "strides")),
ratios_(OP_REPEATED_ARG(float, "ratios")),
scales_(OP_REPEATED_ARG(float, "scales")),
pre_nms_topk_(OP_SINGLE_ARG(int64_t, "pre_nms_topk", 1000)),
score_thresh_(OP_SINGLE_ARG(float, "score_thresh", 0.05f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::TypesBase<float>>::Call(this, Input(SCORES));
}
template <typename T>
void DoRunWithType();
enum INPUT_TAGS { SCORES = 0, DELTAS = 1, IM_INFO = 2, GRID_INFO = 3 };
protected:
float score_thresh_;
vector<int64_t> strides_;
vector<float> ratios_, scales_;
int64_t pre_nms_topk_;
vector<float> scores_;
vector<int64_t> indices_;
vector<vector<float>> cell_anchors_;
};
} // namespace dragon
#endif // DRAGON_EXTENSION_PERATORS_RETINANET_DECODER_OP_H_
#include "../operators/rpn_decoder_op.h"
#include "../utils/detection.h"
namespace dragon {
template <class Context>
template <typename T>
void RPNDecoderOp<Context>::DoRunWithType() {
auto N = Input(SCORES).dim(0);
auto AxK = Input(SCORES).dim(1);
auto A = int64_t(ratios_.size() * scales_.size());
auto num_lvls = int64_t(strides_.size());
// Generate anchors.
CHECK_EQ(Input(GRID_INFO).dim(0), num_lvls);
cell_anchors_.resize(strides_.size());
vector<detection::GridArgs<int64_t>> grid_args(strides_.size());
for (int i = 0; i < strides_.size(); ++i) {
grid_args[i].stride = strides_[i];
auto& anchors = cell_anchors_[i];
if (int64_t(anchors.size()) == A * 4) continue;
anchors.resize(A * 4);
detection::GenerateAnchors(
strides_[i],
int64_t(ratios_.size()),
int64_t(scales_.size()),
ratios_.data(),
scales_.data(),
anchors.data());
}
// Set grid arguments.
auto* grid_info = Input(GRID_INFO).template data<int64_t, CPUContext>();
detection::SetGridArgs(AxK, A, grid_info, grid_args);
// Decode proposals.
auto* scores = Input(SCORES).template data<T, CPUContext>();
auto* deltas = Input(DELTAS).template data<T, CPUContext>();
auto* im_info = Input(IM_INFO).template data<float, CPUContext>();
auto* Y = Output("Y")->Reshape({N * num_lvls * pre_nms_topk_, 5});
auto* dets = Y->template mutable_data<float, CPUContext>();
for (int batch_ind = 0; batch_ind < N; ++batch_ind) {
detection::ImageArgs<int64_t> im_args(im_info + batch_ind * 4);
im_args.batch_ind = batch_ind;
for (int lvl_ind = 0; lvl_ind < num_lvls; ++lvl_ind) {
detection::SelectTopK(
grid_args[lvl_ind].size,
pre_nms_topk_,
0.f,
scores + batch_ind * AxK + grid_args[lvl_ind].offset,
scores_,
indices_,
(CPUContext*)nullptr); // Faster.
indices_.resize(pre_nms_topk_, indices_.back());
auto* offset_dets = dets + lvl_ind * pre_nms_topk_ * 5;
detection::GetAnchors(
pre_nms_topk_,
A, // num_cell_anchors
grid_args[lvl_ind],
cell_anchors_[lvl_ind].data(),
indices_.data(),
offset_dets);
detection::DecodeProposals(
pre_nms_topk_,
AxK, // num_anchors
im_args,
grid_args[lvl_ind],
scores_.data(),
deltas + batch_ind * Input(DELTAS).stride(0),
indices_.data(),
offset_dets);
detection::SortBoxes<T, detection::Box5d<T>>(pre_nms_topk_, offset_dets);
}
}
// Apply NMS.
auto* dets_v2 = Y->template data<float, Context>();
int64_t size_rois = 0;
scores_.resize(N * post_nms_topk_);
indices_.resize(N * post_nms_topk_);
for (int batch_ind = 0; batch_ind < N; ++batch_ind) {
std::priority_queue<std::pair<float, int64_t>> pq;
for (int lvl_ind = 0; lvl_ind < num_lvls; ++lvl_ind) {
const auto offset = lvl_ind * pre_nms_topk_;
detection::ApplyNMS(
pre_nms_topk_, // N
pre_nms_topk_, // K
offset * 5, // boxes_offset
nms_thresh_,
dets_v2,
nms_indices_,
ctx());
for (size_t i = 0; i < nms_indices_.size(); ++i) {
const auto index = nms_indices_[i] + offset;
pq.push(std::make_pair(*(dets + index * 5 + 4), index));
}
}
for (int i = 0; i < post_nms_topk_ && !pq.empty(); ++i) {
scores_[size_rois] = batch_ind;
indices_[size_rois++] = pq.top().second;
pq.pop();
}
}
// Apply Histogram.
detection::ApplyHistogram(
size_rois,
min_level_,
max_level_,
canonical_level_,
canonical_scale_,
dets,
scores_.data(),
indices_.data(),
output_rois_);
// Copy to outputs.
for (int i = 0; i < OutputSize(); ++i) {
const auto& rois = output_rois_[i];
vector<int64_t> dims({int64_t(rois.size()) / 5, 5});
auto* Yi = Output(i)->Reshape(dims);
std::memcpy(
Yi->template mutable_data<T, CPUContext>(),
rois.data(),
sizeof(T) * rois.size());
}
}
DEPLOY_CPU_OPERATOR(RPNDecoder);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(RPNDecoder);
#endif
#ifdef USE_MPS
DEPLOY_MPS_OPERATOR(RPNDecoder, RPNDecoder);
#endif
OPERATOR_SCHEMA(RPNDecoder).NumInputs(4).NumOutputs(1, INT_MAX);
NO_GRADIENT(RPNDecoder);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_OPERATORS_RPN_DECODER_OP_H_
#define DRAGON_EXTENSION_OPERATORS_RPN_DECODER_OP_H_
#include <dragon/core/operator.h>
namespace dragon {
template <class Context>
class RPNDecoderOp final : public Operator<Context> {
public:
RPNDecoderOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
strides_(OP_REPEATED_ARG(int64_t, "strides")),
ratios_(OP_REPEATED_ARG(float, "ratios")),
scales_(OP_REPEATED_ARG(float, "scales")),
pre_nms_topk_(OP_SINGLE_ARG(int64_t, "pre_nms_topk", 1000)),
post_nms_topk_(OP_SINGLE_ARG(int64_t, "post_nms_topk", 1000)),
nms_thresh_(OP_SINGLE_ARG(float, "nms_thresh", 0.7f)),
min_level_(OP_SINGLE_ARG(int64_t, "min_level", 2)),
max_level_(OP_SINGLE_ARG(int64_t, "max_level", 5)),
canonical_level_(OP_SINGLE_ARG(int64_t, "canonical_level", 4)),
canonical_scale_(OP_SINGLE_ARG(int64_t, "canonical_scale", 224)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::TypesBase<float>>::Call(this, Input(SCORES));
}
template <typename T>
void DoRunWithType();
enum INPUT_TAGS { SCORES = 0, DELTAS = 1, IM_INFO = 2, GRID_INFO = 3 };
protected:
float nms_thresh_;
vector<int64_t> strides_;
vector<float> ratios_, scales_;
int64_t min_level_, max_level_;
int64_t pre_nms_topk_, post_nms_topk_;
int64_t canonical_level_, canonical_scale_;
vector<float> scores_;
vector<int64_t> indices_, nms_indices_;
vector<vector<float>> cell_anchors_;
vector<vector<float>> output_rois_;
};
} // namespace dragon
#endif // DRAGON_EXTENSION_OPERATORS_RPN_DECODER_OP_H_
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build cpp extensions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import dragon
from dragon.utils import cpp_extension
from setuptools import setup
Extension = cpp_extension.CppExtension
if (dragon.cuda.is_available() and
cpp_extension.CUDA_HOME is not None):
Extension = cpp_extension.CUDAExtension
elif dragon.mps.is_available():
Extension = cpp_extension.MPSExtension
def find_sources(*dirs):
ext_suffixes = ['.cc']
if Extension is cpp_extension.CUDAExtension:
ext_suffixes.append('.cu')
elif Extension is cpp_extension.MPSExtension:
ext_suffixes.append('.mm')
sources = []
for path in dirs:
for ext_suffix in ext_suffixes:
sources += glob.glob(path + '/*' + ext_suffix, recursive=True)
return sources
ext_modules = [
Extension(
name='seetadet.ops._C',
sources=find_sources('**'),
),
]
setup(
name='seetadet',
ext_modules=ext_modules,
cmdclass={'build_ext': cpp_extension.BuildExtension},
)
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_H_
#include "../utils/detection/anchors.h"
#include "../utils/detection/bbox.h"
#include "../utils/detection/mask.h"
#include "../utils/detection/nms.h"
#include "../utils/detection/proposals.h"
#include "../utils/detection/types.h"
#endif // DRAGON_EXTENSION_UTILS_DETECTION_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_ANCHORS_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_ANCHORS_H_
#include "../../utils/detection/types.h"
namespace dragon {
namespace detection {
/*!
* Anchor Functions.
*/
template <typename IndexT>
inline void SetGridArgs(
const int num_anchors,
const int num_cell_anchors,
const IndexT* grid_info,
vector<GridArgs<IndexT>>& grid_args) {
IndexT grid_offset = 0;
for (int i = 0; i < grid_args.size(); ++i, grid_info += 2) {
auto& args = grid_args[i];
args.h = grid_info[0];
args.w = grid_info[1];
args.size = num_cell_anchors * args.h * args.w;
args.offset = grid_offset;
grid_offset += args.size;
}
std::stringstream ss;
if (grid_offset != num_anchors) {
ss << "Mismatched number of anchors. (Excepted ";
ss << num_anchors << ", Got " << grid_offset << ")";
for (int i = 0; i < grid_args.size(); ++i) {
ss << "\nGrid #" << i << ": "
<< "A=" << num_cell_anchors << ", H=" << grid_args[i].h
<< ", W=" << grid_args[i].w << "\n";
}
}
if (!ss.str().empty()) LOG(FATAL) << ss.str();
}
template <typename T>
inline void GenerateAnchors(
const int stride,
const int num_ratios,
const int num_scales,
const T* ratios,
const T* scales,
T* anchors) {
T* offset_anchors = anchors;
T x = T(0.5) * T(stride), y = T(0.5) * T(stride);
for (int i = 0; i < num_ratios; ++i) {
const T ratio_w = std::sqrt(T(1) / ratios[i]);
const T ratio_h = ratio_w * ratios[i];
for (int j = 0; j < num_scales; ++j) {
offset_anchors[0] = -x * ratio_w * scales[j];
offset_anchors[1] = -y * ratio_h * scales[j];
offset_anchors[2] = x * ratio_w * scales[j];
offset_anchors[3] = y * ratio_h * scales[j];
offset_anchors += 4;
}
}
}
template <typename T>
inline void GetAnchors(
const int num_anchors,
const int num_cell_anchors,
const GridArgs<int64_t>& args,
const T* cell_anchors,
const int64_t* indices,
T* anchors) {
for (int i = 0; i < num_anchors; ++i) {
auto index = indices[i];
const auto w = index % args.w;
index /= args.w;
const auto h = index % args.h;
index /= args.h;
const auto shift_x = T(w * args.stride);
const auto shift_y = T(h * args.stride);
auto* offset_anchors = anchors + i * 5;
const auto* offset_cell_anchors = cell_anchors + index * 4;
offset_anchors[0] = shift_x + offset_cell_anchors[0];
offset_anchors[1] = shift_y + offset_cell_anchors[1];
offset_anchors[2] = shift_x + offset_cell_anchors[2];
offset_anchors[3] = shift_y + offset_cell_anchors[3];
}
}
template <typename T>
inline void GetAnchors(
const int num_anchors,
const int num_cell_anchors,
const int num_classes,
const GridArgs<int64_t>& args,
const T* cell_anchors,
const int64_t* indices,
T* anchors) {
for (int i = 0; i < num_anchors; ++i) {
auto index = indices[i];
index /= num_classes;
const auto w = index % args.w;
index /= args.w;
const auto h = index % args.h;
index /= args.h;
const auto shift_x = T(w * args.stride);
const auto shift_y = T(h * args.stride);
auto* offset_anchors = anchors + i * 7 + 1;
const auto* offset_cell_anchors = cell_anchors + index * 4;
offset_anchors[0] = shift_x + offset_cell_anchors[0];
offset_anchors[1] = shift_y + offset_cell_anchors[1];
offset_anchors[2] = shift_x + offset_cell_anchors[2];
offset_anchors[3] = shift_y + offset_cell_anchors[3];
}
}
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_ANCHORS_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_BBOX_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_BBOX_H_
#include "../../utils/detection/types.h"
#if defined(__CUDACC__)
#define HOSTDEVICE_DECL inline __host__ __device__
#else
#define HOSTDEVICE_DECL inline
#endif
namespace dragon {
namespace detection {
/*
* BBox Functions.
*/
template <typename T, class BoxT>
inline void SortBoxes(const int N, T* data, bool descend = true) {
auto* boxes = reinterpret_cast<BoxT*>(data);
std::sort(boxes, boxes + N, [descend](BoxT lhs, BoxT rhs) {
return descend ? (lhs.score > rhs.score) : (lhs.score < rhs.score);
});
}
/*
* BBox Utilities.
*/
namespace utils {
template <typename T>
HOSTDEVICE_DECL bool CheckIoU(const T thresh, const T* a, const T* b) {
#if defined(__CUDACC__)
const T x1 = max(a[0], b[0]);
const T y1 = max(a[1], b[1]);
const T x2 = min(a[2], b[2]);
const T y2 = min(a[3], b[3]);
const T width = max(T(0), x2 - x1);
const T height = max(T(0), y2 - y1);
#else
const T x1 = std::max(a[0], b[0]);
const T y1 = std::max(a[1], b[1]);
const T x2 = std::min(a[2], b[2]);
const T y2 = std::min(a[3], b[3]);
const T width = std::max(T(0), x2 - x1);
const T height = std::max(T(0), y2 - y1);
#endif
const T inter = width * height;
const T Sa = (a[2] - a[0]) * (a[3] - a[1]);
const T Sb = (b[2] - b[0]) * (b[3] - b[1]);
return inter >= thresh * (Sa + Sb - inter);
}
template <typename T>
inline void BBoxTransform(
const T dx,
const T dy,
const T dw,
const T dh,
const T im_w,
const T im_h,
const T im_scale_h,
const T im_scale_w,
T* bbox) {
const T w = bbox[2] - bbox[0];
const T h = bbox[3] - bbox[1];
const T ctr_x = bbox[0] + T(0.5) * w;
const T ctr_y = bbox[1] + T(0.5) * h;
const T pred_ctr_x = dx * w + ctr_x;
const T pred_ctr_y = dy * h + ctr_y;
const T pred_w = std::exp(dw) * w;
const T pred_h = std::exp(dh) * h;
const T x1 = pred_ctr_x - T(0.5) * pred_w;
const T y1 = pred_ctr_y - T(0.5) * pred_h;
const T x2 = pred_ctr_x + T(0.5) * pred_w;
const T y2 = pred_ctr_y + T(0.5) * pred_h;
bbox[0] = std::max(T(0), std::min(x1, im_w)) / im_scale_w;
bbox[1] = std::max(T(0), std::min(y1, im_h)) / im_scale_h;
bbox[2] = std::max(T(0), std::min(x2, im_w)) / im_scale_w;
bbox[3] = std::max(T(0), std::min(y2, im_h)) / im_scale_h;
}
template <typename T>
inline int GetBBoxLevel(
const int lvl_min,
const int lvl_max,
const int lvl0,
const int s0,
T* bbox) {
const T w = bbox[2] - bbox[0];
const T h = bbox[3] - bbox[1];
if (w <= T(0) || h <= T(0)) return -1;
const T s = std::sqrt(w * h);
const int lvl = lvl0 + std::log2(s / s0 + T(1e-6));
return std::min(std::max(lvl, lvl_min), lvl_max);
}
} // namespace utils
} // namespace detection
} // namespace dragon
#undef HOSTDEVICE_DECL
#endif // DRAGON_EXTENSION_UTILS_DETECTION_BBOX_H_
#include <dragon/core/context.h>
#include "../../../utils/detection/mask.h"
namespace dragon {
namespace detection {
namespace {
template <typename IndexT>
inline bool WithinBounds2d(IndexT h, IndexT w, IndexT H, IndexT W) {
return h >= IndexT(0) && h < H && w >= IndexT(0) && w < W;
}
template <typename T>
void _PasteMask(
const int N,
const int H,
const int W,
const int mask_h,
const int mask_w,
const T thresh,
const T* masks,
const float* boxes,
uint8_t* im) {
const auto HxW = H * W;
for (int n = 0; n < N; ++n) {
const auto count = H * W;
const float* box = boxes + n * 4;
const T* mask = masks + n * mask_h * mask_w;
uint8_t* offset_im = im + n * H * W;
const float box_w_half = (box[2] - box[0]) * 0.5f;
const float box_h_half = (box[3] - box[1]) * 0.5f;
const float mask_w_half = float(mask_w) * 0.5f;
const float mask_h_half = float(mask_w) * 0.5f;
for (int index = 0; index < HxW; ++index) {
const int w = index % W;
const int h = index / W;
const float gx = (float(w) + 0.5f - box[0]) / box_w_half;
const float gy = (float(h) + 0.5f - box[1]) / box_h_half;
const float ix = gx * mask_w_half - 0.5f;
const float iy = gy * mask_h_half - 0.5f;
const int ix_nw = floorf(ix);
const int iy_nw = floorf(iy);
const int ix_ne = ix_nw + 1;
const int iy_ne = iy_nw;
const int ix_sw = ix_nw;
const int iy_sw = iy_nw + 1;
const int ix_se = ix_nw + 1;
const int iy_se = iy_nw + 1;
T nw = T((ix_se - ix) * (iy_se - iy));
T ne = T((ix - ix_sw) * (iy_sw - iy));
T sw = T((ix_ne - ix) * (iy - iy_ne));
T se = T((ix - ix_nw) * (iy - iy_nw));
T val = T(0);
if (WithinBounds2d(iy_nw, ix_nw, mask_h, mask_w)) {
val += mask[iy_nw * mask_w + ix_nw] * nw;
}
if (WithinBounds2d(iy_ne, ix_ne, mask_h, mask_w)) {
val += mask[iy_ne * mask_w + ix_ne] * ne;
}
if (WithinBounds2d(iy_sw, ix_sw, mask_h, mask_w)) {
val += mask[iy_sw * mask_w + ix_sw] * sw;
}
if (WithinBounds2d(iy_se, ix_se, mask_h, mask_w)) {
val += mask[iy_se * mask_w + ix_se] * se;
}
*(offset_im++) = (val >= thresh ? uint8_t(1) : uint8_t(0));
}
}
}
} // namespace
template <>
void PasteMask<float, CPUContext>(
const int N,
const int H,
const int W,
const int mask_h,
const int mask_w,
const float thresh,
const float* masks,
const float* boxes,
uint8_t* im,
CPUContext* ctx) {
_PasteMask(N, H, W, mask_h, mask_w, thresh, masks, boxes, im);
}
} // namespace detection
} // namespace dragon
#include <dragon/core/context.h>
#include "../../../utils/detection/bbox.h"
#include "../../../utils/detection/nms.h"
namespace dragon {
namespace detection {
template <>
void ApplyNMS<float, CPUContext>(
const int N,
const int K,
const int boxes_offset,
const float thresh,
const float* boxes,
vector<int64_t>& indices,
CPUContext* ctx) {
boxes = boxes + boxes_offset;
int num_selected = 0;
indices.resize(K);
vector<char> is_dead(N, 0);
for (int i = 0; i < N; ++i) {
if (is_dead[i]) continue;
indices[num_selected++] = i;
if (num_selected >= K) break;
for (int j = i + 1; j < N; ++j) {
if (is_dead[j]) continue;
if (!utils::CheckIoU(thresh, &boxes[i * 5], &boxes[j * 5])) continue;
is_dead[j] = 1;
}
}
indices.resize(num_selected);
}
} // namespace detection
} // namespace dragon
#include <dragon/core/context.h>
#include "../../../utils/detection/proposals.h"
namespace dragon {
namespace detection {
namespace {
template <typename KeyT, typename ValueT>
inline void
ArgPartition(const int N, const int K, const ValueT* values, KeyT* keys) {
std::nth_element(keys, keys + K, keys + N, [&values](KeyT lhs, KeyT rhs) {
return values[lhs] > values[rhs];
});
}
} // namespace
template <>
void SelectTopK<float, CPUContext>(
const int N,
const int K,
const float thresh,
const float* scores,
vector<float>& out_scores,
vector<int64_t>& out_indices,
CPUContext* ctx) {
int num_selected = 0;
out_indices.resize(N);
if (thresh > 0.f) {
for (int i = 0; i < N; ++i) {
if (scores[i] > thresh) {
out_indices[num_selected++] = i;
}
}
} else {
num_selected = N;
std::iota(out_indices.begin(), out_indices.end(), 0);
}
if (num_selected > K) {
ArgPartition(num_selected, K, scores, out_indices.data());
out_scores.resize(K);
out_indices.resize(K);
for (int i = 0; i < K; ++i) {
out_scores[i] = scores[out_indices[i]];
}
} else {
out_scores.resize(num_selected);
out_indices.resize(num_selected);
for (int i = 0; i < num_selected; ++i) {
out_scores[i] = scores[out_indices[i]];
}
}
}
} // namespace detection
} // namespace dragon
#include <dragon/core/context_cuda.h>
#include "../../../utils/detection/mask.h"
namespace dragon {
namespace detection {
namespace {
template <typename IndexT>
inline __device__ bool WithinBounds2d(IndexT h, IndexT w, IndexT H, IndexT W) {
return h >= IndexT(0) && h < H && w >= IndexT(0) && w < W;
}
template <typename T>
__global__ void _PasteMask(
const int nthreads,
const int H,
const int W,
const int mask_h,
const int mask_w,
const T thresh,
const T* masks,
const float* boxes,
uint8_t* im) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int w = index % W;
const int h = index / W % H;
const int n = index / (H * W);
const float* box = boxes + n * 4;
const T* mask = masks + n * mask_h * mask_w;
const float gx = (float(w) + 0.5f - box[0]) / (box[2] - box[0]) * 2.f;
const float gy = (float(h) + 0.5f - box[1]) / (box[3] - box[1]) * 2.f;
const float ix = (gx * float(mask_w) - 1.f) * 0.5f;
const float iy = (gy * float(mask_h) - 1.f) * 0.5f;
const int ix_nw = floorf(ix);
const int iy_nw = floorf(iy);
const int ix_ne = ix_nw + 1;
const int iy_ne = iy_nw;
const int ix_sw = ix_nw;
const int iy_sw = iy_nw + 1;
const int ix_se = ix_nw + 1;
const int iy_se = iy_nw + 1;
T nw = T((ix_se - ix) * (iy_se - iy));
T ne = T((ix - ix_sw) * (iy_sw - iy));
T sw = T((ix_ne - ix) * (iy - iy_ne));
T se = T((ix - ix_nw) * (iy - iy_nw));
T val = T(0);
if (WithinBounds2d(iy_nw, ix_nw, mask_h, mask_w)) {
val += mask[iy_nw * mask_w + ix_nw] * nw;
}
if (WithinBounds2d(iy_ne, ix_ne, mask_h, mask_w)) {
val += mask[iy_ne * mask_w + ix_ne] * ne;
}
if (WithinBounds2d(iy_sw, ix_sw, mask_h, mask_w)) {
val += mask[iy_sw * mask_w + ix_sw] * sw;
}
if (WithinBounds2d(iy_se, ix_se, mask_h, mask_w)) {
val += mask[iy_se * mask_w + ix_se] * se;
}
im[index] = (val >= thresh ? uint8_t(1) : uint8_t(0));
}
}
} // namespace
template <>
void PasteMask<float, CUDAContext>(
const int N,
const int H,
const int W,
const int mask_h,
const int mask_w,
const float thresh,
const float* masks,
const float* boxes,
uint8_t* im,
CUDAContext* ctx) {
const auto NxHxW = N * H * W;
_PasteMask<<<CUDA_BLOCKS(NxHxW), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
NxHxW, H, W, mask_h, mask_w, thresh, masks, boxes, im);
}
} // namespace detection
} // namespace dragon
#include <dragon/core/context_cuda.h>
#include <dragon/core/workspace.h>
#include "../../../utils/detection/bbox.h"
#include "../../../utils/detection/nms.h"
#include "../../../utils/detection/utils.h"
namespace dragon {
namespace detection {
namespace {
#define NUM_THREADS 64
template <typename T>
__global__ void _NonMaxSuppression(
const int N,
const T thresh,
const T* boxes,
uint64_t* mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
if (row_start > col_start) return;
const int row_size = min(N - row_start * NUM_THREADS, NUM_THREADS);
const int col_size = min(N - col_start * NUM_THREADS, NUM_THREADS);
__shared__ T block_boxes[NUM_THREADS * 4];
if (threadIdx.x < col_size) {
auto* offset_block_boxes = block_boxes + threadIdx.x * 4;
auto* offset_boxes = boxes + (col_start * NUM_THREADS + threadIdx.x) * 5;
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(offset_block_boxes++) = *(offset_boxes++);
}
}
__syncthreads();
if (threadIdx.x < row_size) {
const int index = row_start * NUM_THREADS + threadIdx.x;
const T* offset_boxes = boxes + index * 5;
uint64_t val = 0;
const int start = (row_start == col_start) ? (threadIdx.x + 1) : 0;
for (int i = start; i < col_size; ++i) {
if (utils::CheckIoU(thresh, offset_boxes, block_boxes + i * 4)) {
val |= (uint64_t(1) << i);
}
}
mask[index * gridDim.x + col_start] = val;
}
}
} // namespace
template <>
void ApplyNMS<float, CUDAContext>(
const int N,
const int K,
const int boxes_offset,
const float thresh,
const float* boxes,
vector<int64_t>& indices,
CUDAContext* ctx) {
boxes = boxes + boxes_offset;
const auto num_blocks = utils::DivUp(N, NUM_THREADS);
auto* NMS_mask = ctx->workspace()->CreateTensor("NMS_mask");
NMS_mask->Reshape({N * num_blocks});
auto* mask = reinterpret_cast<uint64_t*>(
NMS_mask->template mutable_data<int64_t, CUDAContext>());
vector<uint64_t> mask_host(N * num_blocks);
_NonMaxSuppression<<<
dim3(num_blocks, num_blocks),
NUM_THREADS,
0,
ctx->cuda_stream()>>>(N, thresh, boxes, mask);
CUDA_CHECK(cudaMemcpyAsync(
mask_host.data(),
mask,
mask_host.size() * sizeof(uint64_t),
cudaMemcpyDeviceToHost,
ctx->cuda_stream()));
ctx->FinishDeviceComputation();
vector<uint64_t> is_dead(num_blocks);
memset(&is_dead[0], 0, sizeof(uint64_t) * num_blocks);
int num_selected = 0;
indices.resize(K);
for (int i = 0; i < N; ++i) {
const int nblock = i / NUM_THREADS, inblock = i % NUM_THREADS;
if (!(is_dead[nblock] & (uint64_t(1) << inblock))) {
indices[num_selected++] = i;
if (num_selected >= K) break;
auto* offset_mask = &mask_host[0] + i * num_blocks;
for (int j = nblock; j < num_blocks; ++j) {
is_dead[j] |= offset_mask[j];
}
}
}
indices.resize(num_selected);
}
} // namespace detection
} // namespace dragon
#include <dragon/core/context_cuda.h>
#include <dragon/core/workspace.h>
#include <dragon/utils/device/common_thrust.h>
#include "../../../utils/detection/iterator.h"
#include "../../../utils/detection/proposals.h"
namespace dragon {
namespace detection {
namespace {
template <typename KeyT, typename ValueT>
struct ThresholdFunctor {
ThresholdFunctor(ValueT thresh) : thresh_(thresh) {}
inline __device__ bool operator()(
const thrust::tuple<KeyT, ValueT>& kv) const {
return thrust::get<1>(kv) > thresh_;
}
ValueT thresh_;
};
template <typename IterT>
inline void ArgPartition(const int N, const int K, IterT data) {
std::nth_element(
data,
data + K,
data + N,
[](const typename IterT::value_type& lhs,
const typename IterT::value_type& rhs) {
return *lhs.value_ptr > *rhs.value_ptr;
});
}
} // namespace
template <>
void SelectTopK<float, CUDAContext>(
const int N,
const int K,
const float thresh,
const float* scores,
vector<float>& out_scores,
vector<int64_t>& out_indices,
CUDAContext* ctx) {
int num_selected = N;
int64_t* indices = nullptr;
if (thresh > 0.f) {
indices = ctx->workspace()->data<int64_t, CUDAContext>(N, "BufferKernel");
auto policy = thrust::cuda::par.on(ctx->cuda_stream());
auto functor = ThresholdFunctor<int64_t, float>(thresh);
thrust::sequence(policy, indices, indices + N);
auto kv = thrust::make_tuple(indices, const_cast<float*>(scores));
auto first = thrust::make_zip_iterator(kv);
auto last = thrust::partition(policy, first, first + N, functor);
num_selected = last - first;
}
out_scores.resize(num_selected);
out_indices.resize(num_selected);
CUDA_CHECK(cudaMemcpyAsync(
out_scores.data(),
scores,
num_selected * sizeof(float),
cudaMemcpyDeviceToHost,
ctx->cuda_stream()));
if (thresh > 0.f) {
CUDA_CHECK(cudaMemcpyAsync(
out_indices.data(),
indices,
num_selected * sizeof(int64_t),
cudaMemcpyDeviceToHost,
ctx->cuda_stream()));
} else {
std::iota(out_indices.begin(), out_indices.end(), 0);
}
ctx->FinishDeviceComputation();
if (num_selected > K) {
auto iter = KeyValueMapIterator<KeyValueMap<int64_t, float>>(
out_indices.data(), out_scores.data());
ArgPartition(num_selected, K, iter);
out_scores.resize(K);
out_indices.resize(K);
}
}
} // namespace detection
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_ITERATOR_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_ITERATOR_H_
#include <dragon/core/common.h>
namespace dragon {
namespace detection {
template <typename MapT>
class KeyValueMapIterator
: public std::iterator<std::input_iterator_tag, MapT> {
public:
typedef KeyValueMapIterator self_type;
typedef ptrdiff_t difference_type;
typedef MapT value_type;
typedef MapT& reference;
KeyValueMapIterator(
typename MapT::key_type* key_ptr,
typename MapT::value_type* value_ptr)
: key_ptr_(key_ptr), value_ptr_(value_ptr) {}
self_type operator++(int) {
self_type ret = *this;
key_ptr_++;
value_ptr_++;
return ret;
}
self_type operator++() {
key_ptr_++;
value_ptr_++;
return *this;
}
self_type operator--() {
key_ptr_--;
value_ptr_--;
return *this;
}
self_type operator--(int) {
self_type ret = *this;
key_ptr_--;
value_ptr_--;
return ret;
}
reference operator*() const {
if (map_.key_ptr != key_ptr_) {
map_.key_ptr = key_ptr_;
map_.value_ptr = value_ptr_;
}
return map_;
}
self_type operator+(difference_type n) const {
return self_type(key_ptr_ + n, value_ptr_ + n);
}
self_type& operator+=(difference_type n) {
key_ptr_ += n;
value_ptr_ += n;
return *this;
}
self_type operator-(difference_type n) const {
return self_type(key_ptr_ - n, value_ptr_ - n);
}
self_type& operator-=(difference_type n) {
key_ptr_ -= n;
value_ptr_ -= n;
return *this;
}
difference_type operator-(self_type other) const {
return key_ptr_ - other.key_ptr_;
}
bool operator<(const self_type& rhs) const {
return key_ptr_ < rhs.key_ptr_;
}
bool operator<=(const self_type& rhs) const {
return key_ptr_ <= rhs.key_ptr_;
}
bool operator==(const self_type& rhs) const {
return key_ptr_ == rhs.key_ptr_;
}
bool operator!=(const self_type& rhs) const {
return key_ptr_ != rhs.key_ptr_;
}
private:
mutable MapT map_;
typename MapT::key_type* key_ptr_;
typename MapT::value_type* value_ptr_;
};
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_ITERATOR_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_MASK_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_MASK_H_
#include "../../utils/detection/types.h"
namespace dragon {
namespace detection {
/*
* Mask Functions.
*/
template <typename T, class Context>
void PasteMask(
const int N,
const int H,
const int W,
const int mask_h,
const int mask_w,
const float thresh,
const T* masks,
const float* boxes,
uint8_t* im,
Context* ctx);
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_MASK_H_
#include <dragon/core/context_mps.h>
#include "../../../utils/detection/mask.h"
namespace dragon {
namespace detection {
namespace {
const static string METAL_SHADERS = R"(
#include <metal_stdlib>
using namespace metal;
constant int int_arg1 [[function_constant(0)]]; // H
constant int int_arg2 [[function_constant(1)]]; // W
constant int int_arg3 [[function_constant(2)]]; // mask_h
constant int int_arg4 [[function_constant(3)]]; // mask_w
constant float float_arg1 [[function_constant(4)]]; // thresh
template <typename IndexT>
bool WithinBounds2d(IndexT h, IndexT w, IndexT H, IndexT W) {
return h >= IndexT(0) && h < H && w >= IndexT(0) && w < W;
}
template <typename T>
kernel void PasteMask(
device const T* masks,
device const float* boxes,
device uint8_t* im,
const uint index [[thread_position_in_grid]]) {
const int w = int(index) % int_arg2;
const int h = int(index) / int_arg2 % int_arg1;
const int n = int(index) / (int_arg2 * int_arg1);
device const float* box = boxes + n * 4;
device const T* mask = masks + n * int_arg3 * int_arg4;
const float gx = (float(w) + 0.5f - box[0]) / (box[2] - box[0]) * 2.f;
const float gy = (float(h) + 0.5f - box[1]) / (box[3] - box[1]) * 2.f;
const float ix = (gx * float(int_arg4) - 1.f) * 0.5f;
const float iy = (gy * float(int_arg3) - 1.f) * 0.5f;
const int ix_nw = floor(ix);
const int iy_nw = floor(iy);
const int ix_ne = ix_nw + 1;
const int iy_ne = iy_nw;
const int ix_sw = ix_nw;
const int iy_sw = iy_nw + 1;
const int ix_se = ix_nw + 1;
const int iy_se = iy_nw + 1;
T nw = T((ix_se - ix) * (iy_se - iy));
T ne = T((ix - ix_sw) * (iy_sw - iy));
T sw = T((ix_ne - ix) * (iy - iy_ne));
T se = T((ix - ix_nw) * (iy - iy_nw));
T val = T(0);
if (WithinBounds2d(iy_nw, ix_nw, int_arg3, int_arg4)) {
val += mask[iy_nw * int_arg4 + ix_nw] * nw;
}
if (WithinBounds2d(iy_ne, ix_ne, int_arg3, int_arg4)) {
val += mask[iy_ne * int_arg4 + ix_ne] * ne;
}
if (WithinBounds2d(iy_sw, ix_sw, int_arg3, int_arg4)) {
val += mask[iy_sw * int_arg4 + ix_sw] * sw;
}
if (WithinBounds2d(iy_se, ix_se, int_arg3, int_arg4)) {
val += mask[iy_se * int_arg4 + ix_se] * se;
}
im[index] = (val >= T(float_arg1) ? uint8_t(1) : uint8_t(0));
}
#define INSTANTIATE_KERNEL(T) \
template [[host_name("PasteMask_"#T)]] \
kernel void PasteMask( \
device const T*, device const float*, device uint8_t*, uint);
INSTANTIATE_KERNEL(float);
#undef INSTANTIATE_KERNEL
)";
} // namespace
template <>
void PasteMask<float, MPSContext>(
const int N,
const int H,
const int W,
const int mask_h,
const int mask_w,
const float thresh,
const float* masks,
const float* boxes,
uint8_t* im,
MPSContext* ctx) {
auto kernel = MPSKernel::TypedString<float>("PasteMask");
auto args = vector<MPSConstant>({
MPSConstant(&H, MTLDataTypeInt, 0),
MPSConstant(&W, MTLDataTypeInt, 1),
MPSConstant(&mask_h, MTLDataTypeInt, 2),
MPSConstant(&mask_w, MTLDataTypeInt, 3),
MPSConstant(&thresh, MTLDataTypeFloat, 4),
});
auto* command_buffer = ctx->mps_stream()->command_buffer();
auto* encoder = [command_buffer computeCommandEncoder];
auto* pso = MPSKernel(kernel, METAL_SHADERS).GetState(ctx, args);
[encoder setComputePipelineState:pso];
[encoder setBuffer:id<MTLBuffer>(masks) offset:0 atIndex:0];
[encoder setBuffer:id<MTLBuffer>(boxes) offset:0 atIndex:1];
[encoder setBuffer:id<MTLBuffer>(im) offset:0 atIndex:2];
MPSDispatchThreads((N * H * W), encoder, pso);
[encoder endEncoding];
[encoder release];
}
} // namespace detection
} // namespace dragon
#include <dragon/core/context_mps.h>
#include <dragon/core/workspace.h>
#include "../../../utils/detection/nms.h"
#include "../../../utils/detection/utils.h"
namespace dragon {
namespace detection {
namespace {
#define NUM_THREADS 64
const static string METAL_SHADERS = R"(
#include <metal_stdlib>
using namespace metal;
constant uint uint_arg1 [[function_constant(0)]];
constant float float_arg1 [[function_constant(1)]];
template <typename T>
bool CheckIoU(const T thresh, device const T* a, threadgroup T* b) {
const T x1 = max(a[0], b[0]);
const T y1 = max(a[1], b[1]);
const T x2 = min(a[2], b[2]);
const T y2 = min(a[3], b[3]);
const T width = max(T(0), x2 - x1);
const T height = max(T(0), y2 - y1);
const T inter = width * height;
const T Sa = (a[2] - a[0]) * (a[3] - a[1]);
const T Sb = (b[2] - b[0]) * (b[3] - b[1]);
return inter >= thresh * (Sa + Sb - inter);
}
template <typename T>
kernel void NonMaxSuppression(
device const T* boxes,
device uint64_t* mask,
const uint2 gridDim [[threadgroups_per_grid]],
const uint2 blockIdx [[threadgroup_position_in_grid]],
const uint2 threadIdx [[thread_position_in_threadgroup]]) {
const uint row_start = blockIdx.y;
const uint col_start = blockIdx.x;
if (row_start > col_start) return;
const uint row_size = min(uint_arg1 - row_start * uint(64), uint(64));
const uint col_size = min(uint_arg1 - col_start * uint(64), uint(64));
threadgroup T block_boxes[256];
if (threadIdx.x < col_size) {
threadgroup T* offset_block_boxes = block_boxes + threadIdx.x * 4;
device const T* offset_boxes = boxes + (col_start * uint(64) + threadIdx.x) * 5;
for (int i = 0; i < 4; ++i) {
*(offset_block_boxes++) = *(offset_boxes++);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (threadIdx.x < row_size) {
const uint index = row_start * uint(64) + threadIdx.x;
device const T* offset_boxes = boxes + index * 5;
uint64_t val = 0;
const uint start = (row_start == col_start) ? (threadIdx.x + 1) : 0;
for (uint i = start; i < col_size; ++i) {
if (CheckIoU(T(float_arg1), offset_boxes, block_boxes + i * 4)) {
val |= (uint64_t(1) << i);
}
}
mask[index * gridDim.x + col_start] = val;
}
}
#define INSTANTIATE_KERNEL(T) \
template [[host_name("NonMaxSuppression_"#T)]] \
kernel void NonMaxSuppression( \
device const T*, device uint64_t*, uint2, uint2, uint2);
INSTANTIATE_KERNEL(float);
#undef INSTANTIATE_KERNEL
)";
} // namespace
template <>
void ApplyNMS<float, MPSContext>(
const int N,
const int K,
const int boxes_offset,
const float thresh,
const float* boxes,
vector<int64_t>& indices,
MPSContext* ctx) {
const auto num_blocks = utils::DivUp(N, NUM_THREADS);
auto* NMS_mask = ctx->workspace()->CreateTensor("NMS_mask");
NMS_mask->Reshape({N * num_blocks});
auto* mask = reinterpret_cast<uint64_t*>(
NMS_mask->template mutable_data<int64_t, MPSContext>());
auto kernel = MPSKernel::TypedString<float>("NonMaxSuppression");
const uint arg1 = N;
auto args = vector<MPSConstant>({
MPSConstant(&arg1, MTLDataTypeUInt, 0),
MPSConstant(&thresh, MTLDataTypeFloat, 1),
});
auto* command_buffer = ctx->mps_stream()->command_buffer();
auto* encoder = [command_buffer computeCommandEncoder];
auto* pso = MPSKernel(kernel, METAL_SHADERS).GetState(ctx, args);
[encoder setComputePipelineState:pso];
[encoder setBuffer:id<MTLBuffer>(boxes) offset:boxes_offset * 4 atIndex:0];
[encoder setBuffer:id<MTLBuffer>(mask) offset:0 atIndex:1];
[encoder dispatchThreadgroups:MTLSizeMake(num_blocks, num_blocks, 1)
threadsPerThreadgroup:MTLSizeMake(NUM_THREADS, 1, 1)];
[encoder endEncoding];
[encoder release];
ctx->FinishDeviceComputation();
mask = reinterpret_cast<uint64_t*>(
const_cast<int64_t*>(NMS_mask->template data<int64_t, CPUContext>()));
vector<uint64_t> is_dead(num_blocks);
memset(&is_dead[0], 0, sizeof(uint64_t) * num_blocks);
int num_selected = 0;
indices.resize(K);
for (int i = 0; i < N; ++i) {
const int nblock = i / NUM_THREADS, inblock = i % NUM_THREADS;
if (!(is_dead[nblock] & (uint64_t(1) << inblock))) {
indices[num_selected++] = i;
if (num_selected >= K) break;
auto* offset_mask = mask + i * num_blocks;
for (int j = nblock; j < num_blocks; ++j) {
is_dead[j] |= offset_mask[j];
}
}
}
indices.resize(num_selected);
}
} // namespace detection
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_NMS_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_NMS_H_
#include "../../utils/detection/types.h"
namespace dragon {
namespace detection {
template <typename T, class Context>
void ApplyNMS(
const int N,
const int K,
const int boxes_offset,
const T thresh,
const T* boxes,
vector<int64_t>& indices,
Context* ctx);
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_NMS_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_PROPOSALS_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_PROPOSALS_H_
#include "../../utils/detection/bbox.h"
#include "../../utils/detection/types.h"
namespace dragon {
namespace detection {
template <typename T, class Context>
void SelectTopK(
const int N,
const int K,
const float thresh,
const T* input_scores,
vector<T>& output_scores,
vector<int64_t>& output_indices,
Context* ctx);
template <typename T>
void DecodeProposals(
const int num_proposals,
const int num_anchors,
const ImageArgs<int64_t>& im_args,
const GridArgs<int64_t>& grid_args,
const T* scores,
const T* deltas,
const int64_t* indices,
T* proposals) {
T* offset_proposals = proposals;
const int64_t index_min = grid_args.offset;
const T* offset_dx = deltas;
const T* offset_dy = deltas + num_anchors;
const T* offset_dw = deltas + num_anchors * 2;
const T* offset_dh = deltas + num_anchors * 3;
for (int i = 0; i < num_proposals; ++i) {
const auto index = indices[i] + index_min;
utils::BBoxTransform(
offset_dx[index],
offset_dy[index],
offset_dw[index],
offset_dh[index],
T(im_args.w),
T(im_args.h),
T(1),
T(1),
offset_proposals);
offset_proposals[4] = scores[i];
offset_proposals += 5;
}
}
template <typename T>
void DecodeDetections(
const int num_dets,
const int num_anchors,
const int num_classes,
const ImageArgs<int64_t>& im_args,
const GridArgs<int64_t>& grid_args,
const T* scores,
const T* deltas,
const int64_t* indices,
T* dets) {
T* offset_dets = dets;
const int64_t index_min = num_classes * grid_args.offset;
const T* offset_dx = deltas;
const T* offset_dy = deltas + num_anchors;
const T* offset_dw = deltas + num_anchors * 2;
const T* offset_dh = deltas + num_anchors * 3;
for (int i = 0; i < num_dets; ++i) {
const auto index = (indices[i] + index_min) / num_classes;
utils::BBoxTransform(
offset_dx[index],
offset_dy[index],
offset_dw[index],
offset_dh[index],
T(im_args.w),
T(im_args.h),
T(im_args.scale_h),
T(im_args.scale_w),
offset_dets + 1);
offset_dets[0] = T(im_args.batch_ind);
offset_dets[5] = scores[i];
offset_dets[6] = T((indices[i] + index_min) % num_classes + 1);
offset_dets += 7;
}
}
template <typename T>
inline void ApplyHistogram(
const int N,
const int lvl_min,
const int lvl_max,
const int lvl0,
const int s0,
const T* boxes,
const T* batch_indices,
const int64_t* box_indices,
vector<vector<T>>& output_rois) {
int K = 0;
vector<int> keep_indices(N), bin_indices(N);
vector<int> bin_count(lvl_max - lvl_min + 1, 0);
for (int i = 0; i < N; ++i) {
const T* offset_boxes = boxes + box_indices[i] * 5;
auto lvl = utils::GetBBoxLevel(lvl_min, lvl_max, lvl0, s0, offset_boxes);
if (lvl < 0) continue; // Empty.
keep_indices[K++] = i;
bin_indices[i] = lvl - lvl_min;
bin_count[lvl - lvl_min]++;
}
keep_indices.resize(K);
output_rois.resize(lvl_max - lvl_min + 1);
for (int i = 0; i < output_rois.size(); ++i) {
auto& rois = output_rois[i];
rois.resize(std::max(bin_count[i], 1) * 5, T(0));
if (bin_count[i] == 0) rois[0] = T(-1); // Ignored.
}
for (auto i : keep_indices) {
const T* offset_boxes = boxes + box_indices[i] * 5;
const auto bin_index = bin_indices[i];
const auto roi_index = --bin_count[bin_index];
auto& rois = output_rois[bin_index];
T* offset_rois = rois.data() + roi_index * 5;
offset_rois[0] = batch_indices[i];
offset_rois[1] = offset_boxes[0];
offset_rois[2] = offset_boxes[1];
offset_rois[3] = offset_boxes[2];
offset_rois[4] = offset_boxes[3];
}
}
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_PROPOSALS_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_TYPES_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_TYPES_H_
#include <dragon/core/common.h>
namespace dragon {
namespace detection {
template <typename T>
struct Box4d {
T x1, y1, x2, y2;
};
template <typename T>
struct Box5d {
T x1, y1, x2, y2, score;
};
template <typename IndexT>
struct ImageArgs {
ImageArgs(const float* im_info) {
h = im_info[0], w = im_info[1];
scale_h = im_info[2], scale_w = im_info[3];
}
IndexT batch_ind, h, w;
float scale_h, scale_w;
};
template <typename IndexT>
struct GridArgs {
IndexT h, w, stride, size, offset;
};
template <typename KeyT, typename ValueT>
struct KeyValueMap {
typedef KeyT key_type;
typedef ValueT value_type;
friend void swap(KeyValueMap& x, KeyValueMap& y) {
std::swap(*x.key_ptr, *y.key_ptr);
std::swap(*x.value_ptr, *y.value_ptr);
}
KeyT* key_ptr = nullptr;
ValueT* value_ptr = nullptr;
};
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_TYPES_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_DETECTION_UTILS_H_
#define DRAGON_EXTENSION_UTILS_DETECTION_UTILS_H_
namespace dragon {
namespace detection {
/*
* Detection Utilities.
*/
namespace utils {
template <typename T>
inline T DivUp(const T a, const T b) {
return (a + b - T(1)) / b;
}
} // namespace utils
} // namespace detection
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_DETECTION_UTILS_H_
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Sergey Karayev
# --------------------------------------------------------
cimport cython
import numpy as np
cimport numpy as np
DTYPE = np.float
ctypedef np.float_t DTYPE_t
@cython.boundscheck(False)
def bbox_overlaps(
np.ndarray[DTYPE_t, ndim=2] boxes,
np.ndarray[DTYPE_t, ndim=2] query_boxes):
"""
Parameters
----------
boxes: (N, 4) ndarray of float
query_boxes: (K, 4) ndarray of float
Returns
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
cdef unsigned int N = boxes.shape[0]
cdef unsigned int K = query_boxes.shape[0]
cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE)
cdef DTYPE_t iw, ih, box_area
cdef DTYPE_t ua
cdef unsigned int k, n
with nogil:
for k in range(K):
box_area = (
(query_boxes[k, 2] - query_boxes[k, 0]) *
(query_boxes[k, 3] - query_boxes[k, 1])
)
for n in range(N):
iw = (
min(boxes[n, 2], query_boxes[k, 2]) -
max(boxes[n, 0], query_boxes[k, 0])
)
if iw > 0:
ih = (
min(boxes[n, 3], query_boxes[k, 3]) -
max(boxes[n, 1], query_boxes[k, 1])
)
if ih > 0:
ua = float(
(boxes[n, 2] - boxes[n, 0]) *
(boxes[n, 3] - boxes[n, 1]) +
box_area - iw * ih
)
overlaps[n, k] = iw * ih / ua
return overlaps
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
cimport cython
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1) * (y2 - y1)
cdef np.ndarray[np.intp_t, ndim=1] order = scores.argsort()[::-1]
cdef int ndets = dets.shape[0]
cdef np.ndarray[np.int_t, ndim=1] suppressed = \
np.zeros((ndets), dtype=np.int)
# nominal indices
cdef int _i, _j
# sorted indices
cdef int i, j
# temp variables for box i's (the box currently under consideration)
cdef np.float32_t ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
cdef np.float32_t xx1, yy1, xx2, yy2
cdef np.float32_t w, h
cdef np.float32_t inter, ovr
keep = []
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1)
h = max(0.0, yy2 - yy1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return keep
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float thresh,
unsigned int method=0, float sigma=0.5, float score_thresh=0.001):
cdef unsigned int N = boxes.shape[0]
cdef float iw, ih, box_area
cdef float ua
cdef int pos = 0
cdef float maxscore = 0
cdef int maxpos = 0
cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
for i in range(N):
maxscore = boxes[i, 4]
maxpos = i
tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
pos = i + 1
# get max box
while pos < N:
if maxscore < boxes[pos, 4]:
maxscore = boxes[pos, 4]
maxpos = pos
pos = pos + 1
# add max box as a detection
boxes[i,0] = boxes[maxpos,0]
boxes[i,1] = boxes[maxpos,1]
boxes[i,2] = boxes[maxpos,2]
boxes[i,3] = boxes[maxpos,3]
boxes[i,4] = boxes[maxpos,4]
# swap ith box with position of max box
boxes[maxpos,0] = tx1
boxes[maxpos,1] = ty1
boxes[maxpos,2] = tx2
boxes[maxpos,3] = ty2
boxes[maxpos,4] = ts
tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
pos = i + 1
# NMS iterations, note that N changes if detection boxes fall below threshold
while pos < N:
x1 = boxes[pos, 0]
y1 = boxes[pos, 1]
x2 = boxes[pos, 2]
y2 = boxes[pos, 3]
s = boxes[pos, 4]
area = (x2 - x1) * (y2 - y1)
iw = min(tx2, x2) - max(tx1, x1)
if iw > 0:
ih = min(ty2, y2) - max(ty1, y1)
if ih > 0:
ua = float((tx2 - tx1) * (ty2 - ty1) + area - iw * ih)
ov = iw * ih / ua #iou between max box and detection box
if method == 1: # linear
if ov > thresh:
weight = 1 - ov
else:
weight = 1
elif method == 2: # gaussian
weight = np.exp(-(ov * ov) / sigma)
else: # original NMS
if ov > thresh:
weight = 0
else:
weight = 1
boxes[pos, 4] = weight * boxes[pos, 4]
# if box score falls below threshold, discard the box by swapping with last box
# update N
if boxes[pos, 4] < score_thresh:
boxes[pos,0] = boxes[N-1, 0]
boxes[pos,1] = boxes[N-1, 1]
boxes[pos,2] = boxes[N-1, 2]
boxes[pos,3] = boxes[N-1, 3]
boxes[pos,4] = boxes[N-1, 4]
N = N - 1
pos = pos - 1
pos = pos + 1
keep = [i for i in range(N)]
return keep
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build cython extensions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.core import setup
from distutils.extension import Extension
import os
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy as np
def clean_builds():
"""Clean the builds."""
for file in os.listdir('./'):
if file.endswith('.c'):
os.remove(file)
ext_modules = [
Extension(
'seetadet.utils.bbox.cython_bbox',
['cython_bbox.pyx'],
extra_compile_args=['-w'],
include_dirs=[np.get_include()],
),
Extension(
'seetadet.utils.nms.cython_nms',
['cython_nms.pyx'],
extra_compile_args=['-w'],
include_dirs=[np.get_include()],
),
]
setup(
name='seetadet',
ext_modules=cythonize(
ext_modules, compiler_directives={'language_level': '3'}),
cmdclass={'build_ext': build_ext},
)
clean_builds()
# Datasets
## Introduction
This folder is kept for the record and json datasets.
Please prepare the datasets following the [documentation](../../scripts/datasets/README.md).
# Pretrained Models
## Introduction
This folder is kept for the pretrained models.
## ImageNet Pretrained Models
### Training settings
- ResNet models trained with 200 epochs follow the procedure in arXiv.1812.01187.
### ResNet
| Model | Lr sched | Acc@1 | Acc@5 | Source |
| :---: | :------: | :---: | :---: | :----: |
| [R50](https://dragon.seetatech.com/download/seetadet/pretrained/R-50_in1k_cls90e.pkl) | 90e | 76.53 | 93.16 | Ours |
| [R50](https://dragon.seetatech.com/download/seetadet/pretrained/R-50_in1k_cls200e.pkl) | 200e | 78.64 | 94.30 | Ours |
| [R50-A](https://dragon.seetatech.com/download/seetadet/pretrained/R-50-A_in1k_cls120e.pkl) | 120e | 75.30 | 92.20 | MSRA |
### MobileNet
| Model | Lr sched | Acc@1 | Acc@5 | Source |
| :---: | :------: | :---: | :---: | :----: |
| [MobileNetV2](https://dragon.seetatech.com/download/seetadet/pretrained/MobileNetV2_in1k_cls300e.pkl) | 300e | 71.88 | 90.29 | TorchVision |
| [MobileNetV3L](https://dragon.seetatech.com/download/seetadet/pretrained/MobileNetV3L_in1k_cls600e.pkl) | 600e | 74.04 | 91.34 | TorchVision |
### VGG
| Model | Lr sched | Acc@1 | Acc@5 | Source |
| :---: | :------: | :---: | :---: | :----: |
| [VGG16-FCN](https://dragon.seetatech.com/download/seetadet/pretrained/VGG-16-FCN_in1k.pkl) | - | - | - | weiliu89 |
# Python dependencies required for development.
opencv-python
Pillow
pyyaml
prettytable
matplotlib
codewithgpu
shapely
Cython
pycocotools
# Prepare Datasets
## Create Datasets for PASCAL VOC
We assume that raw dataset has the following structure:
```
VOC<year>
|_ JPEGImages
| |_ <im-1-name>.jpg
| |_ ...
| |_ <im-N-name>.jpg
|_ Annotations
| |_ <im-1-name>.xml
| |_ ...
| |_ <im-N-name>.xml
|_ ImageSets
| |_ Main
| | |_ trainval.txt
| | |_ test.txt
| | |_ ...
```
Create record and json dataset by:
```
python pascal_voc.py \
--rec /path/to/datasets/voc_trainval0712 \
--gt /path/to/datasets/voc_trainval0712.json \
--images /path/to/VOC2007/JPEGImages \
/path/to/VOC2012/JPEGImages \
--annotations /path/to/VOC2007/Annotations \
/path/to/VOC2012/Annotations \
--splits /path/to/VOC2007/ImageSets/Main/trainval.txt \
/path/to/VOC2012/ImageSets/Main/trainval.txt
```
## Create Datasets for COCO
We assume that raw dataset has the following structure:
```
COCO
|_ images
| |_ train2017
| | |_ <im-1-name>.jpg
| | |_ ...
| | |_ <im-N-name>.jpg
|_ annotations
| |_ instances_train2017.json
| |_ ...
```
Create record dataset by:
```
python coco.py \
--rec /path/to/datasets/coco_train2017 \
--images /path/to/COCO/images/train2017 \
--annotations /path/to/COCO/annotations/instances_train2017.json
```
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Prepare MS COCO datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import time
import codewithgpu
from pycocotools.coco import COCO
from pycocotools.mask import frPyObjects
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Prepare MS COCO datasets')
parser.add_argument(
'--rec',
default=None,
help='path to write record dataset')
parser.add_argument(
'--images',
nargs='+',
type=str,
default=None,
help='path of images folder')
parser.add_argument(
'--annotations',
nargs='+',
type=str,
default=None,
help='path of annotations folder')
parser.add_argument(
'--splits',
nargs='+',
type=str,
default=None,
help='path of split file')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_example(img_id, img_file, cocoGt):
"""Return the record example."""
img_meta = cocoGt.imgs[img_id]
img_anns = cocoGt.loadAnns(cocoGt.getAnnIds(imgIds=[img_id]))
cat_id_to_cat = dict((v['id'], v['name'])
for v in cocoGt.cats.values())
with open(img_file, 'rb') as f:
img_bytes = bytes(f.read())
height, width = img_meta['height'], img_meta['width']
example = {'id': str(img_id), 'height': height, 'width': width,
'depth': 3, 'content': img_bytes, 'object': []}
for ann in img_anns:
x1 = float(max(0, ann['bbox'][0]))
y1 = float(max(0, ann['bbox'][1]))
x2 = float(min(width, x1 + max(0, ann['bbox'][2])))
y2 = float(min(height, y1 + max(0, ann['bbox'][3])))
mask, polygons = b'', []
segm = ann.get('segmentation', None)
if segm is not None and isinstance(segm, list):
for p in ann['segmentation']:
if len(p) < 6:
print('Remove Invalid segm.')
# Valid polygons have >= 3 points, so require >= 6 coordinates
polygons = [p for p in ann['segmentation'] if len(p) >= 6]
elif segm is not None:
# Crowd masks.
# Some are encoded with wrong height or width.
# Do not use them or decoding error is inevitable.
rle = frPyObjects(ann['segmentation'], height, width)
assert type(rle) == dict
mask = rle['counts']
example['object'].append({
'name': cat_id_to_cat[ann['category_id']],
'xmin': x1, 'ymin': y1, 'xmax': x2, 'ymax': y2,
'mask': mask, 'polygons': polygons,
'difficult': ann.get('iscrowd', 0)})
return example
def write_dataset(args):
assert len(args.images) == len(args.annotations)
if os.path.exists(args.rec):
raise ValueError('The record path is already exist.')
os.makedirs(args.rec)
print('Write record dataset to {}'.format(args.rec))
writer = codewithgpu.RecordWriter(
path=args.rec,
features={
'id': 'string',
'content': 'bytes',
'height': 'int64',
'width': 'int64',
'depth': 'int64',
'object': [{
'name': 'string',
'xmin': 'float64',
'ymin': 'float64',
'xmax': 'float64',
'ymax': 'float64',
'mask': 'bytes',
'polygons': [['float64']],
'difficult': 'int64',
}]
}
)
# Scan all available entries.
print('Scan entries...')
entries, cocoGts = [], []
for ann_file in args.annotations:
cocoGts.append(COCO(ann_file))
if args.splits is not None:
assert len(args.splits) == len(args.images)
for i, split in enumerate(args.splits):
f = open(split, 'r')
for line in f.readlines():
filename = line.strip()
img_id = int(filename)
img_file = os.path.join(args.images[i], filename + '.jpg')
entries.append((img_id, img_file, cocoGts[i]))
f.close()
else:
for i, cocoGt in enumerate(cocoGts):
for info in cocoGt.imgs.values():
img_id = info['id']
img_file = os.path.join(args.images[i], info['file_name'])
entries.append((img_id, img_file, cocoGts[i]))
print('Start Time:', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()
for i, entry in enumerate(entries):
if i > 0 and i % 2000 == 0:
now_time = time.time()
print('{} / {} in {:.2f} sec'.format(
i, len(entries), now_time - start_time))
writer.write(make_example(*entry))
now_time = time.time()
print('{} / {} in {:.2f} sec'.format(
len(entries), len(entries), now_time - start_time))
writer.close()
end_time = time.time()
data_size = os.path.getsize(args.rec + '/00000.data') * 1e-6
print('{} images take {:.2f} MB in {:.2f} sec.'
.format(len(entries), data_size, end_time - start_time))
if __name__ == '__main__':
args = parse_args()
if args.rec is not None:
write_dataset(args)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Prepare JSON datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import sys
import codewithgpu
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Prepare PASCAL VOC datasets')
parser.add_argument(
'--rec',
default=None,
help='path to read record')
parser.add_argument(
'--gt',
default=None,
help='path to write json ground-truth')
parser.add_argument(
'--categories',
nargs='+',
type=str,
default=None,
help='dataset object categories')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def get_image_id(image_name):
image_id = image_name.split('_')[-1].split('.')[0]
try:
return int(image_id)
except ValueError:
return image_name
def write_dataset(args):
dataset = {'images': [], 'categories': [], 'annotations': []}
record_dataset = codewithgpu.RecordDataset(args.rec)
cat_to_cat_id = dict(zip(args.categories,
range(1, len(args.categories) + 1)))
print('Writing json dataset to {}'.format(args.gt))
for cat in args.categories:
dataset['categories'].append({
'name': cat, 'id': cat_to_cat_id[cat]})
for example in record_dataset:
image_id = get_image_id(example['id'])
dataset['images'].append({
'id': image_id, 'height': example['height'],
'width': example['width']})
for obj in example['object']:
if 'x2' in obj:
x1, y1, x2, y2 = obj['x1'], obj['y1'], obj['x2'], obj['y2']
elif 'xmin' in obj:
x1, y1, x2, y2 = obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax']
else:
x1, y1, x2, y2 = obj['bbox']
w, h = x2 - x1, y2 - y1
dataset['annotations'].append({
'id': str(len(dataset['annotations'])),
'bbox': [x1, y1, w, h],
'area': w * h,
'iscrowd': obj.get('difficult', 0),
'image_id': image_id,
'category_id': cat_to_cat_id[obj['name']]})
with open(args.gt, 'w') as f:
json.dump(dataset, f)
if __name__ == '__main__':
args = parse_args()
if args.rec is None or not os.path.exists(args.rec):
raise ValueError('Specify the prepared record dataset.')
if args.gt is None:
raise ValueError('Specify the path to write json dataset.')
write_dataset(args)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Prepare PASCAL VOC datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import time
import codewithgpu
import cv2
import numpy as np
import xml.etree.ElementTree
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Prepare PASCAL VOC datasets')
parser.add_argument(
'--rec',
default=None,
help='path to write record dataset')
parser.add_argument(
'--gt',
default=None,
help='path to write json dataset')
parser.add_argument(
'--images',
nargs='+',
type=str,
default=None,
help='path of images folder')
parser.add_argument(
'--annotations',
nargs='+',
type=str,
default=None,
help='path of annotations folder')
parser.add_argument(
'--splits',
nargs='+',
type=str,
default=None,
help='path of split file')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_example(img_file, xml_file):
"""Return the record example."""
tree = xml.etree.ElementTree.parse(xml_file)
filename = os.path.split(xml_file)[-1]
objects = tree.findall('object')
size = tree.find('size')
example = {'id': filename.split('.')[0], 'object': []}
with open(img_file, 'rb') as f:
img_bytes = bytes(f.read())
if size is not None:
example['height'] = int(size.find('height').text)
example['width'] = int(size.find('width').text)
example['depth'] = int(size.find('depth').text)
else:
img = cv2.imdecode(np.frombuffer(img_bytes, 'uint8'), 3)
example['height'], example['width'], example['depth'] = img.shape
example['content'] = img_bytes
for obj in objects:
bbox = obj.find('bndbox')
is_diff = 0
if obj.find('difficult') is not None:
is_diff = int(obj.find('difficult').text) == 1
example['object'].append({
'name': obj.find('name').text.strip(),
'xmin': float(bbox.find('xmin').text),
'ymin': float(bbox.find('ymin').text),
'xmax': float(bbox.find('xmax').text),
'ymax': float(bbox.find('ymax').text),
'difficult': is_diff})
return example
def write_dataset(args):
"""Write the record dataset."""
assert len(args.splits) == len(args.images)
assert len(args.splits) == len(args.annotations)
if os.path.exists(args.rec):
raise ValueError('The record path is already exist.')
os.makedirs(args.rec)
print('Write record dataset to {}'.format(args.rec))
writer = codewithgpu.RecordWriter(
path=args.rec,
features={
'id': 'string',
'content': 'bytes',
'height': 'int64',
'width': 'int64',
'depth': 'int64',
'object': [{
'name': 'string',
'xmin': 'float64',
'ymin': 'float64',
'xmax': 'float64',
'ymax': 'float64',
'difficult': 'int64',
}]
}
)
# Scan all available entries.
print('Scan entries...')
entries = []
for i, split in enumerate(args.splits):
with open(split, 'r') as f:
lines = f.readlines()
for line in lines:
filename = line.strip()
img_file = os.path.join(args.images[i], filename + '.jpg')
ann_file = os.path.join(args.annotations[i], filename + '.xml')
entries.append((img_file, ann_file))
# Parse and write into record file.
print('Start Time:', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()
for i, (img_file, xml_file) in enumerate(entries):
if i > 0 and i % 2000 == 0:
now_time = time.time()
print('{} / {} in {:.2f} sec'.format(
i, len(entries), now_time - start_time))
writer.write(make_example(img_file, xml_file))
now_time = time.time()
print('{} / {} in {:.2f} sec'.format(
len(entries), len(entries), now_time - start_time))
writer.close()
end_time = time.time()
data_size = os.path.getsize(args.rec + '/00000.data') * 1e-6
print('{} images take {:.2f} MB in {:.2f} sec.'
.format(len(entries), data_size, end_time - start_time))
def write_json_dataset(args):
"""Write the json dataset."""
categories = ['aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
import subprocess
scirpt = os.path.dirname(os.path.abspath(__file__)) + '/json_dataset.py'
cmd = '{} {} '.format(sys.executable, scirpt)
cmd += '--rec {} --gt {} '.format(args.rec, args.gt)
cmd += '--categories {} '.format(' '.join(categories))
return subprocess.call(cmd, shell=True)
if __name__ == '__main__':
args = parse_args()
if args.rec is not None:
write_dataset(args)
if args.gt is not None:
write_json_dataset(args)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""A platform implementing popular object detection algorithms."""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
# Version
from seetadet.version import version as __version__
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Platform configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Variables
from seetadet.core.config.defaults import cfg # noqa
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/yacs/blob/master/yacs/config.py>
#
# ------------------------------------------------------------
"""Yet Another Configuration System (YACS)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
import yaml
class CfgNode(dict):
"""Node for configuration options."""
IMMUTABLE = '__immutable__'
def __init__(self, *args, **kwargs):
super(CfgNode, self).__init__(*args, **kwargs)
self.__dict__[CfgNode.IMMUTABLE] = False
def clone(self):
"""Recursively copy this CfgNode."""
return copy.deepcopy(self)
def freeze(self):
"""Make this CfgNode and all of its children immutable."""
self._immutable(True)
def is_frozen(self):
"""Return mutability."""
return self.__dict__[CfgNode.IMMUTABLE]
def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it into this CfgNode."""
with open(cfg_filename, 'r') as f:
other_cfg = CfgNode(yaml.safe_load(f))
self.merge_from_other_cfg(other_cfg)
def merge_from_list(self, cfg_list):
"""Merge config (keys, values) in a list into this CfgNode."""
assert len(cfg_list) % 2 == 0
from ast import literal_eval
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
key_list = k.split('.')
d = self
for sub_key in key_list[:-1]:
assert sub_key in d
d = d[sub_key]
sub_key = key_list[-1]
assert sub_key in d
try:
value = literal_eval(v)
except: # noqa
# Handle the case when v is a string literal
value = v
if type(value) != type(d[sub_key]): # noqa
raise TypeError('Type {} does not match original type {}'
.format(type(value), type(d[sub_key])))
d[sub_key] = value
def merge_from_other_cfg(self, other_cfg):
"""Merge ``other_cfg`` into this CfgNode."""
_merge_a_into_b(other_cfg, self)
def _immutable(self, is_immutable):
"""Set immutability recursively to all nested CfgNode."""
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
for v in self.__dict__.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
for v in self.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __repr__(self):
return "{}({})".format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def __setattr__(self, name, value):
if not self.__dict__[CfgNode.IMMUTABLE]:
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
else:
raise AttributeError(
'Attempted to set "{}" to "{}", but CfgNode is immutable'
.format(name, value))
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a."""
if not isinstance(a, dict):
return
for k, v in a.items():
# a must specify keys that are in b
if k not in b:
raise KeyError('{} is not a valid config key'.format(k))
# The types must match, too
v = _check_and_coerce_cfg_value_type(v, b[k], k)
# Recursively merge dicts
if type(v) is CfgNode:
try:
_merge_a_into_b(a[k], b[k])
except: # noqa
print('Error under config key: {}'.format(k))
raise
else:
b[k] = v
def _check_and_coerce_cfg_value_type(value_a, value_b, key):
"""Check if the value type matched."""
type_a, type_b = type(value_a), type(value_b)
if type_a is type_b:
return value_a
if type_b is float and type_a is int:
return float(value_a)
# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, np.ndarray):
value_a = np.array(value_a, dtype=value_b.dtype)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
elif isinstance(value_a, dict) and isinstance(value_b, CfgNode):
value_a = CfgNode(value_a)
else:
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(type_b, type_a, value_b, value_a, key))
return value_a
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Experiment coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import time
import numpy as np
from seetadet.core.config import cfg
from seetadet.utils import logging
class Coordinator(object):
"""Manage the unique experiments."""
def __init__(self, cfg_file, exp_dir=None):
cfg.merge_from_file(cfg_file)
if exp_dir is None:
name = time.strftime('%Y%m%d_%H%M%S',
time.localtime(time.time()))
exp_dir = '../experiments/{}'.format(name)
if not osp.exists(exp_dir):
os.makedirs(exp_dir)
else:
if not osp.exists(exp_dir):
raise ValueError('Invalid experiment dir: ' + exp_dir)
self.exp_dir = exp_dir
def path_at(self, file, auto_create=True):
try:
path = osp.abspath(osp.join(self.exp_dir, file))
if auto_create and not osp.exists(path):
os.makedirs(path)
except OSError:
path = osp.abspath(osp.join('/tmp', file))
if auto_create and not osp.exists(path):
os.makedirs(path)
return path
def get_checkpoint(self, step=None, last_idx=1, wait=False):
path = self.path_at('checkpoints')
def locate(last_idx=None):
files = os.listdir(path)
files = list(filter(lambda x: '_iter_' in x and
x.endswith('.pkl'), files))
file_steps = []
for i, file in enumerate(files):
file_step = int(file.split('_iter_')[-1].split('.')[0])
if step == file_step:
return osp.join(path, files[i]), file_step
file_steps.append(file_step)
if step is None:
if len(files) == 0:
return None, 0
if last_idx > len(files):
return None, 0
file = files[np.argsort(file_steps)[-last_idx]]
file_step = file_steps[np.argsort(file_steps)[-last_idx]]
return osp.join(path, file), file_step
return None, 0
file, file_step = locate(last_idx)
while file is None and wait:
logging.info('Wait for checkpoint at {}.'.format(step))
time.sleep(10)
file, file_step = locate(last_idx)
return file, file_step
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for training library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
from seetadet.core.config import cfg
from seetadet.core.engine import lr_scheduler
def build_optimizer(params, **kwargs):
"""Build the optimizer."""
args = {'lr': cfg.SOLVER.BASE_LR,
'weight_decay': cfg.SOLVER.WEIGHT_DECAY,
'clip_norm': cfg.SOLVER.CLIP_NORM,
'grad_scale': 1.0 / cfg.SOLVER.LOSS_SCALE}
optimizer = kwargs.pop('optimizer', cfg.SOLVER.OPTIMIZER)
if optimizer == 'SGD':
args['momentum'] = cfg.SOLVER.MOMENTUM
args.update(kwargs)
return getattr(torch.optim, optimizer)(params, **args)
def build_lr_scheduler(**kwargs):
"""Build the LR scheduler."""
args = {'lr_max': cfg.SOLVER.BASE_LR,
'lr_min': cfg.SOLVER.MIN_LR,
'warmup_steps': cfg.SOLVER.WARM_UP_STEPS,
'warmup_factor': cfg.SOLVER.WARM_UP_FACTOR}
policy = kwargs.pop('policy', cfg.SOLVER.LR_POLICY)
args.update(kwargs)
if policy == 'steps_with_decay':
return lr_scheduler.MultiStepLR(
decay_steps=cfg.SOLVER.DECAY_STEPS,
decay_gamma=cfg.SOLVER.DECAY_GAMMA, **args)
elif policy == 'linear_decay':
return lr_scheduler.LinearLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
elif policy == 'cosine_decay':
return lr_scheduler.CosineLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
return lr_scheduler.ConstantLR(**args)
def build_tensorboard(log_dir):
"""Build the tensorboard."""
try:
from dragon.utils.tensorboard import tf
from dragon.utils.tensorboard import TensorBoard
# Avoid using of GPUs by TF API.
if tf is not None:
tf.config.set_visible_devices([], 'GPU')
return TensorBoard(log_dir)
except ImportError:
return None
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Learning rate schedulers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
class ConstantLR(object):
"""Constant LR scheduler."""
def __init__(self, **kwargs):
self._lr_max = kwargs.pop('lr_max')
self._lr_min = kwargs.pop('lr_min', 0)
self._warmup_steps = kwargs.pop('warmup_steps', 0)
self._warmup_factor = kwargs.pop('warmup_factor', 0)
if kwargs:
raise ValueError('Unexpected arguments: ' + ','.join(v for v in kwargs))
self._step_count = 0
self._last_decay = 1.
def step(self):
self._step_count += 1
def get_lr(self):
if self._step_count < self._warmup_steps:
alpha = (self._step_count + 1.) / self._warmup_steps
return self._lr_max * (alpha + (1. - alpha) * self._warmup_factor)
return self._lr_min + (self._lr_max - self._lr_min) * self.get_decay()
def get_decay(self):
return self._last_decay
class CosineLR(ConstantLR):
"""LR scheduler with cosine decay."""
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
super(CosineLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
self._decay_step = decay_step
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t > 0 and t % self._decay_step == 0:
self._last_decay = .5 * (1. + math.cos(math.pi * t / t_max))
return self._last_decay
class MultiStepLR(ConstantLR):
"""LR scheduler with multi-steps decay."""
def __init__(self, lr_max, decay_steps, decay_gamma, **kwargs):
super(MultiStepLR, self).__init__(lr_max=lr_max, **kwargs)
self._decay_steps = decay_steps
self._decay_gamma = decay_gamma
self._stage_count = 0
self._num_stages = len(decay_steps)
def get_decay(self):
if self._stage_count < self._num_stages:
k = self._decay_steps[self._stage_count]
while self._step_count >= k:
self._stage_count += 1
if self._stage_count >= self._num_stages:
break
k = self._decay_steps[self._stage_count]
self._last_decay = self._decay_gamma ** self._stage_count
return self._last_decay
class LinearLR(ConstantLR):
"""LR scheduler with linear decay."""
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
super(LinearLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
self._decay_step = decay_step
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t > 0 and t % self._decay_step == 0:
self._last_decay = 1. - float(t) / t_max
return self._last_decay
if __name__ == '__main__':
def extract_label(scheduler):
class_name = scheduler.__class__.__name__
label = class_name + '('
if class_name == 'CosineLR':
label += 'α=' + str(scheduler._decay_step)
elif class_name == 'LinearCosineLR':
label += 'α=' + str(scheduler._decay_step)
elif class_name == 'MultiStepLR':
label += 'α=' + str(scheduler._decay_steps) + ', '
label += 'γ=' + str(scheduler._decay_gamma)
elif class_name == 'StepLR':
label += 'α=' + str(scheduler._decay_step) + ', '
label += 'γ=' + str(scheduler._decay_gamma)
label += ')'
return label
vis = True
max_steps = 120
shared_args = {
'lr_max': 0.0004,
'warmup_steps': 0,
'warmup_factor': 0.,
}
schedulers = [
# CosineLR(lr_min=0., decay_step=1, max_steps=max_steps, **shared_args),
CosineLR(lr_min=1e-6, decay_step=1, max_steps=140, **shared_args),
]
for i in range(max_steps):
info = 'Step = %d\n' % i
for scheduler in schedulers:
if i == 0:
scheduler.lr_seq = []
info += ' * {}: {}\n'.format(
extract_label(scheduler),
scheduler.get_lr())
scheduler.lr_seq.append(scheduler.get_lr())
scheduler.step()
if not vis:
print(info)
if vis:
import matplotlib.pyplot as plt
plt.figure(1)
plt.title('Visualization of different LR Schedulers')
plt.xlabel('Step')
plt.ylabel('Learning Rate')
line = '-'
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for i, scheduler in enumerate(schedulers):
plt.plot(
range(max_steps),
scheduler.lr_seq,
colors[i] + line,
linewidth=1.,
label=extract_label(scheduler),
)
plt.legend()
plt.grid(linestyle='--')
plt.show()
plt.savefig('x.png')
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Testing engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import datetime
import multiprocessing as mp
import codewithgpu
from dragon.vm import torch
import numpy as np
from seetadet.core.config import cfg
from seetadet.data.build import build_evaluator
from seetadet.models.build import build_detector
from seetadet.modules.build import build_inference
from seetadet.utils import logging
from seetadet.utils import profiler
from seetadet.utils import vis
class InferenceCommand(codewithgpu.InferenceCommand):
"""Command to run inference."""
def __init__(self, input_queue, output_queue, kwargs):
super(InferenceCommand, self).__init__(input_queue, output_queue)
self.kwargs = kwargs
def build_env(self):
"""Build the environment."""
cfg.merge_from_other_cfg(self.kwargs['cfg'])
cfg.GPU_ID = self.kwargs['device']
cfg.freeze()
logging.set_root(self.kwargs.get('verbose', True))
self.batch_size = cfg.TEST.IMS_PER_BATCH
self.batch_timeout = self.kwargs.get('batch_timeout', None)
if self.kwargs.get('deterministic', False):
torch.backends.cudnn.deterministic = True
def build_model(self):
"""Build and return the model."""
return build_detector(self.kwargs['device'], self.kwargs['weights'])
def build_module(self, model):
"""Build and return the inference module."""
return build_inference(model)
def send_results(self, module, indices, imgs):
"""Send the batch results."""
results = module.get_results(imgs)
time_diffs = module.get_time_diffs()
time_diffs['im_detect'] += time_diffs.pop('im_detect_mask', 0.)
for i, outputs in enumerate(results):
outputs['im_shape'] = imgs[i].shape
self.output_queue.put((indices[i], time_diffs, outputs))
def filter_outputs(outputs, max_dets=100):
"""Limit the max number of detections."""
if max_dets <= 0:
return outputs
boxes = outputs.pop('boxes')
masks = outputs.pop('masks', None)
scores, num_classes = [], len(boxes)
for i in range(num_classes):
if len(boxes[i]) > 0:
scores.append(boxes[i][:, -1])
scores = np.hstack(scores) if len(scores) > 0 else []
if len(scores) > max_dets:
thr = np.sort(scores)[-max_dets]
for i in range(num_classes):
if len(boxes[i]) < 1:
continue
keep = np.where(boxes[i][:, -1] >= thr)[0]
boxes[i] = boxes[i][keep]
if masks is not None:
masks[i] = masks[i][keep]
outputs['boxes'] = boxes
outputs['masks'] = masks
return outputs
def extend_results(index, collection, results):
"""Add image results to the collection."""
if results is None:
return
for _ in range(len(results) - len(collection)):
collection.append([])
for i in range(1, len(results)):
for _ in range(index - len(collection[i]) + 1):
collection[i].append([])
collection[i][index] = results[i]
def run_test(
test_cfg,
weights,
output_dir,
devices,
deterministic=False,
read_every=100,
vis_thresh=0,
vis_output_dir=None,
):
"""Run a model testing.
Parameters
----------
test_cfg : CfgNode
The cfg for testing.
weights : str
The path of model weights to load.
output_dir : str
The path to save results.
devices : Sequence[int]
The index of computing devices.
deterministic : bool, optional, default=False
Set cudnn deterministic or not.
read_every : int, optional, default=100
Read every N images to distribute to devices.
vis_thresh : float, optional, default=0
The score threshold for visualization.
vis_output_dir : str, optional
The path to save visualizations.
"""
cfg.merge_from_other_cfg(test_cfg)
evaluator = build_evaluator(output_dir)
devices = devices if devices else [cfg.GPU_ID]
num_devices = len(devices)
num_images = evaluator.num_images
max_dets = cfg.TEST.DETECTIONS_PER_IM
read_stride = float(num_devices * cfg.TEST.IMS_PER_BATCH)
read_every = int(np.ceil(read_every / read_stride) * read_stride)
visualizer = vis.Visualizer(cfg.MODEL.CLASSES, vis_thresh)
queues = [mp.Queue() for _ in range(num_devices + 1)]
commands = [InferenceCommand(
queues[i], queues[-1], kwargs={
'cfg': test_cfg,
'weights': weights,
'device': devices[i],
'deterministic': deterministic,
'verbose': i == 0,
}) for i in range(num_devices)]
actors = [mp.Process(target=command.run) for command in commands]
for actor in actors:
actor.start()
timers = collections.defaultdict(profiler.Timer)
all_boxes, all_masks, vis_images = [], [], {}
for count in range(1, num_images + 1):
img_id, img = evaluator.get_image()
queues[count % num_devices].put((count - 1, img))
if vis_thresh > 0 and vis_output_dir:
filename = vis_output_dir + '/%s.png' % img_id
vis_images[count - 1] = (filename, img)
if count % read_every > 0 and count < num_images:
continue
if count == num_images:
for i in range(num_devices):
queues[i].put((-1, None))
for _ in range(((count - 1) % read_every + 1)):
index, time_diffs, outputs = queues[-1].get()
outputs = filter_outputs(outputs, max_dets)
extend_results(index, all_boxes, outputs['boxes'])
extend_results(index, all_masks, outputs.get('masks', None))
for name, diff in time_diffs.items():
timers[name].add_diff(diff)
if vis_thresh > 0 and vis_output_dir:
filename, img = vis_images[index]
visualizer.draw_instances(
img=img,
boxes=outputs['boxes'],
masks=outputs.get('masks', None)).save(filename)
del vis_images[index]
avg_time = sum([t.average_time for t in timers.values()])
eta_seconds = avg_time * (num_images - count)
print('\rim_detect: {:d}/{:d} [{:.3f}s + {:.3f}s] (eta: {})'
.format(count, num_images,
timers['im_detect'].average_time,
timers['misc'].average_time,
str(datetime.timedelta(seconds=int(eta_seconds)))),
end='')
print('\nEvaluating detections...')
evaluator.eval_bbox(all_boxes)
if len(all_masks) > 0:
print('Evaluating segmentations...')
evaluator.eval_segm(all_boxes, all_masks)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Training engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import os
from dragon.vm import torch
from seetadet.core.config import cfg
from seetadet.core.engine.build import build_lr_scheduler
from seetadet.core.engine.build import build_optimizer
from seetadet.core.engine.build import build_tensorboard
from seetadet.core.engine.utils import count_params
from seetadet.core.engine.utils import get_device
from seetadet.core.engine.utils import get_param_groups
from seetadet.data.build import build_loader_train
from seetadet.models.build import build_detector
from seetadet.utils import logging
from seetadet.utils import profiler
class Trainer(object):
"""Schedule the iterative model training."""
def __init__(self, coordinator, start_iter=0):
# Build loader.
self.loader = build_loader_train()
# Build model.
self.model = build_detector(training=True)
self.model.load_weights(cfg.TRAIN.WEIGHTS, strict=start_iter > 0)
self.model.to(device=get_device(cfg.GPU_ID))
if cfg.MODEL.PRECISION.lower() == 'float16':
self.model.half()
# Build optimizer.
self.loss_scale = cfg.SOLVER.LOSS_SCALE
param_groups_getter = get_param_groups
if cfg.SOLVER.LAYER_LR_DECAY < 1.0:
lr_scale_getter = functools.partial(
self.model.backbone.get_lr_scale,
decay=cfg.SOLVER.LAYER_LR_DECAY)
param_groups_getter = functools.partial(
param_groups_getter, lr_scale_getter=lr_scale_getter)
self.optimizer = build_optimizer(param_groups_getter(self.model))
self.scheduler = build_lr_scheduler()
# Build monitor.
self.coordinator = coordinator
self.metrics = collections.OrderedDict()
self.board = None
@property
def iter(self):
return self.scheduler._step_count
def snapshot(self):
"""Save the checkpoint of current iterative step."""
f = cfg.SOLVER.SNAPSHOT_PREFIX
f += '_iter_{}.pkl'.format(self.iter)
f = os.path.join(self.coordinator.path_at('checkpoints'), f)
if logging.is_root() and not os.path.exists(f):
torch.save(self.model.state_dict(), f, pickle_protocol=4)
logging.info('Wrote snapshot to: {:s}'.format(f))
def add_metrics(self, stats):
"""Add or update the metrics."""
for k, v in stats['metrics'].items():
if k not in self.metrics:
self.metrics[k] = profiler.SmoothedValue()
self.metrics[k].update(v)
def display_metrics(self, stats):
"""Send metrics to the monitor."""
logging.info('Iteration %d, lr = %.8f, time = %.2fs'
% (stats['iter'], stats['lr'], stats['time']))
for k, v in self.metrics.items():
logging.info(' ' * 4 + 'Train net output({}): {:.4f} ({:.4f})'
.format(k, stats['metrics'][k], v.average()))
if self.board is not None:
self.board.scalar_summary('lr', stats['lr'], stats['iter'])
self.board.scalar_summary('time', stats['time'], stats['iter'])
for k, v in self.metrics.items():
self.board.scalar_summary(k, v.average(), stats['iter'])
def step(self):
stats = {'iter': self.iter}
metrics = collections.defaultdict(float)
# Run forward.
timer = profiler.Timer().tic()
inputs = self.loader()
outputs, losses = self.model(inputs), []
for k, v in outputs.items():
if 'loss' in k:
if isinstance(v, (tuple, list)):
losses.append(sum(v[1:], v[0]).mul_(1. / len(v)))
metrics.update(dict(('stage%d_' % (i + 1) + k, float(x))
for i, x in enumerate(v)))
else:
losses.append(v)
metrics[k] += float(v)
# Run backward.
losses = sum(losses[1:], losses[0])
if self.loss_scale != 1.0:
losses *= self.loss_scale
losses.backward()
# Apply update.
stats['lr'] = self.scheduler.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = stats['lr'] * group.get('lr_scale', 1.0)
self.optimizer.step()
self.scheduler.step()
stats['time'] = timer.toc()
stats['metrics'] = collections.OrderedDict(sorted(metrics.items()))
return stats
def train_model(self, start_iter=0):
"""Network training loop."""
timer = profiler.Timer()
max_steps = cfg.SOLVER.MAX_STEPS
display_every = cfg.SOLVER.DISPLAY
progress_every = 10 * display_every
snapshot_every = cfg.SOLVER.SNAPSHOT_EVERY
self.scheduler._step_count = start_iter
while self.iter < max_steps:
with timer.tic_and_toc():
stats = self.step()
self.add_metrics(stats)
if stats['iter'] % display_every == 0:
self.display_metrics(stats)
if self.iter % progress_every == 0:
logging.info(profiler.get_progress(timer, self.iter, max_steps))
if self.iter % snapshot_every == 0:
self.snapshot()
self.metrics.clear()
def run_train(coordinator, start_iter=0, enable_tensorboard=False):
"""Start a network training task."""
trainer = Trainer(coordinator, start_iter=start_iter)
if enable_tensorboard and logging.is_root():
trainer.board = build_tensorboard(coordinator.path_at('logs'))
logging.info('#Params: %.2fM' % count_params(trainer.model))
logging.info('Start training...')
trainer.train_model(start_iter)
trainer.snapshot()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Engine utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import importlib.machinery
import os
import dragon
from dragon.core.framework import backend
from dragon.vm import torch
def count_params(module):
"""Return the number of parameters in MB."""
return sum([v.size().numel() for v in module.parameters()]) / 1e6
def freeze_module(module):
"""Freeze parameters of given module."""
module.eval()
for param in module.parameters():
param.requires_grad = False
def get_device(index):
"""Create the available device object."""
if torch.cuda.is_available():
return torch.device('cuda', index)
try:
if torch.backends.mps.is_available():
return torch.device('mps', index)
except AttributeError:
pass
return torch.device('cpu')
def get_param_groups(module, lr_scale_getter=None):
"""Separate parameters into groups."""
memo, groups = {}, collections.OrderedDict()
for name, param in module.named_parameters():
if not param.requires_grad:
continue
attrs = collections.OrderedDict()
if lr_scale_getter:
attrs['lr_scale'] = lr_scale_getter(name)
memo[name] = param.shape
no_weight_decay = not (name.endswith('weight') and param.dim() > 1)
no_weight_decay = getattr(param, 'no_weight_decay', no_weight_decay)
if no_weight_decay:
attrs['weight_decay'] = 0
group_name = '/'.join(['%s:%s' % (v[0], v[1]) for v in list(attrs.items())])
if group_name not in groups:
groups[group_name] = {'params': []}
groups[group_name].update(attrs)
groups[group_name]['params'].append(param)
return list(groups.values())
def load_library(library_prefix):
"""Load a shared library."""
loader_details = (importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES)
library_prefix = os.path.abspath(library_prefix)
lib_dir, fullname = os.path.split(library_prefix)
finder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = finder.find_spec(fullname)
if ext_specs is None:
raise ImportError('Could not find the pre-built library '
'for <%s>.' % library_prefix)
backend.load_library(ext_specs.origin)
def synchronize_device(device):
"""Synchronize the computation of device."""
if device.type == 'cuda':
torch.cuda.synchronize(device)
elif device.type == 'mps':
dragon.mps.synchronize(device.index)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Registry class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
class Registry(object):
"""Registry class."""
def __init__(self, name):
self.name = name
self.registry = collections.OrderedDict()
def has(self, key):
return key in self.registry
def register(self, name, func=None, **kwargs):
def decorated(inner_function):
for key in (name if isinstance(
name, (tuple, list)) else [name]):
self.registry[key] = \
functools.partial(inner_function, **kwargs)
return inner_function
if func is not None:
return decorated(func)
return decorated
def get(self, name, default=None):
if name is None:
return None
if not self.has(name):
if default is not None:
return default
raise KeyError("`%s` is not registered in <%s>."
% (name, self.name))
return self.registry[name]
def try_get(self, name):
if self.has(name):
return self.get(name)
return None
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetadet.data import datasets
from seetadet.data import evaluators
from seetadet.data import pipelines
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Anchor generator for RPN head."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class AnchorGenerator(object):
"""Generate anchors for bbox regression."""
def __init__(self, strides, sizes, aspect_ratios,
scales_per_octave=1):
self.strides = strides
self.sizes = _align_args(strides, sizes)
self.aspect_ratios = _align_args(strides, aspect_ratios)
for i in range(len(self.sizes)):
octave_sizes = []
for j in range(1, scales_per_octave):
scale = 2 ** (float(j) / scales_per_octave)
octave_sizes += [x * scale for x in self.sizes[i]]
self.sizes[i] += octave_sizes
self.scales = [[x / y for x in z] for y, z in zip(strides, self.sizes)]
self.cell_anchors = []
for i in range(len(strides)):
self.cell_anchors.append(generate_anchors(
strides[i], self.aspect_ratios[i], self.sizes[i]))
self.grid_shapes = None
self.grid_anchors = None
self.grid_coords = None
def reset_grid(self, max_size):
"""Reset the grid."""
self.grid_shapes = [(int(np.ceil(max_size / x)),) * 2 for x in self.strides]
self.grid_coords = self.get_coords(self.grid_shapes)
self.grid_anchors = self.get_anchors(self.grid_shapes)
def num_cell_anchors(self, index=0):
"""Return number of cell anchors."""
return self.cell_anchors[index].shape[0]
def num_anchors(self, shapes):
"""Return the number of grid anchors."""
return sum(self.cell_anchors[i].shape[0] * np.prod(shapes[i])
for i in range(len(shapes)))
def get_slices(self, shapes):
slices, offset = [], 0
for i, shape in enumerate(shapes):
num = self.cell_anchors[i].shape[0] * np.prod(shape)
slices.append(slice(offset, offset + num))
offset = offset + num
return slices
def get_coords(self, shapes):
"""Return the x-y coordinates of grid anchors."""
xs, ys = [], []
for i in range(len(shapes)):
height, width = shapes[i]
x, y = np.arange(0, width), np.arange(0, height)
x, y = np.meshgrid(x, y)
# Add A anchors (A,) to cell K shifts (K,)
# to get shift coords (A, K)
xs.append(np.tile(x.flatten(), self.cell_anchors[i].shape[0]))
ys.append(np.tile(y.flatten(), self.cell_anchors[i].shape[0]))
return np.concatenate(xs), np.concatenate(ys)
def get_anchors(self, shapes):
"""Return the grid anchors."""
grid_anchors = []
for i in range(len(shapes)):
h, w = shapes[i]
shift_x = np.arange(0, w) * self.strides[i]
shift_y = np.arange(0, h) * self.strides[i]
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
shifts = np.vstack((shift_x.ravel(), shift_y.ravel(),
shift_x.ravel(), shift_y.ravel())).transpose()
shifts = shifts.astype(self.cell_anchors[i].dtype)
# Add A anchors (A, 1, 4) to cell K shifts (1, K, 4)
# to get shift anchors (A, K, 4)
a, k = self.num_cell_anchors(i), shifts.shape[0]
anchors = (self.cell_anchors[i].reshape((a, 1, 4)) +
shifts.reshape((1, k, 4)))
grid_anchors.append(anchors.reshape((a * k, 4)))
return np.vstack(grid_anchors)
def narrow_anchors(self, shapes, inds, return_anchors=False):
"""Return the valid anchors on given shapes."""
max_shapes = self.grid_shapes
anchors = self.grid_anchors
x_coords, y_coords = self.grid_coords
offset1 = offset2 = num1 = num2 = 0
out_inds, out_anchors = [], []
for i in range(len(max_shapes)):
num1 += self.num_cell_anchors(i) * np.prod(max_shapes[i])
num2 += self.num_cell_anchors(i) * np.prod(shapes[i])
inds_keep = inds[np.where((inds >= offset1) & (inds < num1))[0]]
anchors_keep = anchors[inds_keep] if return_anchors else None
x, y = x_coords[inds_keep], y_coords[inds_keep]
z = ((inds_keep - offset1) // max_shapes[i][1]) // max_shapes[i][0]
keep = np.where((x < shapes[i][1]) & (y < shapes[i][0]))[0]
inds_keep = (z * shapes[i][0] + y) * shapes[i][1] + x + offset2
out_inds.append(inds_keep[keep])
out_anchors.append(anchors_keep[keep] if return_anchors else None)
offset1, offset2 = num1, num2
outputs = [np.concatenate(out_inds)]
if return_anchors:
outputs += [np.concatenate(out_anchors)]
return outputs[0] if len(outputs) == 1 else outputs
def generate_anchors(stride=16, ratios=(0.5, 1, 2), sizes=(32, 64, 128, 256, 512)):
"""Generate anchors by enumerating aspect ratios and sizes."""
scales = np.array(sizes) / stride
base_anchor = np.array([-stride / 2., -stride / 2., stride / 2., stride / 2.])
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors.astype('float32')
def _whctrs(anchor):
"""Return the xywh of an anchor."""
w = anchor[2] - anchor[0]
h = anchor[3] - anchor[1]
x_ctr = anchor[0] + 0.5 * w
y_ctr = anchor[1] + 0.5 * h
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""Return a sef of anchors by widths, heights and center."""
ws, hs = ws[:, np.newaxis], hs[:, np.newaxis]
return np.hstack((x_ctr - 0.5 * ws, y_ctr - 0.5 * hs,
x_ctr + 0.5 * ws, y_ctr + 0.5 * hs))
def _ratio_enum(anchor, ratios):
"""Enumerate a set of anchors by aspect ratios."""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = np.sqrt(w * h / ratios)
hs = ws * ratios
return _mkanchors(ws, hs, x_ctr, y_ctr)
def _scale_enum(anchor, scales):
"""Enumerate a set of anchors by scales."""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws, hs = w * scales, h * scales
return _mkanchors(ws, hs, x_ctr, y_ctr)
def _align_args(strides, args):
"""Align the args to the strides."""
args = (args * len(strides)) if len(args) == 1 else args
assert len(args) == len(strides)
return [[x] if not isinstance(x, (tuple, list)) else x[:] for x in args]
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!