Commit 01602fc9 by Ting PAN

Add MPS backend

Summary:
This commit implements most kernels and operators for the MPS device.
1 parent 6720373b
Showing with 594 additions and 2966 deletions
project(dragon)
cmake_minimum_required(VERSION 3.0.2)
# ---[ Build Options
option(BUILD_PYTHON "Build Python binding library" ON)
option(BUILD_RUNTIME "Build C++ runtime library" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
# ---[ Library Options
option(USE_CUDA "Use CUDA" ON)
option(USE_CUDNN "Use CUDNN" ON)
option(USE_MPS "Use MPS" OFF)
option(USE_BLAS "Use BLAS" OFF)
option(USE_OPENMP "Use OpenMP" ON)
option(USE_MPI "Use MPI" OFF)
option(USE_NCCL "Use NCCL" OFF)
option(USE_AVX "Use AVX instructions" ON)
option(USE_AVX2 "Use AVX2 instructions" ON)
option(USE_FMA "Use FMA instructions" ON)
option(USE_NATIVE_ARCH "Use all native instructions" OFF)
option(USE_SHARED_LIBS "Use shared libraries" ON)
# ---[ Project Variables
# Set the directory of third party.
if (NOT THIRD_PARTY_DIR)
set(THIRD_PARTY_DIR ${PROJECT_SOURCE_DIR}/third_party)
endif()
# Set the CUDA target architectures.
# If not, common architectures (>= 5.0) will be used.
if (NOT CUDA_ARCH)
set(CUDA_ARCH Common)
endif()
# Set the custom protobuf sdk if necessary.
# If not, "${THIRD_PARTY_DIR}/protobuf" will be used.
# set(PROTOBUF_SDK_ROOT_DIR <sdk_root_dir>)
# Set the protobuf compiler (i.e., protoc) if necessary.
# If not, a compiler in the sdk or environment will be used.
# set(PROTOBUF_PROTOC_EXECUTABLE <executable>)
# Set the python interpreter if necessary.
# If not, a searched interpreter will be used.
# set(PYTHON_EXECUTABLE <executable>)
# ---[ CMake Modules
include(${PROJECT_SOURCE_DIR}/cmake/MiscCheck.cmake)
include(${PROJECT_SOURCE_DIR}/cmake/LinkLibrary.cmake)
include(${PROJECT_SOURCE_DIR}/cmake/StripDebugInfo.cmake)
include(${PROJECT_SOURCE_DIR}/cmake/Dependencies.cmake)
include(${PROJECT_SOURCE_DIR}/cmake/Codegen.cmake)
# ---[ CMake Variables
set(CMAKE_BUILD_TYPE Release CACHE INTERNAL "" FORCE)
set(CMAKE_CONFIGURATION_TYPES Release CACHE INTERNAL "" FORCE)
if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/../targets/native
CACHE INTERNAL "" FORCE)
endif()
if (NOT LIBRARY_INSTALL_PREFIX)
set(LIBRARY_INSTALL_PREFIX "")
endif()
# ---[ Subdirectories
if (BUILD_PYTHON)
add_subdirectory(dragon/modules/python)
endif()
if (BUILD_RUNTIME)
add_subdirectory(dragon/modules/runtime)
endif()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The base layer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
from dragon.core.autograph import context
from dragon.core.framework.tensor import Tensor
from dragon.vm.caffe.core.proto import caffe_pb2
class Layer(object):
"""The abstraction of ``caffe.Layer``."""
def __init__(self, layer_param):
"""Create a ``Layer``.
Parameters
----------
layer_param : LayerParameter
The parameter containing layer arguments.
"""
self._proto = layer_param
self._bottom_names = [name for name in layer_param.bottom]
self._top_names = [name for name in layer_param.top]
self._blobs = []
self._call_layer = None
@property
def blobs(self):
"""Return the blobs."""
return self._blobs
@property
def bottom(self):
"""Return the bottom names."""
return self._bottom_names
@property
def name(self):
"""Return the layer name."""
return self._proto.name
@property
def top(self):
"""Return the top names."""
return self._top_names
def add_blob(self, shape, filler, requires_grad=True):
"""Add a blob into this layer."""
data = Tensor(shape, name='blob%d' % (len(self._blobs) + 1))
if filler.type == 'constant':
data.fill(filler.value)
elif filler.type == 'gaussian':
data.normal(filler.mean, filler.std)
elif filler.type == 'uniform':
data.uniform(filler.min, filler.max)
elif filler.type == 'xavier':
norm_modes = {0: 'fan_in', 1: 'fan_out', 2: 'fan_avg'}
data.glorot_uniform(norm_modes[filler.variance_norm])
elif filler.type == 'msra':
norm_modes = {0: 'fan_in', 1: 'fan_out', 2: 'fan_avg'}
data.glorot_normal(norm_modes[filler.variance_norm])
else:
raise ValueError('Unknown filler type: ' + filler.type)
data.requires_grad = requires_grad
self._blobs.append({'data': data, 'diff': None})
def from_proto(self, proto):
"""Deserialize from the proto.
Parameters
----------
proto : LayerParameter
The ``LayerParameter`` protocol buffer.
"""
for i in range(len(self._blobs)):
if i < len(proto.blobs):
blob_proto = proto.blobs[i]
if len(blob_proto.data) > 0:
value = numpy.array(blob_proto.data, dtype='float32')
elif len(blob_proto.double_data) > 0:
value = numpy.array(blob_proto.double_data, dtype='float64')
else:
raise ValueError('Neither <data> or <double_data> in blob proto.')
if len(blob_proto.shape.dim) > 0:
value = value.reshape([dim for dim in blob_proto.shape.dim])
self._blobs[i]['data']._impl.FromNumpy(value, False)
def setup(self, bottom):
"""Setup the layer."""
bottom = bottom[0] if len(bottom) == 1 else bottom
with context.graph_mode():
call_layer = self._call_layer or self
return call_layer.__call__(bottom)
def to_proto(self):
"""Serialize to the proto.
Returns
-------
LayerParameter
The ``LayerParameter`` protocol buffer.
"""
proto = caffe_pb2.LayerParameter()
proto.CopyFrom(self._proto)
for blob in self._blobs:
value = blob['data'].numpy()
if str(value.dtype) == 'float32':
blob_proto = caffe_pb2.BlobProto(
data=value.flatten(),
shape=caffe_pb2.BlobShape(dim=value.shape))
elif str(value.dtype) == 'float64':
blob_proto = caffe_pb2.BlobProto(
double_data=value.flatten(),
shape=caffe_pb2.BlobShape(dim=value.shape))
else:
raise ValueError('Either float32 or float64 blob is required.')
proto.blobs.extend([blob_proto])
return proto
def __call__(self, bottom):
"""Define the forward pipeline."""
raise NotImplementedError
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.caffe.core.layers.common import Accuracy
from dragon.vm.caffe.core.layers.common import ArgMax
from dragon.vm.caffe.core.layers.common import BatchNorm
from dragon.vm.caffe.core.layers.common import Concat
from dragon.vm.caffe.core.layers.common import Crop
from dragon.vm.caffe.core.layers.common import Eltwise
from dragon.vm.caffe.core.layers.common import Flatten
from dragon.vm.caffe.core.layers.common import InnerProduct
from dragon.vm.caffe.core.layers.common import Input
from dragon.vm.caffe.core.layers.common import Normalize
from dragon.vm.caffe.core.layers.common import Permute
from dragon.vm.caffe.core.layers.common import Python
from dragon.vm.caffe.core.layers.common import Reduction
from dragon.vm.caffe.core.layers.common import Reshape
from dragon.vm.caffe.core.layers.common import Scale
from dragon.vm.caffe.core.layers.common import Slice
from dragon.vm.caffe.core.layers.common import Softmax
from dragon.vm.caffe.core.layers.common import Tile
from dragon.vm.caffe.core.layers.data import Data
from dragon.vm.caffe.core.layers.loss import EuclideanLoss
from dragon.vm.caffe.core.layers.loss import SigmoidCrossEntropyLoss
from dragon.vm.caffe.core.layers.loss import SmoothL1Loss
from dragon.vm.caffe.core.layers.loss import SoftmaxWithLoss
from dragon.vm.caffe.core.layers.neuron import Dropout
from dragon.vm.caffe.core.layers.neuron import ELU
from dragon.vm.caffe.core.layers.neuron import Power
from dragon.vm.caffe.core.layers.neuron import PReLU
from dragon.vm.caffe.core.layers.neuron import ReLU
from dragon.vm.caffe.core.layers.neuron import Sigmoid
from dragon.vm.caffe.core.layers.neuron import TanH
from dragon.vm.caffe.core.layers.vision import Convolution
from dragon.vm.caffe.core.layers.vision import Deconvolution
from dragon.vm.caffe.core.layers.vision import LRN
from dragon.vm.caffe.core.layers.vision import Pooling
__all__ = [_s for _s in dir() if not _s.startswith('_')]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Data layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.framework import workspace
from dragon.core.io.kpl_record import KPLRecordDataset
from dragon.core.ops import framework_ops
from dragon.core.ops import normalization_ops
from dragon.utils import vision
from dragon.vm.caffe.core.layer import Layer
class _DataPlugin(object):
"""Embedded plugin for data layer."""
def setup(self, inputs, outputs):
kwargs = eval(self.kwargs_str)
default_ws = workspace.get_workspace()
self.outputs = [default_ws.get_tensor(output) for output in outputs]
self.iterator = vision.DataIterator(dataset=KPLRecordDataset, **kwargs)
def forward(self, inputs, outputs):
blobs = self.iterator.next()
for i, blob in enumerate(blobs):
self.outputs[i].FromNumpy(blob)
class Data(Layer):
r"""Load batch of data for image classification.
Examples:
```python
layer {
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
data_param {
source: "/data/train"
batch_size: 128
prefetch: 4
}
image_data_param {
shuffle: true
}
transform_param {
mirror: true
crop_size: 224
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
}
}
layer {
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
data_param {
source: "/data/val"
batch_size: 64
}
transform_param {
crop_size: 224
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
}
}
```
"""
def __init__(self, layer_param):
super(Data, self).__init__(layer_param)
data_param = layer_param.data_param
image_data_param = layer_param.image_data_param
transform_param = layer_param.transform_param
self.data_args = {
'source': data_param.source,
'batch_size': data_param.batch_size,
'prefetch_depth': data_param.prefetch,
'shuffle': image_data_param.shuffle,
'training': {0: True, 1: False}[int(layer_param.phase)],
'crop_size': transform_param.crop_size,
'mirror': transform_param.mirror,
}
self.norm_args = {
'axis': 1,
'perm': (0, 3, 1, 2),
'mean': [e for e in transform_param.mean_value],
'std': [1. / transform_param.scale] * len(transform_param.mean_value),
'dtype': 'float32',
}
def __call__(self, bottom):
args = {
'module_name': __name__,
'class_name': '_DataPlugin',
'kwargs_str': str(self.data_args),
'num_outputs': 2,
}
data, label = framework_ops.python_plugin([], **args)
data._shape = (self.data_args['batch_size'],
None, None, len(self.norm_args['mean']))
label._shape = (self.data_args['batch_size'], None)
data = normalization_ops.channel_norm(data, **self.norm_args)
return data, label
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Loss layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import loss_ops
from dragon.vm.caffe.core.layer import Layer
class EuclideanLoss(Layer):
r"""Compute the element-wise squared error.
The ``EuclideanLoss`` function is defined as:
.. math:: \text{L2Loss}(x, y) = 0.5(x - y)^{2}
Examples:
```python
layer {
type: "EuclideanLoss"
bottom: "bbox_pred"
bottom: "bbox_target"
top: "bbox_loss"
loss_param {
normalization: BATCH_SIZE
}
}
```
"""
def __init__(self, layer_param):
super(EuclideanLoss, self).__init__(layer_param)
param = layer_param.loss_param
norm_dict = {0: 'mean', 1: 'mean', 2: 'batch_mean', 3: 'sum'}
reduction = 'batch_mean'
if param.HasField('normalize'):
if param.normalize:
reduction = 'mean'
else:
reduction = norm_dict[param.normalization]
self.call_args = {'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.l2_loss(bottom, **self.call_args)
loss_weight = 1. if self.loss_weight is None else self.loss_weight
return loss * (loss_weight * 0.5)
class SigmoidCrossEntropyLoss(Layer):
"""Compute the loss of sigmoid cross entropy.
Examples:
```python
layer {
type: "SigmoidCrossEntropyLoss"
bottom: "rpn_cls_score"
bottom: "rpn_labels"
top: "rpn_loss"
loss_param {
normalization: VALID
}
}
```
"""
def __init__(self, layer_param):
super(SigmoidCrossEntropyLoss, self).__init__(layer_param)
param = layer_param.loss_param
norm_dict = {0: 'mean', 1: 'valid', 2: 'batch_mean', 3: 'sum'}
reduction = 'valid'
if param.HasField('normalize'):
if not param.normalize:
reduction = 'batch_mean'
else:
reduction = norm_dict[param.normalization]
self.call_args = {'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.sigmoid_cross_entropy_loss(bottom, **self.call_args)
if self.loss_weight != 1:
loss *= self.loss_weight
return loss
class SmoothL1Loss(Layer):
r"""Compute the element-wise error transited from L1 and L2.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
Examples:
```python
layer {
type: "SmoothL1Loss"
bottom: "bbox_pred"
bottom: "bbox_targets"
bottom: "bbox_inside_weights"
bottom: "bbox_outside_weights"
top: "bbox_loss"
loss_param {
normalization: BATCH_SIZE
}
}
```
"""
def __init__(self, layer_param):
super(SmoothL1Loss, self).__init__(layer_param)
param = layer_param.loss_param
smooth_l1_param = layer_param.smooth_l1_loss_param
norm_dict = {0: 'mean', 1: 'mean', 2: 'batch_mean', 3: 'sum'}
reduction = 'batch_mean'
if param.HasField('normalize'):
if param.normalize:
reduction = 'mean'
else:
reduction = norm_dict[param.normalization]
sigma2 = smooth_l1_param.sigma * smooth_l1_param.sigma
self.call_args = {'beta': float(1. / sigma2), 'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.smooth_l1_loss(bottom, **self.call_args)
if self.loss_weight != 1:
loss *= self.loss_weight
return loss
class SoftmaxWithLoss(Layer):
"""Compute the loss of softmax cross entropy.
Examples:
```python
layer {
type: "SoftmaxWithLoss"
bottom: "cls_score"
bottom: "labels"
top: "cls_loss"
softmax_param {
axis: 1
}
loss_param {
ignore_label: -1
normalization: VALID
}
}
```
"""
def __init__(self, layer_param):
super(SoftmaxWithLoss, self).__init__(layer_param)
param = layer_param.loss_param
softmax_param = layer_param.softmax_param
norm_dict = {0: 'mean', 1: 'valid', 2: 'batch_mean', 3: 'sum'}
reduction = 'valid'
if param.HasField('normalize'):
if not param.normalize:
reduction = 'batch_mean'
else:
reduction = norm_dict[param.normalization]
self.call_args = {
'axis': softmax_param.axis,
'reduction': reduction,
'ignore_index': param.ignore_label
if param.HasField('ignore_label') else None,
}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.softmax_cross_entropy_loss(bottom, **self.call_args)
if self.loss_weight != 1:
loss *= self.loss_weight
return loss
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Neuron layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import activation_ops
from dragon.core.ops import math_ops
from dragon.vm.caffe.core.layer import Layer
from dragon.vm.caffe.core.proto import caffe_pb2
class Dropout(Layer):
r"""Set the elements of the input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
The **Dropout** function is defined as:
.. math:: \text{Dropout}(x) = x * \text{Bernoulli}(p=1 - prob)
Examples:
```python
layer {
type: "Dropout"
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
}
}
```
"""
def __init__(self, layer_param):
super(Dropout, self).__init__(layer_param)
param = layer_param.dropout_param
if not param.scale_train:
raise ValueError('Unscaled dropout is not supported.')
self.call_args = {'ratio': param.dropout_ratio}
def __call__(self, bottom):
return activation_ops.dropout(bottom, **self.call_args)
class ELU(Layer):
r"""Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
The **ELU** function is defined as:
.. math::
\text{ELU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
\alpha * (\exp(x) - 1), & \text{ otherwise }
\end{cases}
Examples:
```python
layer {
type: "ELU"
bottom: "conv2"
top: "conv2"
elu_param {
alpha: 1.
}
}
```
"""
def __init__(self, layer_param):
super(ELU, self).__init__(layer_param)
self.call_args = {'alpha': float(layer_param.elu_param.alpha)}
def __call__(self, bottom):
return activation_ops.elu(bottom, **self.call_args)
class Power(Layer):
r"""Compute the power of input.
.. math:: y = (scale * x + shift)^{power}
Examples:
```python
layer {
type: "Power"
bottom: "x"
top: "y"
power_param {
scale: 1.
shift: 0.
power: 2.
}
}
```
"""
def __init__(self, layer_param):
super(Power, self).__init__(layer_param)
param = layer_param.power_param
self.scale = param.scale
self.shift = param.shift
self.power = param.power
def __call__(self, bottom):
if self.scale != 1:
bottom = bottom * self.scale
if self.shift != 0:
bottom = bottom + self.shift
return math_ops.pow([bottom, self.power])
class PReLU(Layer):
r"""Apply the parametric rectified linear unit.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
The **PReLU** function is defined as:
.. math::
\text{PReLU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
weight * x, & \text{ otherwise }
\end{cases}
Examples:
Examples:
```python
layer {
type: "PReLU"
bottom: "conv2"
top: "conv2/relu"
prelu_param {
channel_shared: false
filler {
type: "constant"
value: 0.25
}
}
}
```
"""
def __init__(self, layer_param):
super(PReLU, self).__init__(layer_param)
param = layer_param.prelu_param
self.filler = caffe_pb2.FillerParameter(type='constant', value=0.25)
self.filler = param.filler if param.HasField('filler') else self.filler
self.channel_shared = param.channel_shared
def build(self, bottom):
if self.channel_shared:
weight_shape = [1]
elif len(bottom.shape) > 1:
weight_shape = [bottom.shape[1]]
else:
weight_shape = [bottom.shape[0]]
self.add_blob(weight_shape, self.filler)
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs]
return activation_ops.prelu(inputs)
class ReLU(Layer):
r"""Apply the rectified linear unit.
`[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
The **ReLU** function is defined as:
.. math::
\text{ReLU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
0, & \text{ otherwise }
\end{cases}
Examples:
```python
layer {
type: "ReLU"
bottom: "conv2"
top: "conv2/relu"
relu_param {
negative_slope: 0.
}
}
```
"""
def __init__(self, layer_param):
super(ReLU, self).__init__(layer_param)
param = layer_param.relu_param
self.negative_slope = param.negative_slope
def __call__(self, bottom):
if self.negative_slope > 0:
return activation_ops.leaky_relu(bottom, self.negative_slope)
return activation_ops.relu(bottom)
class Sigmoid(Layer):
r"""Apply the sigmoid function.
The **Sigmoid** function is defined as:
.. math:: \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}
Examples:
```python
layer {
type: "Sigmoid"
bottom: "rpn_cls_score"
top: "rpn_cls_prob"
}
```
"""
def __init__(self, layer_param):
super(Sigmoid, self).__init__(layer_param)
def __call__(self, bottom):
return activation_ops.sigmoid(bottom)
class TanH(Layer):
r"""Apply the tanh function.
The **Tanh** function is defined as:
.. math:: \text{Tanh}(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}
Examples:
```python
layer {
type: "TanH"
bottom: "g/conv5"
top: "g/image"
}
```
"""
def __init__(self, layer_param):
super(TanH, self).__init__(layer_param)
def __call__(self, bottom):
return activation_ops.tanh(bottom)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Vision layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import normalization_ops
from dragon.core.ops import vision_ops
from dragon.vm.caffe.core.layer import Layer
class Convolution(Layer):
r"""Apply the n-dimension convolution.
Examples:
```python
layer {
type: "Convolution"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
bias_term: true
kernel_size: 3
pad: 1
stride: 1
dilation: 1
group: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
}
}
```
"""
def __init__(self, layer_param):
super(Convolution, self).__init__(layer_param)
param = layer_param.convolution_param
self.kernel_shape = param.kernel_size or [1]
self.strides = param.stride or [1]
self.pads = param.pad or [0]
self.dilations = param.dilation or [1]
self.out_channels = param.num_output
self.weight_filler = param.weight_filler
self.bias_filler = param.bias_filler
self.bias_term = param.bias_term
self.call_args = {'group': param.group}
def build(self, bottom):
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
in_channels = bottom.shape[1]
for k in ('kernel_shape', 'strides', 'pads', 'dilations'):
self.call_args[k] = [int(d) for d in getattr(self, k)]
if len(self.call_args[k]) < num_axes:
reps = num_axes - len(self.call_args[k])
self.call_args[k] += [self.call_args[k][-1]] * reps
weight_shape = [self.out_channels, in_channels] + self.call_args['kernel_shape']
self.add_blob(weight_shape, self.weight_filler)
if self.bias_term:
self.add_blob([self.out_channels], self.bias_filler)
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs]
conv_op = 'conv{}d'.format(len(self.call_args['kernel_shape']))
return getattr(vision_ops, conv_op)(inputs, **self.call_args)
class Deconvolution(Convolution):
r"""Apply the n-dimension deconvolution.
Examples:
```python
layer {
type: "Deconvolution"
bottom: "conv5"
top: "conv5/upscale"
convolution_param {
num_output: 256
bias_term: true
kernel_size: 2
pad: 0
stride: 2
dilation: 1
group: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
}
}
```
"""
def __init__(self, layer_param):
super(Deconvolution, self).__init__(layer_param)
def build(self, bottom):
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
in_channels = bottom.shape[1]
for k in ('kernel_shape', 'strides', 'pads', 'dilations'):
self.call_args[k] = [int(d) for d in getattr(self, k)]
if len(self.call_args[k]) < num_axes:
reps = num_axes - len(self.call_args[k])
self.call_args[k] += [self.call_args[k][-1]] * reps
weight_shape = [in_channels, self.out_channels] + self.call_args['kernel_shape']
self.add_blob(weight_shape, self.weight_filler)
if self.bias_term:
self.add_blob([self.out_channels], self.bias_filler)
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs]
conv_op = 'conv{}d_transpose'.format(len(self.call_args['kernel_shape']))
return getattr(vision_ops, conv_op)(inputs, **self.call_args)
class LRN(Layer):
r"""Apply the local response normalization.
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
Examples:
```python
layer {
type: "LRN"
bottom: "conv2"
top: "conv2/norm"
lrn_param {
local_size: 5
alpha: 1.
beta: 0.75
k: 1.
}
}
```
"""
def __init__(self, layer_param):
super(LRN, self).__init__(layer_param)
param = layer_param.lrn_param
if param.norm_region > 0:
raise NotImplementedError('<WITHIN_CHANNEL> mode is not implemented.')
self.op_args = {'size': param.local_size,
'alpha': param.alpha,
'beta': param.beta,
'bias': param.k}
def __call__(self, bottom):
return normalization_ops.local_response_norm(bottom, **self.op_args)
class Pooling(Layer):
r"""Apply the n-dimension pooling.
Examples:
```python
layer {
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
kernel_size: 3
stride: 2
pool: AVG
}
}
```
"""
def __init__(self, layer_param):
super(Pooling, self).__init__(layer_param)
param = layer_param.pooling_param
self.kernel_shape = [param.kernel_size or 1]
self.strides = [param.stride or 1]
self.pads = [param.pad or 0]
self.call_args = {
'ceil_mode': True,
'mode': {0: 'MAX', 1: 'AVG'}[param.pool],
'global_pool': param.global_pooling,
}
def __call__(self, bottom):
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
call_args = self.call_args.copy()
for k in ('kernel_shape', 'strides', 'pads'):
call_args[k] = [int(d) for d in getattr(self, k)]
if len(call_args[k]) < num_axes:
reps = num_axes - len(call_args[k])
call_args[k] += [call_args[k][-1]] * reps
pool_op = 'pool{}d'.format(num_axes)
return getattr(vision_ops, pool_op)(bottom, **call_args)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/BVLC/caffe/blob/master/python/caffe/net_spec.py>
#
# ------------------------------------------------------------
"""Net proto maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from dragon.vm.caffe.core.proto import caffe_pb2
def param_name_dict():
"""Find out the correspondence between layer names and parameter names."""
layer = caffe_pb2.LayerParameter()
# Get all parameter names (typically underscore case) and corresponding
# type names (typically camel case), which contain the layer names
# (note that not all parameters correspond to layers, but we'll ignore that).
param_names = [f.name for f in layer.DESCRIPTOR.fields if f.name.endswith('_param')]
param_type_names = [type(getattr(layer, s)).__name__ for s in param_names]
# Strip the final '_param' or 'Parameter'.
param_names = [s[:-len('_param')] for s in param_names]
param_type_names = [s[:-len('Parameter')] for s in param_type_names]
return dict(zip(param_type_names, param_names))
def to_proto(*tops):
"""Generate a NetParameter that contains all layers needed to compute
all arguments."""
layers = collections.OrderedDict()
autonames = collections.Counter()
for top in tops:
top.fn._to_proto(layers, {}, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net
def assign_proto(proto, name, val):
"""Assign a Python object to a protobuf message, based on the Python
type (in recursive fashion). Lists become repeated fields/messages, dicts
become messages, and other types are assigned directly. For convenience,
repeated fields whose values are not lists are converted to single-element
lists; e.g., `my_repeated_int_field=3` is converted to
`my_repeated_int_field=[3]`."""
is_repeated_field = hasattr(getattr(proto, name), 'extend')
if is_repeated_field and not isinstance(val, list):
val = [val]
if isinstance(val, list):
if isinstance(val[0], dict):
for item in val:
proto_item = getattr(proto, name).add()
for k, v in item.items():
assign_proto(proto_item, k, v)
else:
getattr(proto, name).extend(val)
elif isinstance(val, dict):
for k, v in val.items():
assign_proto(getattr(proto, name), k, v)
else:
setattr(proto, name, val)
class Top(object):
"""A Top specifies a single output blob (which could be one of several
produced by a layer.)"""
def __init__(self, fn, n):
self.fn = fn
self.n = n
def to_proto(self):
"""Generate a NetParameter that contains all layers needed to compute
this top."""
return to_proto(self)
def _update(self, params):
self.fn._update(params)
def _to_proto(self, layers, names, autonames):
return self.fn._to_proto(layers, names, autonames)
class Function(object):
"""A Function specifies a layer, its parameters, and its inputs (which
are Tops from other layers)."""
def __init__(self, type_name, inputs, params):
self.type_name = type_name
self.inputs = inputs
self.params = params
self.ntop = self.params.get('ntop', 1)
# Use del to make sure kwargs are not double-processed as layer params.
if 'ntop' in self.params:
del self.params['ntop']
self.in_place = self.params.get('in_place', False)
if 'in_place' in self.params:
del self.params['in_place']
self.tops = tuple(Top(self, n) for n in range(self.ntop))
def _get_name(self, names, autonames):
if self not in names and self.ntop > 0:
names[self] = self._get_top_name(self.tops[0], names, autonames)
elif self not in names:
autonames[self.type_name] += 1
names[self] = self.type_name + str(autonames[self.type_name])
return names[self]
def _get_top_name(self, top, names, autonames):
if top not in names:
autonames[top.fn.type_name] += 1
names[top] = top.fn.type_name + str(autonames[top.fn.type_name])
return names[top]
def _update(self, params):
self.params.update(params)
def _to_proto(self, layers, names, autonames):
if self in layers:
return
bottom_names = []
for inp in self.inputs:
inp._to_proto(layers, names, autonames)
bottom_names.append(layers[inp.fn].top[inp.n])
layer = caffe_pb2.LayerParameter()
layer.type = self.type_name
layer.bottom.extend(bottom_names)
if self.in_place:
layer.top.extend(layer.bottom)
else:
for top in self.tops:
layer.top.append(self._get_top_name(top, names, autonames))
layer.name = self._get_name(names, autonames)
for k, v in self.params.items():
# special case to handle generic *params
if k.endswith('param'):
assign_proto(layer, k, v)
else:
try:
assign_proto(getattr(
layer, _param_names[self.type_name] + '_param'), k, v)
except (AttributeError, KeyError):
assign_proto(layer, k, v)
layers[self] = layer
class NetSpec(object):
"""A NetSpec contains a set of Tops (assigned directly as attributes).
Calling NetSpec.to_proto generates a NetParameter containing all of the
layers needed to produce all of the assigned Tops, using the assigned
names."""
def __init__(self):
super(NetSpec, self).__setattr__('tops', collections.OrderedDict())
def __setattr__(self, name, value):
self.tops[name] = value
def __getattr__(self, name):
return self.tops[name]
def __setitem__(self, key, value):
self.__setattr__(key, value)
def __getitem__(self, item):
return self.__getattr__(item)
def __delitem__(self, name):
del self.tops[name]
def keys(self):
keys = [k for k, v in self.tops.items()]
return keys
def vals(self):
vals = [v for k, v in self.tops.items()]
return vals
def update(self, name, params):
self.tops[name]._update(params)
def to_proto(self):
names = {v: k for k, v in self.tops.items()}
autonames = collections.Counter()
layers = collections.OrderedDict()
for name, top in self.tops.items():
top._to_proto(layers, names, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net
class Layers(object):
"""A Layers object is a pseudo-module which generates ops that specify
layers; e.g., Layers().Convolution(bottom, kernel_size=3) will produce a Top
specifying a 3x3 convolution applied to bottom."""
def __getattr__(self, name):
def layer_fn(*args, **kwargs):
fn = Function(name, args, kwargs)
if fn.ntop == 0:
return fn
elif fn.ntop == 1:
return fn.tops[0]
else:
return fn.tops
return layer_fn
class Parameters(object):
"""A Parameters object is a pseudo-module which generates constants used
in layer parameters; e.g., Parameters().Pooling.MAX is the value used
to specify max pooling."""
def __getattr__(self, name):
class Param:
def __getattr__(self, param_name):
return getattr(getattr(caffe_pb2, name + 'Parameter'), param_name)
return Param()
_param_names = param_name_dict()
layers = Layers()
params = Parameters()
# ---[ Protobuf # ---[ Protobuf
file(GLOB PROTO_FILES ${PROJECT_SOURCE_DIR}/proto/*.proto) file(GLOB PROTO_FILES ${PROJECT_SOURCE_DIR}/dragon/proto/*.proto)
protobuf_generate_cpp(${PROTO_FILES}) protobuf_generate_cpp(${PROTO_FILES})
# ---[ Runtime if (BUILD_PYTHON)
if (PYTHON_EXECUTABLE AND BUILD_RUNTIME) file(GLOB_RECURSE PROTO_FILES ${PROJECT_SOURCE_DIR}/dragon/python/*.proto)
set(HAS_RUNTIME_CODEGEN ON) protobuf_generate_python(${PROTO_FILES})
execute_process(
COMMAND
${PYTHON_EXECUTABLE}
${PROJECT_SOURCE_DIR}/../tools/codegen_runtime.py
${PROJECT_SOURCE_DIR} "REMOVE_GRADIENT")
else()
set(HAS_RUNTIME_CODEGEN OFF)
endif() endif()
...@@ -7,19 +7,28 @@ else() ...@@ -7,19 +7,28 @@ else()
endif() endif()
# ---[ Packages # ---[ Packages
include(${PROJECT_SOURCE_DIR}/../cmake/FindProtobuf.cmake) include(${PROJECT_SOURCE_DIR}/cmake/FindProtobuf.cmake)
if (BUILD_PYTHON) if (BUILD_PYTHON)
include(${PROJECT_SOURCE_DIR}/../cmake/FindPythonLibs.cmake) include(${PROJECT_SOURCE_DIR}/cmake/FindPythonLibs.cmake)
include(${PROJECT_SOURCE_DIR}/../cmake/FindNumPy.cmake) include(${PROJECT_SOURCE_DIR}/cmake/FindNumPy.cmake)
endif()
if (USE_BLAS)
include(${PROJECT_SOURCE_DIR}/cmake/FindBLAS.cmake)
endif()
if (USE_OPENMP)
include(${PROJECT_SOURCE_DIR}/cmake/FindOpenMP.cmake)
endif()
if (USE_MPI)
include(${PROJECT_SOURCE_DIR}/cmake/FindMPI.cmake)
endif() endif()
if (USE_CUDA) if (USE_CUDA)
include(${PROJECT_SOURCE_DIR}/../cmake/FindCUDA.cmake) include(${PROJECT_SOURCE_DIR}/cmake/FindCUDA.cmake)
endif() endif()
if (USE_CUDNN) if (USE_CUDNN)
include(${PROJECT_SOURCE_DIR}/../cmake/FindCUDNN.cmake) include(${PROJECT_SOURCE_DIR}/cmake/FindCUDNN.cmake)
endif() endif()
if (USE_MPI) if (USE_MPS)
include(${PROJECT_SOURCE_DIR}/../cmake/FindMPI.cmake) include(${PROJECT_SOURCE_DIR}/cmake/FindMPS.cmake)
endif() endif()
if (USE_TENSORRT) if (USE_TENSORRT)
if (NOT TENSORRT_SDK_ROOT_DIR) if (NOT TENSORRT_SDK_ROOT_DIR)
...@@ -28,7 +37,7 @@ if (USE_TENSORRT) ...@@ -28,7 +37,7 @@ if (USE_TENSORRT)
endif() endif()
# ---[ Directories # ---[ Directories
include_directories(${PROJECT_SOURCE_DIR}/../) include_directories(${PROJECT_SOURCE_DIR})
include_directories(${THIRD_PARTY_DIR}/eigen) include_directories(${THIRD_PARTY_DIR}/eigen)
include_directories(${PROTOBUF_SDK_ROOT_DIR}/include) include_directories(${PROTOBUF_SDK_ROOT_DIR}/include)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib)
...@@ -46,6 +55,9 @@ endif() ...@@ -46,6 +55,9 @@ endif()
if (USE_CUDNN) if (USE_CUDNN)
include_directories(${CUDNN_INCLUDE_DIR}) include_directories(${CUDNN_INCLUDE_DIR})
endif() endif()
if (USE_OPENMP)
include_directories(${OPENMP_INCLUDE_DIR})
endif()
if (USE_MPI) if (USE_MPI)
include_directories(${MPI_INCLUDE_DIR}) include_directories(${MPI_INCLUDE_DIR})
endif() endif()
...@@ -75,6 +87,14 @@ if (USE_CUDNN) ...@@ -75,6 +87,14 @@ if (USE_CUDNN)
add_definitions(-DUSE_CUDNN) add_definitions(-DUSE_CUDNN)
message(STATUS "Use CUDNN.") message(STATUS "Use CUDNN.")
endif() endif()
if (USE_MPS)
add_definitions(-DUSE_MPS)
message(STATUS "Use MPS.")
endif()
if (USE_BLAS)
add_definitions(-DEIGEN_USE_BLAS)
message(STATUS "Use BLAS.")
endif()
if (USE_OPENMP) if (USE_OPENMP)
add_definitions(-DUSE_OPENMP) add_definitions(-DUSE_OPENMP)
message(STATUS "Use OpenMP.") message(STATUS "Use OpenMP.")
......
# Distributed under the OSI-approved BSD 3-Clause License. See accompanying
# file Copyright.txt or https://cmake.org/licensing for details.
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License.
# - Find the BLAS libraries
#
# Following variables can be set and are optional:
#
# BLAS_VENDOR - search for the specific BLAS vendor
# BLAS_FOUND_VENDOR - vendor implementing the BLAS interface is found
# BLAS_LIBRARIES - path to the BLAS library
#
if(CMAKE_Fortran_COMPILER_LOADED)
include(CheckFortranFunctionExists)
else()
include(CheckFunctionExists)
endif()
if(NOT $ENV{BLAS_VENDOR} STREQUAL "")
set(BLAS_VENDOR $ENV{BLAS_VENDOR})
else()
if(NOT BLAS_VENDOR)
set(BLAS_VENDOR "All")
endif()
endif()
set(BLAS_FOUND_VENDOR "")
function(CHECK_BLAS_LIBRARIES LIBRARIES _prefix _name _flags _list _deps _addlibdir _subdirs)
# This function checks for the existence of the combination of libraries
# given by _list. If the combination is found, this checks whether can link
# against that library combination using the name of a routine given by _name
# using the linker flags given by _flags. If the combination of libraries is
# found and passes the link test, ${LIBRARIES} is set to the list of complete
# library paths that have been found. Otherwise, ${LIBRARIES} is set to FALSE.
set(_libraries_work TRUE)
set(_libraries)
set(_combined_name)
if(NOT USE_SHARED_LIBS)
if(WIN32)
set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES})
else()
set(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES})
endif()
else()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# for ubuntu's libblas3gf and liblapack3gf packages
set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES} .so.3gf)
endif()
endif()
set(_extaddlibdir "${_addlibdir}")
if(WIN32)
list(APPEND _extaddlibdir ENV LIB)
elseif(APPLE)
list(APPEND _extaddlibdir ENV DYLD_LIBRARY_PATH)
else()
list(APPEND _extaddlibdir ENV LD_LIBRARY_PATH)
endif()
list(APPEND _extaddlibdir "${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}")
foreach(_library ${_list})
if(_library MATCHES "^-")
# Respect linker flags as-is (required by MKL)
list(APPEND _libraries "${_library}")
else()
string(REGEX REPLACE "[^A-Za-z0-9]" "_" _lib_var "${_library}")
string(APPEND _combined_name "_${_lib_var}")
if(NOT "${_deps}" STREQUAL "")
string(APPEND _combined_name "_deps")
endif()
if(_libraries_work)
find_library(${_prefix}_${_lib_var}_LIBRARY
NAMES ${_library}
NAMES_PER_DIR
PATHS ${_extaddlibdir}
PATH_SUFFIXES ${_subdirs})
mark_as_advanced(${_prefix}_${_lib_var}_LIBRARY)
list(APPEND _libraries ${${_prefix}_${_lib_var}_LIBRARY})
set(_libraries_work ${${_prefix}_${_lib_var}_LIBRARY})
endif()
endif()
endforeach()
foreach(_flag ${_flags})
string(REGEX REPLACE "[^A-Za-z0-9]" "_" _flag_var "${_flag}")
string(APPEND _combined_name "_${_flag_var}")
endforeach()
if(_libraries_work AND USE_SHARED_LIBS)
# Test this combination of libraries.
set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${_libraries} ${_deps})
set(CMAKE_REQUIRED_QUIET ${BLAS_FIND_QUIETLY})
if(CMAKE_Fortran_COMPILER_LOADED)
check_fortran_function_exists("${_name}" ${_prefix}${_combined_name}_WORKS)
else()
check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS)
endif()
set(CMAKE_REQUIRED_LIBRARIES)
set(_libraries_work ${${_prefix}${_combined_name}_WORKS})
endif()
if(_libraries_work)
if("${_list}" STREQUAL "")
set(_libraries "${LIBRARIES}-PLACEHOLDER-FOR-EMPTY-LIBRARIES")
else()
list(APPEND _libraries ${_deps})
endif()
else()
set(_libraries FALSE)
endif()
set(${LIBRARIES} "${_libraries}" PARENT_SCOPE)
endfunction()
# OpenBLAS? (http://www.openblas.net)
if(BLAS_VENDOR STREQUAL "OpenBLAS" OR BLAS_VENDOR STREQUAL "All")
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "openblas" "" "" "")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "OpenBLAS")
endif()
endif()
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "openblas;pthread;m" "" "" "")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "OpenBLAS")
endif()
endif()
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "openblas;pthread;m;gomp" "" "" "")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "OpenBLAS")
endif()
endif()
endif()
# Apple BLAS library?
if(BLAS_VENDOR STREQUAL "Apple" OR BLAS_VENDOR STREQUAL "All")
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "Accelerate" "" "" "")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "Apple")
endif()
endif()
endif()
# BLAS in acml library?
if(BLAS_VENDOR MATCHES "ACML" OR BLAS_VENDOR STREQUAL "All")
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "acml;gfortran")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "ACML")
endif()
endif()
endif()
# BLAS in the ATLAS library? (http://math-atlas.sourceforge.net/)
if(BLAS_VENDOR STREQUAL "ATLAS" OR BLAS_VENDOR STREQUAL "All")
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "ptf77blas;atlas;gfortran" "" "" "")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "ATLAS")
endif()
endif()
endif()
# Generic BLAS library?
if(BLAS_VENDOR STREQUAL "Generic" OR BLAS_VENDOR STREQUAL "All")
if(NOT BLAS_LIBRARIES)
check_blas_libraries(BLAS_LIBRARIES BLAS sgemm "" "blas" "" "" "")
if(BLAS_LIBRARIES)
set(BLAS_FOUND_VENDOR "Generic")
endif()
endif()
endif()
# Check libraries.
if (NOT BLAS_FOUND_VENDOR)
message(FATAL_ERROR "Search for BLAS vendor: ${BLAS_VENDOR}. Not found.")
else()
message(STATUS "Found BLAS: ${BLAS_LIBRARIES} (found vendor: ${BLAS_FOUND_VENDOR})")
endif()
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# #
find_package(CUDA REQUIRED) find_package(CUDA REQUIRED)
include(${PROJECT_SOURCE_DIR}/../cmake/SelectCudaArch.cmake) include(${PROJECT_SOURCE_DIR}/cmake/SelectCudaArch.cmake)
# Set NVCC flags. # Set NVCC flags.
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH ${CUDA_ARCH}) CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH ${CUDA_ARCH})
......
...@@ -58,6 +58,12 @@ elseif (EXISTS "${CUDNN_INCLUDE_DIR}/../lib") ...@@ -58,6 +58,12 @@ elseif (EXISTS "${CUDNN_INCLUDE_DIR}/../lib")
endif() endif()
set(CUDNN_LIBRARIES_SHARED cudnn) set(CUDNN_LIBRARIES_SHARED cudnn)
set(CUDNN_LIBRARIES_STATIC cudnn_static) set(CUDNN_LIBRARIES_STATIC cudnn_static)
if (CUDNN_VERSION VERSION_GREATER "7.6.5")
set(CUDNN_LIBRARIES_SHARED ${CUDNN_LIBRARIES_SHARED}
cudnn_adv_infer cudnn_adv_train
cudnn_cnn_infer cudnn_cnn_train
cudnn_ops_infer cudnn_ops_train)
endif()
if (CUDNN_VERSION VERSION_GREATER "8.2.4") if (CUDNN_VERSION VERSION_GREATER "8.2.4")
set(CUDNN_LIBRARIES_STATIC cudnn_adv_infer_static cudnn_adv_train_static set(CUDNN_LIBRARIES_STATIC cudnn_adv_infer_static cudnn_adv_train_static
cudnn_cnn_infer_static cudnn_cnn_train_static cudnn_cnn_infer_static cudnn_cnn_train_static
......
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd. # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License. # Licensed under the BSD 2-Clause License.
# - Find the MPI libraries
#
# Following variables can be set and are optional: # Following variables can be set and are optional:
# #
# MPI_INCLUDE_DIR - path to the MPI headers # MPI_INCLUDE_DIR - path to the MPI headers
......
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License.
# - Find the MPS libraries
#
# Following variables can be set and are optional:
#
# FRAMEWORK_FOUNDATION - path to the Foundation.framework
# FRAMEWORK_METAL - path to the Metal.framework
# FRAMEWORK_MPS - path to the MetalPerformanceShaders.framework
# FRAMEWORK_MPSGRAPH - path to the MetalPerformanceShadersGraph.framework
# MPS_OSX_VERSION - osx version of the MPS library
# MPS_LIBRARIES - path to the MPS library
#
# Check frameworks.
set(FRAMEWORK_FOUNDATION ${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Foundation.framework)
set(FRAMEWORK_METAL ${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Metal.framework)
set(FRAMEWORK_MPS ${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/MetalPerformanceShaders.framework)
set(FRAMEWORK_MPSGRAPH ${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/MetalPerformanceShadersGraph.framework)
if (NOT EXISTS ${FRAMEWORK_FOUNDATION})
message(FATAL_ERROR "Foundation is not found.")
else()
get_filename_component(_dir "${FRAMEWORK_FOUNDATION}" ABSOLUTE)
message(STATUS "Found Foundation: ${_dir}")
endif()
if (NOT EXISTS ${FRAMEWORK_METAL})
message(FATAL_ERROR "Metal is not found.")
else()
get_filename_component(_dir "${FRAMEWORK_METAL}" ABSOLUTE)
message(STATUS "Found Metal: ${_dir}")
endif()
if (NOT EXISTS ${FRAMEWORK_MPS})
message(FATAL_ERROR "MPS is not found.")
else()
get_filename_component(_dir "${FRAMEWORK_MPS}" ABSOLUTE)
message(STATUS "Found MPS: ${_dir}")
endif()
if (NOT EXISTS ${FRAMEWORK_MPSGRAPH})
message(FATAL_ERROR "MPSGraph is not found.")
else()
get_filename_component(_dir "${FRAMEWORK_MPSGRAPH}" ABSOLUTE)
message(STATUS "Found MPSGraph: ${_dir}")
endif()
# Set defines.
string(REGEX MATCH "([0-9]+)\\.([0-9]+)" _version "${FRAMEWORK_MPS}")
set(MPS_OSX_VERSION_DEFINES "MPS_OSX_VERSION_MAJOR=${CMAKE_MATCH_1} "
"MPS_OSX_VERSION_MINOR=${CMAKE_MATCH_2}")
# Set libraries.
set(MPS_LIBRARIES "-weak_framework Foundation \
-weak_framework Metal \
-weak_framework MetalPerformanceShaders \
-weak_framework MetalPerformanceShadersGraph")
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License.
# - Find the OpenMP libraries
#
# Following variables can be set and are optional:
#
# OPENMP_INCLUDE_DIR - path to the OpenMP headers
# OPENMP_LIBRARIES - path to the OpenMP library
#
# Set include directory.
if (EXISTS "${THIRD_PARTY_DIR}/openmp/include/omp.h")
set(OPENMP_INCLUDE_DIR ${THIRD_PARTY_DIR}/openmp/include)
endif()
# Set libraries.
if (EXISTS "${OPENMP_INCLUDE_DIR}/../lib")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${OPENMP_INCLUDE_DIR}/../lib)
endif()
set(OPENMP_LIBRARIES omp)
...@@ -48,6 +48,9 @@ else() # GNU, Clang, AppleClang ...@@ -48,6 +48,9 @@ else() # GNU, Clang, AppleClang
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfma") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfma")
endif() endif()
endif() endif()
if (USE_MPS)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-objc-arc -Wno-unguarded-availability-new")
endif()
if (USE_OPENMP) if (USE_OPENMP)
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xpreprocessor -fopenmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xpreprocessor -fopenmp")
......
...@@ -41,7 +41,7 @@ from dragon.vm.dali.core.ops.image_ops import WarpAffine ...@@ -41,7 +41,7 @@ from dragon.vm.dali.core.ops.image_ops import WarpAffine
from dragon.vm.dali.core.ops.math_ops import Normalize from dragon.vm.dali.core.ops.math_ops import Normalize
from dragon.vm.dali.core.ops.random_ops import CoinFlip from dragon.vm.dali.core.ops.random_ops import CoinFlip
from dragon.vm.dali.core.ops.random_ops import Uniform from dragon.vm.dali.core.ops.random_ops import Uniform
from dragon.vm.dali.core.ops.reader_ops import KPLRecordReader from dragon.vm.dali.core.ops.reader_ops import CGRecordReader
from dragon.vm.dali.core.ops.reader_ops import TFRecordReader from dragon.vm.dali.core.ops.reader_ops import TFRecordReader
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -25,15 +25,17 @@ except ImportError: ...@@ -25,15 +25,17 @@ except ImportError:
from dragon.core.util import deprecation from dragon.core.util import deprecation
ops = deprecation.NotInstalled('nvidia.dali') ops = deprecation.NotInstalled('nvidia.dali')
tfrecord = deprecation.NotInstalled('nvidia.dali') tfrecord = deprecation.NotInstalled('nvidia.dali')
try:
import codewithgpu
except ImportError:
codewithgpu = deprecation.NotInstalled('codewithgpu')
from dragon.core.io import reader
from dragon.core.io import kpl_record
from dragon.vm.dali.core.framework import context from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core.ops.builtin_ops import ExternalSource from dragon.vm.dali.core.ops.builtin_ops import ExternalSource
class KPLRecordReader(object): class CGRecordReader(object):
"""Read examples from the KPLRecord. """Read examples from the CGRecord.
Examples: Examples:
...@@ -42,20 +44,18 @@ class KPLRecordReader(object): ...@@ -42,20 +44,18 @@ class KPLRecordReader(object):
def __init__(): def __init__():
super(MyPipeline, self).__init__() super(MyPipeline, self).__init__()
# Assume the we have the following data: # Assume that we have the following files:
# /data/root.data # /path/to/records/00000.data
# /data/root.index # /path/to/records/00000.index
# /data/root.meta # /path/to/records/METADATA
self.reader = dali.ops.KPLRecordReader( self.reader = dali.ops.CGRecordReader(
path='/data' path='/path/to/records'
features=('image', 'label'), features=('image', 'label'),
pipeline=self, pipeline=self,
# Shuffle locally in the next ``initial_fill`` examples # Shuffle locally in the next ``initial_fill`` examples
# It turns to be weak with the decreasing of ``initial_fill`` # It turns to be weak with the decreasing of ``initial_fill``
# and disabled if ``initial_fill`` is set to **1** # and disabled if ``initial_fill`` is set to **1**
random_shuffle=True, random_shuffle=True, initial_fill=1024)
initial_fill=1024,
)
def iter_step(self): def iter_step(self):
self.reader.feed_inputs() self.reader.feed_inputs()
...@@ -100,19 +100,12 @@ class KPLRecordReader(object): ...@@ -100,19 +100,12 @@ class KPLRecordReader(object):
self._pipe = pipeline self._pipe = pipeline
self._batch_size = pipeline.batch_size self._batch_size = pipeline.batch_size
self._prefetch_depth = pipeline._prefetch_queue_depth self._prefetch_depth = pipeline._prefetch_queue_depth
self._reader = reader.DataReader( self._buffer = mp.Queue(self._prefetch_depth * self._batch_size)
dataset=kpl_record.KPLRecordDataset, self._dataset_reader = codewithgpu.DatasetReader(
source=path, path=path, output_queue=self._buffer,
part_idx=shard_id, partition_idx=shard_id, num_partitions=num_shards,
num_parts=num_shards, shuffle=random_shuffle, initial_fill=initial_fill, **kwargs)
shuffle=random_shuffle, self._dataset_reader.start()
initial_fill=initial_fill,
**kwargs
)
self._buffer = self._reader.q_out = mp.Queue(
self._prefetch_depth * self._batch_size)
self._reader.start()
with context.device('cpu'): with context.device('cpu'):
self.features = dict((k, ExternalSource()) for k in features) self.features = dict((k, ExternalSource()) for k in features)
...@@ -146,8 +139,8 @@ class KPLRecordReader(object): ...@@ -146,8 +139,8 @@ class KPLRecordReader(object):
def terminate(self): def terminate(self):
"""Terminate the reader.""" """Terminate the reader."""
self._reader.terminate() self._dataset_reader.terminate()
self._reader.join() self._dataset_reader.join()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Create the edge references for features. """Create the edge references for features.
...@@ -170,19 +163,16 @@ class TFRecordReader(object): ...@@ -170,19 +163,16 @@ class TFRecordReader(object):
Examples: Examples:
```python ```python
# Assume the we have the following data: # Assume that we have the following files:
# /data/00001.data # /path/to/records/00000.data
# /data/00001.index # /path/to/records/00000.index
# /data/FEATURES # /path/to/records/METADATA
database = '/data'
input = dali.ops.TFRecordReader( input = dali.ops.TFRecordReader(
path=database, path='/path/to/records',
# Shuffle locally in the next ``initial_fill`` examples # Shuffle locally in the next ``initial_fill`` examples
# It turns to be weak with the decreasing of ``initial_fill`` # It turns to be weak with the decreasing of ``initial_fill``
# and disabled if ``initial_fill`` is set to **1** # and disabled if ``initial_fill`` is set to **1**
random_shuffle=True, random_shuffle=True, initial_fill=1024)
initial_fill=1024,
)
``` ```
""" """
...@@ -231,17 +221,17 @@ class TFRecordReader(object): ...@@ -231,17 +221,17 @@ class TFRecordReader(object):
@staticmethod @staticmethod
def check_files(path): def check_files(path):
data_files, index_files, features_file = [], [], None data_files, index_files, meta_data_file = [], [], None
for file in os.listdir(path): for file in os.listdir(path):
if file.endswith('.data'): if file.endswith('.data'):
data_files.append(file) data_files.append(file)
elif file.endswith('.index'): elif file.endswith('.index'):
index_files.append(file) index_files.append(file)
elif file == 'FEATURES': elif file == 'METADATA':
features_file = file meta_data_file = file
if features_file is None: if meta_data_file is None:
raise FileNotFoundError('File <FEATURES> is missing.') raise FileNotFoundError('Excepted meta data file: %s' % meta_data_file)
with open(os.path.join(path, features_file), 'r') as f: with open(os.path.join(path, meta_data_file), 'r') as f:
features = f.read() features = f.read()
features = features.replace('tf.', 'tfrecord.') features = features.replace('tf.', 'tfrecord.')
features = features.replace('tf.io.', 'tfrecord.') features = features.replace('tf.io.', 'tfrecord.')
......
FROM ubuntu:18.04
RUN \
apt-get update && apt-get install -y \
--no-install-recommends \
--allow-change-held-packages \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libudev-dev \
libz-dev \
libnuma-dev \
libprotobuf-dev \
protobuf-compiler \
python3-pip \
python3-dev \
python3-pyqt4 \
python3-tk \
&& rm -rf /var/lib/apt/lists/*
RUN \
pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
kpl-dataset \
opencv-python \
Pillow
RUN \
git clone --recursive https://github.com/seetaresearch/dragon.git && \
mv dragon/third_party/* /opt && rm -rf dragon
RUN \
git clone https://github.com/seetaresearch/dragon.git && \
cd dragon/dragon && mkdir build && cd build && \
cmake .. \
-DTHIRD_PARTY_DIR=/opt \
-DPYTHON_EXECUTABLE=/usr/bin/python3 \
-DUSE_CUDA=OFF \
-DUSE_CUDNN=OFF \
-DUSE_AVX2=ON \
-DUSE_FMA=ON && \
make install -j $(nproc) && \
cd .. && rm -rf build && \
python3 setup.py install
RUN rm /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip
FROM nvidia/cuda:10.2-cudnn8-devel-ubuntu18.04
RUN \
rm /etc/apt/sources.list.d/cuda.list && \
apt-get update && apt-get install -y \
--no-install-recommends \
--allow-change-held-packages \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libudev-dev \
libz-dev \
libnuma-dev \
libprotobuf-dev \
protobuf-compiler \
python3-pip \
python3-dev \
python3-pyqt4 \
python3-tk \
libnccl2 \
libnccl-dev \
&& rm -rf /var/lib/apt/lists/*
RUN \
pip3 install --no-cache-dir --upgrade setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
kpl-dataset \
opencv-python \
Pillow
RUN \
git clone --recursive https://github.com/seetaresearch/dragon.git && \
mv dragon/third_party/* /opt && rm -rf dragon
RUN cd /opt/mpi && bash build.sh && rm -rf src *.gz && cp bin/mpirun /usr/bin
RUN \
git clone https://github.com/seetaresearch/dragon.git && \
cd dragon/dragon && mkdir build && cd build && \
cmake .. \
-DTHIRD_PARTY_DIR=/opt \
-DPYTHON_EXECUTABLE=/usr/bin/python3 \
-DUSE_MPI=ON \
-DUSE_NCCL=ON \
-DUSE_AVX2=ON \
-DUSE_FMA=ON && \
make install -j $(nproc) && \
cd .. && rm -rf build && \
python3 setup.py install
RUN rm /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip
FROM ubuntu:20.04
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libz-dev \
libprotobuf-dev \
protobuf-compiler \
python3-pip \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s /usr/bin/python3 /usr/bin/python
RUN pip install --no-cache-dir numpy protobuf
RUN git clone --recursive https://github.com/seetaresearch/dragon.git
RUN cd dragon && mkdir build && cd build && \
cmake .. -DUSE_CUDA=OFF -DUSE_CUDNN=OFF
RUN cd dragon/build && make install -j $(nproc)
RUN pip install ./dragon && rm -rf dragon/build
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive
RUN rm /etc/apt/sources.list.d/cuda.list
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libz-dev \
libprotobuf-dev \
protobuf-compiler \
python3-pip \
python3-dev \
libnccl2 \
libnccl-dev \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s /usr/bin/python3 /usr/bin/python
RUN pip install --no-cache-dir numpy protobuf
RUN git clone --recursive https://github.com/seetaresearch/dragon.git
RUN cd dragon/third_party/mpi && bash build.sh && rm -rf src *.gz && cp bin/mpirun /usr/bin
RUN ln -s /usr/include/cudnn* /usr/local/cuda/include
RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn* /usr/local/cuda/lib64
RUN cd dragon && mkdir build && cd build && \
cmake .. -DUSE_CUDA=ON -DUSE_CUDNN=ON -DUSE_MPI=ON -DUSE_NCCL=ON
RUN cd dragon/build && make install -j $(nproc)
RUN pip install ./dragon && rm -rf dragon/build

64.8 KB | W: | H:

14.1 KB | W: | H:

docs/api/_static/images/dragon.png
docs/api/_static/images/dragon.png
docs/api/_static/images/dragon.png
docs/api/_static/images/dragon.png
  • 2-up
  • Swipe
  • Onion skin

26.4 KB | W: | H:

8.7 KB | W: | H:

docs/api/_static/images/logo.png
docs/api/_static/images/logo.png
docs/api/_static/images/logo.png
docs/api/_static/images/logo.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -40,7 +40,7 @@ extensions = ['sphinx.ext.autodoc', 'sphinxcontrib.katex', 'breathe'] ...@@ -40,7 +40,7 @@ extensions = ['sphinx.ext.autodoc', 'sphinxcontrib.katex', 'breathe']
project = 'dragon' project = 'dragon'
copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd' copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd'
author = 'SeetaTech, Co.,Ltd' author = 'SeetaTech, Co.,Ltd'
with open('../../../dragon/version.txt', 'r') as f: with open('../../../version.txt', 'r') as f:
version = f.read().strip() version = f.read().strip()
# Sphinx # Sphinx
......
...@@ -47,11 +47,7 @@ New ...@@ -47,11 +47,7 @@ New
SwitchToDevice SwitchToDevice
############## ##############
.. doxygenfunction:: dragon::CPUContext::SwitchToDevice() .. doxygenfunction:: dragon::CPUContext::SwitchToDevice
SwitchToDevice
##############
.. doxygenfunction:: dragon::CPUContext::SwitchToDevice(int stream)
device device
###### ######
......
...@@ -51,11 +51,7 @@ New ...@@ -51,11 +51,7 @@ New
SwitchToDevice SwitchToDevice
############## ##############
.. doxygenfunction:: dragon::CUDAContext::SwitchToDevice() .. doxygenfunction:: dragon::CUDAContext::SwitchToDevice
SwitchToDevice
##############
.. doxygenfunction:: dragon::CUDAContext::SwitchToDevice(int stream)
SynchronizeStream SynchronizeStream
################# #################
......
vm.caffe
========
.. only:: html
Classes
#######
`class AdamSolver <caffe/AdamSolver.html>`_
: The Adam solver.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
`class NesterovSolver <caffe/NesterovSolver.html>`_
: The Nesterov-SGD solver.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
`class Net <caffe/Net.html>`_
: The base net class to connect layers.
`class RMSPropSolver <caffe/RMSPropSolver.html>`_
: The RMSProp solver.
`[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_.
`class SGDSolver <caffe/SGDSolver.html>`_
: The Momentum-SGD solver.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
`class Solver <caffe/Solver.html>`_
: The base solver class to optimize parameters.
.. toctree::
:hidden:
caffe/AdamSolver
caffe/NesterovSolver
caffe/Net
caffe/RMSPropSolver
caffe/SGDSolver
caffe/Solver
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
AdamSolver
==========
.. autoclass:: dragon.vm.caffe.AdamSolver
__init__
--------
.. automethod:: dragon.vm.caffe.AdamSolver.__init__
Properties
----------
base_lr
#######
.. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter
####
.. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net
###
.. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets
#########
.. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods
-------
snapshot
########
.. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step
########
.. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
NesterovSolver
==============
.. autoclass:: dragon.vm.caffe.NesterovSolver
__init__
--------
.. automethod:: dragon.vm.caffe.NesterovSolver.__init__
Properties
----------
base_lr
#######
.. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter
####
.. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net
###
.. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets
#########
.. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods
-------
snapshot
########
.. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step
########
.. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
Net
===
.. autoclass:: dragon.vm.caffe.Net
__init__
--------
.. automethod:: dragon.vm.caffe.Net.__init__
Properties
----------
blobs
#####
.. autoattribute:: dragon.vm.caffe.Net.blobs
inputs
######
.. autoattribute:: dragon.vm.caffe.Net.inputs
outputs
#######
.. autoattribute:: dragon.vm.caffe.Net.outputs
params
######
.. autoattribute:: dragon.vm.caffe.Net.params
Methods
-------
copy_from
#########
.. automethod:: dragon.vm.caffe.Net.copy_from
forward
#########
.. automethod:: dragon.vm.caffe.Net.forward
save
####
.. automethod:: dragon.vm.caffe.Net.save
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
RMSPropSolver
=============
.. autoclass:: dragon.vm.caffe.RMSPropSolver
__init__
--------
.. automethod:: dragon.vm.caffe.RMSPropSolver.__init__
Properties
----------
base_lr
#######
.. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter
####
.. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net
###
.. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets
#########
.. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods
-------
snapshot
########
.. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step
########
.. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
SGDSolver
=========
.. autoclass:: dragon.vm.caffe.SGDSolver
__init__
--------
.. automethod:: dragon.vm.caffe.SGDSolver.__init__
Properties
----------
base_lr
#######
.. autoattribute:: dragon.vm.caffe.Solver.base_lr
:noindex:
iter
####
.. autoattribute:: dragon.vm.caffe.Solver.iter
:noindex:
net
###
.. autoattribute:: dragon.vm.caffe.Solver.net
:noindex:
test_nets
#########
.. autoattribute:: dragon.vm.caffe.Solver.test_nets
:noindex:
Methods
-------
snapshot
########
.. automethod:: dragon.vm.caffe.Solver.snapshot
:noindex:
step
########
.. automethod:: dragon.vm.caffe.Solver.step
:noindex:
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
Solver
======
.. autoclass:: dragon.vm.caffe.Solver
__init__
--------
.. automethod:: dragon.vm.caffe.Solver.__init__
Properties
----------
base_lr
#######
.. autoattribute:: dragon.vm.caffe.Solver.base_lr
iter
####
.. autoattribute:: dragon.vm.caffe.Solver.iter
net
###
.. autoattribute:: dragon.vm.caffe.Solver.net
test_nets
#########
.. autoattribute:: dragon.vm.caffe.Solver.test_nets
Methods
-------
snapshot
########
.. automethod:: dragon.vm.caffe.Solver.snapshot
step
########
.. automethod:: dragon.vm.caffe.Solver.step
.. raw:: html
<style>
h1:before {
content: "caffe.";
color: #103d3e;
}
</style>
vm.caffe.layers
===============
.. only:: html
Classes
-------
`class Accuracy <layers/Accuracy.html>`_
: Compute the top-k accuracy.
`class ArgMax <layers/ArgMax.html>`_
: Compute the index of maximum elements along the given axis.
`class BatchNorm <layers/BatchNorm.html>`_
: Apply the batch normalization.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`class Concat <layers/Concat.html>`_
: Concatenate the inputs along the given axis.
`class Convolution <layers/Convolution.html>`_
: Apply the n-dimension convolution.
`class Crop <layers/Crop.html>`_
: Select the elements according to the dimensions of second bottom.
`class Data <layers/Data.html>`_
: Load batch of data for image classification.
`class Deconvolution <layers/Deconvolution.html>`_
: Apply the n-dimension deconvolution.
`class Dropout <layers/Dropout.html>`_
: Set the elements of the input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`class Eltwise <layers/Eltwise.html>`_
: Compute the element-wise operation on the sequence of inputs.
`class ELU <layers/ELU.html>`_
: Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`class EuclideanLoss <layers/EuclideanLoss.html>`_
: Compute the element-wise squared error.
`class Flatten <layers/Flatten.html>`_
: Flatten the input along the given axes.
`class InnerProduct <layers/InnerProduct.html>`_
: Compute the dense matrix multiplication along the given axes.
`class Input <layers/Input.html>`_
: Produce input blobs with shape and dtype.
`class LRN <layers/LRN.html>`_
: Apply the local response normalization.
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
`class Normalize <layers/Normalize.html>`_
: Apply the fused L2 normalization.
`[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
`class Permute <layers/Permute.html>`_
: Permute the dimensions of input.
`class Pooling <layers/Pooling.html>`_
: Apply the n-dimension pooling.
`class Power <layers/Power.html>`_
: Compute the power of input.
`class PReLU <layers/PReLU.html>`_
: Apply the parametric rectified linear unit.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
`class Python <layers/Python.html>`_
: Wrap a python class into a layer.
`class Reduction <layers/Reduction.html>`_
: Compute the reduction value along the given axis.
`class ReLU <layers/ReLU.html>`_
: Apply the rectified linear unit.
`[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
`class Reshape <layers/Reshape.html>`_
: Change the dimensions of input.
`class Scale <layers/Scale.html>`_
: Compute the affine transformation along the given axes.
`class Sigmoid <layers/Sigmoid.html>`_
: Apply the sigmoid function.
`class SigmoidCrossEntropyLoss <layers/SigmoidCrossEntropyLoss.html>`_
: Compute the sigmoid cross entropy with contiguous targets.
`class SmoothL1Loss <layers/SmoothL1Loss.html>`_
: Compute the element-wise error transited from L1 and L2.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`class Softmax <layers/Softmax.html>`_
: Apply the softmax function.
`class SoftmaxWithLoss <layers/SoftmaxWithLoss.html>`_
: Compute the softmax cross entropy with sparse labels.
`class TanH <layers/TanH.html>`_
: Apply the tanh function.
`class Tile <layers/Tile.html>`_
: Repeat the input according to the given axis.
.. toctree::
:hidden:
layers/Accuracy
layers/ArgMax
layers/BatchNorm
layers/Concat
layers/Convolution
layers/Crop
layers/Data
layers/Deconvolution
layers/Dropout
layers/Eltwise
layers/ELU
layers/EuclideanLoss
layers/Flatten
layers/InnerProduct
layers/Input
layers/LRN
layers/Normalize
layers/Permute
layers/Pooling
layers/Power
layers/PReLU
layers/Python
layers/Reduction
layers/ReLU
layers/Reshape
layers/Scale
layers/Sigmoid
layers/SigmoidCrossEntropyLoss
layers/SmoothL1Loss
layers/Softmax
layers/SoftmaxWithLoss
layers/TanH
layers/Tile
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
EuclideanLoss
=============
.. autoclass:: dragon.vm.caffe.core.layers.EuclideanLoss
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Flatten
========
.. autoclass:: dragon.vm.caffe.core.layers.Flatten
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Input
=====
.. autoclass:: dragon.vm.caffe.core.layers.Input
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
LRN
====
.. autoclass:: dragon.vm.caffe.core.layers.LRN
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Normalize
=========
.. autoclass:: dragon.vm.caffe.core.layers.Normalize
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
PReLU
=====
.. autoclass:: dragon.vm.caffe.core.layers.PReLU
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Permute
=======
.. autoclass:: dragon.vm.caffe.core.layers.Permute
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Pooling
=======
.. autoclass:: dragon.vm.caffe.core.layers.Pooling
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Power
=====
.. autoclass:: dragon.vm.caffe.core.layers.Power
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Python
======
.. autoclass:: dragon.vm.caffe.core.layers.Python
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
ReLU
====
.. autoclass:: dragon.vm.caffe.core.layers.ReLU
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Reduction
=========
.. autoclass:: dragon.vm.caffe.core.layers.Reduction
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Reshape
=======
.. autoclass:: dragon.vm.caffe.core.layers.Reshape
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Scale
=====
.. autoclass:: dragon.vm.caffe.core.layers.Scale
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Sigmoid
=======
.. autoclass:: dragon.vm.caffe.core.layers.Sigmoid
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
SigmoidCrossEntropyLoss
=======================
.. autoclass:: dragon.vm.caffe.core.layers.SigmoidCrossEntropyLoss
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
SmoothL1Loss
============
.. autoclass:: dragon.vm.caffe.core.layers.SmoothL1Loss
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Softmax
=======
.. autoclass:: dragon.vm.caffe.core.layers.Softmax
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
TanH
====
.. autoclass:: dragon.vm.caffe.core.layers.TanH
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
Tile
====
.. autoclass:: dragon.vm.caffe.core.layers.Tile
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
...@@ -34,7 +34,7 @@ extensions = [ ...@@ -34,7 +34,7 @@ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinxcontrib.katex', 'sphinxcontrib.katex',
# 'sphinx_seeta_theme.ext.viewcode', 'sphinx_seeta_theme.ext.viewcode',
] ]
napoleon_use_rtype = False napoleon_use_rtype = False
...@@ -42,7 +42,7 @@ napoleon_use_rtype = False ...@@ -42,7 +42,7 @@ napoleon_use_rtype = False
project = 'dragon' project = 'dragon'
copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd' copyright = 'Copyright (c) 2017-present, SeetaTech, Co.,Ltd'
author = 'SeetaTech, Co.,Ltd' author = 'SeetaTech, Co.,Ltd'
with open('../../../dragon/version.txt', 'r') as f: with open('../../../version.txt', 'r') as f:
version = f.read().strip() version = f.read().strip()
# HTML # HTML
...@@ -77,8 +77,6 @@ html_theme_options = { ...@@ -77,8 +77,6 @@ html_theme_options = {
} }
html_sidebars = { html_sidebars = {
'index': ['localtoc.html'], 'index': ['localtoc.html'],
'caffe': ['localtoc.html'],
'caffe/**': ['localtoc.html'],
'dali': ['localtoc.html'], 'dali': ['localtoc.html'],
'dali/**': ['localtoc.html'], 'dali/**': ['localtoc.html'],
'dragon': ['localtoc.html'], 'dragon': ['localtoc.html'],
...@@ -87,7 +85,6 @@ html_sidebars = { ...@@ -87,7 +85,6 @@ html_sidebars = {
'onnx/**': ['localtoc.html'], 'onnx/**': ['localtoc.html'],
'tensorflow': ['localtoc.html'], 'tensorflow': ['localtoc.html'],
'tensorflow/**': ['localtoc.html'], 'tensorflow/**': ['localtoc.html'],
'tensorlayer/**': ['localtoc.html'],
'tensorrt': ['localtoc.html'], 'tensorrt': ['localtoc.html'],
'tensorrt/**': ['localtoc.html'], 'tensorrt/**': ['localtoc.html'],
'torch': ['localtoc.html'], 'torch': ['localtoc.html'],
......
...@@ -21,6 +21,9 @@ vm.dali.ops ...@@ -21,6 +21,9 @@ vm.dali.ops
`class Cast <ops/Cast.html>`_ `class Cast <ops/Cast.html>`_
: Cast the data type of input. : Cast the data type of input.
`class CGRecordReader <ops/CGRecordReader.html>`_
: Read examples from the cg-record file.
`class CoinFlip <ops/CoinFlip.html>`_ `class CoinFlip <ops/CoinFlip.html>`_
: Sample values from a bernoulli distribution. : Sample values from a bernoulli distribution.
...@@ -81,9 +84,6 @@ vm.dali.ops ...@@ -81,9 +84,6 @@ vm.dali.ops
`class Slice <ops/Slice.html>`_ `class Slice <ops/Slice.html>`_
: Select an interval of elements from input. : Select an interval of elements from input.
`class KPLRecordReader <ops/KPLRecordReader.html>`_
: Read examples from the kpl-record file.
`class TFRecordReader <ops/TFRecordReader.html>`_ `class TFRecordReader <ops/TFRecordReader.html>`_
: Read examples from the tf-record file. : Read examples from the tf-record file.
...@@ -101,6 +101,7 @@ vm.dali.ops ...@@ -101,6 +101,7 @@ vm.dali.ops
ops/Brightness ops/Brightness
ops/BrightnessContrast ops/BrightnessContrast
ops/Cast ops/Cast
ops/CGRecordReader
ops/CoinFlip ops/CoinFlip
ops/ColorSpaceConversion ops/ColorSpaceConversion
ops/ColorTwist ops/ColorTwist
...@@ -121,7 +122,6 @@ vm.dali.ops ...@@ -121,7 +122,6 @@ vm.dali.ops
ops/Resize ops/Resize
ops/Rotate ops/Rotate
ops/Slice ops/Slice
ops/KPLRecordReader
ops/TFRecordReader ops/TFRecordReader
ops/Uniform ops/Uniform
ops/WarpAffine ops/WarpAffine
......
KPLRecordReader CGRecordReader
=============== ===============
.. autoclass:: dragon.vm.dali.ops.KPLRecordReader .. autoclass:: dragon.vm.dali.ops.CGRecordReader
__init__ __init__
-------- --------
.. automethod:: dragon.vm.dali.ops.KPLRecordReader.__init__ .. automethod:: dragon.vm.dali.ops.CGRecordReader.__init__
Methods Methods
------- -------
example_to_data example_to_data
############### ###############
.. automethod:: dragon.vm.dali.ops.KPLRecordReader.example_to_data .. automethod:: dragon.vm.dali.ops.CGRecordReader.example_to_data
feed_inputs feed_inputs
########### ###########
.. automethod:: dragon.vm.dali.ops.KPLRecordReader.feed_inputs .. automethod:: dragon.vm.dali.ops.CGRecordReader.feed_inputs
__call__ __call__
######## ########
.. automethod:: dragon.vm.dali.ops.KPLRecordReader.__call__ .. automethod:: dragon.vm.dali.ops.CGRecordReader.__call__
.. raw:: html .. raw:: html
......
dragon.io
=========
.. only:: html
Classes
-------
`class DataReader <io/DataReader.html>`_
: Read examples from a dataset.
`class KPLRecordDataset <io/KPLRecordDataset.html>`_
: Dataset to load the KPLRecord.
`class KPLRecordWriter <io/KPLRecordWriter.html>`_
: Write examples into the KPLRecord.
`class TFRecordExample <io/TFRecordExample.html>`_
: Describe an example of the TFRecord.
`class TFRecordWriter <io/KPLRecordWriter.html>`_
: Write examples into the TFRecord.
.. toctree::
:hidden:
io/DataReader
io/KPLRecordDataset
io/KPLRecordWriter
io/TFRecordExample
io/TFRecordWriter
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
</style>
DataReader
==========
.. autoclass:: dragon.io.DataReader
__init__
--------
.. automethod:: dragon.io.DataReader.__init__
Methods
-------
before_first
############
.. automethod:: dragon.io.DataReader.before_first
next_example
############
.. automethod:: dragon.io.DataReader.next_example
reset
#####
.. automethod:: dragon.io.DataReader.reset
run
###
.. automethod:: dragon.io.DataReader.run
.. raw:: html
<style>
h1:before {
content: "dragon.io.";
color: #103d3e;
}
</style>
KPLRecordDataset
================
.. autoclass:: dragon.io.KPLRecordDataset
__init__
--------
.. automethod:: dragon.io.KPLRecordDataset.__init__
Properties
----------
protocol
########
.. autoattribute:: dragon.io.KPLRecordDataset.protocol
size
####
.. autoattribute:: dragon.io.KPLRecordDataset.size
Methods
-------
get
###
.. automethod:: dragon.io.KPLRecordDataset.get
redirect
########
.. automethod:: dragon.io.KPLRecordDataset.redirect
.. raw:: html
<style>
h1:before {
content: "dragon.io.";
color: #103d3e;
}
</style>
KPLRecordWriter
===============
.. autoclass:: dragon.io.KPLRecordWriter
__init__
--------
.. automethod:: dragon.io.KPLRecordWriter.__init__
Methods
-------
close
#####
.. automethod:: dragon.io.KPLRecordWriter.close
write
#####
.. automethod:: dragon.io.KPLRecordWriter.write
.. raw:: html
<style>
h1:before {
content: "dragon.io.";
color: #103d3e;
}
</style>
TFRecordExample
===============
.. autoclass:: dragon.io.TFRecordExample
__init__
--------
.. automethod:: dragon.io.TFRecordExample.__init__
Methods
-------
add_floats
##########
.. automethod:: dragon.io.TFRecordExample.add_floats
add_ints
########
.. automethod:: dragon.io.TFRecordExample.add_ints
add_strings
###########
.. automethod:: dragon.io.TFRecordExample.add_strings
parse_from
##########
.. automethod:: dragon.io.TFRecordExample.parse_from
serialize_to
############
.. automethod:: dragon.io.TFRecordExample.serialize_to
.. raw:: html
<style>
h1:before {
content: "dragon.io.";
color: #103d3e;
}
</style>
TFRecordWriter
===============
.. autoclass:: dragon.io.TFRecordWriter
__init__
--------
.. automethod:: dragon.io.TFRecordWriter.__init__
Methods
-------
close
#####
.. automethod:: dragon.io.TFRecordWriter.close
write
#####
.. automethod:: dragon.io.TFRecordWriter.write
.. raw:: html
<style>
h1:before {
content: "dragon.io.";
color: #103d3e;
}
</style>
...@@ -154,7 +154,10 @@ dragon.math ...@@ -154,7 +154,10 @@ dragon.math
: Compute the tanh of input. : Compute the tanh of input.
`top_k(...) <math/top_k.html>`_ `top_k(...) <math/top_k.html>`_
: Return the top-K largest or smallest elements along the given axis. : Return the top k-largest or k-smallest elements along the given axis.
`var(...) <math/var.html>`_
: Compute the variance value of elements along the given axis.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -209,6 +212,7 @@ dragon.math ...@@ -209,6 +212,7 @@ dragon.math
math/sum math/sum
math/tanh math/tanh
math/top_k math/top_k
math/var
.. raw:: html .. raw:: html
......
ELU var
=== ===
.. autoclass:: dragon.vm.caffe.core.layers.ELU .. autofunction:: dragon.math.var
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
dragon.mps
===========
.. only:: html
Functions
---------
`current_device(...) <mps/current_device.html>`_
: Return the index of current selected device.
`get_device_family(...) <mps/get_device_family.html>`_
: Return the supported families of specified device.
`is_available(...) <mps/is_available.html>`_
: Return a bool reporting if runtime is available.
`set_default_device(...) <mps/set_default_device.html>`_
: Set the default device.
`set_device(...) <mps/set_device.html>`_
: Set the current device.
`synchronize(...) <mps/synchronize.html>`_
: Synchronize a specified MPS stream.
.. toctree::
:hidden:
mps/current_device
mps/get_device_family
mps/is_available
mps/set_default_device
mps/set_device
mps/synchronize
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
</style>
Eltwise current_device
======== ==============
.. autoclass:: dragon.vm.caffe.core.layers.Eltwise .. autofunction:: dragon.mps.current_device
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.mps.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
ArgMax get_device_family
====== =================
.. autoclass:: dragon.vm.caffe.core.layers.ArgMax .. autofunction:: dragon.mps.get_device_family
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.mps.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
InnerProduct is_available
============ ============
.. autoclass:: dragon.vm.caffe.core.layers.InnerProduct .. autofunction:: dragon.mps.is_available
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.mps.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
BatchNorm set_default_device
========= ==================
.. autoclass:: dragon.vm.caffe.core.layers.BatchNorm .. autofunction:: dragon.mps.set_default_device
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.mps.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
Concat set_device
====== ==========
.. autoclass:: dragon.vm.caffe.core.layers.Concat .. autofunction:: dragon.mps.set_device
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.mps.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
Convolution synchronize
=========== ===========
.. autoclass:: dragon.vm.caffe.core.layers.Convolution .. autofunction:: dragon.mps.synchronize
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.mps.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -3,12 +3,6 @@ dragon.vision ...@@ -3,12 +3,6 @@ dragon.vision
.. only:: html .. only:: html
Classes
-------
`class DataIterator <vision/DataIterator.html>`_
: Iterator to return the batch of data for image classification.
Functions Functions
--------- ---------
...@@ -29,7 +23,6 @@ dragon.vision ...@@ -29,7 +23,6 @@ dragon.vision
.. toctree:: .. toctree::
:hidden: :hidden:
vision/DataIterator
vision/extract_patches vision/extract_patches
vision/resize vision/resize
vision/roi_align vision/roi_align
......
DataIterator
============
.. autoclass:: dragon.vision.DataIterator
__init__
--------
.. automethod:: dragon.vision.DataIterator.__init__
Methods
-------
next
####
.. automethod:: dragon.vision.DataIterator.next
.. raw:: html
<style>
h1:before {
content: "dragon.vision.";
color: #103d3e;
}
</style>
...@@ -32,11 +32,11 @@ Dragon ...@@ -32,11 +32,11 @@ Dragon
* `dragon.cuda <dragon/cuda.html>`_ * `dragon.cuda <dragon/cuda.html>`_
* `dragon.distributed <dragon/distributed.html>`_ * `dragon.distributed <dragon/distributed.html>`_
* `dragon.dlpack <dragon/dlpack.html>`_ * `dragon.dlpack <dragon/dlpack.html>`_
* `dragon.io <dragon/io.html>`_
* `dragon.logging <dragon/logging.html>`_ * `dragon.logging <dragon/logging.html>`_
* `dragon.losses <dragon/losses.html>`_ * `dragon.losses <dragon/losses.html>`_
* `dragon.math <dragon/math.html>`_ * `dragon.math <dragon/math.html>`_
* `dragon.metrics <dragon/metrics.html>`_ * `dragon.metrics <dragon/metrics.html>`_
* `dragon.mps <dragon/mps.html>`_
* `dragon.nn <dragon/nn.html>`_ * `dragon.nn <dragon/nn.html>`_
* `dragon.onnx <dragon/onnx.html>`_ * `dragon.onnx <dragon/onnx.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_ * `dragon.optimizers <dragon/optimizers.html>`_
...@@ -44,20 +44,6 @@ Dragon ...@@ -44,20 +44,6 @@ Dragon
* `dragon.sysconfig <dragon/sysconfig.html>`_ * `dragon.sysconfig <dragon/sysconfig.html>`_
* `dragon.vision <dragon/vision.html>`_ * `dragon.vision <dragon/vision.html>`_
Caffe
#####
*Caffe* is the most famous framework for vision.
Our work is very different from the official python wrappers, a.k.a,
the *PyCaffe*, which comes from the exports of *BoostPython*
based on C++ language.
This style involves the following components:
* `caffe <caffe.html>`_
* `caffe.layers <caffe/layers.html>`_
TensorFlow TensorFlow
########## ##########
...@@ -79,22 +65,6 @@ TensorFlow ...@@ -79,22 +65,6 @@ TensorFlow
* `tensorflow.nn <tensorflow/nn.html>`_ * `tensorflow.nn <tensorflow/nn.html>`_
* `tensorflow.random <tensorflow/random.html>`_ * `tensorflow.random <tensorflow/random.html>`_
TensorLayer
###########
*TensorLayer* takes a high-level layer abstraction to build complex models.
Original *TensorLayer* project is restricted to execute *TensorFlow* operations,
which is also known as a competitive tf-wrapper comparing to *tf.keras*.
We transplant and remake the abstractions to match our engine, while keeping
the compatibility with tf-based codes as possible.
This style involves the following components:
* `tensorlayer.initializers <tensorlayer/initializers.html>`_
* `tensorlayer.layers <tensorlayer/layers.html>`_
* `tensorlayer.models <tensorlayer/models.html>`_
PyTorch PyTorch
####### #######
...@@ -187,9 +157,6 @@ Modules ...@@ -187,9 +157,6 @@ Modules
`Module dlpack <dragon/dlpack.html>`_ `Module dlpack <dragon/dlpack.html>`_
: Native API for ``dragon.dlpack`` namespace. : Native API for ``dragon.dlpack`` namespace.
`Module io <dragon/io.html>`_
: Native API for ``dragon.io`` namespace.
`Module logging <dragon/logging.html>`_ `Module logging <dragon/logging.html>`_
: Native API for ``dragon.logging`` namespace. : Native API for ``dragon.logging`` namespace.
...@@ -202,6 +169,9 @@ Modules ...@@ -202,6 +169,9 @@ Modules
`Module metrics <dragon/metrics.html>`_ `Module metrics <dragon/metrics.html>`_
: Native API for ``dragon.metrics`` namespace. : Native API for ``dragon.metrics`` namespace.
`Module mps <dragon/mps.html>`_
: Native API for ``dragon.mps`` namespace.
`Module nn <dragon/nn.html>`_ `Module nn <dragon/nn.html>`_
: Native API for ``dragon.nn`` namespace. : Native API for ``dragon.nn`` namespace.
...@@ -220,12 +190,6 @@ Modules ...@@ -220,12 +190,6 @@ Modules
`Module vision <dragon/vision.html>`_ `Module vision <dragon/vision.html>`_
: Native API for ``dragon.vision`` namespace. : Native API for ``dragon.vision`` namespace.
`Module vm.caffe <caffe.html>`_
: Virtual API for ``caffe`` namespace.
`Module vm.caffe.layers <caffe/layers.html>`_
: Virtual API for ``caffe.layers`` namespace.
`Module vm.dali <dali.html>`_ `Module vm.dali <dali.html>`_
: Virtual API for ``dali`` namespace. : Virtual API for ``dali`` namespace.
...@@ -256,15 +220,6 @@ Modules ...@@ -256,15 +220,6 @@ Modules
`Module vm.tensorflow.random <tensorflow/random.html>`_ `Module vm.tensorflow.random <tensorflow/random.html>`_
: Virtual API for ``tensorflow.random`` namespace. : Virtual API for ``tensorflow.random`` namespace.
`Module vm.tensorlayer.initializers <tensorlayer/initializers.html>`_
: Virtual API for ``tensorlayer.initializers`` namespace.
`Module vm.tensorlayer.layers <tensorlayer/layers.html>`_
: Virtual API for ``tensorlayer.layers`` namespace.
`Module vm.tensorlayer.models <tensorlayer/models.html>`_
: Virtual API for ``tensorlayer.models`` namespace.
`Module vm.tensorrt <tensorrt.html>`_ `Module vm.tensorrt <tensorrt.html>`_
: Virtual API for ``tensorrt`` namespace. : Virtual API for ``tensorrt`` namespace.
...@@ -322,19 +277,17 @@ Modules ...@@ -322,19 +277,17 @@ Modules
dragon/cuda dragon/cuda
dragon/distributed dragon/distributed
dragon/dlpack dragon/dlpack
dragon/io
dragon/logging dragon/logging
dragon/losses dragon/losses
dragon/math dragon/math
dragon/metrics dragon/metrics
dragon/mps
dragon/nn dragon/nn
dragon/onnx dragon/onnx
dragon/optimizers dragon/optimizers
dragon/random dragon/random
dragon/sysconfig dragon/sysconfig
dragon/vision dragon/vision
caffe
caffe/layers
dali dali
dali/ops dali/ops
tensorflow tensorflow
...@@ -345,9 +298,6 @@ Modules ...@@ -345,9 +298,6 @@ Modules
tensorflow/math tensorflow/math
tensorflow/nn tensorflow/nn
tensorflow/random tensorflow/random
tensorlayer/initializers
tensorlayer/layers
tensorlayer/models
tensorrt tensorrt
tensorrt/onnx tensorrt/onnx
torch torch
......
...@@ -99,6 +99,9 @@ vm.tensorflow.math ...@@ -99,6 +99,9 @@ vm.tensorflow.math
`reduce_sum(...) <math/reduce_sum.html>`_ `reduce_sum(...) <math/reduce_sum.html>`_
: Compute the sum value of elements along the given axis. : Compute the sum value of elements along the given axis.
`reduce_variance(...) <math/reduce_variance.html>`_
: Compute the variance value of elements along the given axis.
`round(...) <math/round.html>`_ `round(...) <math/round.html>`_
: Compute the nearest integer of input. : Compute the nearest integer of input.
...@@ -127,7 +130,7 @@ vm.tensorflow.math ...@@ -127,7 +130,7 @@ vm.tensorflow.math
: Compute the tanh of input. : Compute the tanh of input.
`top_k(...) <math/top_k.html>`_ `top_k(...) <math/top_k.html>`_
: Return the top-K largest elements along the last axis. : Return the top k-largest elements along the last axis.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -163,6 +166,7 @@ vm.tensorflow.math ...@@ -163,6 +166,7 @@ vm.tensorflow.math
math/reduce_mean math/reduce_mean
math/reduce_min math/reduce_min
math/reduce_sum math/reduce_sum
math/reduce_variance
math/round math/round
math/rsqrt math/rsqrt
math/sigmoid math/sigmoid
......
SoftmaxWithLoss reduce_variance
=============== ===============
.. autoclass:: dragon.vm.caffe.core.layers.SoftmaxWithLoss .. autofunction:: dragon.vm.tensorflow.math.reduce_variance
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "tf.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -68,6 +68,9 @@ vm.tensorflow.nn ...@@ -68,6 +68,9 @@ vm.tensorflow.nn
: Apply the gaussian error linear unit. : Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_. `[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`l2_loss(...) <nn/l2_loss.html>`_
: Compute the loss of element-wise squared error.
`leaky_relu(...) <nn/leaky_relu.html>`_ `leaky_relu(...) <nn/leaky_relu.html>`_
: Apply the leaky rectified linear unit. : Apply the leaky rectified linear unit.
...@@ -105,6 +108,9 @@ vm.tensorflow.nn ...@@ -105,6 +108,9 @@ vm.tensorflow.nn
: Apply the scaled exponential linear unit. : Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_. `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`sigmoid_cross_entropy_with_logits(...) <nn/sigmoid_cross_entropy_with_logits.html>`_
: Compute the loss of sigmoid cross entropy.
`silu(...) <nn/silu.html>`_ `silu(...) <nn/silu.html>`_
: Apply the sigmoid linear unit. : Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_. `[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
...@@ -113,13 +119,13 @@ vm.tensorflow.nn ...@@ -113,13 +119,13 @@ vm.tensorflow.nn
: Apply the softmax function. : Apply the softmax function.
`softmax_cross_entropy_with_logits(...) <nn/softmax_cross_entropy_with_logits.html>`_ `softmax_cross_entropy_with_logits(...) <nn/softmax_cross_entropy_with_logits.html>`_
: Compute the softmax cross entropy with contiguous labels. : Compute the loss of softmax cross entropy.
`space_to_depth(...) <nn/space_to_depth.html>`_ `space_to_depth(...) <nn/space_to_depth.html>`_
: Rearrange blocks of spatial data into depth. : Rearrange blocks of spatial data into depth.
`sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_ `sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_
: Compute the softmax cross entropy with sparse labels. : Compute the loss of softmax cross entropy with sparse labels.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -143,6 +149,7 @@ vm.tensorflow.nn ...@@ -143,6 +149,7 @@ vm.tensorflow.nn
nn/elu nn/elu
nn/fused_batch_norm nn/fused_batch_norm
nn/gelu nn/gelu
nn/l2_loss
nn/leaky_relu nn/leaky_relu
nn/local_response_normalization nn/local_response_normalization
nn/log_softmax nn/log_softmax
...@@ -154,6 +161,7 @@ vm.tensorflow.nn ...@@ -154,6 +161,7 @@ vm.tensorflow.nn
nn/relu nn/relu
nn/relu6 nn/relu6
nn/selu nn/selu
nn/sigmoid_cross_entropy_with_logits
nn/silu nn/silu
nn/softmax nn/softmax
nn/softmax_cross_entropy_with_logits nn/softmax_cross_entropy_with_logits
......
Dropout l2_loss
======= =======
.. autoclass:: dragon.vm.caffe.core.layers.Dropout .. autofunction:: dragon.vm.tensorflow.nn.l2_loss
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "tf.nn.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
Crop sigmoid_cross_entropy_with_logits
==== =================================
.. autoclass:: dragon.vm.caffe.core.layers.Crop .. autofunction:: dragon.vm.tensorflow.nn.sigmoid_cross_entropy_with_logits
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "tf.nn.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
vm.tensorlayer.initializers
===========================
.. only:: html
Classes
-------
`class Constant <initializers/Constant.html>`_
: Fill tensor with a scalar value.
`class GlorotNormal <initializers/GlorotNormal.html>`_
: Fill tensor from a glorot normal distribution.
`class GlorotUniform <initializers/GlorotUniform.html>`_
: Fill tensor from a glorot uniform distribution.
`class Initializer <initializers/GlorotUniform.html>`_
: The basic Initializer.
`class Ones <initializers/Ones.html>`_
: Fill tensor with ones.
`class RandomNormal <initializers/RandomNormal.html>`_
: Fill tensor from a normal distribution.
`class RandomUniform <initializers/RandomUniform.html>`_
: Fill tensor from an uniform distribution.
`class TruncatedNormal <initializers/TruncatedNormal.html>`_
: Fill tensor from a truncated normal distribution.
`class Zeros <initializers/Zeros.html>`_
: Fill tensor with zeros.
.. toctree::
:hidden:
initializers/Constant
initializers/GlorotNormal
initializers/GlorotUniform
initializers/Initializer
initializers/Ones
initializers/RandomNormal
initializers/RandomUniform
initializers/TruncatedNormal
initializers/Zeros
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
</style>
Constant
========
.. autoclass:: dragon.vm.tensorlayer.initializers.Constant
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.Constant.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
GlorotNormal
============
.. autoclass:: dragon.vm.tensorlayer.initializers.GlorotNormal
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.GlorotNormal.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
GlorotUniform
=============
.. autoclass:: dragon.vm.tensorlayer.initializers.GlorotUniform
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.GlorotUniform.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
Initializer
===========
.. autoclass:: dragon.vm.tensorlayer.initializers.Initializer
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
Ones
====
.. autoclass:: dragon.vm.tensorlayer.initializers.Ones
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.Ones.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
RandomNormal
============
.. autoclass:: dragon.vm.tensorlayer.initializers.RandomNormal
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.RandomNormal.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
RandomUniform
=============
.. autoclass:: dragon.vm.tensorlayer.initializers.RandomUniform
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.RandomUniform.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
TruncatedNormal
===============
.. autoclass:: dragon.vm.tensorlayer.initializers.TruncatedNormal
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.TruncatedNormal.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
Zeros
=====
.. autoclass:: dragon.vm.tensorlayer.initializers.Zeros
__init__
--------
.. automethod:: dragon.vm.tensorlayer.initializers.Zeros.__init__
Methods
-------
__call__
########
.. automethod:: dragon.vm.tensorlayer.initializers.Initializer.__call__
:noindex:
.. raw:: html
<style>
h1:before {
content: "tl.initializers.";
color: #103d3e;
}
</style>
vm.tensorlayer.layers
=====================
.. only:: html
Classes
-------
`class BatchNorm <layers/BatchNorm.html>`_
: Batch normalization layer.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`class Concat <layers/Concat.html>`_
: Layer to concat tensors according to the given axis.
`class Conv2d <layers/Conv2d.html>`_
: 2d convolution layer.
`class Dense <layers/Dense.html>`_
: Fully connected layer.
`class Elementwise <layers/Elementwise.html>`_
: Layer to combine inputs by applying element-wise operation.
`class Flatten <layers/Flatten.html>`_
: Layer to reshape input into a matrix.
`class GlobalMaxPool2d <layers/GlobalMaxPool2d.html>`_
: 2d global max pooling layer.
`class GlobalMeanPool2d <layers/GlobalMeanPool2d.html>`_
: 2d global mean pooling layer.
`class MaxPool2d <layers/MaxPool2d.html>`_
: 2d max pooling layer.
`class MeanPool2d <layers/MeanPool2d.html>`_
: 2d mean pooling layer.
`class Layer <layers/Layer.html>`_
: The base layer class.
`class LayerList <layers/LayerList.html>`_
: Layer to stack a group of layers.
`class Relu <layers/Relu.html>`_
: Layer to apply the rectified linear unit.
`[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
`class Reshape <layers/Reshape.html>`_
: Layer to change the dimensions of input.
`class Transpose <layers/Transpose.html>`_
: Layer to permute the dimensions of input.
Functions
---------
`Input(...) <layers/Input.html>`_
: Create a placeholder as input.
.. toctree::
:hidden:
layers/BatchNorm
layers/Concat
layers/Conv2d
layers/Dense
layers/Elementwise
layers/Flatten
layers/GlobalMaxPool2d
layers/GlobalMeanPool2d
layers/MaxPool2d
layers/MeanPool2d
layers/Input
layers/Layer
layers/LayerList
layers/Relu
layers/Reshape
layers/Transpose
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
</style>
BatchNorm
=========
.. autoclass:: dragon.vm.tensorlayer.layers.BatchNorm
__init__
--------
.. automethod:: dragon.vm.tensorlayer.layers.BatchNorm.__init__
.. raw:: html
<style>
h1:before {
content: "tl.layers.";
color: #103d3e;
}
</style>
Concat
======
.. autoclass:: dragon.vm.tensorlayer.layers.Concat
__init__
--------
.. automethod:: dragon.vm.tensorlayer.layers.Concat.__init__
.. raw:: html
<style>
h1:before {
content: "tl.layers.";
color: #103d3e;
}
</style>
This diff is collapsed. Click to expand it.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!