Commit 93b7acf1 by Ting PAN

Init repository

0 parents
Showing with 5108 additions and 0 deletions
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
[flake8]
max-line-length = 120
ignore = E741, # ambiguous variable name
F403, # ‘from module import *’ used; unable to detect undefined names
F405, # name may be undefined, or defined from star imports: module
F811, # redefinition of unused name from line N
F821, # undefined name
W503, # line break before binary operator
W504 # line break after binary operator
# module imported but unused
per-file-ignores = __init__.py: F401
# Compiled Object files
*.slo
*.lo
*.o
*.cuo
# Compiled Dynamic libraries
*.so
*.dll
*.dylib
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Compiled python
*.pyc
__pycache__
# Compiled MATLAB
*.mex*
# IPython notebook checkpoints
.ipynb_checkpoints
# Editor temporaries
*.swp
*~
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Eclipse Project settings
*.*project
.settings
# QtCreator files
*.user
# VSCode files
.vscode
# IDEA files
.idea
# OSX dir files
.DS_Store
# Android files
.gradle
*.iml
local.properties
Copyright (c) 2017, SeetaTech, Co.,Ltd. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Benchmark and Model Zoo
## Introduction
### Pretrained Models
Refer to [Pretrained Models](data/pretrained) for details.
## Baselines
### UPerNet
Refer to [UPerNet](configs/upernet) for details.
# SeetaSeg
SeetaSeg is a platform implementing popular segmentation algorithms.
This repository is based on [seeta-dragon](https://github.com/seetaresearch/dragon),
while the style of codes is torch.
<p align="center">
<img width="100%" src="https://dragon.seetatech.com/download/seetaseg/assets/banner.png"/>
</p>
## Installation
### Build From Source
If you prefer to develop modules as well as running experiments,
following commands will build but not install to ***site-packages***:
```bash
cd seetaseg && python setup.py build
```
### Install From Source
Clone this repository to local disk and install:
```bash
cd seetaseg && python setup.py install
```
### Install From Git
You can also install it from remote repository:
```bash
pip install git+https://gitlab.seetatech.com/seetaresearch/seetaseg.git
```
## Quick Start
### Train a segmentation models
```bash
cd tools
python train.py --cfg <MODEL_YAML>
```
We have provided the default YAML examples into [configs](configs).
### Test a segmentations model
```bash
cd tools
python test.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
## Benchmark and Model Zoo
Results and models are available in the [Model Zoo](MODEL_ZOO.md).
## License
[BSD 2-Clause license](LICENSE)
# UperNet: Unified Perceptual Parsing for Scene Understanding
## Introduction
```
@inproceedings{xiao2018unified,
title={Unified perceptual parsing for scene understanding},
author={Xiao, Tete and Liu, Yingcheng and Zhou, Bolei and Jiang, Yuning and Sun, Jian},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
pages={418--434},
year={2018}
}
```
## ADE20K Baselines
| Model | Lr sched | Infer time (fps) | mIoU | Download |
| :---: | :------: | :--------------: | :----: | :-----: |
| [R50-512](ade20k_upernet_R_50_512_160k.yml) | 160k | 25.64 | 42.88 | [model](https://dragon.seetatech.com/download/seetaseg/upernet/ade20k_upernet_R_50_512_160k/model_6077a7d0.pkl) &#124; [log](https://dragon.seetatech.com/download/seetaseg/upernet/ade20k_upernet_R_50_512_160k/logs.json) |
NUM_GPUS: 8
MODEL:
TYPE: 'upernet'
PRECISION: 'float16'
CLASSES: ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed',
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag']
BACKBONE:
TYPE: 'resnet50_v1c'
NORM: 'SyncBN'
FPN:
DIM: 512
NORM: 'SyncBN'
DECODER:
DIM: 512
NORM: 'SyncBN'
DROPOUT_RATE: 0.1
SOLVER:
BASE_LR: 0.01
MIN_LR: 0.0001
WEIGHT_DECAY: 0.0005
LR_POLICY: 'poly_decay'
MAX_STEPS: 160000 # 128 epochs
SNAPSHOT_EVERY: 5000 # 4 epochs
SNAPSHOT_PREFIX: 'ade20k_upernet_R_50_512'
TRAIN:
WEIGHTS: '../data/pretrained/R-50-C_in1k_cls200e.pkl'
DATASET: '../data/datasets/ade20k_train'
IMS_PER_BATCH: 2
SCALES: [512]
SCALES_RANGE: [0.5, 2.0]
MAX_SIZE: 2048
CROP_SIZE: 512
TEST:
DATASET: '../data/datasets/ade20k_val'
IMS_PER_BATCH: 1
SCALES: [512]
MAX_SIZE: 2048
# Datasets
## Introduction
This folder is kept for the record and json datasets.
Please prepare the datasets following the [documentation](../../scripts/datasets/README.md).
# Demo Images
## Introduction
This folder is kept for the demo images.
# Pretrained Models
## Introduction
This folder is kept for the pretrained models.
## ImageNet Pretrained Models
### ResNet
| Model | Lr sched | Acc@1 | Acc@5 | Source |
| :---: | :------: | :---: | :---: | :----: |
| [R50-C](https://dragon.seetatech.com/download/seetaseg/pretrained/R-50-C_in1k_cls200e.pkl) | 200e | 78.43 | 94.46 | MMSeg |
# Prepare Datasets
## Create Datasets for ADE
We assume that raw dataset has the following structure:
```
ADE
|_ images
| |_ training
| | |_ <im-1-name>.jpg
| | |_ ...
| | |_ <im-N-name>.jpg
|_ annotations
| |_ training
| | |_ <im-1-name>.png
| | |_ ...
| | |_ <im-N-name>.png
```
Create record dataset by:
```
python ade.py \
--rec /path/to/datasets/ade20k_train \
--images /path/to/ADE/images/training \
--annotations /path/to/ADE/annotations/training
```
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Prepare ADE datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import time
import dragon
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Prepare ADE datasets')
parser.add_argument(
'--rec',
default=None,
help='path to write record dataset')
parser.add_argument(
'--images',
nargs='+',
type=str,
default=None,
help='path of images folder')
parser.add_argument(
'--annotations',
nargs='+',
type=str,
default=None,
help='path of annotations folder')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_example(img_file, seg_file):
"""Return the record example."""
filename = os.path.split(img_file)[-1]
example = {'id': filename.split('.')[0], 'label': b''}
with open(img_file, 'rb') as f:
img_bytes = bytes(f.read())
example['data'] = img_bytes
if seg_file:
with open(seg_file, 'rb') as f:
seg_bytes = bytes(f.read())
example['label'] = seg_bytes
return example
def write_dataset(args):
"""Write the record dataset."""
if os.path.exists(args.rec):
raise ValueError('The record path is already exist.')
os.makedirs(args.rec)
print('Write record dataset to {}'.format(args.rec))
writer = dragon.io.KPLRecordWriter(
path=args.rec, protocol={
'id': 'string', 'data': 'bytes',
'label': 'bytes'})
# Scan all available entries.
print('Scan entries...')
entries = []
for i, img_dir in enumerate(args.images):
seg_dir = args.annotations[i] if args.annotations else ''
img_files = os.listdir(img_dir)
img_files.sort()
for img_file in img_files:
seg_file = img_file.replace('.jpg', '.png')
img_file = os.path.join(img_dir, img_file)
seg_file = os.path.join(seg_dir, seg_file)
entries.append((img_file, seg_file if seg_dir else ''))
# Parse and write into record file.
print('Start Time:', time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()
for i, (img_file, seg_file) in enumerate(entries):
if i > 0 and i % 2000 == 0:
now_time = time.time()
print('{} / {} in {:.2f} sec'.format(
i, len(entries), now_time - start_time))
writer.write(make_example(img_file, seg_file))
now_time = time.time()
print('{} / {} in {:.2f} sec'.format(
len(entries), len(entries), now_time - start_time))
writer.close()
if __name__ == '__main__':
args = parse_args()
if args.rec is not None:
write_dataset(args)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""A platform implementing popular segmentation algorithms."""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from seetaseg.version import version as __version__
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Platform configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Variables
from seetaseg.core.config.defaults import cfg # noqa
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Default configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.core.config.yacs import CfgNode
_C = cfg = CfgNode()
# ------------------------------------------------------------
# Training options
# ------------------------------------------------------------
_C.TRAIN = CfgNode()
# Initialize network with weights from this file
_C.TRAIN.WEIGHTS = ''
# The dataset for training
_C.TRAIN.DATASET = ''
# The loader type for training
_C.TRAIN.LOADER = 'seg_train'
# The number of workers to load training data
_C.TRAIN.NUM_WORKERS = 3
# Scales to use during training (can list multiple scales)
# Each scale is the pixel size of an image shortest side
_C.TRAIN.SCALES = (512,)
# Range to jitter the image scales randomly
_C.TRAIN.SCALES_RANGE = (1.0, 1.0)
# The longest side of training images
_C.TRAIN.MAX_SIZE = 2048
# The size to crop training images
_C.TRAIN.CROP_SIZE = 512
# The number of images per mini-batch
_C.TRAIN.IMS_PER_BATCH = 1
# ------------------------------------------------------------
# Testing options
# ------------------------------------------------------------
_C.TEST = CfgNode()
# The dataset for testing
_C.TEST.DATASET = ''
# The loader type for testing
_C.TEST.LOADER = 'seg_test'
# The evaluator type for dataset
_C.TEST.EVALUATOR = 'default'
# Scales to use during testing (can list multiple scales)
# Each scale is the pixel size of an image's shortest side
_C.TEST.SCALES = (512,)
# The longest side of testing images
_C.TEST.MAX_SIZE = 2048
# The size to crop training images
_C.TEST.CROP_SIZE = 0
# The number of images per mini-batch
_C.TEST.IMS_PER_BATCH = 1
# ------------------------------------------------------------
# Model options
# ------------------------------------------------------------
_C.MODEL = CfgNode()
# The model type
_C.MODEL.TYPE = ''
# The float precision for training and inference
# Values supported: 'float16', 'float32'
_C.MODEL.PRECISION = 'float32'
# Pixel mean and stddev values for image normalization (BGR order)
_C.MODEL.PIXEL_MEAN = [103.53, 116.28, 123.675]
_C.MODEL.PIXEL_STD = [57.375, 57.12, 58.395]
# The object class names
_C.MODEL.CLASSES = ['__background__']
# ------------------------------------------------------------
# Backbone options
# ------------------------------------------------------------
_C.BACKBONE = CfgNode()
# The backbone type
_C.BACKBONE.TYPE = ''
# Freeze backbone since the stage K
# The value of ``K`` is usually set to 2
_C.BACKBONE.FREEZE_AT = 0
# The normalization in backbone modules
_C.BACKBONE.NORM = 'BN'
# The drop path rate in backbone
_C.BACKBONE.DROP_PATH_RATE = 0.0
# ------------------------------------------------------------
# FPN options
# ------------------------------------------------------------
_C.FPN = CfgNode()
# The FPN type
_C.FPN.TYPE = 'FPN'
# Channel dimension of the FPN feature levels
_C.FPN.DIM = 512
# The fpn normalization module
_C.FPN.NORM = 'BN'
# The fpn activation module
_C.FPN.ACTIVATION = 'ReLU'
# ------------------------------------------------------------
# Decoder options
# ------------------------------------------------------------
_C.DECODER = CfgNode()
# The dimension of the decoder features
_C.DECODER.DIM = 512
# The normalization in decoder modules
_C.DECODER.NORM = 'BN'
# The activation in decoder modules
_C.DECODER.ACTIVATION = 'ReLU'
# The dropout rate decoder in decoder modules
_C.DECODER.DROPOUT_RATE = 0.0
# ------------------------------------------------------------
# Solver options
# ------------------------------------------------------------
_C.SOLVER = CfgNode()
# The interval to display logs
_C.SOLVER.DISPLAY = 20
# The interval to snapshot a model
_C.SOLVER.SNAPSHOT_EVERY = 5000
# Prefix to yield the path: <prefix>_iter_XYZ.pkl
_C.SOLVER.SNAPSHOT_PREFIX = ''
# The loss scaling factor for mixed precision training
_C.SOLVER.LOSS_SCALE = 1024.0
# The maximum number of training steps
_C.SOLVER.MAX_STEPS = 80000
# The base learning rate
_C.SOLVER.BASE_LR = 0.001
# The minimal learning rate
_C.SOLVER.MIN_LR = 0.0
# The intervals to decay learning rate
_C.SOLVER.DECAY_STEPS = []
# The decay gamma to exponential lr policy
_C.SOLVER.DECAY_GAMMA = 0.1
# The decay power to poly lr policy
_C.SOLVER.DECAY_POWER = 0.9
# Warm up to ``BASE_LR`` over this number of steps
_C.SOLVER.WARM_UP_STEPS = 0
# Start the warm up from ``BASE_LR`` * ``FACTOR``
_C.SOLVER.WARM_UP_FACTOR = 0.1
# The optimizier type
_C.SOLVER.OPTIMIZER = 'SGD'
# The lr decay policy
_C.SOLVER.LR_POLICY = 'linear_decay'
# The decay factor for layer-wise lr scaling
_C.SOLVER.LAYER_LR_DECAY = 1.0
# The momentum to SGD optimizers
_C.SOLVER.MOMENTUM = 0.9
# The L2 regularization for weight parameters
_C.SOLVER.WEIGHT_DECAY = 0.0001
# The L2 norm factor for clipping gradients
_C.SOLVER.CLIP_NORM = 0.0
# ------------------------------------------------------------
# Misc options
# ------------------------------------------------------------
# Number of GPUs for distributed training
_C.NUM_GPUS = 1
# Random seed for reproducibility
_C.RNG_SEED = 3
# Default GPU device index
_C.GPU_ID = 0
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/yacs/blob/master/yacs/config.py>
#
# ------------------------------------------------------------
"""Configuration utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
import yaml
class CfgNode(dict):
"""Node for configuration options."""
IMMUTABLE = '__immutable__'
def __init__(self, *args, **kwargs):
super(CfgNode, self).__init__(*args, **kwargs)
self.__dict__[CfgNode.IMMUTABLE] = False
def clone(self):
"""Recursively copy this CfgNode."""
return copy.deepcopy(self)
def freeze(self):
"""Make this CfgNode and all of its children immutable."""
self._immutable(True)
def is_frozen(self):
"""Return mutability."""
return self.__dict__[CfgNode.IMMUTABLE]
def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it into this CfgNode."""
with open(cfg_filename, 'r') as f:
other_cfg = CfgNode(yaml.safe_load(f))
self.merge_from_other_cfg(other_cfg)
def merge_from_list(self, cfg_list):
"""Merge config (keys, values) in a list into this CfgNode."""
assert len(cfg_list) % 2 == 0
from ast import literal_eval
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
key_list = k.split('.')
d = self
for sub_key in key_list[:-1]:
assert sub_key in d
d = d[sub_key]
sub_key = key_list[-1]
assert sub_key in d
try:
value = literal_eval(v)
except: # noqa
# Handle the case when v is a string literal
value = v
if type(value) != type(d[sub_key]): # noqa
raise TypeError('Type {} does not match original type {}'
.format(type(value), type(d[sub_key])))
d[sub_key] = value
def merge_from_other_cfg(self, other_cfg):
"""Merge ``other_cfg`` into this CfgNode."""
_merge_a_into_b(other_cfg, self)
def _immutable(self, is_immutable):
"""Set immutability recursively to all nested CfgNode."""
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
for v in self.__dict__.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
for v in self.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __repr__(self):
return "{}({})".format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def __setattr__(self, name, value):
if not self.__dict__[CfgNode.IMMUTABLE]:
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
else:
raise AttributeError(
'Attempted to set "{}" to "{}", but CfgNode is immutable'
.format(name, value))
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a."""
if not isinstance(a, dict):
return
for k, v in a.items():
# a must specify keys that are in b
if k not in b:
raise KeyError('{} is not a valid config key'.format(k))
# The types must match, too
v = _check_and_coerce_cfg_value_type(v, b[k], k)
# Recursively merge dicts
if type(v) is CfgNode:
try:
_merge_a_into_b(a[k], b[k])
except: # noqa
print('Error under config key: {}'.format(k))
raise
else:
b[k] = v
def _check_and_coerce_cfg_value_type(value_a, value_b, key):
"""Check if the value type matched."""
type_a, type_b = type(value_a), type(value_b)
if type_a is type_b:
return value_a
if type_b is float and type_a is int:
return float(value_a)
# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, np.ndarray):
value_a = np.array(value_a, dtype=value_b.dtype)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
elif isinstance(value_a, dict) and isinstance(value_b, CfgNode):
value_a = CfgNode(value_a)
else:
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(type_b, type_a, value_b, value_a, key))
return value_a
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Experiment coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import time
import numpy as np
from seetaseg.core.config import cfg
from seetaseg.utils import logging
class Coordinator(object):
"""Manage the unique experiments."""
def __init__(self, cfg_file, exp_dir=None):
cfg.merge_from_file(cfg_file)
if exp_dir is None:
name = time.strftime('%Y%m%d_%H%M%S',
time.localtime(time.time()))
exp_dir = '../experiments/{}'.format(name)
if not osp.exists(exp_dir):
os.makedirs(exp_dir)
else:
if not osp.exists(exp_dir):
raise ValueError('Invalid experiment dir: ' + exp_dir)
self.exp_dir = exp_dir
def path_at(self, name, auto_create=True):
try:
path = osp.abspath(osp.join(self.exp_dir, name))
if auto_create and not osp.exists(path):
os.makedirs(path)
except OSError:
path = osp.abspath(osp.join('/tmp', name))
if auto_create and not osp.exists(path):
os.makedirs(path)
return path
def get_checkpoint(self, step=None, last_idx=1, wait=False):
path = self.path_at('checkpoints')
def locate(last_idx=None):
files = os.listdir(path)
files = list(filter(lambda x: '_iter_' in x and
x.endswith('.pkl'), files))
file_steps = []
for i, file in enumerate(files):
file_step = int(file.split('_iter_')[-1].split('.')[0])
if step == file_step:
return osp.join(path, files[i]), file_step
file_steps.append(file_step)
if step is None:
if len(files) == 0:
return None, 0
if last_idx > len(files):
return None, 0
file = files[np.argsort(file_steps)[-last_idx]]
file_step = file_steps[np.argsort(file_steps)[-last_idx]]
return osp.join(path, file), file_step
return None, 0
file, file_step = locate(last_idx)
while file is None and wait:
logging.info('Wait for checkpoint at {}.'.format(step))
time.sleep(10)
file, file_step = locate(last_idx)
return file, file_step
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Registry class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
class Registry(object):
"""Registry class."""
def __init__(self, name):
self._name = name
self._registry = collections.OrderedDict()
def has(self, key):
return key in self._registry
def register(self, name, func=None, **kwargs):
def decorated(inner_function):
for key in (name if isinstance(
name, (tuple, list)) else [name]):
self._registry[key] = \
functools.partial(inner_function, **kwargs)
return inner_function
if func is not None:
return decorated(func)
return decorated
def get(self, name, default=None):
if name is None:
return None
if not self.has(name):
if default is not None:
return default
raise KeyError("`%s` is not registered in <%s>."
% (name, self._name))
return self._registry[name]
def try_get(self, name):
if self.has(name):
return self.get(name)
return None
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Testing engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import datetime
import multiprocessing as mp
import time
from dragon.vm import torch
import numpy as np
from seetaseg.core.config import cfg
from seetaseg.core.testing import test_server
from seetaseg.models.build import build_segmentor
from seetaseg.modules.build import build_model_inference
from seetaseg.utils import logging
from seetaseg.utils import profiler
from seetaseg.utils import vis
def extend_results(index, collection, results):
"""Add image results to the collection."""
if results is None:
return
for _ in range(index - len(collection) + 1):
collection.append([])
collection[index] = results
def test_segmentor(
test_cfg,
weights,
queues,
device,
deterministic=False,
verbose=True,
batch_timeout=None,
):
"""Test a segmentor.
Parameters
----------
test_cfg : CfgNode
The cfg for testing.
weights : str
The path of model weights to load.
queues : Sequence[multiprocessing.Queue]
The input and output queue.
device : int
The index of computing device.
deterministic : bool, optional, default=False
Set cudnn deterministic or not.
verbose : bool, optional, default=True
Print the infomation or not.
batch_timeout : number, optional
The timeout to wait "IMS_PER_BATCH"
"""
cfg.merge_from_other_cfg(test_cfg)
cfg.GPU_ID = device
cfg.freeze()
logging.set_root(verbose)
if deterministic:
torch.backends.cudnn.deterministic = True
model = build_segmentor(device, weights)
module = build_model_inference(model)
input_queue, output_queue = queues
imgs_per_batch = cfg.TEST.IMS_PER_BATCH
must_stop = False
while not must_stop:
indices, imgs = [], []
deadline, timeout = None, None
for i in range(imgs_per_batch):
if batch_timeout and i == 1:
deadline = time.monotonic() + batch_timeout
if batch_timeout and i >= 1:
timeout = deadline - time.monotonic()
try:
index, img = input_queue.get(timeout=timeout)
if index < 0:
must_stop = True
break
indices.append(index)
imgs.append(img)
except Exception:
pass
if len(imgs) == 0:
continue
results = module.get_results(imgs)
time_diffs = module.get_time_diffs()
for i, outputs in enumerate(results):
output_queue.put((indices[i], time_diffs, outputs))
def run_test(
weights,
output_dir,
devices,
deterministic=False,
read_every=100,
palette=None,
):
"""Run a model testing.
Parameters
----------
weights : str
The path of network weights file.
output_dir : str
The path to save results.
devices : Sequence[int]
The index of computing devices.
read_every : int, optional, default=100
Read every N inputs to distribute to devices.
deterministic : bool, optional, default=False
Set cudnn deterministic or not.
palette : str, optional
The palette for visualization.
"""
server = test_server.EvaluateServer(output_dir)
devices = devices if devices else [cfg.GPU_ID]
num_devices = len(devices)
num_images = server.dataset.num_images
read_stride = float(num_devices * cfg.TEST.IMS_PER_BATCH)
read_every = int(np.ceil(read_every / read_stride) * read_stride)
visualizer = vis.Visualizer(palette=palette)
queues = [mp.Queue() for _ in range(num_devices + 1)]
actors = [mp.Process(
target=test_segmentor,
kwargs={'test_cfg': cfg,
'weights': weights,
'queues': [queues[i], queues[-1]],
'device': devices[i],
'deterministic': deterministic,
'verbose': i == 0}) for i in range(num_devices)]
for actor in actors:
actor.start()
timers = collections.defaultdict(profiler.Timer)
segs, vis_images = [], {}
for count in range(1, num_images + 1):
img_id, img = server.get_image()
queues[count % num_devices].put((count - 1, img))
if palette is not None:
filename = server.get_save_filename(img_id)
vis_images[count - 1] = (filename, img)
if count % read_every > 0 and count < num_images:
continue
if count == num_images:
for i in range(num_devices):
queues[i].put((-1, None))
for _ in range(((count - 1) % read_every + 1)):
index, time_diffs, outputs = queues[-1].get()
extend_results(index, segs, outputs['seg'])
for name, diff in time_diffs.items():
timers[name].add_diff(diff)
if palette is not None:
filename, img = vis_images[index]
visualizer.draw_seg(img, outputs['seg']).save(filename)
del vis_images[index]
avg_time = sum([t.average_time for t in timers.values()])
eta_seconds = avg_time * (num_images - count)
print('\rim_segment: {:d}/{:d} [{:.3f}s + {:.3f}s] (eta: {})'
.format(count, num_images,
timers['im_segment'].average_time,
timers['misc'].average_time,
str(datetime.timedelta(seconds=int(eta_seconds)))),
end='')
print('\nEvaluating segmentations')
server.eval_seg(segs)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Testing servers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import cv2
from seetaseg.core.config import cfg
from seetaseg.data.build import build_dataset
from seetaseg.data.build import build_evaluator
from seetaseg.data.build import build_loader_test
class BaseServer(object):
"""Base server class."""
def __init__(self, output_dir):
self.output_dir = output_dir
self.vis_dir = os.path.join(self.output_dir, 'vis')
def get_image(self):
"""Return the image."""
def get_save_filename(self, image_id, ext='.png'):
if not os.path.exists(self.vis_dir):
os.makedirs(self.vis_dir)
return os.path.join(self.vis_dir, image_id + ext)
class EvaluateServer(BaseServer):
"""Server to evaluate model with ground-truth."""
def __init__(self, output_dir):
super(EvaluateServer, self).__init__(output_dir)
self.loader = build_loader_test()
self.dataset = build_dataset(cfg.TEST.DATASET)
self.evaluator = build_evaluator()
self.next_inputs = []
self.metas = collections.OrderedDict()
def get_image(self):
if len(self.next_inputs) == 0:
inputs = self.loader()
for i, img_meta in enumerate(inputs['img_meta']):
self.next_inputs.append({
'img': inputs['img'][i],
'seg': inputs['seg'][i],
'id': img_meta['id']})
inputs = self.next_inputs.pop(0)
img_id, img = inputs.pop('id'), inputs.pop('img')
self.metas[img_id] = inputs
return img_id, img
def eval_seg(self, segs):
gt_segs = [v['seg'] for v in self.metas.values()]
self.evaluator.eval_seg(segs, gt_segs)
def check_metas(self):
if len(self.metas) != self.dataset.num_images:
raise RuntimeError(
'Mismatched number of metas and images. ({} vs. {}).'
'\nCheck if existing duplicate image ids.'
.format(len(self.metas), self.dataset.num_images))
class InferServer(BaseServer):
"""Server to run model inference."""
def __init__(self, output_dir):
super(InferServer, self).__init__(output_dir)
self.images_dir = cfg.TEST.DATASET
self.images = os.listdir(self.images_dir)
self.classes = cfg.MODEL.CLASSES
self.num_images = len(self.images)
self.num_classes = len(cfg.MODEL.CLASSES)
self.output_dir = output_dir
self.image_idx = 0
def get_image(self):
image_name = self.images[self.image_idx]
image_index = image_name.split('.')[0]
image = cv2.imread(os.path.join(self.images_dir, image_name))
self.image_idx = (self.image_idx + 1) % self.num_images
return image_index, image
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Training library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.core.training.build import build_lr_scheduler
from seetaseg.core.training.build import build_optimizer
from seetaseg.core.training.build import build_tensorboard
from seetaseg.core.training.utils import count_params
from seetaseg.core.training.utils import freeze_module
from seetaseg.core.training.utils import get_param_groups
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for training library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
from seetaseg.core.config import cfg
from seetaseg.core.training import lr_scheduler
def build_optimizer(params, **kwargs):
"""Build the optimizer."""
args = {'lr': cfg.SOLVER.BASE_LR,
'weight_decay': cfg.SOLVER.WEIGHT_DECAY,
'clip_norm': cfg.SOLVER.CLIP_NORM,
'grad_scale': 1.0 / cfg.SOLVER.LOSS_SCALE}
optimizer = kwargs.pop('optimizer', cfg.SOLVER.OPTIMIZER)
if optimizer == 'SGD':
args['momentum'] = cfg.SOLVER.MOMENTUM
args.update(kwargs)
return getattr(torch.optim, optimizer)(params, **args)
def build_lr_scheduler(**kwargs):
"""Build the LR scheduler."""
args = {'lr_max': cfg.SOLVER.BASE_LR,
'lr_min': cfg.SOLVER.MIN_LR,
'warmup_steps': cfg.SOLVER.WARM_UP_STEPS,
'warmup_factor': cfg.SOLVER.WARM_UP_FACTOR}
policy = kwargs.pop('policy', cfg.SOLVER.LR_POLICY)
args.update(kwargs)
if policy == 'cosine_decay':
return lr_scheduler.CosineLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
elif policy == 'steps_with_decay':
return lr_scheduler.MultiStepLR(
decay_steps=cfg.SOLVER.DECAY_STEPS,
decay_gamma=cfg.SOLVER.DECAY_GAMMA, **args)
elif policy == 'linear_decay':
return lr_scheduler.LinearLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
elif policy == 'poly_decay':
return lr_scheduler.PolyLR(
decay_power=cfg.SOLVER.DECAY_POWER,
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
return lr_scheduler.ConstantLR(**args)
def build_tensorboard(log_dir):
"""Build the tensorboard."""
try:
from dragon.utils.tensorboard import tf
from dragon.utils.tensorboard import TensorBoard
# Avoid using of GPUs by TF API.
if tf is not None:
tf.config.set_visible_devices([], 'GPU')
return TensorBoard(log_dir)
except ImportError:
return None
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""LearningRate schedulers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
class ConstantLR(object):
"""Constant LR scheduler."""
def __init__(self, **kwargs):
self._lr_max = kwargs.pop('lr_max')
self._lr_min = kwargs.pop('lr_min', 0)
self._warmup_steps = kwargs.pop('warmup_steps', 0)
self._warmup_factor = kwargs.pop('warmup_factor', 0)
if kwargs:
raise ValueError('Unexpected arguments: ' + ','.join(v for v in kwargs))
self._step_count = 0
self._last_decay = 1.
def step(self):
self._step_count += 1
def get_lr(self):
if self._step_count < self._warmup_steps:
alpha = (self._step_count + 1.) / self._warmup_steps
return self._lr_max * (alpha + (1. - alpha) * self._warmup_factor)
return self._lr_min + (self._lr_max - self._lr_min) * self.get_decay()
def get_decay(self):
return self._last_decay
class CosineLR(ConstantLR):
"""LR scheduler with cosine decay."""
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
super(CosineLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
self._decay_step = decay_step
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t > 0 and t % self._decay_step == 0:
self._last_decay = .5 * (1. + math.cos(math.pi * t / t_max))
return self._last_decay
class MultiStepLR(ConstantLR):
"""LR scheduler with multi-steps decay."""
def __init__(self, lr_max, decay_steps, decay_gamma, **kwargs):
super(MultiStepLR, self).__init__(lr_max=lr_max, **kwargs)
self._decay_steps = decay_steps
self._decay_gamma = decay_gamma
self._stage_count = 0
self._num_stages = len(decay_steps)
def get_decay(self):
if self._stage_count < self._num_stages:
k = self._decay_steps[self._stage_count]
while self._step_count >= k:
self._stage_count += 1
if self._stage_count >= self._num_stages:
break
k = self._decay_steps[self._stage_count]
self._last_decay = self._decay_gamma ** self._stage_count
return self._last_decay
class PolyLR(ConstantLR):
"""LR scheduler with poly decay."""
def __init__(self, lr_max, max_steps, lr_min=0,
decay_step=1, decay_power=0.9, **kwargs):
super(PolyLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
self._decay_step = decay_step
self._decay_power = decay_power
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t > 0 and t % self._decay_step == 0:
self._last_decay = (1. - float(t) / t_max) ** self._decay_power
return self._last_decay
class LinearLR(ConstantLR):
"""LR scheduler with linear decay."""
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
super(LinearLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
self._decay_step = decay_step
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t > 0 and t % self._decay_step == 0:
self._last_decay = 1. - float(t) / t_max
return self._last_decay
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Training engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import os
from dragon.vm import torch
from seetaseg.core.config import cfg
from seetaseg.core.training.build import build_lr_scheduler
from seetaseg.core.training.build import build_optimizer
from seetaseg.core.training.build import build_tensorboard
from seetaseg.core.training.utils import count_params
from seetaseg.core.training.utils import get_param_groups
from seetaseg.data.build import build_loader_train
from seetaseg.models.build import build_segmentor
from seetaseg.utils import logging
from seetaseg.utils import profiler
class Trainer(object):
"""Schedule the iterative model training."""
def __init__(self, coordinator):
# Build loader.
self.loader = build_loader_train()
# Build model.
self.model = build_segmentor(training=True)
self.model.load_weights(cfg.TRAIN.WEIGHTS)
self.model.cuda(cfg.GPU_ID)
if cfg.MODEL.PRECISION.lower() == 'float16':
self.model.half()
# Build optimizer.
self.loss_scale = cfg.SOLVER.LOSS_SCALE
param_groups_getter = get_param_groups
if cfg.SOLVER.LAYER_LR_DECAY < 1.0:
lr_scale_getter = functools.partial(
self.model.backbone.get_lr_scale,
decay=cfg.SOLVER.LAYER_LR_DECAY)
param_groups_getter = functools.partial(
param_groups_getter, lr_scale_getter=lr_scale_getter)
self.optimizer = build_optimizer(param_groups_getter(self.model))
self.scheduler = build_lr_scheduler()
# Build monitor.
self.coordinator = coordinator
self.metrics = collections.OrderedDict()
self.board = None
@property
def iter(self):
return self.scheduler._step_count
def snapshot(self):
"""Save the checkpoint of current iterative step."""
f = cfg.SOLVER.SNAPSHOT_PREFIX
f += '_iter_{}.pkl'.format(self.iter)
f = os.path.join(self.coordinator.path_at('checkpoints'), f)
if logging.is_root() and not os.path.exists(f):
torch.save(self.model.state_dict(), f, pickle_protocol=4)
logging.info('Wrote snapshot to: {:s}'.format(f))
def add_metrics(self, stats):
"""Add or update the metrics."""
for k, v in stats['metrics'].items():
if k not in self.metrics:
self.metrics[k] = profiler.SmoothedValue()
self.metrics[k].update(v)
def display_metrics(self, stats):
"""Show metrics."""
logging.info('Iteration %d, lr = %.8f, time = %.2fs'
% (stats['iter'], stats['lr'], stats['time']))
for k, v in self.metrics.items():
logging.info(' ' * 4 + 'Train net output({}): {:.4f} ({:.4f})'
.format(k, stats['metrics'][k], v.average()))
if self.board is not None:
self.board.scalar_summary('lr', stats['lr'], stats['iter'])
self.board.scalar_summary('time', stats['time'], stats['iter'])
for k, v in self.metrics.items():
self.board.scalar_summary(k, v.average(), stats['iter'])
def step(self):
stats = {'iter': self.iter}
metrics = collections.defaultdict(float)
# Run forward.
timer = profiler.Timer().tic()
inputs = self.loader()
outputs, losses = self.model(inputs), []
for k, v in outputs.items():
if 'loss' in k:
losses.append(v)
metrics[k] += float(v)
elif 'acc' in k:
metrics[k] += float(v)
# Run backward.
losses = sum(losses[1:], losses[0])
if self.loss_scale != 1.0:
losses *= self.loss_scale
losses.backward()
# Apply update.
stats['lr'] = self.scheduler.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = stats['lr'] * group.get('lr_scale', 1.0)
self.optimizer.step()
self.scheduler.step()
stats['time'] = timer.toc()
stats['metrics'] = collections.OrderedDict(sorted(metrics.items()))
return stats
def train_model(self, start_iter=0):
"""Model training loop."""
timer = profiler.Timer()
max_steps = cfg.SOLVER.MAX_STEPS
display_every = cfg.SOLVER.DISPLAY
progress_every = 10 * display_every
snapshot_every = cfg.SOLVER.SNAPSHOT_EVERY
self.scheduler._step_count = start_iter
while self.iter < max_steps:
with timer.tic_and_toc():
stats = self.step()
self.add_metrics(stats)
if stats['iter'] % display_every == 0:
self.display_metrics(stats)
if self.iter % progress_every == 0:
logging.info(profiler.get_progress(timer, self.iter, max_steps))
if self.iter % snapshot_every == 0:
self.snapshot()
self.metrics.clear()
def run_train(coordinator, start_iter=0, enable_tensorboard=False):
"""Start a model training task."""
trainer = Trainer(coordinator)
if enable_tensorboard and logging.is_root():
trainer.board = build_tensorboard(coordinator.path_at('logs'))
logging.info('#Params: %.2fM' % count_params(trainer.model))
logging.info('Start training...')
trainer.train_model(start_iter)
trainer.snapshot()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Training utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
def count_params(module):
"""Return the number of parameters in MB."""
return sum([v.size().numel() for v in module.parameters()]) / 1e6
def freeze_module(module):
"""Freeze parameters of given module."""
module.eval()
for param in module.parameters():
param.requires_grad = False
def get_param_groups(module, lr_scale_getter=None):
"""Separate parameters into groups."""
groups = collections.OrderedDict()
for name, param in module.named_parameters():
if not param.requires_grad:
continue
attrs = collections.OrderedDict()
if lr_scale_getter:
attrs['lr_scale'] = lr_scale_getter(name)
no_weight_decay = not (name.endswith('weight') and param.dim() > 1)
no_weight_decay = getattr(param, 'no_weight_decay', no_weight_decay)
if no_weight_decay:
attrs['weight_decay'] = 0
group_name = '/'.join(['%s:%s' % (v[0], v[1]) for v in list(attrs.items())])
if group_name not in groups:
groups[group_name] = {'params': []}
groups[group_name].update(attrs)
groups[group_name]['params'].append(param)
return list(groups.values())
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.data import datasets
from seetaseg.data import evaluators
from seetaseg.data import pipelines
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for data library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.core.config import cfg
from seetaseg.core.registry import Registry
LOADERS = Registry('loaders')
DATASETS = Registry('datasets')
EVALUATORS = Registry('evaluators')
def build_dataset(path):
"""Build the dataset."""
keys = path.split('://')
if len(keys) >= 2:
return DATASETS.get(keys[0])(keys[1])
return DATASETS.get('kpl')(path)
def build_loader_train(**kwargs):
"""Build the train loader."""
args = {'dataset': cfg.TRAIN.DATASET,
'batch_size': cfg.TRAIN.IMS_PER_BATCH,
'num_workers': cfg.TRAIN.NUM_WORKERS,
'shuffle': True, 'contiguous': True}
args.update(kwargs)
return LOADERS.get(cfg.TRAIN.LOADER)(**args)
def build_loader_test(**kwargs):
"""Build the test loader."""
args = {'dataset': cfg.TEST.DATASET,
'batch_size': cfg.TEST.IMS_PER_BATCH,
'shuffle': False, 'contiguous': False}
args.update(kwargs)
return LOADERS.get(cfg.TEST.LOADER)(**args)
def build_evaluator(**kwargs):
evaluator_type = cfg.TEST.EVALUATOR
if not evaluator_type:
return None
args = {'classes': cfg.MODEL.CLASSES}
args.update(kwargs)
return EVALUATORS.get(evaluator_type)(**args)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.data.datasets import kpl_dataset
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Base dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.core.config import cfg
class Dataset(object):
"""Base dataset class."""
def __init__(self, source):
self.source = source
self.num_images = 0
self.classes = cfg.MODEL.CLASSES
self.num_classes = len(self.classes)
self.class_to_ind = dict(zip(self.classes, range(self.num_classes)))
@property
def type(self):
return type(self)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""KPLRecord dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dragon
from seetaseg.data.build import DATASETS
from seetaseg.data.datasets.dataset import Dataset
@DATASETS.register('kpl')
class KPLRecordDataset(Dataset):
"""KPLRecordDataset."""
def __init__(self, source):
super(KPLRecordDataset, self).__init__(source)
self.num_images = self.type(self.source).size
@property
def type(self):
return dragon.io.KPLRecordDataset
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Evaluators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.data.evaluators import evaluator
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Base evaluator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import prettytable
from seetaseg.data.build import EVALUATORS
from seetaseg.utils.seg.metrics import seg_overlaps
@EVALUATORS.register('default')
class Evaluator(object):
"""Evaluator using COCO json dataset format."""
def __init__(self, classes):
self.classes = classes
self.num_classes = len(self.classes)
self.reduce_zero_label = self.classes[0] != '__background__'
def eval_seg(self, segs, gt_segs):
cnts = seg_overlaps(segs, gt_segs, self.num_classes,
reduce_zero_label=self.reduce_zero_label)
metrics = {'IoU': cnts[0].astype('float64') / cnts[1].astype('float64'),
'Acc': cnts[0].astype('float64') / cnts[3].astype('float64')}
self.print_eval_results(metrics)
def print_eval_results(self, metrics):
"""Print the evaluation results."""
class_table = prettytable.PrettyTable()
summary_table = prettytable.PrettyTable()
for k, v in metrics.items():
metrics[k] = np.nan_to_num(v, nan=0)
class_table.add_column(k, np.round(metrics[k] * 100, 2))
summary_table.add_column('m' + k, [np.round(np.mean(metrics[k]) * 100, 2)])
class_table.add_column('Class', self.classes)
print('Per class results:\n' + class_table.get_string(), '\n')
print('Summary:\n' + summary_table.get_string())
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import multiprocessing as mp
import time
import threading
import queue
import dragon
from seetaseg.data.build import build_dataset
from seetaseg.utils import logging
from seetaseg.utils.blob import blob_vstack
class BalancedQueues(object):
"""Balanced queues."""
def __init__(self, base_queue, num=1):
self.queues = [base_queue]
self.queues += [mp.Queue(base_queue._maxsize) for _ in range(num - 1)]
self.index = 0
def put(self, obj, block=True, timeout=None):
q = self.queues[self.index]
q.put(obj, block=block, timeout=timeout)
self.index = (self.index + 1) % len(self.queues)
def get(self, block=True, timeout=None):
q = self.queues[self.index]
obj = q.get(block=block, timeout=timeout)
self.index = (self.index + 1) % len(self.queues)
return obj
def get_n(self, num=1):
outputs = []
while len(outputs) < num:
obj = self.get()
if obj is not None:
outputs.append(obj)
return outputs
class DataLoaderBase(threading.Thread):
"""Base class of data loader."""
def __init__(self, worker, **kwargs):
super(DataLoaderBase, self).__init__(daemon=True)
self.batch_size = kwargs.get('batch_size', 2)
self.num_readers = kwargs.get('num_readers', 1)
self.num_workers = kwargs.get('num_workers', 2)
self.queue_depth = kwargs.get('queue_depth', 2)
# Initialize distributed group.
rank, group_size = 0, 1
dist_group = dragon.distributed.get_group()
if dist_group is not None:
group_size = dist_group.size
rank = dragon.distributed.get_rank(dist_group)
# Build queues.
self.reader_queue = mp.Queue(self.queue_depth * self.batch_size)
self.worker_queue = mp.Queue(self.queue_depth * self.batch_size)
self.batch_queue = queue.Queue(self.queue_depth)
self.reader_queue = BalancedQueues(self.reader_queue, self.num_workers)
self.worker_queue = BalancedQueues(self.worker_queue, self.num_workers)
# Build readers.
self.readers = []
for i in range(self.num_readers):
part_idx, num_parts = i, self.num_readers
num_parts *= group_size
part_idx += rank * self.num_readers
self.readers.append(dragon.io.DataReader(**kwargs))
self.readers[i]._part_idx = part_idx
self.readers[i]._num_parts = num_parts
self.readers[i]._seed += part_idx
self.readers[i]._reader_queue = self.reader_queue
self.readers[i].start()
time.sleep(0.1)
# Build workers.
self.workers = []
for i in range(self.num_workers):
p = worker(**kwargs)
p.seed += (i + rank * self.num_workers)
p.reader_queue = self.reader_queue.queues[i]
p.worker_queue = self.worker_queue.queues[i]
p.start()
self.workers.append(p)
time.sleep(0.1)
# Register cleanup callbacks.
def cleanup():
def terminate(processes):
for p in processes:
p.terminate()
p.join()
terminate(self.workers)
terminate(self.readers)
import atexit
atexit.register(cleanup)
# Start batch prefetching.
self.start()
def next(self):
"""Return the next batch of data."""
return self.__next__()
def run(self):
"""Main loop."""
def __call__(self):
return self.next()
def __iter__(self):
"""Return the iterator self."""
return self
def __next__(self):
"""Return the next batch of data."""
return self.batch_queue.get()
class DataLoader(DataLoaderBase):
"""Loader to return the batch of data."""
def __init__(self, dataset, worker, **kwargs):
dataset = build_dataset(dataset)
self.contiguous = kwargs.get('contiguous', True)
self.prefetch_count = kwargs.get('prefetch_count', 50)
args = {'dataset': dataset.type,
'source': dataset.source,
'classes': dataset.classes,
'shuffle': kwargs.get('shuffle', True),
'batch_size': kwargs.get('batch_size', 1),
'num_workers': kwargs.get('num_workers', 1),
'stick_to_part': kwargs.get('stick_to_part', True)}
super(DataLoader, self).__init__(worker, **args)
def run(self):
"""Main loop."""
logging.info('Prefetch batches...')
prev_inputs = self.worker_queue.get_n(
self.prefetch_count * self.batch_size)
next_inputs = []
while True:
# Use cached buffer for next N inputs.
if len(next_inputs) == 0:
next_inputs = prev_inputs
prev_inputs = []
# Collect the next batch.
outputs = collections.defaultdict(list)
for _ in range(self.batch_size):
inputs = next_inputs.pop(0)
for k, v in inputs.items():
outputs[k].extend(v)
prev_inputs += self.worker_queue.get_n(1)
# Stack batch data.
if self.contiguous:
outputs['img'] = blob_vstack(outputs['img'])
outputs['seg'] = blob_vstack(outputs['seg'])
# Send batch data to consumer.
self.batch_queue.put(outputs)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Data loading pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import cv2
import numpy as np
from seetaseg.core.config import cfg
from seetaseg.data import transforms
from seetaseg.data.build import LOADERS
from seetaseg.data.loader import DataLoader
from seetaseg.utils.image import im_decode
class WorkerBase(multiprocessing.Process):
"""Base class of data worker."""
def __init__(self):
super(WorkerBase, self).__init__(daemon=True)
self.seed = cfg.RNG_SEED
self.reader_queue = None
self.worker_queue = None
def get_outputs(self, inputs):
"""Return the processed outputs."""
return inputs
def run(self):
"""Main prefetch loop."""
# Disable the opencv threading.
cv2.setNumThreads(1)
# Fix the process-local random seed.
np.random.seed(self.seed)
while True:
inputs = self.reader_queue.get()
outputs = self.get_outputs(inputs)
self.worker_queue.put(outputs)
class SegTrainWorker(WorkerBase):
"""Worker that defines a generic train pipeline."""
def __init__(self, **kwargs):
super(SegTrainWorker, self).__init__()
self.classes = kwargs.get('classes', None)
self.reduce_zero_label = self.classes[0] != '__background__'
self.resize = transforms.RandomResize(
scales=cfg.TRAIN.SCALES,
scales_range=cfg.TRAIN.SCALES_RANGE,
max_size=cfg.TRAIN.MAX_SIZE)
self.flip = transforms.RandomFlip()
self.crop = transforms.RandomCrop(cfg.TRAIN.CROP_SIZE)
self.distort = transforms.ColorJitter()
self.pad = transforms.Pad(cfg.TRAIN.CROP_SIZE)
def get_outputs(self, inputs):
img = im_decode(inputs['data'], mode='color')
seg = im_decode(inputs['label'], mode='gray')
if self.reduce_zero_label:
seg[seg == 0] = 255
seg -= 1
seg[seg == 254] = 255
img, seg = self.resize(img, seg)
img, seg = self.crop(img, seg)
img, seg = self.flip(img, seg)
img = self.distort(img)
img, seg = self.pad(img, seg)
seg = seg[None, :, :] # [1, H, W]
outputs = {'img': [img], 'seg': [seg]}
return outputs
class SegTestWorker(WorkerBase):
"""Worker that defines a generic test pipeline."""
def __init__(self, **kwargs):
super(SegTestWorker, self).__init__()
def get_outputs(self, inputs):
img = im_decode(inputs['data'], mode='color')
outputs = {'img': [img], 'img_meta': [{'id': inputs['id']}]}
if 'label' in inputs:
seg = im_decode(inputs['label'], mode='gray')
outputs['seg'] = [seg]
return outputs
LOADERS.register('seg_train', DataLoader, worker=SegTrainWorker)
LOADERS.register('seg_test', DataLoader, worker=SegTestWorker)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import numpy.random as npr
from seetaseg.core.config import cfg
from seetaseg.utils.image import im_pad
from seetaseg.utils.image import im_resize
from seetaseg.utils.image import color_jitter
class ColorJitter(object):
"""Distort the brightness, contrast and color of image."""
def __init__(self):
self.brightness_range = (0.875, 1.125)
self.contrast_range = (0.5, 1.5)
self.saturation_range = (0.5, 1.5)
def __call__(self, img):
brightness = contrast = saturation = None
if npr.rand() < 0.5:
brightness = self.brightness_range
if npr.rand() < 0.5:
contrast = self.contrast_range
if npr.rand() < 0.5:
saturation = self.saturation_range
return color_jitter(img, brightness=brightness,
contrast=contrast, saturation=saturation)
class RandomFlip(object):
"""Flip the image randomly."""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, img, seg=None):
if npr.rand() < self.prob:
img = img[:, ::-1]
if seg is not None:
seg = seg[:, ::-1]
return img, seg
class RandomResize(object):
"""Resize the image randomly."""
def __init__(self, scales=(512,), scales_range=(1.0, 1.0), max_size=2048):
self.scales = scales
self.scales_range = scales_range
self.max_size = max_size
self.im_scale = 1.0
self.im_scale_factor = 1.0
def __call__(self, img, seg=None):
im_shape = img.shape
target_size = npr.choice(self.scales)
max_size = self.max_size if self.max_size > 0 else target_size
# Scale along the shortest side.
im_size_min = np.min(im_shape[:2])
im_size_max = np.max(im_shape[:2])
self.im_scale = float(target_size) / float(im_size_min)
# Prevent the biggest axis from being more than *MAX_SIZE*.
if np.round(self.im_scale * im_size_max) > max_size:
self.im_scale = float(max_size) / float(im_size_max)
# Apply the scale jitter to get a range of dynamic scales.
r = self.scales_range
self.im_scale_factor = r[0] + npr.rand() * (r[1] - r[0])
self.im_scale *= self.im_scale_factor
img = im_resize(img, scale=self.im_scale, mode='linear')
if seg is not None:
seg = im_resize(seg, scale=self.im_scale, mode='nearest')
return img, seg
class RandomCrop(object):
"""Crop the image randomly."""
def __init__(self, crop_size=512, cat_max_ratio=0.75):
self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio
def get_crop_bbox(self, img):
h, w = img.shape[:2]
out_h, out_w = (self.crop_size,) * 2
y1 = npr.randint(max(h - out_h, 0) + 1)
x1 = npr.randint(max(w - out_w, 0) + 1)
return x1, y1, x1 + out_w, y1 + out_h
def __call__(self, img, seg):
if self.crop_size > 0:
x1, y1, x2, y2 = self.get_crop_bbox(img)
crop_seg = seg[y1:y2, x1:x2]
if self.cat_max_ratio < 1.:
for _ in range(10):
labels, cnt = np.unique(crop_seg, return_counts=True)
cnt = cnt[labels != 255]
if len(cnt) > 1:
if np.max(cnt) / np.sum(cnt) < self.cat_max_ratio:
break
x1, y1, x2, y2 = self.get_crop_bbox(img)
crop_seg = seg[y1:y2, x1:x2]
img, seg = img[y1:y2, x1:x2], crop_seg
return img, seg
class Pad(object):
"""Pad the image randomly."""
def __init__(self, pad_size=512):
self.pad_size = pad_size
self.pixel_mean = cfg.MODEL.PIXEL_MEAN
def __call__(self, img, seg=None):
img = im_pad(img, self.pad_size, value=self.pixel_mean)
if seg is not None:
seg = im_pad(seg, self.pad_size, value=255)
return img, seg
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.models import backbones
from seetaseg.models import necks
from seetaseg.models import segmentors
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Backbones."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Modules
from seetaseg.models.backbones import resnet
from seetaseg.models.backbones import vit
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""ResNet backbone."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.core.training.utils import freeze_module
from seetaseg.ops.normalization import get_norm
from seetaseg.models.build import BACKBONES
class BasicBlock(nn.Module):
"""The basic resnet block."""
expansion = 1
def __init__(self, dim_in, dim, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(dim_in, dim, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = get_norm(cfg.BACKBONE.NORM, dim)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False)
self.bn2 = get_norm(cfg.BACKBONE.NORM, dim)
self.downsample = downsample
def forward(self, x):
shortcut = x
x = self.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
if self.downsample is not None:
shortcut = self.downsample(shortcut)
return self.relu(x.add_(shortcut))
class Bottleneck(nn.Module):
"""The bottleneck resnet block."""
expansion = 4
groups, width_per_group = 1, 64
def __init__(self, dim_in, dim, stride=1, downsample=None):
super(Bottleneck, self).__init__()
width = int(dim * (self.width_per_group / 64.)) * self.groups
self.conv1 = nn.Conv2d(dim_in, width, kernel_size=1, bias=False)
self.bn1 = get_norm(cfg.BACKBONE.NORM, width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = get_norm(cfg.BACKBONE.NORM, width)
self.conv3 = nn.Conv2d(width, dim * self.expansion,
kernel_size=1, bias=False)
self.bn3 = get_norm(cfg.BACKBONE.NORM, dim * self.expansion)
self.relu = nn.ReLU(True)
self.downsample = downsample
def forward(self, x):
shortcut = x
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
if self.downsample is not None:
shortcut = self.downsample(shortcut)
return self.relu(x.add_(shortcut))
class ResNet(nn.Module):
"""ResNet."""
def __init__(self, block, depths, deep_stem=False):
super(ResNet, self).__init__()
dim_in, dims, blocks = 64, [64, 128, 256, 512], []
self.out_indices = [v - 1 for v in itertools.accumulate(depths)]
self.out_dims = [v * block.expansion for v in dims]
if deep_stem:
self.conv1 = nn.Sequential(
nn.Conv2d(3, dim_in // 2, 3, 2, padding=1, bias=False),
get_norm(cfg.BACKBONE.NORM, dim_in // 2), nn.ReLU(True),
nn.Conv2d(dim_in // 2, dim_in // 2, 3, padding=1, bias=False),
get_norm(cfg.BACKBONE.NORM, dim_in // 2), nn.ReLU(True),
nn.Conv2d(dim_in // 2, dim_in, 3, padding=1, bias=False))
else:
self.conv1 = nn.Conv2d(3, dim_in, 7, 2, padding=3, bias=False)
self.bn1 = get_norm(cfg.BACKBONE.NORM, dim_in)
self.relu = nn.ReLU(True)
self.maxpool = nn.MaxPool2d(3, 2, padding=1)
# Blocks.
for i, depth, dim in zip(range(4), depths, dims):
stride = 1 if i == 0 else 2
downsample = None
if stride != 1 or dim_in != dim * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(dim_in, dim * block.expansion, 1, stride, bias=False),
get_norm(cfg.BACKBONE.NORM, dim * block.expansion))
blocks.append(block(dim_in, dim, stride, downsample))
dim_in = dim * block.expansion
for _ in range(depth - 1):
blocks.append(block(dim_in, dim))
setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:]))
self.blocks = blocks
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
num_freeze_stages = cfg.BACKBONE.FREEZE_AT
if num_freeze_stages > 0:
self.conv1.apply(freeze_module)
self.bn1.apply(freeze_module)
for i in range(num_freeze_stages - 1, 0, -1):
getattr(self, 'layer%d' % i).apply(freeze_module)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self.out_indices:
outputs.append(x)
return outputs
class ResNetV1c(ResNet):
"""ResNet with deep 3x3 stem instead of 7x7."""
def __init__(self, block, depths):
super(ResNetV1c, self).__init__(block, depths, deep_stem=True)
BACKBONES.register('resnet18', ResNet, block=BasicBlock, depths=[2, 2, 2, 2])
BACKBONES.register('resnet34', ResNet, block=BasicBlock, depths=[3, 4, 6, 3])
BACKBONES.register('resnet50', ResNet, block=Bottleneck, depths=[3, 4, 6, 3])
BACKBONES.register('resnet101', ResNet, block=Bottleneck, depths=[3, 4, 23, 3])
BACKBONES.register('resnet50_v1c', ResNetV1c, block=Bottleneck, depths=[3, 4, 6, 3])
BACKBONES.register('resnet101_v1c', ResNetV1c, block=Bottleneck, depths=[3, 4, 23, 3])
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""ViT backbone."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
from dragon.vm import torch
from dragon.vm.torch import nn
import numpy as np
from seetaseg.core.config import cfg
from seetaseg.models.build import BACKBONES
from seetaseg.ops.normalization import get_norm
from seetaseg.ops.random import batch_permutation
def space_to_depth(input, block_size):
"""Rearrange blocks of spatial data into depth."""
if input.dim() == 3:
hXw, c = input.size()[1:]
h = w = int(hXw ** 0.5)
else:
h, w, c = input.size()[1:]
h1, w1 = h // block_size, w // block_size
c1 = (block_size ** 2) * c
input.reshape_((-1, h1, block_size, w1, block_size, c))
out = input.permute(0, 1, 3, 2, 4, 5)
input.reshape_((-1, h, w, c))
return out.reshape_((-1, h1, w1, c1))
def depth_to_space(input, block_size):
"""Rearrange blocks of depth data into spatial."""
h1, w1, c1 = input.size()[1:]
h, w = h1 * block_size, w1 * block_size
c = c1 // (block_size ** 2)
input.reshape_((-1, h1, w1, block_size, block_size, c))
out = input.permute(0, 1, 3, 2, 4, 5)
input.reshape_((-1, h1, w1, c1))
return out.reshape_((-1, h, w, c))
class MLP(nn.Module):
"""Two layers MLP."""
def __init__(self, dim, mlp_ratio=4):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
self.activation = nn.GELU()
def forward(self, x):
return self.fc2(self.activation(self.fc1(x)))
class Attention(nn.Module):
"""Multihead attention."""
def __init__(self, dim, num_heads, qkv_bias=True):
super(Attention, self).__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
qkv = self.qkv(x).reshape_(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0, copy=False)
attn = q @ k.transpose(-2, -1).mul_(self.scale)
attn = nn.functional.softmax(attn, dim=-1, inplace=True)
return self.proj((attn @ v).transpose(1, 2).flatten_(2))
class Block(nn.Module):
"""Transformer block."""
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, drop_path=0):
super(Block, self).__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
self.drop_path = nn.DropPath(p=drop_path, inplace=True)
def forward(self, x):
x = self.drop_path(self.attn(self.norm1(x))).add_(x)
return self.drop_path(self.mlp(self.norm2(x))).add_(x)
class Bottleneck(nn.Module):
"""The bottleneck block."""
def __init__(self, dim, expansion=4):
super(Bottleneck, self).__init__()
width = dim // expansion
self.conv1 = nn.Conv2d(dim, width, 1, bias=False)
self.bn1 = get_norm(cfg.BACKBONE.NORM, width)
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
self.bn2 = get_norm(cfg.BACKBONE.NORM, width)
self.conv3 = nn.Conv2d(width, dim, 1, bias=False)
self.bn3 = get_norm(cfg.BACKBONE.NORM, dim)
self.relu = nn.ReLU(True)
def forward(self, x):
shortcut = x
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
return self.bn3(self.conv3(x)).add_(shortcut)
class PatchEmbed(nn.Module):
"""Patch embedding layer."""
def __init__(self, dim=768, patch_size=16):
super(PatchEmbed, self).__init__()
self.proj = nn.Conv2d(3, dim, patch_size, patch_size)
def forward(self, x):
return self.proj(x)
class PosEmbed(nn.Module):
"""Position embedding layer."""
def __init__(self, dim, num_patches):
super(PosEmbed, self).__init__()
self.dim = dim
self.num_patches = num_patches
self.weight = nn.Parameter(torch.zeros(num_patches, dim))
self.weight.no_weight_decay = True
nn.init.normal_(self.weight, std=0.02)
def _load_from_state_dict(
self,
state_dict,
prefix,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict[prefix + 'weight']
num_patches, dim = weight.shape
if num_patches != self.num_patches:
h = w = int(num_patches ** 0.5)
new_h = new_w = int(self.num_patches ** 0.5)
if not isinstance(weight, torch.Tensor):
weight = torch.from_numpy(weight)
weight = weight.reshape_(1, h, w, dim).permute(0, 3, 1, 2)
weight = nn.functional.interpolate(
weight, size=(new_h, new_w), mode='bilinear')
weight = weight.flatten_(2).transpose(1, 2).squeeze_(0)
state_dict[prefix + 'weight'] = weight
super(PosEmbed, self)._load_from_state_dict(
state_dict,
prefix,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
return x.add_(self.weight)
class RelPosEmbed(nn.Module):
"""Relative position embedding layer."""
def __init__(self, num_heads, window_size):
super(RelPosEmbed, self).__init__()
num_pos = (2 * window_size - 1) ** 2 + 3
grid = np.arange(window_size)
pos = np.stack(np.meshgrid(grid, grid, indexing='ij'))
pos = pos.reshape((2, -1))
pos = pos[:, :, None] - pos[:, None, :]
pos += window_size - 1
pos[0] *= 2 * window_size - 1
index = np.zeros(((window_size ** 2) + 1,) * 2, 'int64')
index[0, 0], index[1:, 1:] = num_pos - 1, pos.sum(0)
index[:, 0], index[0, :] = num_pos - 2, num_pos - 3
self.register_buffer('index', torch.from_numpy(index))
self.weight = nn.Parameter(torch.zeros(num_heads, num_pos))
def forward(self, x):
return self.weight[:, self.index].expand(x.size(0), -1, -1, -1)
class FPN(nn.Module):
"""Feature Pyramid Network."""
def __init__(self, dim, patch_size):
super(FPN, self).__init__()
self.output_conv = nn.ModuleList()
patch_strides = int(math.log2(patch_size))
for i in range(4):
if i + 2 < patch_strides:
stride, layers = 2 ** (patch_strides - i - 2), []
while stride > 1:
layers += [nn.ConvTranspose2d(dim, dim, 2, 2)]
if stride > 2:
layers += [get_norm(cfg.BACKBONE.NORM, dim), nn.GELU()]
stride /= 2
self.output_conv.append(nn.Sequential(*layers))
elif i + 2 == patch_strides:
self.output_conv += [nn.Identity()]
elif i + 2 > patch_strides:
stride = 2 ** (i + 2 - patch_strides)
self.output_conv += [nn.MaxPool2d(stride, stride)]
def forward(self, inputs):
return [conv(x) for conv, x in zip(self.output_conv, inputs)]
class VisionTransformer(nn.Module):
"""Vision Transformer."""
def __init__(self, depths, dims, num_heads, patch_size,
out_indices=(3, 5, 7, 11)):
super(VisionTransformer, self).__init__()
drop_path = cfg.BACKBONE.DROP_PATH_RATE
drop_path = (torch.linspace(
0, drop_path, sum(depths), dtype=torch.float32).tolist()
if drop_path > 0 else [drop_path] * sum(depths))
self.img_size = cfg.TRAIN.CROP_SIZE
self.num_patches = (self.img_size // patch_size) ** 2
self.patch_embed = PatchEmbed(dims[0], patch_size)
self.pos_embed = PosEmbed(dims[0], self.num_patches)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dims[0]))
self.blocks = nn.ModuleList([Block(
dim=dims[0], num_heads=num_heads[0],
mlp_ratio=4, qkv_bias=True, drop_path=drop_path[i])
for i in range(depths[0])])
self.out_indices = out_indices
self.out_dims = (dims[0],) * len(out_indices)
self.fpn = FPN(dims[0], patch_size)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.cls_token, std=.02)
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten_(2).transpose(1, 2)
x = self.pos_embed(x)
cls_tokens = self.cls_token.expand(x.size(0), 1, -1)
x = torch.cat((cls_tokens, x), dim=1)
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self.out_indices:
outputs.append(x)
for i, x in enumerate(outputs):
x = x[:, 1:, :]
hXw, dim = x.shape[1:]
outputs[i] = x.transpose(1, 2).reshape_(
-1, dim, int(hXw ** 0.5), int(hXw ** 0.5))
return self.fpn(outputs)
def get_lr_scale(self, name, decay):
values = list(decay ** (len(self.blocks) + 1 - i)
for i in range(len(self.blocks) + 2))
if name.startswith('backbone.pos_embed') or 'cls_token' in name:
return values[0]
elif name.startswith('backbone.patch_embed'):
return values[0]
elif name.startswith('backbone.blocks'):
return values[int(name.split('.')[2]) + 1]
return values[-1]
class VisionTransformerV2(VisionTransformer):
"""Vision Transformer with W-MSA."""
def __init__(self, depths, dims, num_heads, patch_size, window_size):
super(VisionTransformerV2, self).__init__(
depths, dims, num_heads, patch_size)
self.window_size = window_size or self.img_size // patch_size
self.mask_ratio = 0.15
self.cross_conv = nn.ModuleList([Bottleneck(dims[0]) for _ in range(4)])
self.cross_indices = list(range(depths[0] // 4 - 1, depths[0], depths[0] // 4))
for m in self.cross_conv.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
self.register_buffer('mask_token', torch.zeros(1, 1, dims[0]))
self.buffers = collections.defaultdict(lambda: torch.empty(1))
def random_masking(self, x):
batch_size, num_patches, dim = x.size()
num_encodes = int(num_patches * (1. - self.mask_ratio))
patch_index = self.buffers['patch_index']
batch_permutation(batch_size, num_patches, num_encodes,
patch_index, device=x.device)
patch_index = patch_index.unsqueeze_(-1)
vis_indices = patch_index.expand(-1, -1, dim)
return x.gather(1, vis_indices), vis_indices
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten_(2).transpose(1, 2)
x = self.pos_embed(x)
x = space_to_depth(x, self.window_size)
wmsa_shape = (-1,) + x.shape[1:]
msa_shape = (-1, self.window_size ** 2, self.out_dims[0])
x = x.reshape_(msa_shape)
vis_indices = mask_tokens = None
if self.training and self.mask_ratio > 0:
mask_tokens = self.mask_token.expand_as(x)
x, vis_indices = self.random_masking(x)
cls_tokens = self.cls_token.expand(x.size(0), 1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self.cross_indices or i == len(self.blocks) - 1:
if i < len(self.blocks) - 1:
cls_tokens, x = x.split((1, x.size(1) - 1), dim=1)
else:
x = x[:, 1:]
if mask_tokens is not None:
x = mask_tokens.scatter(1, vis_indices, x)
x = depth_to_space(x.reshape_(wmsa_shape), self.window_size)
x = x.permute(0, 3, 1, 2)
if i in self.cross_indices:
x = self.cross_conv[self.cross_indices.index(i)](x)
if i in self.cross_indices and i < len(self.blocks) - 1:
x = x.permute(0, 2, 3, 1)
x = space_to_depth(x, self.window_size).reshape_(msa_shape)
if mask_tokens is not None:
if self.cross_indices.index(i) < 2:
x = x.gather(1, vis_indices)
else:
mask_tokens = vis_indices = None
x = torch.cat((cls_tokens, x), dim=1)
return self.fpn([x, x, x, x])
BACKBONES.register('vit_small_patch16', VisionTransformer,
depths=(12,), dims=(384,), num_heads=(6,),
out_indices=(3, 5, 7, 11), patch_size=16)
BACKBONES.register('vit_base_patch16', VisionTransformer,
depths=(12,), dims=(768,), num_heads=(12,),
out_indices=(3, 5, 7, 11), patch_size=16)
BACKBONES.register('vit_large_patch16', VisionTransformer,
depths=(24,), dims=(1024,), num_heads=(16,),
out_indices=(7, 11, 15, 23), patch_size=16)
BACKBONES.register('vit_base_patch16_windowX', VisionTransformerV2,
depths=(12,), dims=(768,), num_heads=(12,),
patch_size=16, window_size=0)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.core.registry import Registry
from seetaseg.utils.profiler import Timer
BACKBONES = Registry('backbones')
NECKS = Registry('necks')
SEGMENTORS = Registry('segmentors')
def build_backbone():
backbone_types = cfg.BACKBONE.TYPE.split('.')
backbone = BACKBONES.get(backbone_types[0])()
backbone_dims = backbone.out_dims
neck = nn.Identity()
if len(backbone_types) > 1:
neck = NECKS.get(backbone_types[1])(backbone_dims)
else:
neck.out_dims = backbone_dims
return backbone, neck
def build_segmentor(device=None, weights=None, training=False):
"""Create a segmentor instance.
Parameters
----------
device : int, optional
The index of compute device.
weights : str, optional
The path of weight file.
training : bool, optional, default=False
Return a training detector or not.
"""
model = SEGMENTORS.get(cfg.MODEL.TYPE)()
if model is None:
raise ValueError('Unknown segmentor: ' + cfg.MODEL.TYPE)
if weights is not None:
model.load_weights(weights, strict=True)
if device is not None:
model.cuda(device)
if not training:
model.eval()
model.optimize_for_inference()
model.timers = collections.defaultdict(Timer)
return model
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Head for FCN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.ops.conv import ConvNorm2d
from seetaseg.ops.vision import resize
class FCNHead(nn.Module):
"""FCN head."""
def __init__(self, dim_in, dim_out=None):
super(FCNHead, self).__init__()
conv_module = functools.partial(
ConvNorm2d,
norm_type=cfg.DECODER.NORM,
activation_type=cfg.DECODER.ACTIVATION)
self.dim = dim_out or cfg.DECODER.DIM
self.output_conv = conv_module(dim_in, self.dim, 3)
self.cls_conv = nn.Conv2d(self.dim, len(cfg.MODEL.CLASSES), 1)
self.dropout = nn.Dropout(cfg.DECODER.DROPOUT_RATE)
def get_outputs(self, inputs):
outputs = self.output_conv(inputs['features'][2])
return self.cls_conv(self.dropout(outputs))
def get_losses(self, cls_seg, seg):
cls_seg = resize(cls_seg.float(), seg.shape[2:])
cls_seg, seg = cls_seg, seg.long()
cls_loss = nn.functional.cross_entropy(
cls_seg, seg, ignore_index=255, reduction='none').mean()
cls_acc = cls_seg.argmax(1, True).eq(seg).float().mean()
return {'cls_loss': cls_loss, 'cls_acc': cls_acc}
def forward(self, inputs):
cls_seg = self.get_outputs(inputs)
outputs = collections.OrderedDict()
if self.training:
outputs.update(self.get_losses(cls_seg, inputs['seg']))
else:
outputs['cls_seg'] = cls_seg
return outputs
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Head for UperNet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from dragon.vm import torch
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.models.build import NECKS
from seetaseg.ops.conv import ConvNorm2d
from seetaseg.ops.vision import resize
class UPerNetHead(nn.Module):
"""Head for UperNet."""
def __init__(self, in_dims):
super(UPerNetHead, self).__init__()
conv_module = functools.partial(
ConvNorm2d,
norm_type=cfg.DECODER.NORM,
activation_type=cfg.DECODER.ACTIVATION)
self.dim = cfg.DECODER.DIM
self.ppm = NECKS.get('PPM')(in_dims[-1])
self.fpn = NECKS.get(cfg.FPN.TYPE)(in_dims[:-1])
self.cat_conv = conv_module(len(in_dims) * self.dim, self.dim, 3)
self.cls_conv = nn.Conv2d(self.dim, len(cfg.MODEL.CLASSES), 1)
self.dropout = nn.Dropout(cfg.DECODER.DROPOUT_RATE)
def get_outputs(self, inputs):
features = list(inputs['features'])
features[-1] = self.ppm(features[-1])
features = self.fpn(features)
output_size = features[0].shape[2:]
outputs = [features[0]] + [resize(x, output_size) for x in features[1:]]
outputs = self.cat_conv(torch.cat(outputs, dim=1))
return self.cls_conv(self.dropout(outputs))
def get_losses(self, cls_seg, seg):
cls_seg = resize(cls_seg.float(), seg.shape[2:])
cls_seg, seg = cls_seg, seg.long()
cls_loss = nn.functional.cross_entropy(
cls_seg, seg, ignore_index=255, reduction='none').mean()
cls_acc = cls_seg.argmax(1, True).eq(seg).float().mean()
return {'cls_loss': cls_loss, 'cls_acc': cls_acc}
def forward(self, inputs):
cls_seg = self.get_outputs(inputs)
outputs = collections.OrderedDict()
if self.training:
outputs.update(self.get_losses(cls_seg, inputs['seg']))
else:
outputs['cls_seg'] = cls_seg
return outputs
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Necks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.models.necks import fpn
from seetaseg.models.necks import ppm
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""FPN neck."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.models.build import NECKS
from seetaseg.ops.conv import ConvNorm2d
from seetaseg.ops.vision import resize
@NECKS.register(['simfpn', 'SimFPN'])
class SimFPN(nn.Module):
"""Simple Feature Pyramid Network."""
def __init__(self, in_dims):
super(FPN, self).__init__()
conv_module = functools.partial(
ConvNorm2d,
norm_type=cfg.FPN.NORM,
activation_type=cfg.FPN.ACTIVATION)
self.dim = cfg.FPN.DIM
self.lateral_conv = nn.ModuleList()
self.output_conv = nn.ModuleList()
for dim in in_dims:
self.lateral_conv += [conv_module(dim, self.dim, 1)]
self.output_conv += [conv_module(self.dim, self.dim, 3)]
def forward(self, features):
laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)]
laterals += features[len(self.lateral_conv):]
outputs = [conv(x) for conv, x in zip(self.output_conv, laterals)]
return outputs + laterals[len(self.output_conv):]
@NECKS.register(['fpn', 'FPN'])
class FPN(nn.Module):
"""Feature Pyramid Network."""
def __init__(self, in_dims):
super(FPN, self).__init__()
conv_module = functools.partial(
ConvNorm2d,
norm_type=cfg.FPN.NORM,
activation_type=cfg.FPN.ACTIVATION)
self.dim = cfg.FPN.DIM
self.lateral_conv = nn.ModuleList()
self.output_conv = nn.ModuleList()
for dim in in_dims:
self.lateral_conv += [conv_module(dim, self.dim, 1)]
self.output_conv += [conv_module(self.dim, self.dim, 3)]
def forward(self, features):
laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)]
laterals += features[len(self.lateral_conv):]
for i in range(len(features) - 1, 0, -1):
laterals[i - 1] = laterals[i - 1] + \
resize(laterals[i], laterals[i - 1].shape[2:])
outputs = [conv(x) for conv, x in zip(self.output_conv, laterals)]
return outputs + laterals[len(self.output_conv):]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""PPM neck."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from dragon.vm import torch
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.models.build import NECKS
from seetaseg.ops.conv import ConvNorm2d
from seetaseg.ops.vision import resize
@NECKS.register(['ppm', 'PPM'])
class PPM(nn.Module):
"""Pooling Pyramid Module."""
def __init__(self, dim_in, pool_scales=(1, 2, 3, 6)):
super(PPM, self).__init__()
conv_module = functools.partial(
ConvNorm2d, norm_type=cfg.FPN.NORM,
activation_type=cfg.FPN.ACTIVATION)
self.dim = cfg.FPN.DIM
self.pool_scales = pool_scales
self.lateral_conv = nn.ModuleList([conv_module(
dim_in, self.dim, 1) for _ in range(len(pool_scales))])
self.cat_conv = conv_module(dim_in + len(pool_scales) * self.dim,
self.dim, 3)
def forward(self, x):
outputs, output_size = [x], x.shape[2:]
for i, scale in enumerate(self.pool_scales):
lateral = nn.functional.adaptive_avg_pool2d(x, scale)
lateral = self.lateral_conv[i](lateral)
outputs.append(resize(lateral, output_size))
return self.cat_conv(torch.cat(outputs, dim=1))
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Segmentors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Modules.
from seetaseg.models.segmentors import upernet
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Base segmentor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.models.build import build_backbone
from seetaseg.ops.fusion import get_fusion
from seetaseg.ops.normalization import ToTensor
from seetaseg.utils import logging
class Segmentor(nn.Module):
"""Class to build and compute the segmentation."""
def __init__(self):
super(Segmentor, self).__init__()
self.to_tensor = ToTensor()
self.backbone, self.neck = build_backbone()
self.backbone_dims = self.neck.out_dims
def get_inputs(self, inputs):
"""Return the segmentation inputs.
Parameters
----------
inputs : dict
The inputs.
"""
inputs['img'] = self.to_tensor(inputs['img'], normalize=True)
if 'seg' in inputs:
inputs['seg'] = self.to_tensor(inputs['seg'])
return inputs
def get_features(self, inputs):
"""Return the segmentation features.
Parameters
----------
inputs : dict
The inputs.
"""
return self.neck(self.backbone(inputs['img']))
def get_outputs(self, inputs):
"""Return the segmentation outputs.
Parameters
----------
inputs : dict
The inputs.
"""
return inputs
def forward(self, inputs):
"""Define the computation performed at every call.
Parameters
----------
inputs : dict
The inputs.
"""
return self.get_outputs(inputs)
def load_weights(self, weights, strict=False):
"""Load the state dict of this detector.
Parameters
----------
weights : str
The path of the weights file.
"""
self.load_state_dict(torch.load(weights), strict=strict)
def optimize_for_inference(self):
"""Optimize the graph for the inference."""
# Set precision.
precision = cfg.MODEL.PRECISION.lower()
self.half() if precision == 'float16' else self.float()
logging.info('Set precision: ' + precision)
# Fuse modules.
fusion_memo, last_module = set(), None
for module in self.modules():
key, fn = get_fusion(last_module, module)
if fn is not None:
fusion_memo.add(key)
fn(last_module, module)
last_module = module
for key in fusion_memo:
logging.info('Fuse modules: ' + key)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""UperNet segmentor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.models.decode_heads.fcn import FCNHead
from seetaseg.models.decode_heads.upernet import UPerNetHead
from seetaseg.models.segmentors.segmentor import Segmentor
from seetaseg.models.build import SEGMENTORS
@SEGMENTORS.register('upernet')
class UperNet(Segmentor):
"""Unified Perceptual Parsing for Scene Understanding."""
def __init__(self):
super(UperNet, self).__init__()
self.decode_head = UPerNetHead(self.backbone_dims)
self.auxiliary_head = FCNHead(self.backbone_dims[-2], 256)
def get_outputs(self, inputs):
"""Compute segmentation outputs."""
inputs = self.get_inputs(inputs)
inputs['features'] = self.get_features(inputs)
outputs = self.decode_head(inputs)
if self.training:
aux_outputs = self.auxiliary_head(inputs)
for k, v in aux_outputs.items():
v = v.mul_(0.4) if 'loss' in k else v
outputs['aux_' + k] = v
return outputs
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.modules import generic
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import types
from dragon.vm import torch
from seetaseg.core.config import cfg
from seetaseg.core.registry import Registry
from seetaseg.utils.profiler import Timer
MODEL_INFERENCE = Registry('model_inference')
def build_model_inference(model):
"""Build the model inference."""
return MODEL_INFERENCE.get(cfg.MODEL.TYPE)(model)
class ModelInference(object):
"""Model inference module."""
def __init__(self, model):
self.model = model
self.timers = collections.defaultdict(Timer)
@torch.no_grad()
def get_results(self, imgs):
"""Return the detection results."""
def get_time_diffs(self):
"""Return the time differences."""
return dict((k, v.average_time)
for k, v in self.timers.items())
def trace(self, name, func, example_inputs=None):
"""Trace the function and bound to model."""
if not hasattr(self.model, name):
setattr(self.model, name, torch.jit.trace(
func=types.MethodType(func, self.model),
example_inputs=example_inputs))
return getattr(self.model, name)
@staticmethod
def register(model_type, **kwargs):
"""Register a model inference function."""
def decorated(func):
return MODEL_INFERENCE.register(model_type, func, **kwargs)
return decorated
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Generic modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from dragon.vm import torch
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
from seetaseg.modules.build import ModelInference
from seetaseg.ops.vision import resize
from seetaseg.utils.blob import blob_vstack
from seetaseg.utils.image import im_rescale
@ModelInference.register(['upernet'])
class GenericInference(ModelInference):
"""RetinaNet inference module."""
def __init__(self, model):
super(GenericInference, self).__init__(model)
self.forward_model = self.trace(
'forward_eval', lambda self, img:
self.forward({'img': img}))
@torch.no_grad()
def get_results(self, imgs):
"""Return the inference results."""
results = self.forward_seg(imgs)
seg_type = 'byte' if len(cfg.MODEL.CLASSES) <= 255 else 'int'
with self.timers['misc'].tic_and_toc(n=len(imgs)):
segs = [getattr(x.argmax(1), seg_type)()
.numpy().copy() for x in results]
return [{'seg': seg.squeeze(0)} for seg in segs]
@torch.no_grad()
def forward_data(self, imgs):
"""Return the inference data."""
im_batch, im_shapes = [], []
for img in imgs:
scaled_imgs, _ = im_rescale(
img, scales=cfg.TEST.SCALES, max_size=cfg.TEST.MAX_SIZE)
im_batch += scaled_imgs
im_shapes += [x.shape[:2] for x in scaled_imgs]
im_batch = blob_vstack(im_batch, fill_value=cfg.MODEL.PIXEL_MEAN)
return im_batch, im_shapes
@torch.no_grad()
def forward_seg(self, imgs):
"""Run segmentation inference."""
im_batch, im_shapes = self.forward_data(imgs)
self.timers['im_segment'].tic()
im_batch = self.model.to_tensor(im_batch, normalize=True)
crop_size = cfg.TEST.CROP_SIZE
if crop_size > 0:
cls_seg = self.forward_slide(im_batch, crop_size)
else:
cls_seg = self.model.forward({'img': im_batch})['cls_seg']
cls_seg = resize(cls_seg, im_batch.shape[2:])
imgs_per_batch, num_scales = len(imgs), len(cfg.TEST.SCALES)
results = [None for _ in range(imgs_per_batch)]
preds = cls_seg.split(1, dim=0, copy=False)
if not isinstance(preds, (tuple, list)):
preds = [preds]
for i in range(imgs_per_batch * num_scales):
index = i // num_scales
pred, (pred_h, pred_w) = preds[i], im_shapes[i]
if pred.shape[2:4] != (pred_h, pred_w):
pred = pred[:, :, :pred_h, :pred_w]
pred = resize(pred, imgs[index].shape[:2])
if results[index] is None:
results[index] = pred
else:
results[index] += pred
self.timers['im_segment'].toc(n=imgs_per_batch)
return results
@torch.no_grad()
def forward_slide(self, im_batch, crop_size):
"""Run slide inference."""
if not isinstance(crop_size, (tuple, list)):
crop_size = (crop_size, crop_size)
crop_h, crop_w = crop_size
stride_h, stride_w = int(crop_h / 1.5), int(crop_w / 1.5)
batch_size, _, h, w = im_batch.size()
grid_h = max(h - crop_h + stride_h - 1, 0) // stride_h + 1
grid_w = max(w - crop_w + stride_w - 1, 0) // stride_w + 1
cls_seg = None
for i, j in itertools.product(range(grid_h), range(grid_w)):
y1, x1 = i * stride_h, j * stride_w
y2, x2 = min(y1 + crop_h, h), min(x1 + crop_w, w)
y1, x1 = max(y2 - crop_h, 0), max(x2 - crop_w, 0)
outputs = self.forward_model(im_batch[:, :, y1:y2, x1:x2])
cls_seg_ij = resize(outputs['cls_seg'], crop_size)
if cls_seg is None:
cls_seg = cls_seg_ij.new_zeros(batch_size, cls_seg_ij.size(1), h, w)
cls_seg += nn.functional.pad(cls_seg_ij, (x1, w - x2, y1, h - y2))
return cls_seg
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Operators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Convolution ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch import nn
from seetaseg.ops.normalization import get_norm
class ConvNorm2d(nn.Sequential):
"""2d convolution followed by norm and activation."""
def __init__(
self,
dim_in,
dim_out,
kernel_size,
stride=1,
padding=None,
dilation=1,
bias=True,
conv_type='Conv2d',
norm_type='',
activation_type='',
inplace=True,
):
super(ConvNorm2d, self).__init__()
if padding is None:
padding = kernel_size // 2
if conv_type == 'Conv2d':
layers = [nn.Conv2d(dim_in, dim_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias and (not norm_type))]
elif conv_type == 'SepConv2d':
layers = [nn.Conv2d(dim_in, dim_in,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=dim_in,
bias=False),
nn.Conv2d(dim_in, dim_out,
kernel_size=1,
bias=bias and (not norm_type))]
else:
raise ValueError('Unknown conv type: ' + conv_type)
if norm_type:
layers += [get_norm(norm_type, dim_out)]
if activation_type:
layers += [getattr(nn, activation_type)()]
layers[-1].inplace = inplace
for i, layer in enumerate(layers):
self.add_module(str(i), layer)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Operator fusions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
from seetaseg.core.registry import Registry
FUSIONS = Registry('fusions')
@FUSIONS.register([
'Conv2d+BatchNorm2d',
'Conv2d+FrozenBatchNorm2d',
'Conv2d+SyncBatchNorm',
'ConvTranspose2d+BatchNorm2d',
'ConvTranspose2d+FrozenBatchNorm2d',
'ConvTranspose2d+SyncBatchNorm',
'DepthwiseConv2d+BatchNorm2d',
'DepthwiseConv2d+FrozenBatchNorm2d',
'DepthwiseConv2d+SyncBatchNorm'])
def fuse_conv_bn(conv, bn):
"""Fuse Conv and BatchNorm."""
with torch.no_grad():
m = bn.running_mean
if conv.bias is not None:
m.sub_(conv.bias.float())
else:
delattr(conv, 'bias')
bn.forward = lambda x: x
t = bn.weight.div((bn.running_var + bn.eps).sqrt_())
conv._parameters['bias'] = bn.bias.sub(t * m)
t_conv_shape = [1, -1] if conv.transposed else [-1, 1]
t_conv_shape += [1] * len(conv.kernel_size)
if conv.weight.dtype == 'float16' and t.dtype == 'float32':
conv.bias.half_()
weight = conv.weight.float()
weight.mul_(t.reshape_(t_conv_shape)).half_()
conv.weight.copy_(weight)
else:
conv.weight.mul_(t.reshape_(t_conv_shape))
def get_fusion(*modules):
"""Return the fusion between modules."""
key = '+'.join(m.__class__.__name__ for m in modules)
return key, FUSIONS.try_get(key)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Normalization ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from dragon.vm import torch
from dragon.vm.torch import nn
from seetaseg.core.config import cfg
class FrozenBatchNorm2d(nn.Module):
"""BatchNorm2d where statistics or affine parameters are fixed."""
def __init__(self, num_features, eps=1e-5, affine=False, inplace=True):
super(FrozenBatchNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.inplace = inplace and (not affine)
if self.affine:
self.weight = torch.nn.Parameter(torch.ones(num_features))
self.bias = torch.nn.Parameter(torch.zeros(num_features))
else:
self.register_buffer('weight', torch.ones(num_features))
self.register_buffer('bias', torch.zeros(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features) - eps)
def extra_repr(self):
affine_str = '{num_features}, eps={eps}, affine={affine}' \
.format(**self.__dict__)
inplace_str = ', inplace' if self.inplace else ''
return affine_str + inplace_str
def forward(self, input):
return nn.functional.affine(
input, self.weight, self.bias,
dim=1, out=input if self.inplace else None)
def _load_from_state_dict(
self,
state_dict,
prefix,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
# Fuse the running stats into weight and bias.
# Note that this behavior will break the original stats
# into zero means and one stds.
with torch.no_grad():
self.running_var.float_().add_(self.eps).sqrt_()
self.weight.float_().div_(self.running_var)
self.bias.float_().sub_(self.running_mean.float_() * self.weight)
self.running_mean.zero_()
self.running_var.one_().sub_(self.eps)
class ToTensor(nn.Module):
"""Convert input to tensor."""
def __init__(self):
super(ToTensor, self).__init__()
self.device = torch.device('cpu')
self.tensor = torch.ones(1)
self.norm = functools.partial(
nn.functional.channel_norm,
mean=cfg.MODEL.PIXEL_MEAN,
std=cfg.MODEL.PIXEL_STD,
dim=1, dims=(0, 3, 1, 2),
dtype=cfg.MODEL.PRECISION.lower())
def _apply(self, fn):
fn(self.tensor)
def cpu(self):
self.device = torch.device('cpu')
def cuda(self, device=None):
self.device = torch.device('cuda', device)
def forward(self, input, normalize=False):
if input is None:
return input
if not isinstance(input, torch.Tensor):
input = torch.from_numpy(input)
input = input.to(self.tensor.device)
if normalize and not input.is_floating_point():
input = self.norm(input)
return input
# Getters
def get_norm(norm_type, dim_in):
"""Return a normalization module."""
if isinstance(norm_type, str):
if len(norm_type) == 0:
return nn.Identity()
norm_type = {
'BN': nn.BatchNorm2d,
'FrozenBN': FrozenBatchNorm2d,
'SyncBN': nn.SyncBatchNorm,
'GN': lambda dim: nn.GroupNorm(32, dim),
'Affine': lambda dim: FrozenBatchNorm2d(dim, affine=True),
}[norm_type]
return norm_type(dim_in)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Random ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
import numpy as np
def batch_permutation(batch_size, n, m, out=None, device=None):
noise = np.random.rand(batch_size, n)
perm = noise.argsort(1)[:, :m]
if out is not None:
out._impl.FromNumpy(perm, False)
return out.to(device=device)
return torch.from_numpy(perm).to(device=device)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Vision ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch import nn
def resize(input, size, scale=None, mode='bilinear'):
"""Resize the input tensor."""
return nn.functional.interpolate(input, size, scale, mode=mode)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Blob utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def blob_vstack(arrays, fill_value=None, dtype=None, size=None, align=None):
"""Stack arrays in sequence vertically."""
if any(arr is None for arr in arrays):
return None
if fill_value is None:
return np.stack(arrays)
# Compute the max stack shape.
max_shape = np.max(np.stack([arr.shape for arr in arrays]), 0)
if size is not None and min(size) > 0:
max_shape[:len(size)] = size
if align is not None and min(align) > 0:
align_size = np.ceil(max_shape[:len(align)] / align)
max_shape[:len(align)] = align_size.astype('int64') * align
# Fill output with the given value.
output_dtype = dtype or arrays[0].dtype
output_shape = [len(arrays)] + list(max_shape)
output = np.empty(output_shape, output_dtype)
output[:] = fill_value
# Copy arrays.
for i, arr in enumerate(arrays):
copy_slices = (slice(0, d) for d in arr.shape)
output[(i,) + tuple(copy_slices)] = arr
return output
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Image utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import numpy as np
import PIL.Image
import PIL.ImageEnhance
def im_decode(img_bytes, mode='color'):
"""Decode image from raw bytes."""
mode = {'color': cv2.IMREAD_COLOR,
'gray': cv2.IMREAD_GRAYSCALE}[mode]
return cv2.imdecode(np.frombuffer(img_bytes, 'uint8'), mode)
def im_pad(img, size, mode='constant', value=0):
if not isinstance(size, (tuple, list)):
size = (size, size)
h, w = img.shape[:2]
bottom, right = size[1] - h, size[0] - w
mode = {'constant': cv2.BORDER_CONSTANT,
'edge': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_REFLECT_101,
'symmetric': cv2.BORDER_REFLECT}[mode]
return cv2.copyMakeBorder(img, 0, bottom, 0, right, mode, value=value)
def im_resize(img, size=None, scale=None, mode='linear'):
"""Resize image by the scale or size."""
if size is None:
if not isinstance(scale, (tuple, list)):
scale = (scale, scale)
h, w = img.shape[:2]
size = int(w * scale[1] + .5), int(h * scale[0] + .5)
else:
if not isinstance(size, (tuple, list)):
size = (size, size)
mode = {'linear': cv2.INTER_LINEAR,
'nearest': cv2.INTER_NEAREST}[mode]
return cv2.resize(img, size, interpolation=mode)
def im_rescale(img, scales, max_size=0):
"""Rescale image to match the detecting scales."""
im_shape = img.shape
img_list, img_scales = [], []
size_min = np.min(im_shape[:2])
size_max = np.max(im_shape[:2])
for target_size in scales:
im_scale = float(target_size) / float(size_min)
target_size_max = max_size if max_size > 0 else target_size
if np.round(im_scale * size_max) > target_size_max:
im_scale = float(target_size_max) / float(size_max)
img_list.append(im_resize(img, scale=im_scale))
img_scales.append(im_scale)
return img_list, img_scales
def color_jitter(img, brightness=None, contrast=None, saturation=None):
"""Distort the color of image."""
def add_transform(transforms, type, range):
if range is not None:
if not isinstance(range, (tuple, list)):
range = (1. - range, 1. + range)
transforms.append((type, range))
transforms = []
contrast_first = np.random.rand() < 0.5
add_transform(transforms, PIL.ImageEnhance.Brightness, brightness)
if contrast_first:
add_transform(transforms, PIL.ImageEnhance.Contrast, contrast)
add_transform(transforms, PIL.ImageEnhance.Color, saturation)
if not contrast_first:
add_transform(transforms, PIL.ImageEnhance.Contrast, contrast)
for transform, jitter_range in transforms:
if isinstance(img, np.ndarray):
img = PIL.Image.fromarray(img)
img = transform(img)
img = img.enhance(np.random.uniform(*jitter_range))
return np.asarray(img)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Logging utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import logging as _logging
import os
import sys as _sys
import threading
_logger = None
_logger_lock = threading.Lock()
def get_logger():
global _logger
# Use double-checked locking to avoid taking lock unnecessarily.
if _logger:
return _logger
_logger_lock.acquire()
try:
if _logger:
return _logger
logger = _logging.getLogger('seetaseg')
logger.setLevel('INFO')
logger.propagate = False
logger._is_root = True
if True:
# Determine whether we are in an interactive environment.
_interactive = False
try:
# This is only defined in interactive shells.
if _sys.ps1:
_interactive = True
except AttributeError:
# Even now, we may be in an interactive shell with `python -i`.
_interactive = _sys.flags.interactive
# If we are in an interactive environment (like Jupyter), set loglevel
# to INFO and pipe the output to stdout.
if _interactive:
logger.setLevel('INFO')
_logging_target = _sys.stdout
else:
_logging_target = _sys.stderr
# Add the output handler.
_handler = _logging.StreamHandler(_logging_target)
_handler.setFormatter(_logging.Formatter('%(levelname)s %(message)s'))
logger.addHandler(_handler)
_logger = logger
return _logger
finally:
_logger_lock.release()
def _detailed_msg(msg):
file, lineno = inspect.stack()[:3][2][1:3]
return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
def log(level, msg, *args, **kwargs):
get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
def debug(msg, *args, **kwargs):
if is_root():
get_logger().debug(_detailed_msg(msg), *args, **kwargs)
def error(msg, *args, **kwargs):
get_logger().error(_detailed_msg(msg), *args, **kwargs)
assert 0
def fatal(msg, *args, **kwargs):
get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
assert 0
def info(msg, *args, **kwargs):
if is_root():
get_logger().info(_detailed_msg(msg), *args, **kwargs)
def warning(msg, *args, **kwargs):
if is_root():
get_logger().warning(_detailed_msg(msg), *args, **kwargs)
def get_verbosity():
"""Return how much logging output will be produced."""
return get_logger().getEffectiveLevel()
def set_verbosity(v):
"""Set the threshold for what messages will be logged."""
get_logger().setLevel(v)
def set_formatter(fmt=None, datefmt=None):
"""Set the formatter."""
handler = _logging.StreamHandler(_sys.stderr)
handler.setFormatter(_logging.Formatter(fmt, datefmt))
logger = get_logger()
logger.removeHandler(logger.handlers[0])
logger.addHandler(handler)
def set_root(is_root=True):
"""Set logger to the root."""
get_logger()._is_root = is_root
def is_root():
"""Return logger is the root."""
return get_logger()._is_root
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Profiler utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.utils.profiler.stats import SmoothedValue
from seetaseg.utils.profiler.timer import Timer
from seetaseg.utils.profiler.timer import get_progress
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Trackable statistics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
class SmoothedValue(object):
"""Track values and provide smoothed report."""
def __init__(self, window_size=None):
self.deque = collections.deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def update(self, value):
self.deque.append(value)
self.count += 1
self.total += value
def mean(self):
return np.mean(self.deque)
def median(self):
return np.median(self.deque)
def average(self):
return self.total / self.count
class ExponentialMovingAverage(object):
"""Track values and provide EMA report."""
def __init__(self, decay=0.9):
self.value = None
self.decay = decay
self.total = 0.0
self.count = 0
def update(self, value):
if self.value is None:
self.value = value
else:
self.value = (self.decay * self.value +
(1.0 - self.decay) * value)
self.total += value
self.count += 1
def global_average(self):
return self.total / self.count
def running_average(self):
return float(self.value)
def __float__(self):
return self.running_average()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Timing functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import datetime
import time
class Timer(object):
"""Simple timer."""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
def add_diff(self, diff, n=1, average=True):
self.total_time += diff
self.calls += n
self.average_time = self.total_time / self.calls
return self.average_time if average else self.diff
@contextlib.contextmanager
def tic_and_toc(self, n=1, average=True):
try:
yield self.tic()
finally:
self.toc(n, average)
def tic(self):
self.start_time = time.time()
return self
def toc(self, n=1, average=True):
self.diff = time.time() - self.start_time
return self.add_diff(self.diff, n, average)
def get_progress(timer, step, max_steps):
"""Return the progress information."""
eta_seconds = timer.average_time * (max_steps - step)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
progress = (step + 1.) / max_steps
return ('< PROGRESS: {:.2%} | SPEED: {:.3f}s / iter | ETA: {} >'
.format(progress, timer.average_time, eta))
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Segmentation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.utils.seg.metrics import seg_overlap
from seetaseg.utils.seg.metrics import seg_overlaps
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Segmentation metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def seg_overlap(
seg,
gt_seg,
num_classes,
ignore_index=255,
reduce_zero_label=True,
):
"""Compute the overlap of two segmentations."""
seg = seg.astype(gt_seg.dtype)
if ignore_index is not None:
if reduce_zero_label:
gt_seg[gt_seg == 0] = ignore_index
gt_seg -= 1
gt_seg[gt_seg == ignore_index - 1] = ignore_index
mask = (gt_seg != ignore_index)
seg, gt_seg = seg[mask], gt_seg[mask]
inter = seg[seg == gt_seg]
bins, bins_range = num_classes, (0, num_classes - 1)
inter = np.histogram(inter, bins, bins_range)[0]
seg = np.histogram(seg, bins, bins_range)[0]
gt_seg = np.histogram(gt_seg, bins, bins_range)[0]
union = seg + gt_seg - inter
return inter, union, seg, gt_seg
def seg_overlaps(
segs,
gt_segs,
num_classes,
ignore_index=255,
reduce_zero_label=True,
):
"""Compute the overlap two segmentation sequences."""
output_cnts = seg_overlap(
segs[0], gt_segs[0], num_classes,
ignore_index, reduce_zero_label)
for i in range(1, len(segs)):
cnts = seg_overlap(
segs[i], gt_segs[i], num_classes,
ignore_index, reduce_zero_label)
for output_cnt, cnt in zip(output_cnts, cnts):
output_cnt += cnt
return output_cnts
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Visualization utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetaseg.utils.vis.palette import get_palette
from seetaseg.utils.vis.visualizer import Visualizer
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Pattle for visualizations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
palettes = {}
# RNG42 palette.
state = np.random.get_state()
np.random.seed(42)
palettes['rng42'] = np.random.randint(0, 255, size=(255, 3), dtype='uint8')
np.random.set_state(state)
# ADE20K palette.
palettes['ade20k'] = np.array([
[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]], dtype='uint8')
def get_palette(name='rng42', rgb=False):
"""Return the palette by name."""
global palettes
palette = palettes.get(name, palettes['rng42'])
return palette if rgb else palette[:, ::-1]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Visualizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import numpy as np
from seetaseg.utils.vis.palette import get_palette
class VisImage(object):
"""VisImage."""
def __init__(self, img, alpha=0.5):
self.img = img
self.alpha = alpha
self.shape = (h, w) = img.shape[:2]
self.mask = np.zeros((h, w, 3), 'uint8')
def save(self, filepath):
cv2.imwrite(filepath, self.get_image())
def get_image(self, rgb=False):
img_bgr = self.img * (1. - self.alpha) + self.mask * self.alpha
return img_bgr[:, :, ::-1] if rgb else img_bgr
class Visualizer(object):
""""Visualizer."""
def __init__(self, palette='rng42'):
self.colormap = get_palette(palette, rgb=False)
self.output = None
def draw_seg(self, img, seg, alpha=0.5):
"""Draw semantic segmentation labels."""
self.output = VisImage(img, alpha=alpha)
labels = np.unique(seg)
colors = self.colormap[labels % len(self.colormap)]
for i, label in enumerate(labels):
self.output.mask[seg == label, :] = colors[i]
return self.output
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import subprocess
import sys
import setuptools
import setuptools.command.build_py
import setuptools.command.install
version = git_version = None
if os.path.exists('version.txt'):
with open('version.txt', 'r') as f:
version = f.read().strip()
if os.path.exists('.git'):
try:
git_version = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], cwd='./')
git_version = git_version.decode('ascii').strip()
except (OSError, subprocess.CalledProcessError):
pass
def clean_builds():
for path in ['build', 'seeta_seg.egg-info']:
if os.path.exists(path):
shutil.rmtree(path)
def find_packages(top):
"""Return the python sources installed to package."""
packages = []
for root, _, _ in os.walk(top):
if os.path.exists(os.path.join(root, '__init__.py')):
packages.append(root)
return packages
def find_package_data(top):
"""Return the external data installed to package."""
headers, libraries = [], []
if sys.platform == 'win32':
dylib_suffix = '.pyd'
elif sys.platform == 'darwin':
dylib_suffix = '.dylib'
else:
dylib_suffix = '.so'
for root, _, files in os.walk(top):
root = root[len(top + '/'):]
for file in files:
if file.endswith(dylib_suffix):
libraries.append(os.path.join(root, file))
return headers + libraries
class BuildPyCommand(setuptools.command.build_py.build_py):
"""Enhanced 'build_py' command."""
def build_packages(self):
clean_builds()
with open('seetaseg/version.py', 'w') as f:
f.write("from __future__ import absolute_import\n"
"from __future__ import division\n"
"from __future__ import print_function\n\n"
"version = '{}'\n"
"git_version = '{}'\n".format(version, git_version))
super(BuildPyCommand, self).build_packages()
def build_package_data(self):
self.package_data = {'seetadet': find_package_data('seetadet')}
super(BuildPyCommand, self).build_package_data()
class InstallCommand(setuptools.command.install.install):
"""Enhanced 'install' command."""
user_options = setuptools.command.install.install.user_options
user_options += [('parallel=', 'j', "number of parallel build jobs")]
def initialize_options(self):
self.parallel = None
super(InstallCommand, self).initialize_options()
self.old_and_unmanageable = True
setuptools.setup(
name='seeta-seg',
version=version,
description='SeetaSeg: A platform implementing popular segmentation algorithms.',
url='https://gitlab.seetatech.com/seetaresearch/seetaseg',
author='SeetaTech',
license='BSD 2-Clause',
packages=find_packages('seetaseg'),
package_dir={'seetaseg': 'seetaseg'},
cmdclass={'build_py': BuildPyCommand, 'install': InstallCommand},
install_requires=[
'opencv-python',
'Pillow>=7.1',
'prettytable',
'matplotlib',
],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License',
'Programming Language :: C++',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
)
clean_builds()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Train a segmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import dragon
import numpy
from seetaseg.core.config import cfg
from seetaseg.core.coordinator import Coordinator
from seetaseg.core.training import train_engine
from seetaseg.data.build import build_dataset
from seetaseg.utils import logging
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Train a segmentation model')
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='config file')
parser.add_argument(
'--exp_dir',
default='',
help='experiment dir')
parser.add_argument(
'--tensorboard',
action='store_true',
help='write metrics to tensorboard or not')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
coordinator = Coordinator(args.cfg_file, exp_dir=args.exp_dir)
checkpoint, start_iter = coordinator.get_checkpoint()
if checkpoint is not None:
cfg.TRAIN.WEIGHTS = checkpoint
# Setup the distributed environment.
world_rank = dragon.distributed.get_rank()
world_size = dragon.distributed.get_world_size()
if cfg.NUM_GPUS != world_size:
raise ValueError(
'Excepted staring of {} processes, got {}.'
.format(cfg.NUM_GPUS, world_size))
# Setup the logging modules.
logging.set_root(world_rank == 0)
# Select the GPU depending on the rank of process.
cfg.GPU_ID = [i for i in range(cfg.NUM_GPUS)][world_rank]
# Fix the random seed for reproducibility.
numpy.random.seed(cfg.RNG_SEED + world_rank)
dragon.random.set_seed(cfg.RNG_SEED)
# Inspect the dataset.
dataset = build_dataset(cfg.TRAIN.DATASET)
logging.info('Dataset({}): {} images will be used to train.'
.format(cfg.TRAIN.DATASET, dataset.num_images))
# Run training.
logging.info('Checkpoints will be saved to `{:s}`'
.format(coordinator.path_at('checkpoints')))
with dragon.distributed.new_group(
ranks=[i for i in range(cfg.NUM_GPUS)],
verbose=True).as_default():
train_engine.run_train(
coordinator, start_iter,
enable_tensorboard=args.tensorboard)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test a segmentation network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import multiprocessing
import os
from seetaseg.core.config import cfg
from seetaseg.core.coordinator import Coordinator
from seetaseg.core.testing import test_engine
from seetaseg.data.build import build_dataset
from seetaseg.utils import logging
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Test a segmentation network')
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='config file')
parser.add_argument(
'--exp_dir',
default='',
help='experiment dir')
parser.add_argument(
'--model_dir',
default='',
help='final model dir')
parser.add_argument(
'--gpu',
nargs='+',
type=int,
default=None,
help='index of GPUs to use')
parser.add_argument(
'--iter',
nargs='+',
type=int,
default=None,
help='iteration step of checkpoints')
parser.add_argument(
'--last',
type=int,
default=1,
help='checkpoint of N last steps')
parser.add_argument(
'--read_every',
type=int,
default=100,
help='read every-n images for testing')
parser.add_argument(
'--palette',
type=str,
default=None,
help='palette for visualization')
parser.add_argument(
'--precision',
default='',
help='compute precision')
parser.add_argument(
'--deterministic',
action='store_true',
help='set cudnn deterministic or not')
return parser.parse_args()
def find_weights(args, coordinator):
"""Return the weights for testing."""
weights_list = []
if args.model_dir:
for file in os.listdir(args.model_dir):
if not file.endswith('.pkl'):
continue
weights_list.append(os.path.join(args.model_dir, file))
return weights_list
if args.iter is not None:
for iter in args.iter:
checkpoint, _ = coordinator.get_checkpoint(iter, wait=True)
weights_list.append(checkpoint)
return weights_list
for i in range(1, args.last + 1):
checkpoint, _ = coordinator.get_checkpoint(last_idx=i)
if checkpoint is None:
break
weights_list.append(checkpoint)
return weights_list
if __name__ == '__main__':
args = parse_args()
logging.info('Called with args:\n' + str(args))
coordinator = Coordinator(args.cfg_file, args.exp_dir or args.model_dir)
cfg.MODEL.PRECISION = args.precision or cfg.MODEL.PRECISION
logging.info('Using config:\n' + str(cfg))
# Inspect dataset.
dataset = build_dataset(cfg.TEST.DATASET)
logging.info('Dataset({}): {} images will be used to test.'
.format(cfg.TEST.DATASET, dataset.num_images))
for weights in find_weights(args, coordinator):
weights_name = os.path.splitext(os.path.basename(weights))[0]
output_dir = coordinator.path_at('results/' + weights_name)
logging.info('Results will be saved to ' + output_dir)
process = multiprocessing.Process(
target=test_engine.run_test,
kwargs={'weights': weights,
'output_dir': output_dir,
'devices': args.gpu,
'deterministic': args.deterministic,
'read_every': args.read_every,
'palette': args.palette})
process.start()
process.join()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Train a segmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import dragon
import numpy
from seetaseg.core.config import cfg
from seetaseg.core.coordinator import Coordinator
from seetaseg.core.training import train_engine
from seetaseg.data.build import build_dataset
from seetaseg.utils import logging
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Train a segmentation model.')
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='config file')
parser.add_argument(
'--exp_dir',
default=None,
help='experiment dir')
parser.add_argument(
'--tensorboard',
action='store_true',
help='write metrics to tensorboard or not')
return parser.parse_args()
def run_distributed(args, coordinator):
"""Run distributed training."""
import subprocess
cmd = 'mpirun --allow-run-as-root -n {} --bind-to none '.format(cfg.NUM_GPUS)
cmd += '{} {}'.format(sys.executable, 'distributed/train.py')
cmd += ' --cfg {}'.format(os.path.abspath(args.cfg_file))
cmd += ' --exp_dir {}'.format(coordinator.exp_dir)
cmd += ' --tensorboard' if args.tensorboard else ''
return subprocess.call(cmd, shell=True)
if __name__ == '__main__':
args = parse_args()
logging.info('Called with args:\n' + str(args))
coordinator = Coordinator(args.cfg_file, args.exp_dir)
logging.info('Using config:\n' + str(cfg))
if cfg.NUM_GPUS > 1:
# Run a distributed task.
run_distributed(args, coordinator)
else:
# Resume training?
checkpoint, start_iter = coordinator.get_checkpoint()
if checkpoint is not None:
cfg.TRAIN.WEIGHTS = checkpoint
# Fix the random seed for reproducibility.
numpy.random.seed(cfg.RNG_SEED)
dragon.random.set_seed(cfg.RNG_SEED)
# Inspect the dataset.
dataset = build_dataset(cfg.TRAIN.DATASET)
logging.info('Dataset({}): {} images will be used to train.'
.format(cfg.TRAIN.DATASET, dataset.num_images))
# Run training.
logging.info('Checkpoints will be saved to `{:s}`'
.format(coordinator.path_at('checkpoints')))
train_engine.run_train(coordinator, start_iter,
enable_tensorboard=args.tensorboard)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!