Commit 06e7ac2e by Ting PAN

Initial repository

0 parents
Showing with 4139 additions and 0 deletions
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
[flake8]
max-line-length = 120
ignore = E741, # ambiguous variable name
F403, # ‘from module import *’ used; unable to detect undefined names
F405, # name may be undefined, or defined from star imports: module
F811, # redefinition of unused name from line N
F821, # undefined name
W503, # line break before binary operator
W504 # line break after binary operator
# module imported but unused
per-file-ignores = __init__.py: F401
# Compiled Object files
*.slo
*.lo
*.o
*.cuo
# Compiled Dynamic libraries
*.so
*.dll
*.dylib
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Compiled python
*.pyc
__pycache__
# Compiled MATLAB
*.mex*
# IPython notebook checkpoints
.ipynb_checkpoints
# Editor temporaries
*.swp
*~
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Eclipse Project settings
*.*project
.settings
# QtCreator files
*.user
# VSCode files
.vscode
# IDEA files
.idea
# OSX dir files
.DS_Store
# Android files
.gradle
*.iml
local.properties
Copyright (c) 2017, SeetaTech, Co.,Ltd. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Benchmark and Model Zoo
## Introduction
## Baselines
### ResNet
Refer to [ResNet](configs/resnet) for details.
# SeetaBase
SeetaDet is a platform implementing popular base vision models.
This repository is based on [seeta-dragon](https://github.com/seetaresearch/dragon),
while the style of codes is torch.
## Installation
### Build From Source
If you prefer to develop modules as well as running experiments,
following commands will build but not install to ***site-packages***:
```bash
cd seetabase && python setup.py build
```
### Install From Source
Clone this repository to local disk and install:
```bash
cd seetabase && python setup.py install
```
### Install From Git
You can also install it from remote repository:
```bash
pip install git+https://gitlab.seetatech.com/seetaresearch/seetabase.git
```
## Quick Start
### Train a model
```bash
cd tools
python train.py --cfg <MODEL_YAML>
```
We have provided the default YAML examples into [configs](configs).
### Test a model
```bash
cd tools
python test.py --cfg <MODEL_YAML> --exp_dir <EXP_DIR> --iter <ITERATION>
```
## Benchmark and Model Zoo
Results and models are available in the [Model Zoo](MODEL_ZOO.md).
## License
[BSD 2-Clause license](LICENSE)
# Deep Residual Learning for Image Recognition
## Introduction
```
@article{He2015,
author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
title = {Deep Residual Learning for Image Recognition},
journal = {arXiv preprint arXiv:1512.03385},
year = {2015}
}
```
## ImageNet-1K Classification Baselines
| Model | Lr sched | Throughout | Acc@1 | Acc@5 | Download |
| :---: | :------: | :--------: | :---: | :---: | :------: |
| [R50](in1k_cls_R_50_90e.yml) | 90e | 2080 | 76.52 | 93.17 | [model](https://dragon.seetatech.com/download/seetabase/resnet/in1k_cls_R_50_90e/model_eecb35fc.pkl) &#124; [log](https://dragon.seetatech.com/download/seetabase/resnet/in1k_cls_R_50_90e/logs.json) |
| [R50](in1k_cls_R_50_200e.yml) | 200e | 2080 | 78.65 | 94.3 | [model](https://dragon.seetatech.com/download/seetabase/resnet/in1k_cls_R_50_200e/model_ed34c3a5.pkl) &#124; [log](https://dragon.seetatech.com/download/seetabase/resnet/in1k_cls_R_50_200e/logs.json) |
NUM_GPUS: 8
MODEL:
TYPE: 'cls'
PRECISION: 'float16'
NUM_CLASSES: 1000
AUGREG:
MIXUP: 0.2
CUTMIX: 1.0
LABEL_SMOOTHING: 0.1
BACKBONE:
TYPE: 'resnet50'
SOLVER:
BASE_LR: 0.4
LR_POLICY: 'cosine_decay'
MAX_STEPS: 250000
WARM_UP_STEPS: 6250
EVAL_STEPS: 500
SNAPSHOT_EVERY: 1250
SNAPSHOT_PREFIX: in1k_cls_R_50
TRAIN:
DATASET: '../data/datasets/in1k_train'
BATCH_SIZE: 128
CROP_SIZE: 224
TEST:
DATASET: '../data/datasets/in1k_val'
BATCH_SIZE: 100
CROP_SIZE: 224
NUM_GPUS: 8
MODEL:
TYPE: 'cls'
PRECISION: 'float16'
NUM_CLASSES: 1000
BACKBONE:
TYPE: 'resnet50'
SOLVER:
BASE_LR: 0.4
LR_POLICY: 'steps_with_decay'
DECAY_STEPS: [37500, 75000, 100000]
MAX_STEPS: 112500
WARM_UP_STEPS: 6250
EVAL_STEPS: 500
SNAPSHOT_EVERY: 1250
SNAPSHOT_PREFIX: in1k_cls_R_50
TRAIN:
DATASET: '../data/datasets/in1k_train'
BATCH_SIZE: 128
CROP_SIZE: 224
TEST:
DATASET: '../data/datasets/in1k_val'
BATCH_SIZE: 100
CROP_SIZE: 224
#include <dragon/utils/device/common_eigen.h>
#include "../utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _SoftmaxV2(
const int N,
const int S,
const int C,
const float tau,
const T* x,
T* y) {
if (S == 1) {
ConstEigenArrayMap<T> X(x, C, N);
EigenArrayMap<T> Y(y, C, N);
Y = ((X.rowwise() - X.colwise().maxCoeff()) * tau).exp();
Y = Y.rowwise() / Y.colwise().sum();
return;
}
for (int i = 0; i < N; ++i) {
const auto offset = i * C * S;
for (int j = 0; j < S; ++j) {
ConstEigenStridedVectorArrayMap<T> X_vec(
x + offset + j, 1, C, EigenInnerStride(S));
EigenStridedVectorArrayMap<T> Y_vec(
y + offset + j, 1, C, EigenInnerStride(S));
Y_vec = ((X_vec - X_vec.maxCoeff()) * tau).exp();
Y_vec /= Y_vec.sum();
}
}
}
template <typename T>
void _SoftmaxV2Grad(
const int N,
const int S,
const int C,
const float tau,
const T* dy,
const T* y,
T* dx) {
if (S == 1) {
ConstEigenArrayMap<T> dY(dy, C, N);
ConstEigenArrayMap<T> Y(y, C, N);
EigenArrayMap<T> dX(dx, C, N);
dX = ((dY.rowwise() - (dY * Y).colwise().sum()) * Y) * tau;
return;
}
for (int i = 0; i < N; ++i) {
const auto offset = i * C * S;
for (int j = 0; j < S; ++j) {
ConstEigenStridedVectorArrayMap<T> dY_vec(
dy + offset + j, 1, C, EigenInnerStride(S));
ConstEigenStridedVectorArrayMap<T> Y_vec(
y + offset + j, 1, C, EigenInnerStride(S));
EigenStridedVectorArrayMap<T> dX_vec(
dx + offset + j, 1, C, EigenInnerStride(S));
dX_vec = ((dY_vec - (dY_vec * Y_vec).sum()) * Y_vec) * tau;
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const float tau, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(N, S, C, tau, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const float tau, \
const T* dy, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_##name(N, S, C, tau, dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(SoftmaxV2, float);
DEFINE_KERNEL_LAUNCHER(SoftmaxV2, double);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxV2Grad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxV2Grad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const float tau, \
const T* x, \
T* y, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const float tau, \
const T* dy, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \
}
DEFINE_KERNEL_LAUNCHER(SoftmaxV2, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxV2Grad, float16);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#include "../operators/softmax_op.h"
#include "../utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void SoftmaxV2Op<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
kernels::SoftmaxV2(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
tau_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
template <typename T>
void SoftmaxV2GradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
GET_OP_AXIS_ARG(axis, Y.ndim(), -1);
kernels::SoftmaxV2Grad(
Y.count(0, axis),
Y.count(axis + 1),
Y.dim(axis),
tau_,
dY.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(SoftmaxV2);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SoftmaxV2);
#endif
DEPLOY_CPU_OPERATOR(SoftmaxV2Gradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SoftmaxV2Gradient);
#endif
OPERATOR_SCHEMA(SoftmaxV2)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SoftmaxV2Gradient)
/* Y, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1)
/* dY => dX */
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(SoftmaxV2, InplaceGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_OPERATORS_SOFTMAX_OP_H_
#define DRAGON_EXTENSION_OPERATORS_SOFTMAX_OP_H_
#include <dragon/core/operator.h>
namespace dragon {
template <class Context>
class SoftmaxV2Op : public Operator<Context> {
public:
SoftmaxV2Op(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), tau_(OP_SINGLE_ARG(float, "tau", 1.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
float tau_;
};
template <class Context>
class SoftmaxV2GradientOp : public Operator<Context> {
public:
SoftmaxV2GradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), tau_(OP_SINGLE_ARG(float, "tau", 1.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
float tau_;
};
} // namespace dragon
#endif // DRAGON_EXTENSION_OPERATORS_SOFTMAX_OP_H_
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build cpp extensions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
from dragon.utils import cpp_extension
from setuptools import setup
Extension = cpp_extension.CppExtension
if cpp_extension.CUDA_HOME is not None:
if cpp_extension._cuda.is_available():
Extension = cpp_extension.CUDAExtension
def find_sources(*dirs):
ext_suffixes = ['.cc']
if Extension is cpp_extension.CUDAExtension:
ext_suffixes.append('.cu')
sources = []
for path in dirs:
for ext_suffix in ext_suffixes:
sources += glob.glob(path + '/*' + ext_suffix, recursive=True)
return sources
ext_modules = [
Extension(
name='seetabase.ops._C',
sources=find_sources('**'),
),
]
setup(
name='seetabase',
ext_modules=ext_modules,
cmdclass={'build_ext': cpp_extension.BuildExtension},
)
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_EXTENSION_UTILS_OP_KERNELS_H_
#define DRAGON_EXTENSION_UTILS_OP_KERNELS_H_
#include <dragon/core/context.h>
namespace dragon {
namespace kernels {
template <typename T, class Context>
void SoftmaxV2(
const int N,
const int S,
const int C,
const float tau,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void SoftmaxV2Grad(
const int N,
const int S,
const int C,
const float tau,
const T* dy,
const T* y,
T* dx,
Context* ctx);
} // namespace kernels
} // namespace dragon
#endif // DRAGON_EXTENSION_UTILS_OP_KERNELS_H_
# Datasets
## Introduction
This folder is kept for the record datasets.
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""A platform implementing popular base vision models."""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
# Version
from seetabase.version import version as __version__
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Platform backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib.machinery
import os
from dragon.core.framework import backend
def load_library(library_prefix):
"""Load a shared library."""
loader_details = (importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES)
library_prefix = os.path.abspath(library_prefix)
lib_dir, fullname = os.path.split(library_prefix)
finder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = finder.find_spec(fullname)
if ext_specs is None:
raise ImportError('Could not find the pre-built library '
'for <%s>.' % library_prefix)
backend.load_library(ext_specs.origin)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Platform configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Classes.
from seetabase.core.config.yacs import CfgNode
# Variables.
from seetabase.core.config.defaults import cfg
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Default configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetabase.core.config.yacs import CfgNode
_C = cfg = CfgNode()
# ------------------------------------------------------------
# Model options
# ------------------------------------------------------------
_C.MODEL = CfgNode()
# The model type
_C.MODEL.TYPE = 'cls'
# The float precision for training and inference
# Values supported: 'float32', 'float16'
_C.MODEL.PRECISION = 'float32'
# The number of classes for classification task
_C.MODEL.NUM_CLASSES = 0
# ------------------------------------------------------------
# Backbone options
# ------------------------------------------------------------
_C.BACKBONE = CfgNode()
# The backbone type
_C.BACKBONE.TYPE = ''
# Freeze given modules in the backbone
_C.BACKBONE.FREEZE_AT = 0
# The drop path rate in backbone
_C.BACKBONE.DROP_PATH_RATE = 0.0
# ------------------------------------------------------------
# Decoder options
# ------------------------------------------------------------
_C.DECODER = CfgNode()
# The number of layers stacked in the decoder
_C.DECODER.DEPTH = 4
# The dimension of the decoder features
_C.DECODER.DIM = 512
# The normalization in decoder modules
_C.DECODER.NORM = ''
# The activation in decoder modules
_C.DECODER.ACTIVATION = ''
# ------------------------------------------------------------
# Training options
# ------------------------------------------------------------
_C.TRAIN = CfgNode()
# Initialize network with weights from this file
_C.TRAIN.WEIGHTS = ''
# The dataset to train
_C.TRAIN.DATASET = ''
# The loader type for training
_C.TRAIN.LOADER = 'cls_train'
# The number of workers to load train data
_C.TRAIN.NUM_WORKERS = 4
# Scales to use during training (can list multiple scales)
# Each scale is the pixel size of an image shortest side
_C.TRAIN.SCALES = (-1,)
# Size to crop the input image
_C.TRAIN.CROP_SIZE = 224
# The training batch size
_C.TRAIN.BATCH_SIZE = 128
# ------------------------------------------------------------
# Testing options
# ------------------------------------------------------------
_C.TEST = CfgNode()
# The dataset to test
_C.TEST.DATASET = ''
# The loader type for testing
_C.TEST.LOADER = 'cls_test'
# The number of threads to load data
_C.TEST.NUM_WORKERS = 4
# The testing batch size
_C.TEST.BATCH_SIZE = 100
# Scales to use during testing (can list multiple scales)
# Each scale is the pixel size of an image's shortest side
_C.TEST.SCALES = (256,)
# Size to crop the input image
_C.TEST.CROP_SIZE = 224
# ------------------------------------------------------------
# Augmentation & Regularization options
# ------------------------------------------------------------
_C.AUGREG = CfgNode()
# The minimum scale for random cropping
_C.AUGREG.CROP_MIN = 0.08
# The color jitter factor
_C.AUGREG.COLOR_JITTER = 0.4
# The auto augment config
_C.AUGREG.AUTO_AUGMENT = ''
# The alpha value for mixup
_C.AUGREG.MIXUP = 0.0
# The alpha value for cutmix
_C.AUGREG.CUTMIX = 0.0
# The drop rate for dropout
_C.AUGREG.DROPOUT_RATE = 0.0
# The drop rate for stochastic depth
_C.AUGREG.DROP_PATH_RATE = 0.0
# The prob for random erasing
_C.AUGREG.RANDOM_ERASING = 0.0
# The epsilon value for label smoothing
_C.AUGREG.LABEL_SMOOTHING = 0.0
# ------------------------------------------------------------
# Solver options
# ------------------------------------------------------------
_C.SOLVER = CfgNode()
# The interval to display logs
_C.SOLVER.DISPLAY = 20
# The interval to snapshot a model
_C.SOLVER.SNAPSHOT_EVERY = 5000
# Prefix to yield the path: <prefix>_iter_XYZ.pth
_C.SOLVER.SNAPSHOT_PREFIX = ''
# The maximum number of train steps
_C.SOLVER.MAX_STEPS = 2147483647
# The interval to evaluate the train model
_C.SOLVER.EVAL_EVERY = 2147483647
# The number of steps under the evaluating stage
_C.SOLVER.EVAL_STEPS = 500
# Scale factor to training loss
_C.SOLVER.LOSS_SCALE = 1024.0
# Base learning rate for the specified scheduler
_C.SOLVER.BASE_LR = 0.1
# Minimal learning rate for the specified scheduler
_C.SOLVER.MIN_LR = 0.0
# The custom intervals for LRScheduler
_C.SOLVER.DECAY_STEPS = []
# The decay factor for exponential LRScheduler
_C.SOLVER.DECAY_GAMMA = 0.1
# Warm up to ``LR_MAX`` over this number of steps
_C.SOLVER.WARM_UP_STEPS = 0
# Start the warm up from ``LR_MAX`` * ``FACTOR``
_C.SOLVER.WARM_UP_FACTOR = 0.0
# The type of Optimizier
_C.SOLVER.OPTIMIZER = 'SGD'
# The type of LRScheduler
_C.SOLVER.LR_POLICY = 'cosine_decay'
# The layer-wise lr decay
_C.SOLVER.LAYER_LR_DECAY = 1.0
# Momentum to use with SGD
_C.SOLVER.MOMENTUM = 0.9
# Beta values to use with Adam
_C.SOLVER.ADAM_BETAS = (0.9, 0.999)
# L2 regularization for weight parameters
_C.SOLVER.WEIGHT_DECAY = 0.0001
# L2 norm factor for clipping gradients
_C.SOLVER.CLIP_NORM = 0.0
# ------------------------------------------------------------
# MAE options
# ------------------------------------------------------------
_C.MAE = CfgNode()
# The ratio of masked patches
_C.MAE.MASK_RATIO = 0.75
# Weight for pixel regression loss
_C.MAE.PIXEL_LOSS_WEIGHT = 4.0
# ------------------------------------------------------------
# Misc options
# ------------------------------------------------------------
# Number of GPUs to use (applies to training)
_C.NUM_GPUS = 1
# Default GPU device id
_C.GPU_ID = 0
# For reproducibility
_C.RNG_SEED = 3
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/rbgirshick/yacs/blob/master/yacs/config.py>
#
# ------------------------------------------------------------
"""Yet Another Configuration System (YACS)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
import yaml
class CfgNode(dict):
"""Node for configuration options."""
IMMUTABLE = '__immutable__'
def __init__(self, *args, **kwargs):
super(CfgNode, self).__init__(*args, **kwargs)
self.__dict__[CfgNode.IMMUTABLE] = False
def clone(self):
"""Recursively copy this CfgNode."""
return copy.deepcopy(self)
def freeze(self):
"""Make this CfgNode and all of its children immutable."""
self._immutable(True)
def is_frozen(self):
"""Return mutability."""
return self.__dict__[CfgNode.IMMUTABLE]
def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it into this CfgNode."""
with open(cfg_filename, 'r') as f:
other_cfg = CfgNode(yaml.safe_load(f))
self.merge_from_other_cfg(other_cfg)
def merge_from_list(self, cfg_list):
"""Merge config (keys, values) in a list into this CfgNode."""
assert len(cfg_list) % 2 == 0
from ast import literal_eval
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
key_list = k.split('.')
d = self
for sub_key in key_list[:-1]:
assert sub_key in d
d = d[sub_key]
sub_key = key_list[-1]
assert sub_key in d
try:
value = literal_eval(v)
except: # noqa
# Handle the case when v is a string literal
value = v
if type(value) != type(d[sub_key]): # noqa
raise TypeError('Type {} does not match original type {}'
.format(type(value), type(d[sub_key])))
d[sub_key] = value
def merge_from_other_cfg(self, other_cfg):
"""Merge ``other_cfg`` into this CfgNode."""
_merge_a_into_b(other_cfg, self)
def _immutable(self, is_immutable):
"""Set immutability recursively to all nested CfgNode."""
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
for v in self.__dict__.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
for v in self.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __repr__(self):
return "{}({})".format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def __setattr__(self, name, value):
if not self.__dict__[CfgNode.IMMUTABLE]:
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
else:
raise AttributeError(
'Attempted to set "{}" to "{}", but CfgNode is immutable'
.format(name, value))
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a."""
if not isinstance(a, dict):
return
for k, v in a.items():
# a must specify keys that are in b
if k not in b:
raise KeyError('{} is not a valid config key'.format(k))
# The types must match, too
v = _check_and_coerce_cfg_value_type(v, b[k], k)
# Recursively merge dicts
if type(v) is CfgNode:
try:
_merge_a_into_b(a[k], b[k])
except: # noqa
print('Error under config key: {}'.format(k))
raise
else:
b[k] = v
def _check_and_coerce_cfg_value_type(value_a, value_b, key):
"""Check if the value type matched."""
type_a, type_b = type(value_a), type(value_b)
if type_a is type_b:
return value_a
if type_b is float and type_a is int:
return float(value_a)
# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, np.ndarray):
value_a = np.array(value_a, dtype=value_b.dtype)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
elif isinstance(value_a, dict) and isinstance(value_b, CfgNode):
value_a = CfgNode(value_a)
else:
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(type_b, type_a, value_b, value_a, key))
return value_a
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import time
import numpy as np
from seetabase.core.config import cfg
from seetabase.utils import logging
class Coordinator(object):
"""Manage the unique experiments."""
def __init__(self, cfg_file, exp_dir=None):
cfg.merge_from_file(cfg_file)
if exp_dir is None:
self.exp_dir = '../experiments/{}'.format(
time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())))
if not osp.exists(self.exp_dir):
os.makedirs(self.exp_dir)
else:
if not osp.exists(exp_dir):
raise ValueError('Experiment ({}) does not exist.'.format(exp_dir))
self.exp_dir = exp_dir
def path_at(self, file, auto_create=True):
path = osp.abspath(osp.join(self.exp_dir, file))
if auto_create and not osp.exists(path):
os.makedirs(path)
return path
def get_checkpoint(self, step=None, last_idx=1, wait=False):
path = self.path_at('checkpoints')
def locate(last_idx=None):
files = os.listdir(path)
files = list(filter(lambda x: '_iter_' in x and
x.endswith('.pkl'), files))
file_steps = []
for i, file in enumerate(files):
file_step = int(file.split('_iter_')[-1].split('.')[0])
if step == file_step:
return osp.join(path, files[i]), step
file_steps.append(file_step)
if step is None:
if len(files) == 0:
return None, 0
if last_idx > len(files):
return None, 0
file = files[np.argsort(file_steps)[-last_idx]]
file_step = file_steps[np.argsort(file_steps)[-last_idx]]
return osp.join(path, file), file_step
return None, 0
file, file_step = locate(last_idx)
while file is None and wait:
logging.info('Wait for checkpoint at {}.'.format(step))
time.sleep(10)
file, file_step = locate(last_idx)
return file, file_step
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Registry class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
class Registry(object):
"""Registry class."""
def __init__(self, name):
self._name = name
self._registry = collections.OrderedDict()
def has(self, key):
return key in self._registry
def register(self, name, func=None, **kwargs):
def decorated(inner_function):
for key in (name if isinstance(
name, (tuple, list)) else [name]):
self._registry[key] = \
functools.partial(inner_function, **kwargs)
return inner_function
if func is not None:
return decorated(func)
return decorated
def get(self, name, default=None):
if name is None:
return None
if not self.has(name):
if default is not None:
return default
raise KeyError("`%s` is not registered in <%s>."
% (name, self._name))
return self._registry[name]
def try_get(self, name):
if self.has(name):
return self.get(name)
return None
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Testing engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import multiprocessing as mp
import os
from dragon.vm import torch
import prettytable
from seetabase.core.config import cfg
from seetabase.data.build import build_loader_test
from seetabase.models.build import build_model
from seetabase.modules.build import build_model_inference
from seetabase.utils import logging
from seetabase.utils import profiler
def test_model(test_cfg, queues, device, deterministic=False, verbose=True):
"""Test a model."""
cfg.merge_from_other_cfg(test_cfg)
cfg.GPU_ID = device
cfg.freeze()
logging.set_root(verbose)
if deterministic:
torch.backends.cudnn.deterministic = True
loader = build_loader_test()
model = build_model(device)
module = build_model_inference(model)
input_queue, output_queue = queues
must_stop = False
while not must_stop:
index, weights = input_queue.get()
if index < 0:
must_stop = True
break
model.load_state_dict(torch.load(weights))
model.eval()
model.optimize_for_inference()
results = module.get_results(loader)
time_diffs = module.get_time_diffs()
output_queue.put((index, time_diffs, results))
def run_test(weights_list, devices, deterministic=False):
"""Run a model testing.
Parameters
----------
weights_list : Sequence[str]
A list with the path of network weights file.
devices : Sequence[int]
The index of computing devices.
deterministic : bool, optional, default=False
Set cudnn deterministic or not.
"""
devices = devices if devices else [cfg.GPU_ID]
num_devices = len(devices)
num_weights = len(weights_list)
queues = [mp.Queue() for _ in range(num_devices + 1)]
actors = [mp.Process(
target=test_model,
kwargs={'test_cfg': cfg,
'queues': [queues[i], queues[-1]],
'device': devices[i],
'deterministic': deterministic,
'verbose': i == 0}) for i in range(num_devices)]
for actor in actors:
actor.start()
timers = collections.defaultdict(profiler.Timer)
logging.info('Start testing on devices: {}'.format(devices))
for count in range(num_weights):
weights = weights_list[count]
queues[count % num_devices].put((count, weights))
for i in range(num_devices):
queues[i].put((-1, None))
for count in range(1, num_weights + 1):
index, time_diffs, outputs = queues[-1].get()
weights_name = os.path.splitext(os.path.basename(weights_list[index]))[0]
for name, diff in time_diffs.items():
timers[name].add_diff(diff)
summary_table = prettytable.PrettyTable()
for k, v in outputs.items():
summary_table.add_column(k, [round(v * 100, 2)])
throughout = int(1.0 / timers['im_model'].average_time)
summary_table.add_column('Throughput', [throughout])
summary_title = '[{:d}/{:d}] '.format(count, num_weights)
summary_title += weights_name + ' results:\n'
print(summary_title + '\nSummary:\n' + summary_table.get_string())
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for training library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
from seetabase.core.config import cfg
from seetabase.core.training import lr_scheduler
from seetabase.core.training.utils import get_param_groups
def build_optimizer(params, **kwargs):
"""Build the optimizer."""
args = {'lr': cfg.SOLVER.BASE_LR,
'weight_decay': cfg.SOLVER.WEIGHT_DECAY,
'clip_norm': cfg.SOLVER.CLIP_NORM,
'grad_scale': 1.0 / cfg.SOLVER.LOSS_SCALE}
optimizer = kwargs.pop('optimizer', cfg.SOLVER.OPTIMIZER)
if isinstance(params, torch.nn.Module):
params = get_param_groups(params)
if optimizer == 'SGD':
args['momentum'] = cfg.SOLVER.MOMENTUM
elif optimizer == 'AdamW' or optimizer == 'AdamW':
args['betas'] = cfg.SOLVER.ADAM_BETAS
elif optimizer == 'Nesterov':
args['momentum'] = cfg.SOLVER.MOMENTUM
optimizer, args['nesterov'] = 'SGD', True
elif optimizer == 'LARS':
args['momentum'] = cfg.SOLVER.MOMENTUM
args['trust_coef'] = 0.001
args.update(kwargs)
return getattr(torch.optim, optimizer)(params, **args)
def build_lr_scheduler(**kwargs):
"""Build the LR scheduler."""
args = {'lr_max': cfg.SOLVER.BASE_LR,
'lr_min': cfg.SOLVER.MIN_LR,
'warmup_steps': cfg.SOLVER.WARM_UP_STEPS,
'warmup_factor': cfg.SOLVER.WARM_UP_FACTOR}
policy = kwargs.pop('policy', cfg.SOLVER.LR_POLICY)
args.update(kwargs)
if policy == 'step':
return lr_scheduler.StepLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
decay_gamma=cfg.SOLVER.DECAY_GAMMA, **args)
elif policy == 'steps_with_decay':
return lr_scheduler.MultiStepLR(
decay_steps=cfg.SOLVER.DECAY_STEPS,
decay_gamma=cfg.SOLVER.DECAY_GAMMA, **args)
elif policy == 'cosine_decay':
return lr_scheduler.CosineLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
elif policy == 'linear_decay':
return lr_scheduler.LinearLR(
decay_step=(cfg.SOLVER.DECAY_STEPS or [1])[0],
max_steps=cfg.SOLVER.MAX_STEPS, **args)
return lr_scheduler.ConstantLR(**args)
def build_tensorboard(log_dir):
"""Build the tensorboard."""
try:
from dragon.utils.tensorboard import tf
from dragon.utils.tensorboard import TensorBoard
# Avoid using of GPUs by TF API.
if tf is not None:
tf.config.set_visible_devices([], 'GPU')
return TensorBoard(log_dir)
except ImportError:
return None
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""LR schedulers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
class ConstantLR(object):
"""Constant LR scheduler."""
def __init__(self, **kwargs):
self._lr_max = kwargs.pop('lr_max')
self._lr_min = kwargs.pop('lr_min', 0)
self._warmup_steps = kwargs.pop('warmup_steps', 0)
self._warmup_factor = kwargs.pop('warmup_factor', 0)
if kwargs:
raise ValueError('Unexpected arguments: ' + ','.join(v for v in kwargs))
self._step_count = 0
self._last_decay = 1.
def step(self):
self._step_count += 1
def get_lr(self):
if self._step_count < self._warmup_steps:
alpha = (self._step_count + 1.) / self._warmup_steps
return self._lr_max * (alpha + (1. - alpha) * self._warmup_factor)
return self._lr_min + (self._lr_max - self._lr_min) * self.get_decay()
def get_decay(self):
return self._last_decay
class StepLR(ConstantLR):
"""LR scheduler with step decay."""
def __init__(self, lr_max, decay_step, decay_gamma, **kwargs):
super(StepLR, self).__init__(lr_max=lr_max, **kwargs)
self._decay_step = decay_step
self._decay_gamma = decay_gamma
def get_decay(self):
step_count = self._step_count - self._warmup_steps
if step_count % self._decay_step == 0:
decay_count = step_count // self._decay_step
self._last_decay = self._decay_gamma ** decay_count
return self._last_decay
class MultiStepLR(ConstantLR):
"""LR scheduler with multi-steps decay."""
def __init__(self, lr_max, decay_steps, decay_gamma, **kwargs):
super(MultiStepLR, self).__init__(lr_max=lr_max, **kwargs)
self._decay_steps = decay_steps
self._decay_gamma = decay_gamma
self._stage_count = 0
self._num_stages = len(decay_steps)
def get_decay(self):
if self._stage_count < self._num_stages:
k = self._decay_steps[self._stage_count]
while self._step_count >= k:
self._stage_count += 1
if self._stage_count >= self._num_stages:
break
k = self._decay_steps[self._stage_count]
self._last_decay = self._decay_gamma ** self._stage_count
return self._last_decay
class CosineLR(ConstantLR):
"""LR scheduler with cosine decay."""
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
super(CosineLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
self._decay_step = decay_step
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t > 0 and t % self._decay_step == 0:
self._last_decay = .5 * (1. + math.cos(math.pi * t / t_max))
return self._last_decay
class LinearLR(ConstantLR):
"""LR scheduler with linear decay."""
def __init__(self, lr_max, max_steps, decay_step=1, **kwargs):
super(LinearLR, self).__init__(lr_max=lr_max, **kwargs)
self._decay_step = decay_step
self._max_steps = max_steps
def get_decay(self):
t = self._step_count - self._warmup_steps
t_max = self._max_steps - self._warmup_steps
if t % self._decay_step == 0:
self._last_decay = 1. - float(t) / float(t_max)
return self._last_decay
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Training engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import os
from dragon.vm import torch
from seetabase.core.config import cfg
from seetabase.core.training.build import build_lr_scheduler
from seetabase.core.training.build import build_optimizer
from seetabase.core.training.build import build_tensorboard
from seetabase.core.training.utils import count_params
from seetabase.core.training.utils import get_param_groups
from seetabase.data.build import build_loader_train
from seetabase.models.build import build_model
from seetabase.utils import logging
from seetabase.utils import profiler
class Trainer(object):
"""Model trainer."""
def __init__(self, coordinator):
# Build loader.
self.loader = build_loader_train()
# Build model.
self.model = build_model()
if cfg.TRAIN.WEIGHTS:
state_dict = torch.load(cfg.TRAIN.WEIGHTS)
self.model.load_state_dict(state_dict, strict=False)
self.model.cuda(cfg.GPU_ID)
if cfg.MODEL.PRECISION.lower() == 'float16':
self.model.half()
# Build optimizer.
self.loss_scale = cfg.SOLVER.LOSS_SCALE
param_groups_getter = get_param_groups
if cfg.SOLVER.LAYER_LR_DECAY < 1.0:
lr_scale_getter = getattr(self.model, 'get_lr_scale', None)
if lr_scale_getter is None:
if hasattr(self.model, 'encoder'):
lr_scale_getter = self.model.encoder.get_lr_scale
elif hasattr(self.model, 'backbone'):
lr_scale_getter = self.model.backbone.get_lr_scale
param_groups_getter = functools.partial(
param_groups_getter, lr_scale_getter=functools.partial(
lr_scale_getter, decay=cfg.SOLVER.LAYER_LR_DECAY))
self.optimizer = build_optimizer(get_param_groups(self.model))
self.scheduler = build_lr_scheduler()
# Build monitor.
self.coordinator = coordinator
self.metrics = collections.OrderedDict()
self.board = None
@property
def iter(self):
return self.scheduler._step_count
def snapshot(self):
f = cfg.SOLVER.SNAPSHOT_PREFIX
f += '_iter_{}.pkl'.format(self.iter)
f = os.path.join(self.coordinator.path_at('checkpoints'), f)
if logging.is_root() and not os.path.exists(f):
torch.save(self.model.state_dict(), f, pickle_protocol=4)
logging.info('Wrote snapshot to: {:s}'.format(f))
def add_metrics(self, stats):
for k, v in stats['metrics'].items():
if k not in self.metrics:
self.metrics[k] = profiler.SmoothedValue()
self.metrics[k].update(v)
def display_metrics(self, stats):
logging.info('Iteration %d, lr = %.8f, time = %.2fs'
% (stats['iter'], stats['lr'], stats['time']))
for k, v in self.metrics.items():
logging.info(' ' * 4 + 'Train net output({}): {:.4f} ({:.4f})'
.format(k, stats['metrics'][k], v.average()))
if self.board is not None:
self.board.scalar_summary('lr', stats['lr'], stats['iter'])
self.board.scalar_summary('time', stats['time'], stats['iter'])
for k, v in self.metrics.items():
self.board.scalar_summary(k, v.average(), stats['iter'])
def step(self):
stats = {'iter': self.iter}
metrics = collections.defaultdict(float)
# Run forward.
timer = profiler.Timer().tic()
inputs = self.loader()
outputs, losses = self.model(inputs), []
for k, v in outputs.items():
if 'loss' in k:
losses.append(v)
metrics[k] += float(v)
# Run backward.
losses = sum(losses[1:], losses[0])
if self.loss_scale != 1:
losses *= self.loss_scale
losses.backward()
# Apply Update.
stats['lr'] = self.scheduler.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = stats['lr']
self.optimizer.step()
self.scheduler.step()
stats['time'] = timer.toc()
stats['metrics'] = collections.OrderedDict(sorted(metrics.items()))
return stats
def train_model(self, start_iter=0):
"""Network training loop."""
timer = profiler.Timer()
max_steps = cfg.SOLVER.MAX_STEPS
display_every = cfg.SOLVER.DISPLAY
progress_every = display_every * 10
snapshot_every = cfg.SOLVER.SNAPSHOT_EVERY
self.scheduler._step_count = start_iter
while self.iter < max_steps:
with timer.tic_and_toc():
stats = self.step()
self.add_metrics(stats)
if stats['iter'] % display_every == 0:
self.display_metrics(stats)
if self.iter % progress_every == 0:
logging.info(profiler.get_progress(timer, self.iter, max_steps))
if self.iter % snapshot_every == 0:
self.snapshot()
self.metrics.clear()
def run_train(coordinator, start_iter=0, enable_tensorboard=False):
"""Start a network training task."""
trainer = Trainer(coordinator)
if enable_tensorboard and logging.is_root():
trainer.board = build_tensorboard(coordinator.path_at('logs'))
logging.info('#Params: %.2fM' % count_params(trainer.model))
logging.info('Start training..')
trainer.train_model(start_iter)
trainer.snapshot()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Training utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
def count_params(module):
"""Return the number of parameters in MB."""
return sum([v.size().numel() for v in module.parameters()]) / 1e6
def get_param_groups(module, lr_scale_getter=None, weight_decay=False):
"""Separate parameters into groups."""
groups = collections.OrderedDict()
for name, param in module.named_parameters():
if not param.requires_grad:
continue
attrs = collections.OrderedDict()
if lr_scale_getter:
attrs['lr_scale'] = lr_scale_getter(name)
no_weight_decay = not (name.endswith('.weight') and param.dim() > 1)
no_weight_decay = getattr(param, 'no_weight_decay', no_weight_decay)
if no_weight_decay and not weight_decay:
attrs['weight_decay'] = 0
if hasattr(param, 'no_lr_decay'):
attrs['lr_fixed'] = True
group_name = '/'.join(['%s:%s' % (v[0], v[1]) for v in list(attrs.items())])
if group_name not in groups:
groups[group_name] = {'params': []}
groups[group_name].update(attrs)
groups[group_name]['params'].append(param)
return list(groups.values())
def unfreeze_module(module, patterns):
"""Unfreeze module by patterns for training."""
unfreezed_keys, unfreezed_modules = [], []
if patterns is None:
return unfreezed_keys
patterns = [re.compile(p) for p in patterns]
for name, submodule in module.named_modules():
if submodule is module:
continue
for pattern in patterns:
if pattern.match(name):
unfreezed_keys.append(name)
unfreezed_modules.append(submodule)
break
for module in unfreezed_modules:
module.training = True
for name, param in module.named_parameters():
param.requires_grad = True
return unfreezed_keys
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from dragon.vm import dali
from seetabase.core.config import cfg
from seetabase.core.registry import Registry
from seetabase.data import loaders
from seetabase.data import pipelines
LOADERS = Registry('loaders')
def build_loader_train(**kwargs):
args = {'dataset': cfg.TRAIN.DATASET,
'batch_size': cfg.TRAIN.BATCH_SIZE,
'num_threads': cfg.TRAIN.NUM_WORKERS,
'seed': cfg.RNG_SEED + cfg.GPU_ID,
'shuffle': True}
args.update(kwargs)
return LOADERS.get(cfg.TRAIN.LOADER)(**args)
def build_loader_test(**kwargs):
"""Build the test loader."""
args = {'dataset': cfg.TEST.DATASET,
'batch_size': cfg.TEST.BATCH_SIZE,
'num_threads': cfg.TEST.NUM_WORKERS,
'seed': cfg.RNG_SEED,
'shuffle': False}
args.update(kwargs)
return LOADERS.get(cfg.TEST.LOADER)(**args)
@LOADERS.register('cls_train')
def loader_cls_train(**kwargs):
args = {'resize': cfg.TRAIN.SCALES[-1],
'crop_size': cfg.TRAIN.CROP_SIZE,
'crop_scale': (cfg.AUGREG.CROP_MIN, 1.0),
'horizontal_flip': True, 'training': True}
if not cfg.AUGREG.AUTO_AUGMENT:
args.update({'dtype': cfg.MODEL.PRECISION,
'color_jitter': cfg.AUGREG.COLOR_JITTER})
args.update(kwargs)
with dali.device('cuda', cfg.GPU_ID):
return loaders.DALILoader(
functools.partial(
pipelines.InceptionPipeline, **args),
batch_tags=['img', 'label'])
else:
args.update(kwargs)
return loaders.DALIHybridLoader(
functools.partial(
pipelines.InceptionPipeline,
dtype=None, **args),
functools.partial(
pipelines.AutoAugmentWorker,
augment=cfg.AUGREG.AUTO_AUGMENT,
reprob=cfg.AUGREG.RANDOM_ERASING,
seed=args['seed']),
batch_size=args['batch_size'],
num_workers=args['num_threads'] // 2,
batch_tags=['img', 'label'])
@LOADERS.register('cls_test')
def loader_cls_test(**kwargs):
args = {'resize': cfg.TEST.SCALES[-1],
'crop_size': cfg.TEST.CROP_SIZE,
'dtype': cfg.MODEL.PRECISION}
args.update(kwargs)
with dali.device('cuda', cfg.GPU_ID):
return loaders.DALILoader(
functools.partial(
pipelines.InceptionPipeline, **args),
batch_tags=['img', 'label'])
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import multiprocessing as mp
import time
import threading
import queue
import numpy as np
from dragon.vm import dali
from dragon.vm.dali.plugin.pytorch import DALIGenericIterator
class BalancedQueues(object):
"""Balanced queues."""
def __init__(self, base_queue, num=1):
self.queues = [base_queue]
self.queues += [mp.Queue(base_queue._maxsize) for _ in range(num - 1)]
self.index = 0
def put(self, obj, block=True, timeout=None):
q = self.queues[self.index]
q.put(obj, block=block, timeout=timeout)
self.index = (self.index + 1) % len(self.queues)
def get(self, block=True, timeout=None):
q = self.queues[self.index]
obj = q.get(block=block, timeout=timeout)
self.index = (self.index + 1) % len(self.queues)
return obj
def get_n(self, num=1):
outputs = []
while len(outputs) < num:
obj = self.get()
if obj is not None:
outputs.append(obj)
return outputs
class DataLoaderBase(threading.Thread):
"""Base class of data loader."""
def __init__(self, **kwargs):
super(DataLoaderBase, self).__init__(daemon=True)
self.batch_size = kwargs.get('batch_size', 64)
self.queue_depth = kwargs.get('queue_depth', 2)
self.batch_queue = queue.Queue(self.queue_depth)
def next(self):
"""Return the next batch of data."""
return self.__next__()
def run(self):
"""Main loop."""
def __call__(self):
return self.next()
def __iter__(self):
"""Return the iterator self."""
return self
def __next__(self):
"""Return the next batch of data."""
return self.batch_queue.get()
class DALIReader(mp.Process):
"""Read data from the DALI pipeline."""
def __init__(self, **kwargs):
super(DALIReader, self).__init__()
self.pipeline = kwargs.get('pipeline', None)
self.device_type = 'cuda'
self.device_index = kwargs.get('device', -1)
if self.device_index < 0:
self.device_type = 'cpu'
self.device_index = 0
self.reader_queue = None
def run(self):
"""Start the process."""
with dali.device(self.device_type, self.device_index):
pipe = self.pipeline()
api_type = dali.types.PIPELINE_API_SCHEDULED
with pipe._check_api_type_scope(api_type):
pipe.build()
pipe.schedule_run()
while True:
with pipe._check_api_type_scope(api_type):
pipe_returns = pipe.share_outputs()
batch_outputs = []
for output in pipe_returns:
if hasattr(output, 'as_cpu'):
output = output.as_cpu()
batch_outputs.append(output.as_array())
for i in range(len(batch_outputs[0])):
outputs = []
for j in range(len(batch_outputs)):
outputs.append(batch_outputs[j][i])
self.reader_queue.put(outputs)
with pipe._check_api_type_scope(api_type):
pipe.release_outputs()
pipe.schedule_run()
class DALILoader(DataLoaderBase):
"""DALI loader."""
def __init__(self, pipeline, **kwargs):
super(DALILoader, self).__init__(**kwargs)
self.iterator = DALIGenericIterator(pipeline())
self.batch_tags = kwargs.get('batch_tags', None)
def __next__(self):
"""Return the next batch of data."""
outputs = collections.OrderedDict()
for i, v in enumerate(self.iterator.next()):
outputs[self.batch_tags[i]] = v
return outputs
class DALIHybridLoader(DataLoaderBase):
"""Hybrid loader with DALI pipeline and custom worker."""
def __init__(self, pipeline, worker, **kwargs):
super(DALIHybridLoader, self).__init__(**kwargs)
self.num_workers = kwargs.get('num_workers', 4)
self.batch_tags = kwargs.get('batch_tags', None)
# Build queues.
self.reader_queue = mp.Queue(self.queue_depth * self.batch_size)
self.worker_queue = mp.Queue(self.queue_depth * self.batch_size)
self.reader_queue = BalancedQueues(self.reader_queue, self.num_workers)
self.worker_queue = BalancedQueues(self.worker_queue, self.num_workers)
# Build readers.
self.readers = [DALIReader(pipeline=pipeline, **kwargs)]
self.readers[0].reader_queue = self.reader_queue
self.readers[0].start()
time.sleep(0.1)
# Build workers.
self.workers = []
for i in range(self.num_workers):
p = worker(**kwargs)
p.seed += i * 1024
p.reader_queue = self.reader_queue
p.worker_queue = self.worker_queue
p.start()
self.workers.append(p)
time.sleep(0.1)
# Register cleanup callbacks.
def cleanup():
def terminate(processes):
for p in processes:
p.terminate()
p.join()
terminate(self.workers)
terminate(self.readers)
import atexit
atexit.register(cleanup)
# Start batch prefetching.
self.start()
def run(self):
"""Main loop."""
while True:
outputs = collections.OrderedDict()
for k, v in zip(self.batch_tags, self.worker_queue.get()):
batch_shape = (self.batch_size,) + v.shape[:]
outputs[k] = np.empty(batch_shape, v.dtype)
outputs[k][0] = v
for i in range(1, self.batch_size):
for k, v in zip(self.batch_tags, self.worker_queue.get()):
outputs[k][i] = v
self.batch_queue.put(outputs)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Modules.
from seetabase.models import backbones
from seetabase.models import classifiers
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Backbones."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Modules
from seetabase.models.backbones import resnet
from seetabase.models.backbones import swin
from seetabase.models.backbones import vit
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch import nn
from seetabase.models.build import BACKBONES
class BasicBlock(nn.Module):
"""Basic ResNet block."""
expansion = 1
def __init__(self, dim_in, dim, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(dim_in, dim, 3, stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(dim)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(dim)
self.downsample = downsample
def forward(self, x):
shortcut = x
x = self.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
if self.downsample is not None:
shortcut = self.downsample(shortcut)
return self.relu(x.add_(shortcut))
class Bottleneck(nn.Module):
"""Bottleneck ResNet block."""
expansion = 4
groups, width_per_group = 1, 64
def __init__(self, dim_in, dim, stride=1, downsample=None):
super(Bottleneck, self).__init__()
width = int(dim * (self.width_per_group / 64.)) * self.groups
self.conv1 = nn.Conv2d(dim_in, width, 1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, 3, stride, 1, bias=False)
self.bn2 = nn.BatchNorm2d(dim)
self.conv3 = nn.Conv2d(width, dim * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(dim * self.expansion)
self.relu = nn.ReLU(True)
self.downsample = downsample
def forward(self, x):
shortcut = x
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
if self.downsample is not None:
shortcut = self.downsample(shortcut)
return self.relu(x.add_(shortcut))
# out = torch.utils.checkpoint.checkpoint_sequential(
# [self.conv1, self.bn1, self.relu, self.conv2], x)
# out = self.bn2(out)
# out = self.relu(out)
# out = torch.utils.checkpoint.checkpoint_sequential(
# [self.conv3, self.bn3], out)
class ResNet(nn.Module):
"""ResNet."""
def __init__(self, block, depths, deep_stem=False):
super(ResNet, self).__init__()
dim_in, dims, blocks = 64, [64, 128, 256, 512], []
if deep_stem:
self.conv1 = nn.Sequential(
nn.Conv2d(3, dim_in // 2, 3, 2, padding=1, bias=False),
nn.BatchNorm2d(dim_in // 2), nn.ReLU(True),
nn.Conv2d(dim_in // 2, dim_in // 2, 3, padding=1, bias=False),
nn.BatchNorm2d(dim_in // 2), nn.ReLU(True),
nn.Conv2d(dim_in // 2, dim_in, 3, padding=1, bias=False))
else:
self.conv1 = nn.Conv2d(3, dim_in, 7, 2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(dim_in)
self.relu = nn.ReLU(True)
self.maxpool = nn.MaxPool2d(3, 2, padding=1)
for i, depth, dim in zip(range(4), depths, dims):
stride = 1 if i == 0 else 2
downsample = None
if stride != 1 or dim_in != dim * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(dim_in, dim * block.expansion, 1, stride, bias=False),
nn.BatchNorm2d(dim * block.expansion))
blocks.append(block(dim_in, dim, stride, downsample))
dim_in = dim * block.expansion
for _ in range(depth - 1):
blocks.append(block(dim_in, dim))
setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:]))
self.blocks = blocks
self.num_features = dim_in
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
for blk in self.blocks:
x = blk(x)
return x
class ResNetV1c(ResNet):
"""ResNet with deep 3x3 stem instead of 7x7."""
def __init__(self, block, depths):
super(ResNetV1c, self).__init__(block, depths, deep_stem=True)
BACKBONES.register('resnet18', ResNet, block=BasicBlock, depths=[2, 2, 2, 2])
BACKBONES.register('resnet34', ResNet, block=BasicBlock, depths=[3, 4, 6, 3])
BACKBONES.register('resnet50', ResNet, block=Bottleneck, depths=[3, 4, 6, 3])
BACKBONES.register('resnet101', ResNet, block=Bottleneck, depths=[3, 4, 23, 3])
BACKBONES.register('resnet152', ResNet, block=Bottleneck, depths=[3, 8, 36, 3])
BACKBONES.register('resnet50_v1c', ResNetV1c, block=Bottleneck, depths=[3, 4, 6, 3])
BACKBONES.register('resnet101_v1c', ResNetV1c, block=Bottleneck, depths=[3, 4, 23, 3])
BACKBONES.register('resnet152_v1c', ResNetV1c, block=Bottleneck, depths=[3, 8, 36, 3])
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Swin Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from dragon.vm import torch
from dragon.vm.torch import nn
import numpy as np
from seetabase.core.config import cfg
from seetabase.models.build import BACKBONES
from seetabase.models.backbones.vit import MLP
from seetabase.models.backbones.vit import PatchEmbed
def space_to_depth(input, block_size):
"""Rearrange blocks of spatial data into depth."""
h, w, c = input.size()[1:]
h1, w1 = h // block_size, w // block_size
c1 = (block_size ** 2) * c
input.reshape_((-1, h1, block_size, w1, block_size, c))
out = input.permute(0, 1, 3, 2, 4, 5)
input.reshape_((-1, h, w, c))
return out.reshape_((-1, h1, w1, c1))
def depth_to_space(input, block_size):
"""Rearrange blocks of depth data into spatial."""
h1, w1, c1 = input.size()[1:]
h, w = h1 * block_size, w1 * block_size
c = c1 // (block_size ** 2)
input.reshape_((-1, h1, w1, block_size, block_size, c))
out = input.permute(0, 1, 3, 2, 4, 5)
input.reshape_((-1, h1, w1, c1))
return out.reshape_((-1, h, w, c))
class RelPosEmbed(nn.Module):
"""Relative position embedding layer."""
def __init__(self, num_heads, window_size):
super(RelPosEmbed, self).__init__()
num_pos = (2 * window_size - 1) ** 2 + 3
grid = np.arange(window_size)
pos = np.stack(np.meshgrid(grid, grid, indexing='ij'))
pos = pos.reshape((2, -1))
pos = pos[:, :, None] - pos[:, None, :]
pos += window_size - 1
pos[0] *= 2 * window_size - 1
index = pos.sum(0).astype('int64')
self.register_buffer('index', torch.from_numpy(index))
self.weight = nn.Parameter(torch.zeros(num_heads, num_pos))
nn.init.normal_(self.weight, std=.02)
def forward(self, x):
return x.add_(self.weight[:, self.index])
class PatchMerging(nn.Module):
"""Merge patches to downsample the input."""
def __init__(self, dim_in, dim_out):
super(PatchMerging, self).__init__()
self.norm = nn.LayerNorm(4 * dim_in)
self.reduction = nn.Linear(4 * dim_in, dim_out, bias=False)
def forward(self, x):
x = space_to_depth(x, 2)
return self.reduction(self.norm(x))
class Attention(nn.Module):
"""Multihead attention."""
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
super(Attention, self).__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.relative_position = RelPosEmbed(num_heads, window_size)
def forward(self, x, mask=None):
num_patches = x.size(1)
qkv_shape = (-1, num_patches, 3, self.num_heads, self.head_dim)
qkv = self.qkv(x).reshape_(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0, copy=False)
attn = q @ k.transpose(-2, -1).mul_(self.scale)
attn = self.relative_position(attn)
if mask is not None:
attn.reshape_(-1, mask.size(1), self.num_heads,
num_patches, num_patches).add_(mask)
attn.reshape_(-1, self.num_heads, num_patches, num_patches)
attn = nn.functional.softmax(attn, dim=-1, inplace=True)
return self.proj((attn @ v).transpose(1, 2).flatten_(2))
class Block(nn.Module):
"""Transformer block."""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4,
qkv_bias=False,
drop_path=0,
downsample=None,
):
super(Block, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads, window_size, qkv_bias=qkv_bias)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
self.drop_path = nn.DropPath(drop_path, inplace=True)
self.downsample = downsample
def get_mask(self, resolution):
index, (height, width) = 0, resolution
img_mask = np.zeros([1, height, width, 1], 'float32')
for h, w in itertools.product(
*[(slice(0, resolution[i] - self.window_size),
slice(resolution[i] - self.window_size,
resolution[i] - self.shift_size),
slice(resolution[i] - self.shift_size, None))
for i in range(len(resolution))]):
img_mask[:, h, w, :] = index
index += 1
img_shape = [1]
for size in resolution:
img_shape += [size // self.window_size, self.window_size]
img_mask = img_mask.reshape(img_shape)
img_mask = img_mask.transpose((0, 1, 3, 2, 4))
img_mask = img_mask.reshape((-1, self.window_size ** 2))
mask = np.expand_dims(img_mask, 1) - np.expand_dims(img_mask, 2)
mask[mask != 0] = -100.0
mask = np.expand_dims(mask, (0, 2))
return torch.from_numpy(mask)
def forward(self, x, mask=None):
if self.downsample is not None:
x = self.downsample(x)
shortcut = x
x = self.norm1(x)
if self.shift_size > 0 and mask is not None:
x = x.roll((-self.shift_size,) * 2, dims=(1, 2))
x = space_to_depth(x, self.window_size)
msa_shape = (-1, self.window_size ** 2, self.dim)
wmsa_shape = (-1,) + x.shape[1:-1] + (self.window_size ** 2 * self.dim,)
x = self.attn(x.reshape_(msa_shape), mask)
x = depth_to_space(x.reshape_(wmsa_shape), self.window_size)
if self.shift_size > 0 and mask is not None:
x = x.roll((self.shift_size,) * 2, dims=(1, 2))
x = self.drop_path(x).add_(shortcut)
x = self.drop_path(self.mlp(self.norm2(x))).add_(x)
return x
class SwinTransformer(nn.Module):
"""SwinTransformer."""
def __init__(self, depths, dims, num_heads, mlp_ratios,
patch_size=4, window_size=7):
super(SwinTransformer, self).__init__()
drop_path = cfg.AUGREG.DROP_PATH_RATE
drop_path = (torch.linspace(
0, drop_path, sum(depths), dtype=torch.float32).tolist()
if drop_path > 0 else [drop_path] * sum(depths))
self.patch_embed = PatchEmbed(dims[0], patch_size)
self.blocks = nn.ModuleList()
for i, depth in enumerate(depths):
downsample = PatchMerging(dims[i - 1], dims[i]) if i > 0 else None
self.blocks += [Block(
dim=dims[i], num_heads=num_heads[i],
window_size=window_size,
shift_size=(0 if j % 2 == 0 else window_size // 2),
mlp_ratio=mlp_ratios[i], qkv_bias=True,
drop_path=drop_path[len(self.blocks) - 1],
downsample=downsample if j == 0 else None)
for j in range(depth)]
self.masks = dict()
self.num_features = dims[-1]
self.norm = nn.LayerNorm(dims[-1])
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, y=None):
x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1)
for blk in self.blocks:
resolution, mask = list(x.shape[1:-1]), None
if blk.shift_size > 0 and min(resolution) > blk.window_size:
mask = self.masks.get(str(resolution), None)
if mask is None:
mask = blk.get_mask(resolution)
self.masks[str(resolution)] = mask
mask = mask.to(x)
x = blk(x)
return self.norm(x).permute(0, 3, 1, 2)
BACKBONES.register('swin_tiny_patch4_window7', SwinTransformer,
depths=(2, 2, 6, 2), dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), mlp_ratios=(4, 4, 4, 4),
patch_size=4, window_size=7)
BACKBONES.register('swin_small_patch4_window7', SwinTransformer,
depths=(2, 2, 18, 2), dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), mlp_ratios=(4, 4, 4, 4),
patch_size=4, window_size=7)
BACKBONES.register('swin_base_patch4_window7', SwinTransformer,
depths=(2, 2, 18, 2), dims=(128, 256, 512, 1024),
num_heads=(4, 8, 16, 32), mlp_ratios=(4, 4, 4, 4),
patch_size=4, window_size=7)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Vision Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from dragon.vm import torch
from dragon.vm.torch import nn
import numpy as np
from seetabase.core.config import cfg
from seetabase.models.build import BACKBONES
from seetabase.ops.activation import softmax
from seetabase.utils import logging
class MLP(nn.Module):
"""Two layers MLP."""
def __init__(self, dim, mlp_ratio=4):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
self.activation = nn.GELU()
def forward(self, x):
return self.fc2(self.activation(self.fc1(x)))
class Attention(nn.Module):
"""Multihead attention."""
def __init__(self, dim, num_heads, qkv_bias=True):
super(Attention, self).__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.qk_scale = self.head_dim ** -0.5
self.qk_rescale = 1.0
def forward(self, x):
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
qkv = self.qkv(x).reshape_(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0, copy=False)
attn = q @ k.transpose(-2, -1).mul_(self.qk_scale / self.qk_rescale)
attn = softmax(attn, -1, tau=self.qk_rescale, inplace=True)
return self.proj((attn @ v).transpose(1, 2).flatten_(2))
class Block(nn.Module):
"""Transformer block."""
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, drop_path=0):
super(Block, self).__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
self.drop_path = nn.DropPath(p=drop_path, inplace=True)
def forward(self, x):
x = self.drop_path(self.attn(self.norm1(x))).add_(x)
return self.drop_path(self.mlp(self.norm2(x))).add_(x)
class PatchEmbed(nn.Module):
"""Patch embedding layer."""
def __init__(self, dim=768, patch_size=16, patch_type=None):
super(PatchEmbed, self).__init__()
if patch_type == 'conv':
num_strides = int(math.log2(patch_size))
dims = [3] + [dim // (2 ** p) for p in range(num_strides - 1, -1, -1)]
layers = []
for i in range(num_strides):
layers += [nn.Conv2d(dims[i], dims[i + 1], 3, 2, 1, bias=False),
nn.BatchNorm2d(dims[i + 1]),
nn.ReLU(inplace=True)]
layers.append(nn.Conv2d(dims[-1], dims[-1], 1, 1))
self.proj = nn.Sequential(*layers)
else:
self.proj = nn.Conv2d(3, dim, patch_size, patch_size)
def freeze(self):
if isinstance(self.proj, nn.Conv2d):
for param in self.proj.parameters():
param.requires_grad = False
logging.info('Freeze %s.' % self.__class__.__name__)
def forward(self, x):
return self.proj(x)
class PosEmbed(nn.Module):
"""Position embedding layer."""
def __init__(self, dim, num_patches):
super(PosEmbed, self).__init__()
self.dim = dim
self.num_patches = num_patches
self.weight = nn.Parameter(torch.zeros(num_patches, dim))
self.weight.no_weight_decay = True
nn.init.normal_(self.weight, std=0.02)
def freeze(self):
pos_dim = self.dim // 4
omega = np.arange(pos_dim, dtype='float64') / pos_dim
omega = 1. / (10000. ** omega)
h = w = int(math.sqrt(self.num_patches))
grid_h = np.arange(h, dtype='float64')
grid_w = np.arange(w, dtype='float64')
pos_w, pos_h = np.meshgrid(grid_w, grid_h)
out_w = np.einsum('m,d->md', pos_w.flatten(), omega)
out_h = np.einsum('m,d->md', pos_h.flatten(), omega)
emb = (np.sin(out_w), np.cos(out_w), np.sin(out_h), np.cos(out_h))
emb = np.concatenate(emb, axis=1)
self.weight = nn.Parameter(torch.tensor(emb.astype('float32')))
self.weight.requires_grad = False
logging.info('Freeze 2D %s.' % (self.__class__.__name__))
def forward(self, x):
return x.add_(self.weight)
class VisionTransformer(nn.Module):
"""Vision Transformer."""
def __init__(self, depths, dims, num_heads, mlp_ratios,
patch_size=16):
super(VisionTransformer, self).__init__()
drop_path = cfg.BACKBONE.DROP_PATH_RATE
drop_path = (torch.linspace(
0, drop_path, sum(depths), dtype=torch.float32).tolist()
if drop_path > 0 else [drop_path] * sum(depths))
self.img_size = cfg.TRAIN.CROP_SIZE
self.patch_size = patch_size
self.num_patches = (self.img_size // patch_size) ** 2
self.num_features = dims[-1]
self.patch_embed = PatchEmbed(dims[0], patch_size)
self.pos_embed = PosEmbed(dims[0], self.num_patches)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dims[0]))
self.blocks = nn.ModuleList([Block(
dim=dims[0], num_heads=num_heads[0],
mlp_ratio=mlp_ratios[0], qkv_bias=True, drop_path=drop_path[i])
for i in range(depths[0])])
self.norm = nn.LayerNorm(self.num_features)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.cls_token, std=.02)
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten_(2).transpose(1, 2)
x = self.pos_embed(x)
cls_tokens = self.cls_token.expand(x.size(0), 1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
return x[:, 1:]
def get_lr_scale(self, name, decay):
values = list(decay ** (len(self.blocks) + 1 - i)
for i in range(len(self.blocks) + 2))
if name.startswith('cls_token'):
return values[0]
elif name.startswith('pos_embed'):
return values[0]
elif name.startswith('patch_embed'):
return values[0]
elif name.startswith('blocks'):
return values[int(name.split('.')[1]) + 1]
return values[-1]
BACKBONES.register('vit_small_patch16', VisionTransformer,
depths=(12,), dims=(384,), num_heads=(6,),
mlp_ratios=(4,), patch_size=16)
BACKBONES.register('vit_base_patch16', VisionTransformer,
depths=(12,), dims=(768,), num_heads=(12,),
mlp_ratios=(4,), patch_size=16)
BACKBONES.register('vit_large_patch16', VisionTransformer,
depths=(24,), dims=(1024,), num_heads=(16,),
mlp_ratios=(4,), patch_size=16)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetabase.core.config import cfg
from seetabase.core.registry import Registry
BACKBONES = Registry('backbones')
MODELS = Registry('models')
def build_backbone():
"""Build the backbone."""
backbone_type = cfg.BACKBONE.TYPE.lower()
return BACKBONES.get(backbone_type)()
def build_model(device=None, weights=None):
"""Build the model."""
model = MODELS.get(cfg.MODEL.TYPE.lower())()
if weights is not None:
model.load_weights(weights, strict=True)
if device is not None:
model.cuda(device)
return model
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Classifiers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetabase.models.classifiers.classifier import Classifier
from seetabase.models.classifiers.classifier import LinearCLS
# from seetabase.models.classifiers.classifier import FinetuneCLS
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch import nn
from seetabase.core.config import cfg
from seetabase.models.build import MODELS
from seetabase.models.build import build_backbone
from seetabase.ops.loss import CrossEntropyLoss
from seetabase.ops.mixup import MixUp
from seetabase.ops.normalization import ToTensor
@MODELS.register('cls')
class Classifier(nn.Module):
"""Base classifier."""
def __init__(self):
super(Classifier, self).__init__()
self.to_tensor = ToTensor()
self.backbone = build_backbone()
self.num_features = self.backbone.num_features
self.norm = nn.Identity()
if hasattr(self.backbone, 'norm'):
self.norm = type(self.backbone.norm)(self.num_features)
fc = nn.Linear if cfg.MODEL.NUM_CLASSES > 0 else nn.Identity
self.fc = fc(self.num_features, cfg.MODEL.NUM_CLASSES)
self.mixup = MixUp()
self.criterion = CrossEntropyLoss(mixup=self.mixup)
def forward(self, inputs):
x = self.to_tensor(inputs['img'], normalize=True)
x = self.mixup(x)
x = self.backbone(x)
if x.dim() == 3:
x = x.mean(1)
else:
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten_(1)
x = self.fc(self.norm(x)).float()
if self.training and 'label' in inputs:
y = self.to_tensor(inputs['label'])
return {'cls_loss': self.criterion(x, y)}
else:
return {'cls_score': x}
def optimize_for_inference(self):
"""Optimize the graph for the inference."""
# Set precision.
precision = cfg.MODEL.PRECISION.lower()
self.half() if precision == 'float16' else self.float()
@MODELS.register('lincls')
class LinearCLS(Classifier):
"""Linear classifier."""
def __init__(self):
super(LinearCLS, self).__init__()
self.backbone.eval()
for param in self.backbone.parameters():
param.requires_grad = False
self.norm = nn.BatchNorm1d(self.num_features, affine=False)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Masked Autoencoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import inspect
from dragon.vm import torch
from dragon.vm.torch import nn
from seetabase.core.config import cfg
from seetabase.models.build import MODELS
from seetabase.models.build import build_backbone
from seetabase.models.backbones.vit import Block
from seetabase.models.backbones.vit import PosEmbed
from seetabase.ops.normalization import ToTensor
from seetabase.ops.random import batch_permutation
class PatchTarget(nn.Module):
"""Patch regression target for MAE."""
def __init__(self, patch_size, eps=1e-6):
super(PatchTarget, self).__init__()
self.patch_size = patch_size
self.dim = patch_size ** 2 * 3
self.eps = eps
self.weight = torch.ones(self.dim)
self.bias = torch.zeros(self.dim)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
lambda_source = inspect.getsource(fn)
if 'half_()' in lambda_source:
return self
for buf in (self.weight, self.bias, self.pos_bias):
fn(buf)
def forward(self, y):
y = nn.functional.pixel_unshuffle(y, self.patch_size)
y = y.flatten_(2).transpose(1, 2).float_()
return nn.functional.layer_norm(
y, (self.dim,), self.weight, self.bias, self.eps)
class DecodeHead(nn.Module):
"""MAE decode head."""
def __init__(self, encoder):
super(DecodeHead, self).__init__()
self.dim = cfg.DECODER.DIM
num_patches = encoder.num_patches
num_features = encoder.num_features
self.num_pixels = encoder.patch_size ** 2 * 3
self.pos_embed = PosEmbed(self.dim, num_patches)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.dim))
self.norm = nn.LayerNorm(self.dim)
self.fc1 = nn.Linear(num_features, self.dim)
self.fc2 = nn.Linear(self.dim, self.num_pixels)
self.blocks = nn.ModuleList()
for i in range(cfg.DECODER.DEPTH):
self.blocks.append(Block(self.dim, self.dim // 32))
self.blocks[i].attn.qk_rescale = 1024.0
nn.init.normal_(self.mask_token, std=.02)
@MODELS.register('mae')
class MAE(nn.Module):
"""Masked autoencoder."""
def __init__(self):
super(MAE, self).__init__()
self.to_tensor = ToTensor()
self.encoder = build_backbone()
self.encoder.norm = nn.LayerNorm(self.encoder.num_features)
self.decoder = DecodeHead(self.encoder)
self.mask_ratio = cfg.MAE.MASK_RATIO
self.num_patches = self.encoder.num_patches
self.patch_target = PatchTarget(self.encoder.patch_size)
self.patch_merge = nn.PixelShuffle(self.encoder.patch_size)
self.encoder.pos_embed.freeze()
self.decoder.pos_embed.freeze()
self.buffers = collections.defaultdict(lambda: torch.empty(1))
def random_masking(self, x):
batch_size, num_patches, dim = x.size()
num_encodes = round(num_patches * (1. - self.mask_ratio))
patch_index = self.buffers['patch_index']
batch_permutation(batch_size, num_patches, num_encodes,
out=patch_index, device=x.device)
patch_index = patch_index.unsqueeze_(-1)
vis_indices = patch_index.expand(-1, -1, dim)
x = x.gather(1, vis_indices)
vis_indices = patch_index.expand(-1, -1, self.decoder.dim)
mask_tokens = self.decoder.mask_token.expand(batch_size, num_patches, -1)
return x, mask_tokens, vis_indices
def get_outputs(self, x):
x = self.encoder.patch_embed(x)
x = x.flatten_(2).transpose(1, 2)
x = self.encoder.pos_embed(x)
x, mask_tokens, vis_indices = self.random_masking(x)
cls_tokens = self.encoder.cls_token.expand(x.size(0), 1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.encoder.blocks:
x = blk(x)
x = self.decoder.fc1(self.encoder.norm(x))
cls_tokens, x = x.split((1, x.size(1) - 1), dim=1)
x = mask_tokens.scatter_(1, vis_indices, x)
x = self.decoder.pos_embed(x)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.decoder.blocks:
x = blk(x)
return self.decoder.fc2(self.decoder.norm(x[:, 1:]))
def get_losses(self, x, y):
"""Return the losses."""
patch_index = self.buffers['patch_index']
loss = nn.functional.mse_loss(x, y, reduction='none')
loss = loss.mean(-1, True).scatter_(1, patch_index, 0)
mask_ratio = 1.0 - patch_index.size(1) / self.num_patches
loss_weight = cfg.MAE.PIXEL_LOSS_WEIGHT / mask_ratio
return {'pixel_loss': loss.mean().mul_(loss_weight)}
def forward(self, inputs):
x = self.to_tensor(inputs['img'], normalize=True)
y = self.patch_target(x)
x = self.get_outputs(x)
if self.training:
return self.get_losses(x, y)
else:
patch_index = self.buffers['patch_index']
mask = x.new_zeros(x.shape[:-1] + (1,))
y = y.mul_(mask.scatter_(1, patch_index, 1))
h = w = int(self.num_patches ** 0.5)
num_pixels = self.decoder.num_pixels
x = x.transpose(1, 2).reshape_((-1, num_pixels, h, w))
y = y.transpose(1, 2).reshape_((-1, num_pixels, h, w))
return self.patch_merge(x), self.patch_merge(y)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetabase.modules import cls
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import types
from dragon.vm import torch
from seetabase.core.config import cfg
from seetabase.core.registry import Registry
from seetabase.utils.profiler import Timer
MODEL_INFERENCE = Registry('model_inference')
def build_model_inference(model):
"""Build the model inference."""
return MODEL_INFERENCE.get(cfg.MODEL.TYPE)(model)
class ModelInference(object):
"""Model inference module."""
def __init__(self, model):
self.model = model
self.timers = collections.defaultdict(Timer)
@torch.no_grad()
def get_results(self, imgs):
"""Return the detection results."""
def get_time_diffs(self):
"""Return the time differences."""
return dict((k, v.average_time)
for k, v in self.timers.items())
def trace(self, name, func, example_inputs=None):
"""Trace the function and bound to model."""
if not hasattr(self.model, name):
setattr(self.model, name, torch.jit.trace(
func=types.MethodType(func, self.model),
example_inputs=example_inputs))
return getattr(self.model, name)
@staticmethod
def register(model_type, **kwargs):
"""Register a model inference function."""
def decorated(func):
return MODEL_INFERENCE.register(model_type, func, **kwargs)
return decorated
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""CLS modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from dragon.vm import torch
from seetabase.core.config import cfg
from seetabase.modules.build import ModelInference
@ModelInference.register(['cls', 'lincls'])
class CLSInference(ModelInference):
"""RCNN inference module."""
def __init__(self, model):
super(CLSInference, self).__init__(model)
self.forward_model = self.trace(
'forward_eval', lambda self, img:
self.forward({'img': img}))
self.eval_steps = cfg.SOLVER.EVAL_STEPS
@torch.no_grad()
def forward_acc(self, score, label, k=1):
return score.topk(k)[1].eq(label).float().sum().mul_(1. / label.size(0))
@torch.no_grad()
def get_results(self, loader):
metrics = collections.defaultdict(float)
for _ in range(self.eval_steps):
inputs = loader()
label = self.model.to_tensor(inputs['label'])
with self.timers['im_model'].tic_and_toc(n=inputs['img'].size(0)):
score = self.forward_model(inputs['img'])['cls_score']
torch.cuda.synchronize(score.device)
with self.timers['misc'].tic_and_toc():
metrics['Acc@1'] += float(self.forward_acc(score, label, 1))
metrics['Acc@5'] += float(self.forward_acc(score, label, 5))
for k in metrics.keys():
metrics[k] /= self.eval_steps
return metrics
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Operators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from seetabase.core.backend import load_library as _load_library
_load_library(os.path.join(os.path.dirname(__file__), '_C'))
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Activation ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch import autograd
def softmax(input, dim, tau=1, inplace=False):
op_type = 'SoftmaxV2' if tau != 1 else 'Softmax'
return autograd.Function.apply(
op_type, input.device, [input],
outputs=[input if inplace else None], axis=dim, tau=float(tau))
autograd.Function.register(
'SoftmaxV2', lambda **kwargs: {
'tau': kwargs.get('tau', 1.0)})
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Build for ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetabase.core.registry import Registry
FUSIONS = Registry('fusions')
def build_fusion(*modules):
"""Return the fusion between modules."""
key = '+'.join(m.__class__.__name__ for m in modules)
return key, FUSIONS.try_get(key)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Loss ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch import nn
from seetabase.core.config import cfg
class CrossEntropyLoss(nn.CrossEntropyLoss):
"""Cross entropy loss with label smoothing."""
def __init__(self, mixup=None):
super(CrossEntropyLoss, self).__init__()
self.mixup = mixup
self.epsilon = cfg.AUGREG.LABEL_SMOOTHING
def forward(self, input, target):
lambdas = getattr(self.mixup, 'lambdas', None)
if self.epsilon > 0 or lambdas is not None:
dim, target = input.shape[-1], target.squeeze_().float()
off_val = self.epsilon / dim
on_val = 1.0 - self.epsilon + off_val
x = nn.functional.log_softmax(input, -1, inplace=True)
y = nn.functional.one_hot(target, dim, on_val, off_val)
if lambdas is not None:
y2 = nn.functional.one_hot(target.flip(0), dim, on_val, off_val)
y.mul_(lambdas[0].float_()).add_(y2.mul_(lambdas[1].float_()))
return x.mul(y.mul_(-1)).sum(-1).mean()
return nn.functional.cross_entropy(input, target)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Mixup ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
from dragon.vm.torch import nn
import numpy as np
from seetabase.core.config import cfg
class MixUp(nn.Module):
"""Mixup module."""
def __init__(self):
super(MixUp, self).__init__()
self.mixup_alpha = cfg.AUGREG.MIXUP
self.cutmix_alpha = cfg.AUGREG.CUTMIX
self.values = (torch.empty(1), torch.empty(1))
self.lambdas = None
def forward(self, input):
self.lambdas = None
mixup, cutmix = self.mixup_alpha > 0, self.cutmix_alpha > 0
if self.training and (mixup or cutmix):
if cutmix and mixup:
cutmix = np.random.rand() < 0.5
mixup = not cutmix
alpha = self.mixup_alpha if mixup else 0
alpha = self.cutmix_alpha if cutmix else alpha
r = float(np.random.beta(alpha, alpha, (1,)).astype('float32'))
x1, y1, x2, y2 = 0, 0, 0, 0
if cutmix:
cut_r = np.sqrt(1. - r)
h, w = input.shape[-2:]
cut_h, cut_w = int(h * cut_r), int(w * cut_r)
y, x = np.random.randint(h), np.random.randint(w)
y1 = np.clip(y - cut_h // 2, 0, h)
y2 = np.clip(y + cut_h // 2, 0, h)
x1 = np.clip(x - cut_w // 2, 0, w)
x2 = np.clip(x + cut_w // 2, 0, w)
cut_area, area = (y2 - y1) * (x2 - x1), h * w
cutmix, r = cut_area > 0, 1. - float(cut_area) / float(area)
if mixup or cutmix:
(v1, v2), input2 = self.values, input.flip(0)
v1._impl.FromNumpy(np.array(r, input.dtype), False)
v2._impl.FromNumpy(np.array(1.0 - r, input.dtype), False)
v1, v2 = self.lambdas = v1.to(input.device), v2.to(input.device)
if cutmix:
input[:, :, y1:y2, x1:x2] = input2[:, :, y1:y2, x1:x2]
else:
input.mul_(v1).add_(input2.mul_(v2))
return input
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Normalization ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from dragon.vm import torch
from dragon.vm.torch import nn
from seetabase.core.config import cfg
class ToTensor(nn.Module):
"""Convert input to tensor."""
def __init__(self):
super(ToTensor, self).__init__()
self.device = torch.device('cpu')
self.tensor = torch.ones(1)
self.normalize = functools.partial(
nn.functional.channel_norm,
mean=(103.53, 116.28, 123.675),
std=(57.375, 57.12, 58.395),
dim=1, dims=(0, 3, 1, 2),
dtype=cfg.MODEL.PRECISION.lower())
def _apply(self, fn):
fn(self.tensor)
def cpu(self):
self.device = torch.device('cpu')
def cuda(self, device=None):
self.device = torch.device('cuda', device)
def forward(self, input, normalize=False):
if input is None:
return input
if not isinstance(input, torch.Tensor):
input = torch.from_numpy(input)
input = input.to(self.tensor.device)
if normalize and not input.is_floating_point():
input = self.normalize(input)
return input
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Random ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm import torch
import numpy as np
def batch_permutation(batch_size, n, m, out=None, device=None):
perm = np.empty((batch_size, m), 'int64')
for i in range(batch_size):
perm[i] = np.random.permutation(n)[:m]
if out is not None:
out._impl.FromNumpy(perm, False)
return out.to(device=device)
return torch.from_numpy(perm).to(device=device)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/platform/tf_logging.py>
#
# ------------------------------------------------------------
"""Logger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import inspect
import logging as _logging
import sys as _sys
import threading
_logger = None
_logger_lock = threading.Lock()
def get_logger():
global _logger
# Use double-checked locking to avoid taking lock unnecessarily.
if _logger:
return _logger
_logger_lock.acquire()
try:
if _logger:
return _logger
logger = _logging.getLogger('seetabase')
logger.setLevel(_logging.INFO)
logger.propagate = False
logger._is_root = True
if True:
# Determine whether we are in an interactive environment
_interactive = False
try:
# This is only defined in interactive shells.
if _sys.ps1:
_interactive = True
except AttributeError:
# Even now, we may be in an interactive shell with `python -i`.
_interactive = _sys.flags.interactive
# If we are in an interactive environment (like Jupyter), set loglevel
# to INFO and pipe the output to stdout.
if _interactive:
logger.setLevel(_logging.INFO)
_logging_target = _sys.stdout
else:
_logging_target = _sys.stderr
# Add the output handler.
_handler = _logging.StreamHandler(_logging_target)
_handler.setFormatter(_logging.Formatter('%(levelname)s %(message)s'))
logger.addHandler(_handler)
_logger = logger
return _logger
finally:
_logger_lock.release()
def _detailed_msg(msg):
file, lineno = inspect.stack()[:3][2][1:3]
return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
def log(level, msg, *args, **kwargs):
get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
def debug(msg, *args, **kwargs):
if is_root():
get_logger().debug(_detailed_msg(msg), *args, **kwargs)
def error(msg, *args, **kwargs):
get_logger().error(_detailed_msg(msg), *args, **kwargs)
assert 0
def fatal(msg, *args, **kwargs):
get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
assert 0
def info(msg, *args, **kwargs):
if is_root():
get_logger().info(_detailed_msg(msg), *args, **kwargs)
def warn(msg, *args, **kwargs):
if is_root():
get_logger().warn(_detailed_msg(msg), *args, **kwargs)
def warning(msg, *args, **kwargs):
if is_root():
get_logger().warning(_detailed_msg(msg), *args, **kwargs)
def get_verbosity():
"""Return how much logging output will be produced."""
return get_logger().getEffectiveLevel()
def set_verbosity(v):
"""Sets the threshold for what messages will be logged."""
get_logger().setLevel(v)
def set_formatter(fmt=None, datefmt=None):
"""Set the formatter."""
handler = _logging.StreamHandler(_sys.stderr)
handler.setFormatter(_logging.Formatter(fmt, datefmt))
logger = get_logger()
logger.removeHandler(logger.handlers[0])
logger.addHandler(handler)
def set_root(is_root=True):
"""Set logger to the root."""
get_logger()._is_root = is_root
def is_root():
"""Return logger is the root."""
return get_logger()._is_root
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Profiler utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from seetabase.utils.profiler.stats import SmoothedValue
from seetabase.utils.profiler.timer import Timer
from seetabase.utils.profiler.timer import get_progress
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Trackable statistics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
class SmoothedValue(object):
"""Track values and provide smoothed rsteport."""
def __init__(self, window_size=None):
self.deque = collections.deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def update(self, value):
self.deque.append(value)
self.count += 1
self.total += value
def mean(self):
return np.mean(self.deque)
def median(self):
return np.median(self.deque)
def average(self):
return self.total / self.count
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Timing functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import datetime
import time
class Timer(object):
"""Simple timer."""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
def add_diff(self, diff, n=1, average=True):
self.total_time += diff
self.calls += n
self.average_time = self.total_time / self.calls
return self.average_time if average else self.diff
@contextlib.contextmanager
def tic_and_toc(self, n=1, average=True):
try:
yield self.tic()
finally:
self.toc(n, average)
def tic(self):
self.start_time = time.time()
return self
def toc(self, n=1, average=True):
self.diff = time.time() - self.start_time
return self.add_diff(self.diff, n, average)
def get_progress(timer, step, max_steps):
"""Return the progress information."""
eta_seconds = timer.average_time * (max_steps - step)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
progress = (step + 1.) / max_steps
return ('< PROGRESS: {:.2%} | SPEED: {:.3f}s / iter | ETA: {} >'
.format(progress, timer.average_time, eta))
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import subprocess
import sys
import setuptools
import setuptools.command.build_py
import setuptools.command.install
version = git_version = None
if os.path.exists('version.txt'):
with open('version.txt', 'r') as f:
version = f.read().strip()
if os.path.exists('.git'):
try:
git_version = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], cwd='./')
git_version = git_version.decode('ascii').strip()
except (OSError, subprocess.CalledProcessError):
pass
def build_extensions(parallel=4):
"""Prepare the package files."""
# Compile cxx sources.
py_exec = sys.executable
if subprocess.call(
'cd csrc && '
'{} setup.py build_ext -b ../ -f --no-python-abi-suffix=0 -j {} &&'
'{} setup.py clean'.format(py_exec, parallel, py_exec), shell=True,
) > 0:
raise RuntimeError('Failed to build the cxx sources.')
def clean_builds():
for path in ['build', 'seeta_base.egg-info']:
if os.path.exists(path):
shutil.rmtree(path)
def find_packages(top):
"""Return the python sources installed to package."""
packages = []
for root, _, _ in os.walk(top):
if os.path.exists(os.path.join(root, '__init__.py')):
packages.append(root)
return packages
def find_package_data(top):
"""Return the external data installed to package."""
headers, libraries = [], []
if sys.platform == 'win32':
dylib_suffix = '.pyd'
elif sys.platform == 'darwin':
dylib_suffix = '.dylib'
else:
dylib_suffix = '.so'
for root, _, files in os.walk(top):
root = root[len(top + '/'):]
for file in files:
if file.endswith(dylib_suffix):
libraries.append(os.path.join(root, file))
return headers + libraries
class BuildPyCommand(setuptools.command.build_py.build_py):
"""Enhanced 'build_py' command."""
def build_packages(self):
clean_builds()
with open('seetabase/version.py', 'w') as f:
f.write("from __future__ import absolute_import\n"
"from __future__ import division\n"
"from __future__ import print_function\n\n"
"version = '{}'\n"
"git_version = '{}'\n".format(version, git_version))
super(BuildPyCommand, self).build_packages()
def build_package_data(self):
parallel = 4
for k in ('build', 'install'):
v = self.get_finalized_command(k).parallel
parallel = max(parallel, (int(v) if v else v) or 1)
build_extensions(parallel=parallel)
self.package_data = {'seetabase': find_package_data('seetabase')}
super(BuildPyCommand, self).build_package_data()
class InstallCommand(setuptools.command.install.install):
"""Enhanced 'install' command."""
user_options = setuptools.command.install.install.user_options
user_options += [('parallel=', 'j', "number of parallel build jobs")]
def initialize_options(self):
self.parallel = None
super(InstallCommand, self).initialize_options()
self.old_and_unmanageable = True
setuptools.setup(
name='seeta-base',
version=version,
description='SeetaBase: A platform implementing popular base vision models.',
url='https://gitlab.seetatech.com/seetaresearch/seetabase',
author='SeetaTech',
license='BSD 2-Clause',
packages=find_packages('seetabase'),
package_dir={'seetabase': 'seetabase'},
cmdclass={'build_py': BuildPyCommand, 'install': InstallCommand},
install_requires=['opencv-python', 'Pillow>=7.1', 'prettytable'],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License',
'Programming Language :: C++',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence'],
)
clean_builds()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Train a base vision model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import dragon
import numpy
from seetabase.core.config import cfg
from seetabase.core.coordinator import Coordinator
from seetabase.core.training import train_engine
from seetabase.utils import logging
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='config file')
parser.add_argument(
'--exp_dir',
default='',
help='experiment dir')
parser.add_argument(
'--tensorboard',
action='store_true',
help='write metrics to tensorboard or not')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
coordinator = Coordinator(args.cfg_file, exp_dir=args.exp_dir)
checkpoint, start_iter = coordinator.get_checkpoint()
if checkpoint is not None:
cfg.TRAIN.WEIGHTS = checkpoint
# Setup distributed environment.
world_rank = dragon.distributed.get_rank()
world_size = dragon.distributed.get_world_size()
if cfg.NUM_GPUS != world_size:
raise ValueError(
'Excepted staring of {} processes, got {}.'
.format(cfg.NUM_GPUS, world_size))
# Setup the logging modules.
logging.set_root(world_rank == 0)
# Select the GPU depending on the rank of process.
cfg.GPU_ID = [i for i in range(cfg.NUM_GPUS)][world_rank]
# Fix the random seed for reproducibility.
numpy.random.seed(cfg.RNG_SEED + world_rank)
dragon.random.set_seed(cfg.RNG_SEED)
# Ready to train the network.
logging.info('Checkpoints will be saved to `{:s}`'
.format(coordinator.path_at('checkpoints')))
with dragon.distributed.new_group(
ranks=[i for i in range(cfg.NUM_GPUS)],
verbose=True).as_default():
train_engine.run_train(
coordinator, start_iter,
enable_tensorboard=args.tensorboard)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test a base vision model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
from seetabase.core.config import cfg
from seetabase.core.coordinator import Coordinator
from seetabase.core.testing import test_engine
from seetabase.utils import logging
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(
description='Test an imagenet network.')
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='config file')
parser.add_argument(
'--exp_dir',
default='',
help='experiment dir')
parser.add_argument(
'--model_dir',
default='',
help='model dir')
parser.add_argument(
'--gpu',
nargs='+',
type=int,
default=None,
help='index of GPUs to use')
parser.add_argument(
'--iter',
nargs='+',
type=int,
default=None,
help='iteration step of checkpoints')
parser.add_argument(
'--last',
type=int,
default=1,
help='last N checkpoints')
parser.add_argument(
'--precision',
default='',
help='compute precision')
parser.add_argument(
'--deterministic',
action='store_true',
help='set cudnn deterministic or not')
return parser.parse_args()
def find_weights(args, coordinator):
"""Return the weights for testing."""
weights_list = []
if args.model_dir:
for file in os.listdir(args.model_dir):
if not file.endswith('.pkl'):
continue
weights_list.append(os.path.join(args.model_dir, file))
weights_list.sort()
return weights_list
if args.iter is not None:
for iter in args.iter:
checkpoint, _ = coordinator.get_checkpoint(iter, wait=True)
weights_list.append(checkpoint)
return weights_list
for i in range(1, args.last + 1):
checkpoint, _ = coordinator.get_checkpoint(last_idx=i)
if checkpoint is None:
break
weights_list.append(checkpoint)
return weights_list
if __name__ == '__main__':
args = parse_args()
logging.info('Called with args:\n' + str(args))
coordinator = Coordinator(args.cfg_file, args.exp_dir or args.model_dir)
cfg.MODEL.PRECISION = args.precision or cfg.MODEL.PRECISION
logging.info('Using config:\n' + str(cfg))
# Run testing.
weights_list = find_weights(args, coordinator)
test_engine.run_test(
weights_list=weights_list,
devices=args.gpu,
deterministic=args.deterministic)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Train a base vision model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import dragon
import numpy
from seetabase.core.config import cfg
from seetabase.core.coordinator import Coordinator
from seetabase.core.training import train_engine
from seetabase.utils import logging
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--cfg',
dest='cfg_file',
default=None,
help='optional config file')
parser.add_argument(
'--exp_dir',
default=None,
help='experiment dir')
parser.add_argument(
'--tensorboard',
action='store_true',
help='write metrics to tensorboard or not')
return parser.parse_args()
def run_distributed(args, coordinator):
"""Run distributed training."""
import subprocess
cmd = 'mpirun --allow-run-as-root -n {} --bind-to none '.format(cfg.NUM_GPUS)
cmd += '{} {} '.format(sys.executable, 'distributed/train.py')
cmd += '--cfg {} '.format(os.path.abspath(args.cfg_file))
cmd += '--exp_dir {} '.format(coordinator.exp_dir)
cmd += '--tensorboard ' if args.tensorboard else ''
return subprocess.call(cmd, shell=True)
if __name__ == '__main__':
args = parse_args()
logging.info('Called with args:\n' + str(args))
coordinator = Coordinator(args.cfg_file, args.exp_dir)
logging.info('Using config:\n' + str(cfg))
if cfg.NUM_GPUS > 1:
# Run a distributed task.
run_distributed(args, coordinator)
else:
# Resume training.
checkpoint, start_iter = coordinator.get_checkpoint()
if checkpoint is not None:
cfg.TRAIN.WEIGHTS = checkpoint
# Fix the random seed for reproducibility.
numpy.random.seed(cfg.RNG_SEED)
dragon.random.set_seed(cfg.RNG_SEED)
# Run training.
logging.info('Checkpoints will be saved to `{:s}`'
.format(coordinator.path_at('checkpoints')))
train_engine.run_train(coordinator, start_iter,
enable_tensorboard=args.tensorboard)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!