Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit d1f714ea
authored
May 15, 2019
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Apply the dispatcher to RunImpl
1 parent
bd84b7fd
Hide whitespace changes
Inline
Side-by-side
Showing
159 changed files
with
1208 additions
and
2363 deletions
Dragon/include/core/common.h
Dragon/include/core/graph.h
Dragon/include/core/graph_gradient.h
Dragon/include/core/operator.h
Dragon/include/core/operator_schema.h
Dragon/include/core/types.h
Dragon/include/core/workspace.h
Dragon/include/operators/arithmetic/fully_connected_op.h
Dragon/include/operators/array/multinomial_op.h
Dragon/include/utils/caffemodel.h
Dragon/python/dragon/operators/array.py
Dragon/python/dragon/vm/torch/ops/builtin.py
Dragon/python/dragon/vm/torch/ops/modules/array.py
Dragon/python/dragon/vm/torch/tensor.py
Dragon/src/contrib/rcnn/bbox_utils.cu
Dragon/src/contrib/rcnn/bbox_utils.h
Dragon/src/core/graph.cc
Dragon/src/core/graph_gradient.cc
Dragon/src/core/graph_optimizer.cc
Dragon/src/core/operator_schema.cc
Dragon/src/kernels/activation/dropout_op_kernel.cu
Dragon/src/kernels/activation/droppath_op_kernel.cu
Dragon/src/kernels/activation/elu_op_kernel.cu
Dragon/src/kernels/activation/prelu_op_kernel.cu
Dragon/src/kernels/activation/relu_op_kernel.cu
Dragon/src/kernels/activation/selu_op_kernel.cu
Dragon/src/kernels/activation/sigmoid_op_kernel.cu
Dragon/src/kernels/activation/softmax_op_kernel.cu
Dragon/src/kernels/activation/tanh_op_kernel.cu
Dragon/src/kernels/arithmetic/affine_op_kernel.cu
Dragon/src/kernels/arithmetic/clip_op_kernel.cu
Dragon/src/kernels/arithmetic/maximum_op_kernel.cu
Dragon/src/kernels/arithmetic/minimum_op_kernel.cu
Dragon/src/kernels/arithmetic/moments_op_kernel.cu
Dragon/src/kernels/array/arange_op_kernel.cu
Dragon/src/kernels/array/argreduce_op_kernel.cc
Dragon/src/kernels/array/argreduce_op_kernel.cu
Dragon/src/kernels/array/concat_op_kernel.cu
Dragon/src/kernels/array/crop_op_kernel.cu
Dragon/src/kernels/array/index_select_op_kernel.cu
Dragon/src/kernels/array/one_hot_op_kernel.cu
Dragon/src/kernels/array/pad_op_kernel.cu
Dragon/src/kernels/array/reduce_sum_op_kernel.cu
Dragon/src/kernels/array/repeat_op_kernel.cu
Dragon/src/kernels/array/slice_op_kernel.cu
Dragon/src/kernels/array/tile_op_kernel.cu
Dragon/src/kernels/array/transpose_op_kernel.cu
Dragon/src/kernels/control_flow/assign_op_kernel.cu
Dragon/src/kernels/control_flow/compare_op_kernel.cu
Dragon/src/kernels/control_flow/masked_assign_op_kernel.cu
Dragon/src/kernels/loss/l1_loss_op_kernel.cu
Dragon/src/kernels/loss/nll_loss_op_kernel.cu
Dragon/src/kernels/loss/sigmoid_ce_loss_op_kernel.cu
Dragon/src/kernels/loss/sigmoid_focal_loss_op_kernel.cu
Dragon/src/kernels/loss/smooth_l1_loss_op_kernel.cu
Dragon/src/kernels/loss/softmax_ce_loss_op_kernel.cu
Dragon/src/kernels/loss/softmax_focal_loss_op_kernel.cu
Dragon/src/kernels/loss/sparse_softmax_ce_loss_op_kernel.cu
Dragon/src/kernels/misc/astype_op_kernel.cu
Dragon/src/kernels/misc/gradient_op_kernel.cu
Dragon/src/kernels/misc/image_data_op_kernel.cu
Dragon/src/kernels/norm/batch_norm_op_kernel.cu
Dragon/src/kernels/norm/group_norm_op_kernel.cu
Dragon/src/kernels/recurrent/lstm_cell_op_kernel.cu
Dragon/src/kernels/update/adam_update_op_kernel.cu
Dragon/src/kernels/update/mprec_update_op_kerne.cu
Dragon/src/kernels/update/nesterov_update_op_kernel.cu
Dragon/src/kernels/update/rmsprop_update_op_kernel.cu
Dragon/src/kernels/update/sgd_update_op_kernel.cu
Dragon/src/kernels/vision/bias_add_op_kernel.cu
Dragon/src/kernels/vision/bilinear_resize_op_kernel.cu
Dragon/src/kernels/vision/conv_op_kernel.cu
Dragon/src/kernels/vision/depthwise_conv_op_kernel.cu
Dragon/src/kernels/vision/drop_block_op_kernel.cu
Dragon/src/kernels/vision/nn_resize_op_kernel.cu
Dragon/src/kernels/vision/pool_op_kernel.cu
Dragon/src/kernels/vision/roi_align_op_kernel.cu
Dragon/src/kernels/vision/roi_align_op_kernel.fp16.cu
Dragon/src/kernels/vision/roi_pool_op_kernel.cu
Dragon/src/onnx/onnx_backend.cc
Dragon/src/onnx/onnx_backend.h
Dragon/src/operators/activation/cudnn_dropout_op.cc
Dragon/src/operators/activation/cudnn_elu_op.cc
Dragon/src/operators/activation/cudnn_relu_op.cc
Dragon/src/operators/activation/cudnn_sigmoid_op.cc
Dragon/src/operators/activation/cudnn_softmax_op.cc
Dragon/src/operators/activation/cudnn_tanh_op.cc
Dragon/src/operators/activation/dropout_op.cc
Dragon/src/operators/activation/droppath_op.cc
Dragon/src/operators/activation/elu_op.cc
Dragon/src/operators/activation/prelu_op.cc
Dragon/src/operators/activation/relu_op.cc
Dragon/src/operators/activation/selu_op.cc
Dragon/src/operators/activation/sigmoid_op.cc
Dragon/src/operators/activation/softmax_op.cc
Dragon/src/operators/activation/tanh_op.cc
Dragon/src/operators/arithmetic/affine_op.cc
Dragon/src/operators/arithmetic/cudnn_affine_op.cc
Dragon/src/operators/arithmetic/eltwise_op.cc
Dragon/src/operators/arithmetic/exp_op.cc
Dragon/src/operators/arithmetic/fully_connected_op.cc
Dragon/src/operators/arithmetic/gram_matrix_op.cc
Dragon/src/operators/arithmetic/log_op.cc
Dragon/src/operators/arithmetic/matmul_op.cc
Dragon/src/operators/arithmetic/maximum_op.cc
Dragon/src/operators/arithmetic/minimum_op.cc
Dragon/src/operators/arithmetic/moments_op.cc
Dragon/src/operators/arithmetic/pow_op.cc
Dragon/src/operators/arithmetic/sqrt_op.cc
Dragon/src/operators/arithmetic/square_op.cc
Dragon/src/operators/array/arange_op.cc
Dragon/src/operators/array/argreduce_op.cc
Dragon/src/operators/array/concat_op.cc
Dragon/src/operators/array/crop_op.cc
Dragon/src/operators/array/index_select_op.cc
Dragon/src/operators/array/multinomial_op.cc
Dragon/src/operators/array/one_hot_op.cc
Dragon/src/operators/array/pad_op.cc
Dragon/src/operators/array/reduce_op.cc
Dragon/src/operators/array/repeat_op.cc
Dragon/src/operators/array/slice_op.cc
Dragon/src/operators/array/stack_op.cc
Dragon/src/operators/array/tile_op.cc
Dragon/src/operators/array/transpose_op.cc
Dragon/src/operators/control_flow/assign_op.cc
Dragon/src/operators/control_flow/copy_op.cc
Dragon/src/operators/control_flow/masked_assign_op.cc
Dragon/src/operators/loss/ctc_loss_op.cc
Dragon/src/operators/loss/l1_loss_op.cc
Dragon/src/operators/loss/l2_loss_op.cc
Dragon/src/operators/loss/sigmoid_ce_loss_op.cc
Dragon/src/operators/loss/smooth_l1_loss_op.cc
Dragon/src/operators/loss/softmax_ce_loss_op.cc
Dragon/src/operators/misc/accuracy_op.cc
Dragon/src/operators/misc/gradient_op.cc
Dragon/src/operators/misc/initialize_op.cc
Dragon/src/operators/misc/python_op.cc
Dragon/src/operators/mpi/mpi_broadcast_op.cc
Dragon/src/operators/mpi/mpi_gather_op.cc
Dragon/src/operators/norm/l2_norm_op.cc
Dragon/src/operators/recurrent/cudnn_recurrent_op.cc
Dragon/src/operators/recurrent/rnn_param_op.cc
Dragon/src/operators/vision/bias_add_op.cc
Dragon/src/operators/vision/bilinear_resize_op.cc
Dragon/src/operators/vision/conv2d_op.cc
Dragon/src/operators/vision/conv2d_transpose_op.cc
Dragon/src/operators/vision/cudnn_bias_add_op.cc
Dragon/src/operators/vision/cudnn_conv2d_op.cc
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
Dragon/src/operators/vision/cudnn_depthwise_conv2d_op.cc
Dragon/src/operators/vision/cudnn_lrn_op.cc
Dragon/src/operators/vision/cudnn_pool2d_op.cc
Dragon/src/operators/vision/depthwise_conv2d_op.cc
Dragon/src/operators/vision/drop_block2d_op.cc
Dragon/src/operators/vision/nn_resize_op.cc
Dragon/src/operators/vision/roi_align_op.cc
Dragon/src/operators/vision/roi_pool_op.cc
Dragon/src/utils/math_functions.cu
Dragon/src/utils/math_functions.fp16.cu
Dragon/include/core/common.h
View file @
d1f714e
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
#include "core/types.h"
#include "core/types.h"
#include "proto/dragon.pb.h"
#include "proto/dragon.pb.h"
#include "utils/string.h"
#include "utils/logging.h"
#include "utils/logging.h"
namespace
dragon
{
namespace
dragon
{
...
...
Dragon/include/core/graph.h
View file @
d1f714e
...
@@ -85,6 +85,8 @@ GraphBase* NewGraph(
...
@@ -85,6 +85,8 @@ GraphBase* NewGraph(
const
GraphDef
&
def
,
const
GraphDef
&
def
,
Workspace
*
ws
);
Workspace
*
ws
);
/* Macros */
DECLARE_REGISTRY
(
DECLARE_REGISTRY
(
GraphRegistry
,
GraphRegistry
,
GraphBase
,
GraphBase
,
...
...
Dragon/include/core/graph_gradient.h
View file @
d1f714e
...
@@ -43,7 +43,7 @@ class GraphGradientMaker {
...
@@ -43,7 +43,7 @@ class GraphGradientMaker {
bool
CheckGrad
(
bool
CheckGrad
(
const
OperatorDef
&
forward_op
,
const
OperatorDef
&
forward_op
,
const
Set
<
string
>&
targets
,
const
Set
<
string
>&
targets
,
vector
<
pair
<
string
,
int
>
>&
gen_grads
);
vector
<
pair
<
string
,
int
>>&
gen_grads
);
string
GetOperatorName
();
string
GetOperatorName
();
...
...
Dragon/include/core/operator.h
View file @
d1f714e
...
@@ -100,7 +100,7 @@ class OperatorBase {
...
@@ -100,7 +100,7 @@ class OperatorBase {
/*! \brief Return the specified argument */
/*! \brief Return the specified argument */
const
Argument
&
arg
(
const
string
&
name
)
{
return
*
(
args_
[
name
]);
}
const
Argument
&
arg
(
const
string
&
name
)
{
return
*
(
args_
[
name
]);
}
typedef
Map
<
string
,
vector
<
OperatorBase
*>
>
SubGraph
;
typedef
Map
<
string
,
vector
<
OperatorBase
*>>
SubGraph
;
/*! \brief Return the recomputing subgraph of this operator */
/*! \brief Return the recomputing subgraph of this operator */
SubGraph
&
subgraph
()
{
return
subgraph_
;
}
SubGraph
&
subgraph
()
{
return
subgraph_
;
}
...
@@ -221,7 +221,7 @@ OperatorBase* NewOperator(
...
@@ -221,7 +221,7 @@ OperatorBase* NewOperator(
const
OperatorDef
&
def
,
const
OperatorDef
&
def
,
Workspace
*
ws
);
Workspace
*
ws
);
/*
!
Macros */
/* Macros */
#define OpArg OperatorBase::Arg
#define OpArg OperatorBase::Arg
#define OpArgs OperatorBase::Args
#define OpArgs OperatorBase::Args
...
@@ -266,7 +266,7 @@ DECLARE_REGISTRY(
...
@@ -266,7 +266,7 @@ DECLARE_REGISTRY(
const
OperatorDef
&
,
const
OperatorDef
&
,
Workspace
*
);
Workspace
*
);
/*
!
NVIDIA's Accelerated Library - CUDNN */
/* NVIDIA's Accelerated Library - CUDNN */
DECLARE_REGISTRY
(
DECLARE_REGISTRY
(
CUDNNOperatorRegistry
,
CUDNNOperatorRegistry
,
...
@@ -274,7 +274,7 @@ DECLARE_REGISTRY(
...
@@ -274,7 +274,7 @@ DECLARE_REGISTRY(
const
OperatorDef
&
,
const
OperatorDef
&
,
Workspace
*
);
Workspace
*
);
/*
!
CAMBRICON's Accelerated Library - CNML */
/* CAMBRICON's Accelerated Library - CNML */
DECLARE_REGISTRY
(
DECLARE_REGISTRY
(
CNMLOperatorRegistry
,
CNMLOperatorRegistry
,
...
@@ -282,13 +282,60 @@ DECLARE_REGISTRY(
...
@@ -282,13 +282,60 @@ DECLARE_REGISTRY(
const
OperatorDef
&
,
const
OperatorDef
&
,
Workspace
*
);
Workspace
*
);
/* Dispatcher for Runtime Typed-Implementation */
#define XIsType(x, dtype) \
x.template IsType<dtype>()
template
<
typename
...
Types
>
struct
TensorTypes
{};
template
<
typename
Sizes
,
typename
...
Args
>
struct
DispatchHelper
;
#define DEFINE_TENSOR_TYPES_DISPATCHER(TensorTypes, Impl) \
template <typename T, typename... Types, typename... Args> \
struct DispatchHelper<TensorTypes<T, Types...>, Args...> { \
template <typename Op> \
static void Call(Op* op, const TypeMeta& meta, string& types) { \
if (meta.Match<T>()) return op->template Impl<T, Args...>(); \
types += " * " + TypeToString<T>() + ",\n"; \
return DispatchHelper<TensorTypes<Types...>, Args...> \
::Call(op, meta, types); \
} \
template <typename Op> \
static void Call(Op* op, const Tensor& tensor) { \
string types; return Call(op, tensor.meta(), types); \
} \
}; \
template <typename... Args> \
struct DispatchHelper<TensorTypes<>, Args...> { \
template <typename Op> \
static void Call(Op* op, const TypeMeta& meta, string& types) { \
LOG(FATAL) << "Unsupported DType: " \
<< TypeMetaToString(meta) << "\n" \
<< "<" << op->type() << "Op>" \
<< " supports the following dtypes: {\n" \
<< types << "}"; \
} \
template <typename Op> \
static void Call(Op* op, const Tensor& tensor) { \
return Call(op, tensor.meta(), ""); \
} \
};
DEFINE_TENSOR_TYPES_DISPATCHER
(
TensorTypes
,
RunImpl
);
#undef DEFINE_TENSOR_TYPES_DISPATCHER
/* TensorFiller */
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \
if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \
CHECK(ws()->GetFiller(tensor.name())) \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it ?"; \
<< "may be specify a filler for it ?"; \
tensor.Reshape(shape); \
tensor.Reshape(shape); \
unique_ptr<
Filler<type, Context>
> filler( \
unique_ptr<
Filler<type, Context>
> filler( \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \
filler->Fill(&tensor, ctx()); \
} else { \
} else { \
...
@@ -308,7 +355,7 @@ DECLARE_REGISTRY(
...
@@ -308,7 +355,7 @@ DECLARE_REGISTRY(
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it ?"; \
<< "may be specify a filler for it ?"; \
tensor.Reshape(shape); \
tensor.Reshape(shape); \
unique_ptr<
Filler<T, Context>
> filler( \
unique_ptr<
Filler<T, Context>
> filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \
filler->Fill(&tensor, ctx()); \
} else { \
} else { \
...
@@ -322,6 +369,8 @@ DECLARE_REGISTRY(
...
@@ -322,6 +369,8 @@ DECLARE_REGISTRY(
tensor.Reshape(shape); \
tensor.Reshape(shape); \
}
}
/* Shared Multiplier */
#define DECLARE_MULTIPLIER(name, size) \
#define DECLARE_MULTIPLIER(name, size) \
const T* name; \
const T* name; \
{ \
{ \
...
@@ -335,6 +384,8 @@ DECLARE_REGISTRY(
...
@@ -335,6 +384,8 @@ DECLARE_REGISTRY(
name = mp->template data<T, Context>(); \
name = mp->template data<T, Context>(); \
}
}
/* Dynamic Arguments */
#define DECLARE_ARG_WITH_DESC(type, arg) \
#define DECLARE_ARG_WITH_DESC(type, arg) \
type arg##_; \
type arg##_; \
string arg##_desc_; \
string arg##_desc_; \
...
@@ -393,8 +444,7 @@ DECLARE_REGISTRY(
...
@@ -393,8 +444,7 @@ DECLARE_REGISTRY(
#define GET_ARGS_SIZE(arg) \
#define GET_ARGS_SIZE(arg) \
(int)std::max(arg##_.size(), arg##_desc_.size())
(int)std::max(arg##_.size(), arg##_desc_.size())
#define XIsType(x, dtype) \
/* Registers */
x.template IsType<dtype>()
#define INSTANTIATE_OPERATOR(name, context) \
#define INSTANTIATE_OPERATOR(name, context) \
template class name##Op<context>;
template class name##Op<context>;
...
...
Dragon/include/core/operator_schema.h
View file @
d1f714e
...
@@ -42,7 +42,7 @@ class OpSchema {
...
@@ -42,7 +42,7 @@ class OpSchema {
return
*
this
;
return
*
this
;
}
}
OpSchema
&
Inplace
(
set
<
pair
<
int
,
int
>
>
inplace
);
OpSchema
&
Inplace
(
set
<
pair
<
int
,
int
>>
inplace
);
std
::
function
<
bool
(
int
,
int
)
>
CheckInplace
;
std
::
function
<
bool
(
int
,
int
)
>
CheckInplace
;
bool
AllowInplace
()
const
{
return
allow_inplace_
;
}
bool
AllowInplace
()
const
{
return
allow_inplace_
;
}
...
...
Dragon/include/core/types.h
View file @
d1f714e
...
@@ -73,6 +73,11 @@ inline const std::string TypeMetaToString(
...
@@ -73,6 +73,11 @@ inline const std::string TypeMetaToString(
m2s_type_map
[
meta
.
id
()]
:
"unknown"
;
m2s_type_map
[
meta
.
id
()]
:
"unknown"
;
}
}
template
<
typename
T
>
inline
const
std
::
string
TypeToString
()
{
return
TypeMetaToString
(
TypeMeta
::
Make
<
T
>
());
}
}
// namespace dragon
}
// namespace dragon
#endif // DRAGON_CORE_TYPES_H_
#endif // DRAGON_CORE_TYPES_H_
\ No newline at end of file
Dragon/include/core/workspace.h
View file @
d1f714e
...
@@ -13,22 +13,18 @@
...
@@ -13,22 +13,18 @@
#ifndef DRAGON_CORE_WORKSPACE_H_
#ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_
#include "core/common.h"
#include "core/graph.h"
#include "core/graph.h"
#include "utils/string.h"
namespace
dragon
{
namespace
dragon
{
class
Workspace
{
class
Workspace
{
public
:
public
:
typedef
Map
<
string
,
Map
<
string
,
int64_t
>
>
DummyNameMap
;
typedef
Map
<
string
,
Map
<
string
,
int64_t
>>
DummyNameMap
;
typedef
Map
<
string
,
unique_ptr
<
Tensor
>>
TensorMap
;
typedef
Map
<
string
,
unique_ptr
<
Tensor
>
>
TensorMap
;
typedef
Map
<
string
,
string
>
TensorAliasMap
;
typedef
Map
<
string
,
string
>
TensorAliasMap
;
typedef
Map
<
string
,
TensorFillerProto
>
TensorFillerMap
;
typedef
Map
<
string
,
TensorFillerProto
>
TensorFillerMap
;
typedef
Map
<
string
,
unique_ptr
<
OperatorBase
>>
OperatorMap
;
typedef
Map
<
string
,
unique_ptr
<
OperatorBase
>
>
OperatorMap
;
typedef
Map
<
string
,
unique_ptr
<
GraphBase
>>
GraphMap
;
typedef
Map
<
string
,
unique_ptr
<
GraphBase
>
>
GraphMap
;
/*! \brief Constructor */
/*! \brief Constructor */
Workspace
(
const
string
&
name
)
:
name_
(
name
)
{
Initialize
();
}
Workspace
(
const
string
&
name
)
:
name_
(
name
)
{
Initialize
();
}
...
...
Dragon/include/operators/arithmetic/fully_connected_op.h
View file @
d1f714e
...
@@ -28,6 +28,7 @@ class FullyConnectedOp final : public Operator<Context> {
...
@@ -28,6 +28,7 @@ class FullyConnectedOp final : public Operator<Context> {
USE_OPERATOR_FUNCTIONS
;
USE_OPERATOR_FUNCTIONS
;
void
RunOnDevice
();
void
RunOnDevice
();
template
<
typename
T
>
void
RunImpl
();
template
<
typename
T
>
void
TransRunImpl
();
template
<
typename
T
>
void
TransRunImpl
();
template
<
typename
T
>
void
NoTransRunImpl
();
template
<
typename
T
>
void
NoTransRunImpl
();
...
...
Dragon/include/operators/array/multinomial_op.h
View file @
d1f714e
...
@@ -22,6 +22,7 @@ class MultinomialOp final : public Operator<Context> {
...
@@ -22,6 +22,7 @@ class MultinomialOp final : public Operator<Context> {
public
:
public
:
MultinomialOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
MultinomialOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
def
,
ws
),
:
Operator
<
Context
>
(
def
,
ws
),
eps_
(
OpArg
<
float
>
(
"eps"
,
0
.
f
)),
normalize_
(
OpArg
<
int64_t
>
(
"normalize"
,
0
)),
normalize_
(
OpArg
<
int64_t
>
(
"normalize"
,
0
)),
num_samples_
(
OpArg
<
int64_t
>
(
"num_samples"
,
1
))
{}
num_samples_
(
OpArg
<
int64_t
>
(
"num_samples"
,
1
))
{}
USE_OPERATOR_FUNCTIONS
;
USE_OPERATOR_FUNCTIONS
;
...
@@ -32,6 +33,7 @@ class MultinomialOp final : public Operator<Context> {
...
@@ -32,6 +33,7 @@ class MultinomialOp final : public Operator<Context> {
template
<
typename
T
>
void
RunImpl
();
template
<
typename
T
>
void
RunImpl
();
protected
:
protected
:
float
eps_
;
int64_t
outer_dim_
,
axis_
;
int64_t
outer_dim_
,
axis_
;
int64_t
normalize_
,
num_samples_
;
int64_t
normalize_
,
num_samples_
;
unique_ptr
<
OperatorBase
>
softmax_op_
;
unique_ptr
<
OperatorBase
>
softmax_op_
;
...
...
Dragon/include/utils/caffemodel.h
View file @
d1f714e
...
@@ -26,22 +26,24 @@ inline void LoadCaffeModel(
...
@@ -26,22 +26,24 @@ inline void LoadCaffeModel(
LOG
(
INFO
)
<<
"Restore From Model @: "
<<
file
<<
"......"
;
LOG
(
INFO
)
<<
"Restore From Model @: "
<<
file
<<
"......"
;
LOG
(
INFO
)
<<
"Model Format: CaffeModel"
;
LOG
(
INFO
)
<<
"Model Format: CaffeModel"
;
for
(
int
i
=
0
;
i
<
net_param
.
layer_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
net_param
.
layer_size
();
i
++
)
{
const
LayerParameter
&
layer
=
net_param
.
layer
(
i
);
const
auto
&
layer
=
net_param
.
layer
(
i
);
const
string
&
layer_name
=
layer
.
name
();
const
auto
&
layer_name
=
layer
.
name
();
string
prefix
=
layer_name
+
"/param:"
;
auto
prefix
=
layer_name
+
"/param:"
;
for
(
int
j
=
0
;
j
<
layer
.
blobs_size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
layer
.
blobs_size
();
j
++
)
{
string
tensor_name
=
prefix
+
std
::
to_string
(
j
);
auto
tensor_name
=
prefix
+
std
::
to_string
(
j
);
if
(
!
ws
->
HasTensor
(
tensor_name
))
if
(
!
ws
->
HasTensor
(
tensor_name
))
{
LOG
(
WARNING
)
<<
"Tensor("
<<
tensor_name
<<
") "
LOG
(
WARNING
)
<<
"does not exist in any Graphs, skip."
;
<<
"Tensor("
<<
tensor_name
<<
") "
else
{
<<
"does not exist in any Graphs, skip."
;
BlobProto
blob
=
layer
.
blobs
(
j
);
}
else
{
vector
<
int64_t
>
dims
;
auto
blob
=
layer
.
blobs
(
j
);
for
(
auto
dim
:
blob
.
shape
().
dim
())
dims
.
push_back
(
dim
);
vec64_t
tensor_shape
;
Tensor
*
tensor
=
ws
->
GetTensor
(
tensor_name
);
for
(
auto
dim
:
blob
.
shape
().
dim
())
tensor_shape
.
push_back
(
dim
);
auto
*
tensor
=
ws
->
GetTensor
(
tensor_name
);
std
::
stringstream
DimString
;
std
::
stringstream
DimString
;
if
(
dims
.
size
()
>
0
)
{
if
(
tensor_shape
.
size
()
>
0
)
{
tensor
->
Reshape
(
dims
);
tensor
->
Reshape
(
tensor_shape
);
CHECK_EQ
(
tensor
->
count
(),
blob
.
data_size
())
CHECK_EQ
(
tensor
->
count
(),
blob
.
data_size
())
<<
"
\n
Tensor("
<<
tensor_name
<<
") "
<<
"
\n
Tensor("
<<
tensor_name
<<
") "
<<
"failed to load, except size: "
<<
"failed to load, except size: "
...
@@ -52,9 +54,9 @@ inline void LoadCaffeModel(
...
@@ -52,9 +54,9 @@ inline void LoadCaffeModel(
tensor
->
Reshape
({
blob
.
data_size
()
});
tensor
->
Reshape
({
blob
.
data_size
()
});
DimString
<<
"(missing)"
;
DimString
<<
"(missing)"
;
}
}
float
*
Xdata
=
tensor
->
mutable_data
<
float
,
CPUContext
>
();
auto
*
x
=
tensor
->
mutable_data
<
float
,
CPUContext
>
();
for
(
int
idx
=
0
;
idx
<
blob
.
data_size
();
idx
++
)
for
(
int
xi
=
0
;
xi
<
blob
.
data_size
();
++
xi
)
Xdata
[
idx
]
=
blob
.
data
(
idx
);
x
[
xi
]
=
blob
.
data
(
xi
);
LOG
(
INFO
)
<<
"Tensor("
<<
tensor_name
<<
") "
LOG
(
INFO
)
<<
"Tensor("
<<
tensor_name
<<
") "
<<
"loaded, shape: "
<<
DimString
.
str
()
<<
"loaded, shape: "
<<
DimString
.
str
()
<<
", size: "
<<
blob
.
data_size
();
<<
", size: "
<<
blob
.
data_size
();
...
@@ -66,32 +68,33 @@ inline void LoadCaffeModel(
...
@@ -66,32 +68,33 @@ inline void LoadCaffeModel(
inline
void
SavaCaffeModel
(
inline
void
SavaCaffeModel
(
string
file
,
string
file
,
const
vector
<
Tensor
*>&
tensors
)
{
const
vector
<
Tensor
*>&
tensors
)
{
NetParameter
net_param
;
int
j
=
-
1
;
NetParameter
net
;
Map
<
string
,
int
>
layer_hash
;
Map
<
string
,
int
>
layer_hash
;
int
layer_idx
=
-
1
;
for
(
int
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
if
(
tensors
[
i
]
->
count
()
<=
0
)
continue
;
if
(
tensors
[
i
]
->
count
()
<=
0
)
continue
;
vector
<
string
>
splits
=
str
::
split
(
auto
splits
=
str
::
split
(
tensors
[
i
]
->
name
(),
"/param:"
);
tensors
[
i
]
->
name
(),
"/param:"
);
if
(
layer_hash
.
count
(
splits
[
0
])
==
0
)
{
if
(
layer_hash
.
count
(
splits
[
0
])
==
0
)
{
layer_hash
[
splits
[
0
]]
=
++
layer_idx
;
layer_hash
[
splits
[
0
]]
=
++
j
;
LayerParameter
*
layer
=
net_param
.
add_layer
();
auto
*
layer
=
net
.
add_layer
();
layer
->
set_name
(
splits
[
0
]);
layer
->
set_name
(
splits
[
0
]);
}
}
BlobProto
*
blob
=
net_param
.
mutable_layer
(
layer_idx
)
->
add_blobs
();
auto
*
blob
=
net
.
mutable_layer
(
j
)
->
add_blobs
();
for
(
auto
dim
:
tensors
[
i
]
->
dims
())
blob
->
mutable_shape
()
->
add_dim
(
dim
);
for
(
auto
dim
:
tensors
[
i
]
->
dims
())
blob
->
mutable_shape
()
->
add_dim
(
dim
);
if
(
XIsType
((
*
tensors
[
i
]),
float
))
{
if
(
XIsType
((
*
tensors
[
i
]),
float
))
{
auto
*
Xdata
=
tensors
[
i
]
->
data
<
float
,
CPUContext
>
();
auto
*
x
=
tensors
[
i
]
->
data
<
float
,
CPUContext
>
();
for
(
int
id
=
0
;
id
<
tensors
[
i
]
->
count
();
id
++
)
for
(
int
xi
=
0
;
xi
<
tensors
[
i
]
->
count
();
++
xi
)
blob
->
mutable_data
()
->
Add
(
Xdata
[
id
]);
blob
->
mutable_data
()
->
Add
(
x
[
xi
]);
}
else
if
(
XIsType
((
*
tensors
[
i
]),
float16
))
{
}
else
if
(
XIsType
((
*
tensors
[
i
]),
float16
))
{
auto
*
Xdata
=
tensors
[
i
]
->
data
<
float16
,
CPUContext
>
();
auto
*
x
=
tensors
[
i
]
->
data
<
float16
,
CPUContext
>
();
for
(
int
id
=
0
;
id
<
tensors
[
i
]
->
count
();
id
++
)
for
(
int
xi
=
0
;
xi
<
tensors
[
i
]
->
count
();
++
xi
)
blob
->
mutable_data
()
->
Add
(
blob
->
mutable_data
()
->
Add
(
cast
::
to
<
float
>
(
Xdata
[
id
]));
cast
::
to
<
float
>
(
x
[
xi
]));
}
}
}
}
WriteProtoToBinaryFile
(
net
_param
,
file
.
c_str
());
WriteProtoToBinaryFile
(
net
,
file
.
c_str
());
LOG
(
INFO
)
<<
"Save the model @: "
<<
file
<<
"......"
;
LOG
(
INFO
)
<<
"Save the model @: "
<<
file
<<
"......"
;
LOG
(
INFO
)
<<
"Model format: Caffe"
;
LOG
(
INFO
)
<<
"Model format: Caffe"
;
}
}
...
...
Dragon/python/dragon/operators/array.py
View file @
d1f714e
...
@@ -748,7 +748,7 @@ def Arange(start, stop=None, step=1, dtype='float32', **kwargs):
...
@@ -748,7 +748,7 @@ def Arange(start, stop=None, step=1, dtype='float32', **kwargs):
@OpSchema.Inputs
(
1
)
@OpSchema.Inputs
(
1
)
def
Multinomial
(
inputs
,
num_samples
=
1
,
normalize
=
False
,
**
kwargs
):
def
Multinomial
(
inputs
,
num_samples
=
1
,
eps
=
0.
,
normalize
=
False
,
**
kwargs
):
"""Return a tensor where each row contains ``num_samples``,
"""Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution.
sampled from the multinomial distribution.
...
@@ -765,6 +765,8 @@ def Multinomial(inputs, num_samples=1, normalize=False, **kwargs):
...
@@ -765,6 +765,8 @@ def Multinomial(inputs, num_samples=1, normalize=False, **kwargs):
The input tensor.
The input tensor.
num_samples : int, optional, default=1
num_samples : int, optional, default=1
The number of samples.
The number of samples.
eps : float, optional, default=0.
The prob to a uniform sampling.
normalize : boolean, optional, default=False
normalize : boolean, optional, default=False
Whether to normalize the inputs.
Whether to normalize the inputs.
...
...
Dragon/python/dragon/vm/torch/ops/builtin.py
View file @
d1f714e
...
@@ -987,7 +987,7 @@ def one_hot(input, depth):
...
@@ -987,7 +987,7 @@ def one_hot(input, depth):
return
module
.
forward
(
input
)
return
module
.
forward
(
input
)
def
multinomial
(
input
,
num_samples
,
out
=
None
):
def
multinomial
(
input
,
num_samples
,
eps
=
0.
,
out
=
None
):
"""Return a tensor where each row contains ``num_samples``,
"""Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution.
sampled from the multinomial distribution.
...
@@ -997,8 +997,8 @@ def multinomial(input, num_samples, out=None):
...
@@ -997,8 +997,8 @@ def multinomial(input, num_samples, out=None):
The input tensor.
The input tensor.
num_samples : int
num_samples : int
The number of samples.
The number of samples.
normalize : boolean, optional, default=False
eps : float, optional, default=0.
Whether to normalize the inputs
.
The prob to a uniform sampling
.
Returns
Returns
-------
-------
...
@@ -1008,9 +1008,11 @@ def multinomial(input, num_samples, out=None):
...
@@ -1008,9 +1008,11 @@ def multinomial(input, num_samples, out=None):
"""
"""
dev
=
MakeDevice
(
inputs
=
[
input
])
dev
=
MakeDevice
(
inputs
=
[
input
])
key
=
'Multinomial/{}'
\
key
=
'Multinomial/{}'
\
'/num_samples:{}'
.
format
(
dev
,
num_samples
)
'/num_samples:{}'
\
'/eps:{}'
.
format
(
dev
,
num_samples
,
eps
)
module
=
get_module
(
module
=
get_module
(
Multinomial
,
key
,
dev
,
Multinomial
,
key
,
dev
,
eps
=
eps
,
num_samples
=
num_samples
,
num_samples
=
num_samples
,
)
)
return
module
.
forward
(
input
,
out
)
return
module
.
forward
(
input
,
out
)
...
...
Dragon/python/dragon/vm/torch/ops/modules/array.py
View file @
d1f714e
...
@@ -377,6 +377,7 @@ class Cast(BaseModule):
...
@@ -377,6 +377,7 @@ class Cast(BaseModule):
class
Multinomial
(
BaseModule
):
class
Multinomial
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Multinomial
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Multinomial
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
eps
=
kwargs
.
get
(
'eps'
,
0
)
self
.
num_samples
=
kwargs
.
get
(
'num_samples'
,
1
)
self
.
num_samples
=
kwargs
.
get
(
'num_samples'
,
1
)
self
.
register_op
()
self
.
register_op
()
...
@@ -384,6 +385,7 @@ class Multinomial(BaseModule):
...
@@ -384,6 +385,7 @@ class Multinomial(BaseModule):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'Multinomial'
,
'op_type'
:
'Multinomial'
,
'arguments'
:
{
'arguments'
:
{
'eps'
:
float
(
self
.
eps
),
'num_samples'
:
self
.
num_samples
,
'num_samples'
:
self
.
num_samples
,
'normalize'
:
False
,
'normalize'
:
False
,
},
},
...
...
Dragon/python/dragon/vm/torch/tensor.py
View file @
d1f714e
...
@@ -980,7 +980,7 @@ class Tensor(object):
...
@@ -980,7 +980,7 @@ class Tensor(object):
"""
"""
raise
NotImplementedError
(
'Refer torch.ops.tensor.normal_'
)
raise
NotImplementedError
(
'Refer torch.ops.tensor.normal_'
)
def
multinomial
(
self
,
num_samples
,
normalize
=
False
):
def
multinomial
(
self
,
num_samples
,
eps
=
0.
):
"""Return a tensor where each row contains ``num_samples``,
"""Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution.
sampled from the multinomial distribution.
...
@@ -988,8 +988,8 @@ class Tensor(object):
...
@@ -988,8 +988,8 @@ class Tensor(object):
----------
----------
num_samples : int
num_samples : int
The number of samples.
The number of samples.
normalize : boolean, optional, default=False
eps : float, optional, default=0.
Whether to normalize the inputs
.
The prob to a uniform sampling
.
Returns
Returns
-------
-------
...
...
Dragon/src/contrib/rcnn/bbox_utils.cu
View file @
d1f714e
...
@@ -81,8 +81,8 @@ void _ApplyNMS(
...
@@ -81,8 +81,8 @@ void _ApplyNMS(
CUDA_CHECK(cudaMemcpy(boxes_dev, boxes,
CUDA_CHECK(cudaMemcpy(boxes_dev, boxes,
boxes_nbytes, cudaMemcpyHostToDevice));
boxes_nbytes, cudaMemcpyHostToDevice));
nms_mask<T>
nms_mask<T>
<<
< blocks, NMS_BLOCK_SIZE,
<<< blocks, NMS_BLOCK_SIZE,
0, ctx->cuda_stream() >>
> (num_boxes,
0, ctx->cuda_stream() >>
> (num_boxes,
thresh, (T*)boxes_dev, (uint64_t*)mask_dev);
thresh, (T*)boxes_dev, (uint64_t*)mask_dev);
ctx->FinishDeviceCompution();
ctx->FinishDeviceCompution();
...
...
Dragon/src/contrib/rcnn/bbox_utils.h
View file @
d1f714e
...
@@ -347,7 +347,7 @@ inline void CollectRoIs(
...
@@ -347,7 +347,7 @@ inline void CollectRoIs(
const
int
canonical_level
,
const
int
canonical_level
,
const
int
canonical_scale
,
const
int
canonical_scale
,
const
T
*
rois
,
const
T
*
rois
,
vector
<
vector
<
int64_t
>
>&
roi_bins
)
{
vector
<
vec64_t
>&
roi_bins
)
{
const
T
*
roi
=
rois
;
const
T
*
roi
=
rois
;
for
(
int
i
=
0
;
i
<
num_rois
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_rois
;
++
i
)
{
int
bin_idx
=
roi_level
(
min_level
,
max_level
,
int
bin_idx
=
roi_level
(
min_level
,
max_level
,
...
@@ -360,7 +360,7 @@ inline void CollectRoIs(
...
@@ -360,7 +360,7 @@ inline void CollectRoIs(
template
<
typename
T
>
template
<
typename
T
>
inline
void
DistributeRoIs
(
inline
void
DistributeRoIs
(
const
vector
<
vector
<
int64_t
>
>&
roi_bins
,
const
vector
<
vec64_t
>&
roi_bins
,
const
T
*
rois
,
const
T
*
rois
,
vector
<
T
*>
outputs
)
{
vector
<
T
*>
outputs
)
{
for
(
int
i
=
0
;
i
<
roi_bins
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
roi_bins
.
size
();
i
++
)
{
...
...
Dragon/src/core/graph.cc
View file @
d1f714e
...
@@ -123,7 +123,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws)
...
@@ -123,7 +123,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws)
// Recomputing-aware
// Recomputing-aware
if
(
subgraph_indices
.
size
()
>
0
)
{
if
(
subgraph_indices
.
size
()
>
0
)
{
Map
<
string
,
vector
<
OperatorBase
*>
>
subgraph
;
Map
<
string
,
vector
<
OperatorBase
*>
>
subgraph
;
for
(
const
auto
&
it
:
subgraph_indices
)
{
for
(
const
auto
&
it
:
subgraph_indices
)
{
subgraph
[
it
.
first
]
=
vector
<
OperatorBase
*>
();
subgraph
[
it
.
first
]
=
vector
<
OperatorBase
*>
();
for
(
const
auto
&
idx
:
subgraph_indices
[
it
.
first
])
for
(
const
auto
&
idx
:
subgraph_indices
[
it
.
first
])
...
...
Dragon/src/core/graph_gradient.cc
View file @
d1f714e
...
@@ -7,7 +7,7 @@ namespace dragon {
...
@@ -7,7 +7,7 @@ namespace dragon {
bool
GraphGradientMaker
::
CheckGrad
(
bool
GraphGradientMaker
::
CheckGrad
(
const
OperatorDef
&
forward_op
,
const
OperatorDef
&
forward_op
,
const
Set
<
string
>&
targets
,
const
Set
<
string
>&
targets
,
vector
<
pair
<
string
,
int
>
>&
gen_grads
)
{
vector
<
pair
<
string
,
int
>>&
gen_grads
)
{
if
(
NoGradientRegistry
()
->
Has
(
forward_op
.
type
()))
{
if
(
NoGradientRegistry
()
->
Has
(
forward_op
.
type
()))
{
for
(
auto
&
input
:
forward_op
.
input
())
for
(
auto
&
input
:
forward_op
.
input
())
blacklist_set_
.
insert
(
input
);
blacklist_set_
.
insert
(
input
);
...
@@ -81,7 +81,7 @@ void GraphGradientMaker::Make(
...
@@ -81,7 +81,7 @@ void GraphGradientMaker::Make(
for
(
int
i
=
(
int
)
forward_def
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
(
int
)
forward_def
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
// Collect inputs & outputs, generate RAW grad ops
// Collect inputs & outputs, generate RAW grad ops
const
OperatorDef
&
op
=
*
forward_def
[
i
];
const
OperatorDef
&
op
=
*
forward_def
[
i
];
vector
<
pair
<
string
,
int
>
>
gen_grads
;
vector
<
pair
<
string
,
int
>
>
gen_grads
;
bool
is_skip
=
CheckGrad
(
op
,
targets_set
,
gen_grads
);
bool
is_skip
=
CheckGrad
(
op
,
targets_set
,
gen_grads
);
vector
<
string
>
g_outputs
;
vector
<
string
>
g_outputs
;
for
(
auto
&
output
:
op
.
output
())
{
for
(
auto
&
output
:
op
.
output
())
{
...
@@ -214,7 +214,7 @@ void GraphGradientMaker::Make(
...
@@ -214,7 +214,7 @@ void GraphGradientMaker::Make(
GraphDef
GraphGradientMaker
::
Share
(
const
GraphDef
&
input_def
)
{
GraphDef
GraphGradientMaker
::
Share
(
const
GraphDef
&
input_def
)
{
Set
<
int
>
invalid_ops
;
Set
<
int
>
invalid_ops
;
Map
<
string
,
int
>
ref_count
;
Map
<
string
,
int
>
ref_count
;
Map
<
string
,
pair
<
int
,
string
>
>
ssa_map
;
Map
<
string
,
pair
<
int
,
string
>
>
ssa_map
;
// Count the refs for detecting leaf nodes
// Count the refs for detecting leaf nodes
for
(
int
i
=
0
;
i
<
input_def
.
op_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
input_def
.
op_size
();
++
i
)
{
const
OperatorDef
&
op
=
input_def
.
op
(
i
);
const
OperatorDef
&
op
=
input_def
.
op
(
i
);
...
...
Dragon/src/core/graph_optimizer.cc
View file @
d1f714e
...
@@ -174,7 +174,7 @@ GraphDef GraphOptimizer::MirrorStage(
...
@@ -174,7 +174,7 @@ GraphDef GraphOptimizer::MirrorStage(
const
GraphDef
&
input_def
,
const
GraphDef
&
input_def
,
Map
<
string
,
vec32_t
>&
op_indices
)
{
Map
<
string
,
vec32_t
>&
op_indices
)
{
GraphDef
output_def
(
input_def
);
GraphDef
output_def
(
input_def
);
Map
<
string
,
set
<
int
>
>
fake_op_indices
;
Map
<
string
,
set
<
int
>>
fake_op_indices
;
Map
<
string
,
string
>
rename_map
;
Map
<
string
,
string
>
rename_map
;
Map
<
string
,
int
>
versions
;
Map
<
string
,
int
>
versions
;
...
...
Dragon/src/core/operator_schema.cc
View file @
d1f714e
...
@@ -54,7 +54,7 @@ OpSchema& OpSchema::NumOutputs(int n) {
...
@@ -54,7 +54,7 @@ OpSchema& OpSchema::NumOutputs(int n) {
return
NumOutputs
(
n
,
n
);
return
NumOutputs
(
n
,
n
);
}
}
OpSchema
&
OpSchema
::
Inplace
(
set
<
pair
<
int
,
int
>
>
inplace
)
{
OpSchema
&
OpSchema
::
Inplace
(
set
<
pair
<
int
,
int
>
>
inplace
)
{
CheckInplace
=
[
inplace
](
int
in
,
int
out
)
->
bool
{
CheckInplace
=
[
inplace
](
int
in
,
int
out
)
->
bool
{
return
(
inplace
.
count
(
std
::
make_pair
(
in
,
out
))
>
0
);
return
(
inplace
.
count
(
std
::
make_pair
(
in
,
out
))
>
0
);
};
};
...
...
Dragon/src/kernels/activation/dropout_op_kernel.cu
View file @
d1f714e
...
@@ -37,14 +37,10 @@ template<> void Dropout<float, CUDAContext>(
...
@@ -37,14 +37,10 @@ template<> void Dropout<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
auto thresh = (uint32_t)(UINT_MAX * prob);
auto thresh = (uint32_t)(UINT_MAX * prob);
math::RandomUniform(
math::RandomUniform(count, 0.f, 1.f, mask32, ctx);
count,
0.f, (float)UINT_MAX,
mask32, ctx
);
_Dropout
_Dropout
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
thresh,
thresh,
scale,
scale,
...
@@ -85,14 +81,10 @@ template<> void Dropout<float16, CUDAContext>(
...
@@ -85,14 +81,10 @@ template<> void Dropout<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
auto thresh = (uint32_t)(UINT_MAX * prob);
auto thresh = (uint32_t)(UINT_MAX * prob);
math::RandomUniform(
math::RandomUniform(count, 0.f, 1.f, mask32, ctx);
count,
0.f, (float)UINT_MAX,
mask32, ctx
);
_Dropout
_Dropout
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
thresh,
thresh,
cast::to<half>(scale),
cast::to<half>(scale),
...
@@ -124,8 +116,8 @@ template <> void ApplyMask<float, uint8_t, CUDAContext>(
...
@@ -124,8 +116,8 @@ template <> void ApplyMask<float, uint8_t, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_ApplyMask
_ApplyMask
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, scale, x, mask, y
count, scale, x, mask, y
);
);
}
}
...
@@ -157,8 +149,8 @@ template <> void ApplyMask<float16, uint8_t, CUDAContext>(
...
@@ -157,8 +149,8 @@ template <> void ApplyMask<float16, uint8_t, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_ApplyMaskHalf
_ApplyMaskHalf
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
cast::to<half>(scale),
cast::to<half>(scale),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
...
...
Dragon/src/kernels/activation/droppath_op_kernel.cu
View file @
d1f714e
...
@@ -44,8 +44,8 @@ template<> void DropPath<float, CUDAContext>(
...
@@ -44,8 +44,8 @@ template<> void DropPath<float, CUDAContext>(
auto nthreads = rows * cols;
auto nthreads = rows * cols;
auto thresh = 1.f - (1.f / scale);
auto thresh = 1.f - (1.f / scale);
_DropPath
_DropPath
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, cols, thresh, scale, x, mask, y
nthreads, cols, thresh, scale, x, mask, y
);
);
}
}
...
@@ -85,8 +85,8 @@ template<> void DropPath<float16, CUDAContext>(
...
@@ -85,8 +85,8 @@ template<> void DropPath<float16, CUDAContext>(
auto nthreads = rows * cols;
auto nthreads = rows * cols;
auto thresh = 1.f - (1.f / scale);
auto thresh = 1.f - (1.f / scale);
_DropPath
_DropPath
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, cols,
nthreads, cols,
thresh,
thresh,
cast::to<half>(scale),
cast::to<half>(scale),
...
...
Dragon/src/kernels/activation/elu_op_kernel.cu
View file @
d1f714e
...
@@ -28,8 +28,8 @@ template<> void Elu<float, CUDAContext>(
...
@@ -28,8 +28,8 @@ template<> void Elu<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Elu
_Elu
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, x, alpha, y
count, x, alpha, y
);
);
}
}
...
@@ -58,8 +58,8 @@ template<> void EluGrad<float, CUDAContext>(
...
@@ -58,8 +58,8 @@ template<> void EluGrad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_EluGrad
_EluGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, alpha, dy, y, dx
count, alpha, dy, y, dx
);
);
}
}
...
...
Dragon/src/kernels/activation/prelu_op_kernel.cu
View file @
d1f714e
...
@@ -66,21 +66,21 @@ template<> void PRelu<float, CUDAContext>(
...
@@ -66,21 +66,21 @@ template<> void PRelu<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if (channel_shared) {
if (channel_shared) {
_PRelu
_PRelu
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, channels, dim, x, w, y
count, channels, dim, x, w, y
);
);
} else {
} else {
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_PReluNCHW
_PReluNCHW
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, channels, dim, x, w, y
count, channels, dim, x, w, y
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_PReluNHWC
_PReluNHWC
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, channels, dim, x, w, y
count, channels, dim, x, w, y
);
);
} else {
} else {
...
@@ -152,21 +152,21 @@ template<> void PReluGrad<float, CUDAContext>(
...
@@ -152,21 +152,21 @@ template<> void PReluGrad<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if (channel_shared) {
if (channel_shared) {
_PReluGrad
_PReluGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, channels, dim, dy, x, w, dx
count, channels, dim, dy, x, w, dx
);
);
} else {
} else {
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_PReluGradNCHW
_PReluGradNCHW
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, channels, dim, dy, x, w, dx
count, channels, dim, dy, x, w, dx
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_PReluGradNHWC
_PReluGradNHWC
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, channels, dim, dy, x, w, dx
count, channels, dim, dy, x, w, dx
);
);
} else {
} else {
...
@@ -210,8 +210,8 @@ template<> void PReluWGrad<float, CUDAContext>(
...
@@ -210,8 +210,8 @@ template<> void PReluWGrad<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto cdim = channels * dim;
auto cdim = channels * dim;
_PReluWGradBcast
_PReluWGradBcast
<<
< CUDA_BLOCKS(cdim), CUDA_THREADS,
<<< CUDA_BLOCKS(cdim), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
cdim, rows, row_offset, dy, x, bcast_dw
cdim, rows, row_offset, dy, x, bcast_dw
);
);
if (channel_shared) {
if (channel_shared) {
...
...
Dragon/src/kernels/activation/relu_op_kernel.cu
View file @
d1f714e
...
@@ -35,8 +35,8 @@ template<> void Relu<float, CUDAContext>(
...
@@ -35,8 +35,8 @@ template<> void Relu<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Relu
_Relu
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, slope, x, y
count, slope, x, y
);
);
}
}
...
@@ -83,8 +83,8 @@ template<> void Relu<float16, CUDAContext>(
...
@@ -83,8 +83,8 @@ template<> void Relu<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((count & 1) == 0) {
if ((count & 1) == 0) {
_Relu
_Relu
<<
< CUDA_BLOCKS(count >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(count >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count >> 1,
count >> 1,
cast::to<half2>(slope),
cast::to<half2>(slope),
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
...
@@ -92,8 +92,8 @@ template<> void Relu<float16, CUDAContext>(
...
@@ -92,8 +92,8 @@ template<> void Relu<float16, CUDAContext>(
);
);
} else {
} else {
_Relu
_Relu
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
cast::to<half>(slope),
cast::to<half>(slope),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
...
@@ -134,8 +134,8 @@ template<> void ReluGrad<float, CUDAContext>(
...
@@ -134,8 +134,8 @@ template<> void ReluGrad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_ReluGrad
_ReluGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, slope, dy, y, dx
count, slope, dy, y, dx
);
);
}
}
...
@@ -170,8 +170,8 @@ template<> void ReluGrad<float16, CUDAContext>(
...
@@ -170,8 +170,8 @@ template<> void ReluGrad<float16, CUDAContext>(
float16* dx,
float16* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_ReluGrad
_ReluGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, slope,
count, slope,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<const half*>(y),
...
...
Dragon/src/kernels/activation/selu_op_kernel.cu
View file @
d1f714e
...
@@ -34,8 +34,8 @@ template<> void SElu<float, CUDAContext>(
...
@@ -34,8 +34,8 @@ template<> void SElu<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SElu
_SElu
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, x, y
count, x, y
);
);
}
}
...
@@ -63,8 +63,8 @@ template<> void SElu<float16, CUDAContext>(
...
@@ -63,8 +63,8 @@ template<> void SElu<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SElu
_SElu
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -99,8 +99,8 @@ template<> void SEluGrad<float, CUDAContext>(
...
@@ -99,8 +99,8 @@ template<> void SEluGrad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SEluGrad
_SEluGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, dy, y, dx
count, dy, y, dx
);
);
}
}
...
@@ -131,8 +131,8 @@ template<> void SEluGrad<float16, CUDAContext>(
...
@@ -131,8 +131,8 @@ template<> void SEluGrad<float16, CUDAContext>(
float16* dx,
float16* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SEluGrad
_SEluGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<const half*>(y),
...
...
Dragon/src/kernels/activation/sigmoid_op_kernel.cu
View file @
d1f714e
...
@@ -25,8 +25,8 @@ template<> void Sigmoid<float, CUDAContext>(
...
@@ -25,8 +25,8 @@ template<> void Sigmoid<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Sigmoid
_Sigmoid
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, x, y
count, x, y
);
);
}
}
...
@@ -51,8 +51,8 @@ template<> void SigmoidGrad<float, CUDAContext>(
...
@@ -51,8 +51,8 @@ template<> void SigmoidGrad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SigmoidGrad
_SigmoidGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, dy, y, dx
count, dy, y, dx
);
);
}
}
...
...
Dragon/src/kernels/activation/softmax_op_kernel.cu
View file @
d1f714e
...
@@ -96,26 +96,26 @@ template<> void Softmax<float, CUDAContext>(
...
@@ -96,26 +96,26 @@ template<> void Softmax<float, CUDAContext>(
auto num_preds = outer_dim * inner_dim;
auto num_preds = outer_dim * inner_dim;
auto nelements = num_preds * axis_dim;
auto nelements = num_preds * axis_dim;
_SoftmaxReduceMax
_SoftmaxReduceMax
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim, x, scale
num_preds, axis_dim, inner_dim, x, scale
);
);
_SoftmaxSub
_SoftmaxSub
<<
< CUDA_BLOCKS(nelements), CUDA_THREADS,
<<< CUDA_BLOCKS(nelements), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nelements, axis_dim, inner_dim, scale, y
nelements, axis_dim, inner_dim, scale, y
);
);
math::Exp(nelements, y, y, ctx);
math::Exp(nelements, y, y, ctx);
_SoftmaxReduceSum
_SoftmaxReduceSum
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim, y, scale
num_preds, axis_dim, inner_dim, y, scale
);
);
_SoftmaxDiv
_SoftmaxDiv
<<
< CUDA_BLOCKS(nelements), CUDA_THREADS,
<<< CUDA_BLOCKS(nelements), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nelements, axis_dim, inner_dim, scale, y
nelements, axis_dim, inner_dim, scale, y
);
);
}
}
...
@@ -159,13 +159,13 @@ template<> void SoftmaxGrad<float, CUDAContext>(
...
@@ -159,13 +159,13 @@ template<> void SoftmaxGrad<float, CUDAContext>(
auto num_preds = outer_dim * inner_dim;
auto num_preds = outer_dim * inner_dim;
auto nelements = num_preds * axis_dim;
auto nelements = num_preds * axis_dim;
_SoftmaxDot
_SoftmaxDot
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim, dy, y, scale
num_preds, axis_dim, inner_dim, dy, y, scale
);
);
_SoftmaxSub
_SoftmaxSub
<<
< CUDA_BLOCKS(nelements), CUDA_THREADS,
<<< CUDA_BLOCKS(nelements), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nelements, axis_dim, inner_dim, scale, dx
nelements, axis_dim, inner_dim, scale, dx
);
);
math::Mul(nelements, dx, y, dx, ctx);
math::Mul(nelements, dx, y, dx, ctx);
...
...
Dragon/src/kernels/activation/tanh_op_kernel.cu
View file @
d1f714e
...
@@ -25,8 +25,8 @@ template<> void Tanh<float, CUDAContext>(
...
@@ -25,8 +25,8 @@ template<> void Tanh<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Tanh
_Tanh
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, x, y
count, x, y
);
);
}
}
...
@@ -51,8 +51,8 @@ template<> void TanhGrad<float, CUDAContext>(
...
@@ -51,8 +51,8 @@ template<> void TanhGrad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_TanhGrad
_TanhGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, dy, y, dx
count, dy, y, dx
);
);
}
}
...
...
Dragon/src/kernels/arithmetic/affine_op_kernel.cu
View file @
d1f714e
...
@@ -60,15 +60,15 @@ template<> void Affine<float, CUDAContext>(
...
@@ -60,15 +60,15 @@ template<> void Affine<float, CUDAContext>(
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
if (beta != nullptr) {
if (beta != nullptr) {
_Affine
_Affine
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim,
nthreads, axis_dim, inner_dim,
x, alpha, beta, y
x, alpha, beta, y
);
);
} else {
} else {
_AffineNoBias
_AffineNoBias
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, x, alpha, y
nthreads, axis_dim, inner_dim, x, alpha, y
);
);
}
}
...
@@ -124,8 +124,8 @@ template<> void Affine<float16, CUDAContext>(
...
@@ -124,8 +124,8 @@ template<> void Affine<float16, CUDAContext>(
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
if (beta != nullptr) {
if (beta != nullptr) {
_Affine
_Affine
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim,
nthreads, axis_dim, inner_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(alpha),
reinterpret_cast<const half*>(alpha),
...
@@ -134,8 +134,8 @@ template<> void Affine<float16, CUDAContext>(
...
@@ -134,8 +134,8 @@ template<> void Affine<float16, CUDAContext>(
);
);
} else {
} else {
_AffineNoBias
_AffineNoBias
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim,
nthreads, axis_dim, inner_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(alpha),
reinterpret_cast<const half*>(alpha),
...
@@ -156,8 +156,8 @@ template <> void AffineGrad<float, CUDAContext>(
...
@@ -156,8 +156,8 @@ template <> void AffineGrad<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
_AffineNoBias
_AffineNoBias
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, dy, alpha, dx
nthreads, axis_dim, inner_dim, dy, alpha, dx
);
);
}
}
...
@@ -174,8 +174,8 @@ template <> void AffineGrad<float16, CUDAContext>(
...
@@ -174,8 +174,8 @@ template <> void AffineGrad<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
_AffineNoBias
_AffineNoBias
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim,
nthreads, axis_dim, inner_dim,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(alpha),
reinterpret_cast<const half*>(alpha),
...
...
Dragon/src/kernels/arithmetic/clip_op_kernel.cu
View file @
d1f714e
...
@@ -83,8 +83,8 @@ template<> __global__ void _ClipGrad<half>(
...
@@ -83,8 +83,8 @@ template<> __global__ void _ClipGrad<half>(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Clip<T> \
_Clip<T> \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, \
count, \
cast::to<T>(low), \
cast::to<T>(low), \
cast::to<T>(high), \
cast::to<T>(high), \
...
@@ -102,8 +102,8 @@ template<> __global__ void _ClipGrad<half>(
...
@@ -102,8 +102,8 @@ template<> __global__ void _ClipGrad<half>(
T* dx, \
T* dx, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_ClipGrad<T> \
_ClipGrad<T> \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, \
count, \
cast::to<T>(low), \
cast::to<T>(low), \
cast::to<T>(high), \
cast::to<T>(high), \
...
@@ -133,8 +133,8 @@ template <> void Clip<float16, CUDAContext>(
...
@@ -133,8 +133,8 @@ template <> void Clip<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Clip
_Clip
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
cast::to<half>(low),
cast::to<half>(low),
cast::to<half>(high),
cast::to<half>(high),
...
@@ -152,8 +152,8 @@ template <> void ClipGrad<float16, CUDAContext>(
...
@@ -152,8 +152,8 @@ template <> void ClipGrad<float16, CUDAContext>(
float16* dx,
float16* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_ClipGrad
_ClipGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
cast::to<half>(low),
cast::to<half>(low),
cast::to<half>(high),
cast::to<half>(high),
...
...
Dragon/src/kernels/arithmetic/maximum_op_kernel.cu
View file @
d1f714e
...
@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>(
...
@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, x1, x2, y \
count, x1, x2, y \
); \
); \
}
}
...
@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>(
...
@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMaximumGrad<half>(
T* dx2, \
T* dx2, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, x1, x2, dy, dx1, dx2 \
count, x1, x2, dy, dx1, dx2 \
); \
); \
}
}
...
@@ -196,8 +196,8 @@ template <> void Maximum<float16, CUDAContext>(
...
@@ -196,8 +196,8 @@ template <> void Maximum<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Maximum \
_Maximum \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2),
reinterpret_cast<const half*>(x2),
...
@@ -212,8 +212,8 @@ template <> void BroadcastMaximum<float16, CUDAContext>(
...
@@ -212,8 +212,8 @@ template <> void BroadcastMaximum<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_BroadcastMaximum \
_BroadcastMaximum \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
cast::to<half>(x2),
cast::to<half>(x2),
...
@@ -230,8 +230,8 @@ template <> void MaximumGrad<float16, CUDAContext>(
...
@@ -230,8 +230,8 @@ template <> void MaximumGrad<float16, CUDAContext>(
float16* dx2,
float16* dx2,
CUDAContext* ctx) {
CUDAContext* ctx) {
_MaximumGrad \
_MaximumGrad \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2),
reinterpret_cast<const half*>(x2),
...
@@ -250,8 +250,8 @@ template <> void BroadcastMaximumGrad<float16, CUDAContext>(
...
@@ -250,8 +250,8 @@ template <> void BroadcastMaximumGrad<float16, CUDAContext>(
float16* dx2,
float16* dx2,
CUDAContext* ctx) {
CUDAContext* ctx) {
_BroadcastMaximumGrad \
_BroadcastMaximumGrad \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
cast::to<half>(x2),
cast::to<half>(x2),
...
...
Dragon/src/kernels/arithmetic/minimum_op_kernel.cu
View file @
d1f714e
...
@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>(
...
@@ -139,8 +139,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, x1, x2, y \
count, x1, x2, y \
); \
); \
}
}
...
@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>(
...
@@ -155,8 +155,8 @@ template<> __global__ void _BroadcastMinimumGrad<half>(
T* dx2, \
T* dx2, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, x1, x2, dy, dx1, dx2 \
count, x1, x2, dy, dx1, dx2 \
); \
); \
}
}
...
@@ -196,8 +196,8 @@ template <> void Minimum<float16, CUDAContext>(
...
@@ -196,8 +196,8 @@ template <> void Minimum<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Minimum \
_Minimum \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2),
reinterpret_cast<const half*>(x2),
...
@@ -212,8 +212,8 @@ template <> void BroadcastMinimum<float16, CUDAContext>(
...
@@ -212,8 +212,8 @@ template <> void BroadcastMinimum<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_BroadcastMinimum \
_BroadcastMinimum \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
cast::to<half>(x2),
cast::to<half>(x2),
...
@@ -230,8 +230,8 @@ template <> void MinimumGrad<float16, CUDAContext>(
...
@@ -230,8 +230,8 @@ template <> void MinimumGrad<float16, CUDAContext>(
float16* dx2,
float16* dx2,
CUDAContext* ctx) {
CUDAContext* ctx) {
_MinimumGrad \
_MinimumGrad \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x2),
reinterpret_cast<const half*>(x2),
...
@@ -250,8 +250,8 @@ template <> void BroadcastMinimumGrad<float16, CUDAContext>(
...
@@ -250,8 +250,8 @@ template <> void BroadcastMinimumGrad<float16, CUDAContext>(
float16* dx2,
float16* dx2,
CUDAContext* ctx) {
CUDAContext* ctx) {
_BroadcastMinimumGrad \
_BroadcastMinimumGrad \
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(x1),
reinterpret_cast<const half*>(x1),
cast::to<half>(x2),
cast::to<half>(x2),
...
...
Dragon/src/kernels/arithmetic/moments_op_kernel.cu
View file @
d1f714e
...
@@ -251,8 +251,8 @@ void _Moments(
...
@@ -251,8 +251,8 @@ void _Moments(
ndims, x_dims, y_dims,
ndims, x_dims, y_dims,
&rows, &cols)) {
&rows, &cols)) {
_ColwiseMoments
_ColwiseMoments
<<
< CUDA_2D_BLOCKS(rows), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(rows), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
rows, cols, x, mean, var
rows, cols, x, mean, var
); return;
); return;
}
}
...
@@ -262,8 +262,8 @@ void _Moments(
...
@@ -262,8 +262,8 @@ void _Moments(
ndims, x_dims, y_dims,
ndims, x_dims, y_dims,
&rows, &cols)) {
&rows, &cols)) {
_RowwiseMoments
_RowwiseMoments
<<
< CUDA_2D_BLOCKS(cols), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(cols), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
rows, cols, x, mean, var
rows, cols, x, mean, var
); return;
); return;
}
}
...
@@ -294,8 +294,8 @@ void _Moments(
...
@@ -294,8 +294,8 @@ void _Moments(
ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data());
ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data());
_GenericMoments
_GenericMoments
<<
< CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
ndims, outer_dim, inner_dim,
ndims, outer_dim, inner_dim,
XSS, YDS, x, mean, var
XSS, YDS, x, mean, var
);
);
...
...
Dragon/src/kernels/array/arange_op_kernel.cu
View file @
d1f714e
...
@@ -30,8 +30,8 @@ __global__ void _Arange(
...
@@ -30,8 +30,8 @@ __global__ void _Arange(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Arange \
_Arange \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, start, step, y \
count, start, step, y \
); \
); \
}
}
...
@@ -64,8 +64,8 @@ template <> void Arange<float16, CUDAContext>(
...
@@ -64,8 +64,8 @@ template <> void Arange<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_Arange
_Arange
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, start, step,
count, start, step,
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
);
);
...
...
Dragon/src/kernels/array/argreduce_op_kernel.cc
View file @
d1f714e
...
@@ -20,12 +20,12 @@ void _ArgMax(
...
@@ -20,12 +20,12 @@ void _ArgMax(
for
(
int
iix
=
0
;
iix
<
inner_dim
;
++
iix
)
{
for
(
int
iix
=
0
;
iix
<
inner_dim
;
++
iix
)
{
const
T
*
X
=
x
+
(
oix
*
axis_dim
*
inner_dim
+
iix
);
const
T
*
X
=
x
+
(
oix
*
axis_dim
*
inner_dim
+
iix
);
const
int
y_offset
=
oix
*
top_k
*
inner_dim
+
iix
;
const
int
y_offset
=
oix
*
top_k
*
inner_dim
+
iix
;
vector
<
pair
<
T
,
int64_t
>
>
vec
(
axis_dim
);
vector
<
pair
<
T
,
int64_t
>
>
vec
(
axis_dim
);
for
(
int
j
=
0
;
j
<
axis_dim
;
++
j
)
for
(
int
j
=
0
;
j
<
axis_dim
;
++
j
)
vec
[
j
]
=
std
::
make_pair
(
X
[
j
*
inner_dim
],
j
);
vec
[
j
]
=
std
::
make_pair
(
X
[
j
*
inner_dim
],
j
);
std
::
partial_sort
(
std
::
partial_sort
(
vec
.
begin
(),
vec
.
begin
()
+
top_k
,
vec
.
end
(),
vec
.
begin
(),
vec
.
begin
()
+
top_k
,
vec
.
end
(),
std
::
greater
<
pair
<
T
,
int64_t
>
>
());
std
::
greater
<
pair
<
T
,
int64_t
>
>
());
for
(
int
j
=
0
;
j
<
top_k
;
++
j
)
{
for
(
int
j
=
0
;
j
<
top_k
;
++
j
)
{
indices
[
y_offset
+
j
*
inner_dim
]
=
vec
[
j
].
second
;
indices
[
y_offset
+
j
*
inner_dim
]
=
vec
[
j
].
second
;
if
(
values
)
values
[
y_offset
+
j
*
inner_dim
]
=
vec
[
j
].
first
;
if
(
values
)
values
[
y_offset
+
j
*
inner_dim
]
=
vec
[
j
].
first
;
...
@@ -49,7 +49,7 @@ void _ArgMin(
...
@@ -49,7 +49,7 @@ void _ArgMin(
for
(
int
iix
=
0
;
iix
<
inner_dim
;
++
iix
)
{
for
(
int
iix
=
0
;
iix
<
inner_dim
;
++
iix
)
{
const
T
*
X
=
x
+
(
oix
*
axis_dim
*
inner_dim
+
iix
);
const
T
*
X
=
x
+
(
oix
*
axis_dim
*
inner_dim
+
iix
);
const
int
y_offset
=
oix
*
top_k
*
inner_dim
+
iix
;
const
int
y_offset
=
oix
*
top_k
*
inner_dim
+
iix
;
vector
<
pair
<
T
,
int64_t
>
>
vec
(
axis_dim
);
vector
<
pair
<
T
,
int64_t
>
>
vec
(
axis_dim
);
for
(
int
j
=
0
;
j
<
axis_dim
;
++
j
)
for
(
int
j
=
0
;
j
<
axis_dim
;
++
j
)
vec
[
j
]
=
std
::
make_pair
(
X
[
j
*
inner_dim
],
j
);
vec
[
j
]
=
std
::
make_pair
(
X
[
j
*
inner_dim
],
j
);
std
::
partial_sort
(
vec
.
begin
(),
vec
.
begin
()
+
top_k
,
vec
.
end
());
std
::
partial_sort
(
vec
.
begin
(),
vec
.
begin
()
+
top_k
,
vec
.
end
());
...
...
Dragon/src/kernels/array/argreduce_op_kernel.cu
View file @
d1f714e
...
@@ -133,8 +133,8 @@ template<> __global__ void _ArgMin<half>(
...
@@ -133,8 +133,8 @@ template<> __global__ void _ArgMin<half>(
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1."; \
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1."; \
auto nthreads = outer_dim * inner_dim; \
auto nthreads = outer_dim * inner_dim; \
_##name \
_##name \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, inner_dim, axis_dim, \
nthreads, inner_dim, axis_dim, \
x, indices, values \
x, indices, values \
); \
); \
...
@@ -168,8 +168,8 @@ template<> void ArgMax<float16, CUDAContext>(
...
@@ -168,8 +168,8 @@ template<> void ArgMax<float16, CUDAContext>(
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1.";
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1.";
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_ArgMax
_ArgMax
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, inner_dim, axis_dim,
nthreads, inner_dim, axis_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
indices,
indices,
...
@@ -189,8 +189,8 @@ template<> void ArgMin<float16, CUDAContext>(
...
@@ -189,8 +189,8 @@ template<> void ArgMin<float16, CUDAContext>(
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1.";
CHECK_EQ(top_k, 1) << "\nRequired top_k == 1.";
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_ArgMin
_ArgMin
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, inner_dim, axis_dim,
nthreads, inner_dim, axis_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
indices,
indices,
...
...
Dragon/src/kernels/array/concat_op_kernel.cu
View file @
d1f714e
...
@@ -43,8 +43,8 @@ __global__ void _Concat(
...
@@ -43,8 +43,8 @@ __global__ void _Concat(
auto cols = axis_dim * inner_dim; \
auto cols = axis_dim * inner_dim; \
auto nthreads = outer_dim * axis_dim * inner_dim; \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_##name \
_##name \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, \
nthreads, \
inner_dim, \
inner_dim, \
cols, \
cols, \
...
...
Dragon/src/kernels/array/crop_op_kernel.cu
View file @
d1f714e
...
@@ -83,8 +83,8 @@ __global__ void _CropGrad(
...
@@ -83,8 +83,8 @@ __global__ void _CropGrad(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, ndims, \
count, ndims, \
x_strides, y_dims, \
x_strides, y_dims, \
starts, x, y \
starts, x, y \
...
...
Dragon/src/kernels/array/index_select_op_kernel.cu
View file @
d1f714e
...
@@ -115,8 +115,8 @@ template <> __global__ void _IndexSelectGrad<half>(
...
@@ -115,8 +115,8 @@ template <> __global__ void _IndexSelectGrad<half>(
CUDAContext* ctx) { \
CUDAContext* ctx) { \
auto nthreads = outer_dim * num_indices * inner_dim; \
auto nthreads = outer_dim * num_indices * inner_dim; \
_IndexSelect \
_IndexSelect \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, inner_dim, \
nthreads, inner_dim, \
axis_dim, num_indices, \
axis_dim, num_indices, \
indices, x, y \
indices, x, y \
...
@@ -135,8 +135,8 @@ template <> __global__ void _IndexSelectGrad<half>(
...
@@ -135,8 +135,8 @@ template <> __global__ void _IndexSelectGrad<half>(
CUDAContext* ctx) { \
CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \
auto nthreads = outer_dim * inner_dim; \
_IndexSelectGrad \
_IndexSelectGrad \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, inner_dim, \
nthreads, inner_dim, \
axis_dim, num_indices, \
axis_dim, num_indices, \
indices, dy, dx \
indices, dy, dx \
...
@@ -170,8 +170,8 @@ template <> void IndexSelectGrad<float16, CUDAContext>(
...
@@ -170,8 +170,8 @@ template <> void IndexSelectGrad<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_IndexSelectGrad
_IndexSelectGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, inner_dim,
nthreads, inner_dim,
axis_dim, num_indices,
axis_dim, num_indices,
indices,
indices,
...
...
Dragon/src/kernels/array/one_hot_op_kernel.cu
View file @
d1f714e
...
@@ -32,8 +32,8 @@ template <> void OneHot<float, CUDAContext>(
...
@@ -32,8 +32,8 @@ template <> void OneHot<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_OneHot
_OneHot
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, depth, on_value, x, y
count, depth, on_value, x, y
);
);
}
}
...
@@ -48,8 +48,8 @@ template <> void OneHot<int, CUDAContext>(
...
@@ -48,8 +48,8 @@ template <> void OneHot<int, CUDAContext>(
int* y,
int* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_OneHot
_OneHot
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, depth, on_value, x, y
count, depth, on_value, x, y
);
);
}
}
...
@@ -64,8 +64,8 @@ template <> void OneHot<int64_t, CUDAContext>(
...
@@ -64,8 +64,8 @@ template <> void OneHot<int64_t, CUDAContext>(
int64_t* y,
int64_t* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_OneHot
_OneHot
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, depth, on_value, x, y
count, depth, on_value, x, y
);
);
}
}
...
...
Dragon/src/kernels/array/pad_op_kernel.cu
View file @
d1f714e
...
@@ -130,8 +130,8 @@ __global__ void _EdgePad(
...
@@ -130,8 +130,8 @@ __global__ void _EdgePad(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_ConstPad \
_ConstPad \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, ndims, \
count, ndims, \
x_dims, x_strides, \
x_dims, x_strides, \
y_dims, l_pads, \
y_dims, l_pads, \
...
@@ -152,8 +152,8 @@ __global__ void _EdgePad(
...
@@ -152,8 +152,8 @@ __global__ void _EdgePad(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, ndims, \
count, ndims, \
x_dims, x_strides, \
x_dims, x_strides, \
y_dims, l_pads, \
y_dims, l_pads, \
...
...
Dragon/src/kernels/array/reduce_sum_op_kernel.cu
View file @
d1f714e
...
@@ -202,8 +202,8 @@ void _ReduceSum(
...
@@ -202,8 +202,8 @@ void _ReduceSum(
ndims, x_dims, y_dims,
ndims, x_dims, y_dims,
&rows, &cols)) {
&rows, &cols)) {
_ColwiseReduceSum
_ColwiseReduceSum
<<
< CUDA_2D_BLOCKS(rows), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(rows), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
rows, cols, scale, x, y
rows, cols, scale, x, y
); return;
); return;
}
}
...
@@ -213,8 +213,8 @@ void _ReduceSum(
...
@@ -213,8 +213,8 @@ void _ReduceSum(
ndims, x_dims, y_dims,
ndims, x_dims, y_dims,
&rows, &cols)) {
&rows, &cols)) {
_RowwiseReduceSum
_RowwiseReduceSum
<<
< CUDA_2D_BLOCKS(cols), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(cols), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
rows, cols, scale, x, y
rows, cols, scale, x, y
); return;
); return;
}
}
...
@@ -245,8 +245,8 @@ void _ReduceSum(
...
@@ -245,8 +245,8 @@ void _ReduceSum(
ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data());
ctx->Memcpy<CUDAContext, CPUContext>(dbytes, YDS, dimsT.data());
_GenericReduceSum
_GenericReduceSum
<<
< CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(outer_dim), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
ndims, outer_dim, inner_dim,
ndims, outer_dim, inner_dim,
XSS, YDS, scale, x, y
XSS, YDS, scale, x, y
);
);
...
@@ -372,8 +372,8 @@ template <> __global__ void _ReduceSumGrad<half>(
...
@@ -372,8 +372,8 @@ template <> __global__ void _ReduceSumGrad<half>(
T* dx, \
T* dx, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_ReduceSumGrad \
_ReduceSumGrad \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, ndim, x_dims, \
count, ndim, x_dims, \
y_dims, y_strides, \
y_dims, y_strides, \
scale, dy, dx \
scale, dy, dx \
...
@@ -398,8 +398,8 @@ template<> void ReduceSumGrad<float16, CUDAContext>(
...
@@ -398,8 +398,8 @@ template<> void ReduceSumGrad<float16, CUDAContext>(
float16* dx,
float16* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_ReduceSumGrad
_ReduceSumGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, ndim, x_dims,
count, ndim, x_dims,
y_dims, y_strides,
y_dims, y_strides,
scale,
scale,
...
...
Dragon/src/kernels/array/repeat_op_kernel.cu
View file @
d1f714e
...
@@ -93,8 +93,8 @@ template<> __global__ void _RepeatGrad<half>(
...
@@ -93,8 +93,8 @@ template<> __global__ void _RepeatGrad<half>(
auto y_inner_dim = inner_dim * repeats; \
auto y_inner_dim = inner_dim * repeats; \
auto nthreads = outer_dim * axis_dim * y_inner_dim; \
auto nthreads = outer_dim * axis_dim * y_inner_dim; \
_Repeat \
_Repeat \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, axis_dim, \
nthreads, axis_dim, \
inner_dim, y_inner_dim, \
inner_dim, y_inner_dim, \
x, y \
x, y \
...
@@ -113,8 +113,8 @@ template<> __global__ void _RepeatGrad<half>(
...
@@ -113,8 +113,8 @@ template<> __global__ void _RepeatGrad<half>(
auto y_inner_dim = inner_dim * repeats; \
auto y_inner_dim = inner_dim * repeats; \
auto nthreads = outer_dim * axis_dim * inner_dim; \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_RepeatGrad \
_RepeatGrad \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, \
nthreads, \
axis_dim, \
axis_dim, \
inner_dim, \
inner_dim, \
...
@@ -151,8 +151,8 @@ template<> void RepeatGrad<float16, CUDAContext>(
...
@@ -151,8 +151,8 @@ template<> void RepeatGrad<float16, CUDAContext>(
auto y_inner_dim = inner_dim * repeats;
auto y_inner_dim = inner_dim * repeats;
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
_RepeatGrad
_RepeatGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
axis_dim,
axis_dim,
inner_dim,
inner_dim,
...
...
Dragon/src/kernels/array/slice_op_kernel.cu
View file @
d1f714e
...
@@ -64,8 +64,8 @@ __global__ void _SliceGrad(
...
@@ -64,8 +64,8 @@ __global__ void _SliceGrad(
auto cols = slice_dim * inner_dim; \
auto cols = slice_dim * inner_dim; \
auto nthreads = outer_dim * cols; \
auto nthreads = outer_dim * cols; \
_##name \
_##name \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, \
nthreads, \
inner_dim, \
inner_dim, \
axis_dim, \
axis_dim, \
...
@@ -126,8 +126,8 @@ template <> void SliceGrad<float16, CUDAContext>(
...
@@ -126,8 +126,8 @@ template <> void SliceGrad<float16, CUDAContext>(
auto cols = slice_dim * inner_dim;
auto cols = slice_dim * inner_dim;
auto nthreads = outer_dim * cols;
auto nthreads = outer_dim * cols;
_SliceGrad
_SliceGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
inner_dim,
inner_dim,
axis_dim,
axis_dim,
...
...
Dragon/src/kernels/array/tile_op_kernel.cu
View file @
d1f714e
...
@@ -98,8 +98,8 @@ template<> __global__ void _TileGrad<half>(
...
@@ -98,8 +98,8 @@ template<> __global__ void _TileGrad<half>(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Tile \
_Tile \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, \
count, \
ndims, \
ndims, \
x_dims, \
x_dims, \
...
@@ -120,8 +120,8 @@ template<> __global__ void _TileGrad<half>(
...
@@ -120,8 +120,8 @@ template<> __global__ void _TileGrad<half>(
auto nthreads = rows * cols; \
auto nthreads = rows * cols; \
auto tiled_cols = multiple * cols; \
auto tiled_cols = multiple * cols; \
_TileGrad \
_TileGrad \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, \
nthreads, \
cols, \
cols, \
tiled_cols, \
tiled_cols, \
...
@@ -156,8 +156,8 @@ template<> void TileGrad<float16, CUDAContext>(
...
@@ -156,8 +156,8 @@ template<> void TileGrad<float16, CUDAContext>(
auto nthreads = rows * cols;
auto nthreads = rows * cols;
auto tiled_cols = multiple * cols;
auto tiled_cols = multiple * cols;
_TileGrad
_TileGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
cols,
cols,
tiled_cols,
tiled_cols,
...
...
Dragon/src/kernels/array/transpose_op_kernel.cu
View file @
d1f714e
...
@@ -80,8 +80,8 @@ __global__ void _TransposeGrad(
...
@@ -80,8 +80,8 @@ __global__ void _TransposeGrad(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, ndims, x_strides, y_dims, x, y \
count, ndims, x_strides, y_dims, x, y \
); \
); \
}
}
...
...
Dragon/src/kernels/control_flow/assign_op_kernel.cu
View file @
d1f714e
...
@@ -55,8 +55,8 @@ __global__ void _Assign(
...
@@ -55,8 +55,8 @@ __global__ void _Assign(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Assign \
_Assign \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, \
count, \
ndims, \
ndims, \
x_dims, \
x_dims, \
...
...
Dragon/src/kernels/control_flow/compare_op_kernel.cu
View file @
d1f714e
...
@@ -153,8 +153,8 @@ __global__ void _GreaterEqualHalf(
...
@@ -153,8 +153,8 @@ __global__ void _GreaterEqualHalf(
bool* y, \
bool* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
IMPL \
IMPL \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, a, b, y \
count, a, b, y \
); \
); \
}
}
...
@@ -167,8 +167,8 @@ __global__ void _GreaterEqualHalf(
...
@@ -167,8 +167,8 @@ __global__ void _GreaterEqualHalf(
bool* y, \
bool* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##OP##Half \
_##OP##Half \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, \
count, \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<const half*>(b), \
...
...
Dragon/src/kernels/control_flow/masked_assign_op_kernel.cu
View file @
d1f714e
...
@@ -30,8 +30,8 @@ __global__ void _MaskedAssign(
...
@@ -30,8 +30,8 @@ __global__ void _MaskedAssign(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_MaskedAssign \
_MaskedAssign \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, mask, x, y \
count, mask, x, y \
); \
); \
}
}
...
...
Dragon/src/kernels/loss/l1_loss_op_kernel.cu
View file @
d1f714e
...
@@ -27,8 +27,8 @@ template<> void AbsGrad<float, CUDAContext>(
...
@@ -27,8 +27,8 @@ template<> void AbsGrad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_AbsGrad
_AbsGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, dy, dx
count, dy, dx
);
);
}
}
...
...
Dragon/src/kernels/loss/nll_loss_op_kernel.cu
View file @
d1f714e
...
@@ -55,8 +55,8 @@ template <> void NLLLoss<float, float, CUDAContext>(
...
@@ -55,8 +55,8 @@ template <> void NLLLoss<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_NLLLoss
_NLLLoss
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, loss, flag
ignore, log_prob, target, loss, flag
);
);
...
@@ -77,8 +77,8 @@ template <> void NLLLoss<float, int64_t, CUDAContext>(
...
@@ -77,8 +77,8 @@ template <> void NLLLoss<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_NLLLoss
_NLLLoss
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, loss, flag
ignore, log_prob, target, loss, flag
);
);
...
@@ -129,8 +129,8 @@ template<> void NLLLossGrad<float, float, CUDAContext>(
...
@@ -129,8 +129,8 @@ template<> void NLLLossGrad<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_NLLLossGrad
_NLLLossGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, dx, flag
ignore, log_prob, target, dx, flag
);
);
...
@@ -151,8 +151,8 @@ template<> void NLLLossGrad<float, int64_t, CUDAContext>(
...
@@ -151,8 +151,8 @@ template<> void NLLLossGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_NLLLossGrad
_NLLLossGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, log_prob, target, dx, flag
ignore, log_prob, target, dx, flag
);
);
...
...
Dragon/src/kernels/loss/sigmoid_ce_loss_op_kernel.cu
View file @
d1f714e
...
@@ -42,8 +42,8 @@ template <> void SigmoidCrossEntropy<float, CUDAContext>(
...
@@ -42,8 +42,8 @@ template <> void SigmoidCrossEntropy<float, CUDAContext>(
int* flag,
int* flag,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SigmoidCrossEntropy
_SigmoidCrossEntropy
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, logit, target, loss, flag
count, logit, target, loss, flag
);
);
}
}
...
@@ -77,8 +77,8 @@ template <> void SigmoidCrossEntropyGrad<float, CUDAContext>(
...
@@ -77,8 +77,8 @@ template <> void SigmoidCrossEntropyGrad<float, CUDAContext>(
int* flag,
int* flag,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SigmoidCrossEntropyGrad
_SigmoidCrossEntropyGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, logit, target, dlogit, flag
count, logit, target, dlogit, flag
);
);
}
}
...
...
Dragon/src/kernels/loss/sigmoid_focal_loss_op_kernel.cu
View file @
d1f714e
...
@@ -71,8 +71,8 @@ template <> void SigmoidFocalLoss<float, float, CUDAContext>(
...
@@ -71,8 +71,8 @@ template <> void SigmoidFocalLoss<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLoss
_SigmoidFocalLoss
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim,
nthreads, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, losses, flags
logits, targets, losses, flags
...
@@ -96,8 +96,8 @@ template <> void SigmoidFocalLoss<float, int64_t, CUDAContext>(
...
@@ -96,8 +96,8 @@ template <> void SigmoidFocalLoss<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLoss
_SigmoidFocalLoss
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim,
nthreads, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, losses, flags
logits, targets, losses, flags
...
@@ -171,8 +171,8 @@ template <> void SigmoidFocalLossGrad<float, float, CUDAContext>(
...
@@ -171,8 +171,8 @@ template <> void SigmoidFocalLossGrad<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto count = outer_dim * axis_dim * inner_dim;
auto count = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLossGrad
_SigmoidFocalLossGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, axis_dim, inner_dim,
count, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, dlogits, flags
logits, targets, dlogits, flags
...
@@ -196,8 +196,8 @@ template <> void SigmoidFocalLossGrad<float, int64_t, CUDAContext>(
...
@@ -196,8 +196,8 @@ template <> void SigmoidFocalLossGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto count = outer_dim * axis_dim * inner_dim;
auto count = outer_dim * axis_dim * inner_dim;
_SigmoidFocalLossGrad
_SigmoidFocalLossGrad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, axis_dim, inner_dim,
count, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
logits, targets, dlogits, flags
logits, targets, dlogits, flags
...
...
Dragon/src/kernels/loss/smooth_l1_loss_op_kernel.cu
View file @
d1f714e
...
@@ -33,8 +33,8 @@ template<> void SmoothL1<float, CUDAContext>(
...
@@ -33,8 +33,8 @@ template<> void SmoothL1<float, CUDAContext>(
float* y,
float* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SmoothL1
_SmoothL1
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, beta, x, y
count, beta, x, y
);
);
}
}
...
@@ -63,8 +63,8 @@ template<> void SmoothL1Grad<float, CUDAContext>(
...
@@ -63,8 +63,8 @@ template<> void SmoothL1Grad<float, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SmoothL1Grad
_SmoothL1Grad
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, beta, dy, dx
count, beta, dy, dx
);
);
}
}
...
...
Dragon/src/kernels/loss/softmax_ce_loss_op_kernel.cu
View file @
d1f714e
...
@@ -29,8 +29,8 @@ template <> void SoftmaxCrossEntropy<float, CUDAContext>(
...
@@ -29,8 +29,8 @@ template <> void SoftmaxCrossEntropy<float, CUDAContext>(
float* losses,
float* losses,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SoftmaxCrossEntropy
_SoftmaxCrossEntropy
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, prob, targets, losses
count, prob, targets, losses
);
);
}
}
...
...
Dragon/src/kernels/loss/softmax_focal_loss_op_kernel.cu
View file @
d1f714e
...
@@ -67,8 +67,8 @@ template <> void SoftmaxFocalLoss<float, float, CUDAContext>(
...
@@ -67,8 +67,8 @@ template <> void SoftmaxFocalLoss<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim;
auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLoss
_SoftmaxFocalLoss
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim,
num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores,
nignores, ignores,
...
@@ -95,8 +95,8 @@ template <> void SoftmaxFocalLoss<float, int64_t, CUDAContext>(
...
@@ -95,8 +95,8 @@ template <> void SoftmaxFocalLoss<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim;
auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLoss
_SoftmaxFocalLoss
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim,
num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores,
nignores, ignores,
...
@@ -179,8 +179,8 @@ template<> void SoftmaxFocalLossGrad<float, float, CUDAContext>(
...
@@ -179,8 +179,8 @@ template<> void SoftmaxFocalLossGrad<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim;
auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLossGrad
_SoftmaxFocalLossGrad
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim,
num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores,
nignores, ignores,
...
@@ -207,8 +207,8 @@ template<> void SoftmaxFocalLossGrad<float, int64_t, CUDAContext>(
...
@@ -207,8 +207,8 @@ template<> void SoftmaxFocalLossGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto num_preds = outer_dim * inner_dim;
auto num_preds = outer_dim * inner_dim;
_SoftmaxFocalLossGrad
_SoftmaxFocalLossGrad
<<
< CUDA_BLOCKS(num_preds), CUDA_THREADS,
<<< CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
num_preds, axis_dim, inner_dim,
num_preds, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id,
pos_alpha, neg_alpha, gamma, neg_id,
nignores, ignores,
nignores, ignores,
...
...
Dragon/src/kernels/loss/sparse_softmax_ce_loss_op_kernel.cu
View file @
d1f714e
...
@@ -59,8 +59,8 @@ template <> void SparseSoftmaxCrossEntropy<float, float, CUDAContext>(
...
@@ -59,8 +59,8 @@ template <> void SparseSoftmaxCrossEntropy<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropy
_SparseSoftmaxCrossEntropy
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, loss, flag
ignore, prob, target, loss, flag
);
);
...
@@ -81,8 +81,8 @@ template <> void SparseSoftmaxCrossEntropy<float, int64_t, CUDAContext>(
...
@@ -81,8 +81,8 @@ template <> void SparseSoftmaxCrossEntropy<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropy
_SparseSoftmaxCrossEntropy
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, loss, flag
ignore, prob, target, loss, flag
);
);
...
@@ -136,8 +136,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, float, CUDAContext>(
...
@@ -136,8 +136,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyGrad
_SparseSoftmaxCrossEntropyGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, dx, flag
ignore, prob, target, dx, flag
);
);
...
@@ -158,8 +158,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CUDAContext>(
...
@@ -158,8 +158,8 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = outer_dim * inner_dim;
auto nthreads = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyGrad
_SparseSoftmaxCrossEntropyGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, nignores,
nthreads, axis_dim, inner_dim, nignores,
ignore, prob, target, dx, flag
ignore, prob, target, dx, flag
);
);
...
...
Dragon/src/kernels/misc/astype_op_kernel.cu
View file @
d1f714e
...
@@ -26,8 +26,8 @@ __global__ void _TypeA2B(
...
@@ -26,8 +26,8 @@ __global__ void _TypeA2B(
Tb* b, \
Tb* b, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_TypeA2B \
_TypeA2B \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, a, b \
count, a, b \
); \
); \
}
}
...
@@ -66,8 +66,8 @@ template <> void TypeA2B<float16, float, CUDAContext>(
...
@@ -66,8 +66,8 @@ template <> void TypeA2B<float16, float, CUDAContext>(
float* b,
float* b,
CUDAContext* ctx) {
CUDAContext* ctx) {
_TypeA2B
_TypeA2B
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, reinterpret_cast<const half*>(a), b
count, reinterpret_cast<const half*>(a), b
);
);
}
}
...
@@ -89,8 +89,8 @@ template <> void TypeA2B<float, float16, CUDAContext>(
...
@@ -89,8 +89,8 @@ template <> void TypeA2B<float, float16, CUDAContext>(
float16* b,
float16* b,
CUDAContext* ctx) {
CUDAContext* ctx) {
_TypeA2B
_TypeA2B
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, a, reinterpret_cast<half*>(b)
count, a, reinterpret_cast<half*>(b)
);
);
}
}
...
@@ -112,8 +112,8 @@ template <> void TypeA2B<float16, float16, CUDAContext>(
...
@@ -112,8 +112,8 @@ template <> void TypeA2B<float16, float16, CUDAContext>(
float16* b,
float16* b,
CUDAContext* ctx) {
CUDAContext* ctx) {
_TypeA2B
_TypeA2B
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(a),
reinterpret_cast<half*>(b)
reinterpret_cast<half*>(b)
...
...
Dragon/src/kernels/misc/gradient_op_kernel.cu
View file @
d1f714e
...
@@ -62,8 +62,8 @@ template <> __global__ void _GradientTwoSum<half2>(
...
@@ -62,8 +62,8 @@ template <> __global__ void _GradientTwoSum<half2>(
T* dx, \
T* dx, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_GradientTwoSum \
_GradientTwoSum \
<<
< CUDA_BLOCKS(count), CUDA_THREADS, \
<<< CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
count, dy1, dy2, dx \
count, dy1, dy2, dx \
); \
); \
}
}
...
@@ -83,8 +83,8 @@ template <> void GradientTwoSum<float16, CUDAContext>(
...
@@ -83,8 +83,8 @@ template <> void GradientTwoSum<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((count & 1) == 0) {
if ((count & 1) == 0) {
_GradientTwoSum
_GradientTwoSum
<<
< CUDA_BLOCKS(count >> 2), CUDA_THREADS,
<<< CUDA_BLOCKS(count >> 2), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count >> 2,
count >> 2,
reinterpret_cast<const half2*>(dy1),
reinterpret_cast<const half2*>(dy1),
reinterpret_cast<const half2*>(dy2),
reinterpret_cast<const half2*>(dy2),
...
@@ -92,8 +92,8 @@ template <> void GradientTwoSum<float16, CUDAContext>(
...
@@ -92,8 +92,8 @@ template <> void GradientTwoSum<float16, CUDAContext>(
);
);
} else {
} else {
_GradientTwoSum
_GradientTwoSum
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
reinterpret_cast<const half*>(dy1),
reinterpret_cast<const half*>(dy1),
reinterpret_cast<const half*>(dy2),
reinterpret_cast<const half*>(dy2),
...
...
Dragon/src/kernels/misc/image_data_op_kernel.cu
View file @
d1f714e
...
@@ -76,14 +76,14 @@ template <> void ImageData<float, float, CUDAContext>(
...
@@ -76,14 +76,14 @@ template <> void ImageData<float, float, CUDAContext>(
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_ImageDataNCHW
_ImageDataNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std, x, y
nthreads, C, H, W, mean, std, x, y
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_ImageDataNHWC
_ImageDataNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std, x, y
nthreads, C, H, W, mean, std, x, y
);
);
} else {
} else {
...
@@ -107,14 +107,14 @@ template <> void ImageData<uint8_t, float, CUDAContext>(
...
@@ -107,14 +107,14 @@ template <> void ImageData<uint8_t, float, CUDAContext>(
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_ImageDataNCHW
_ImageDataNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std, x, y
nthreads, C, H, W, mean, std, x, y
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_ImageDataNHWC
_ImageDataNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std, x, y
nthreads, C, H, W, mean, std, x, y
);
);
} else {
} else {
...
@@ -191,15 +191,15 @@ template <> void ImageData<float, float16, CUDAContext>(
...
@@ -191,15 +191,15 @@ template <> void ImageData<float, float16, CUDAContext>(
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_ImageDataHalfNCHW
_ImageDataHalfNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std,
nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y)
x, reinterpret_cast<half*>(y)
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_ImageDataHalfNHWC
_ImageDataHalfNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std,
nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y)
x, reinterpret_cast<half*>(y)
);
);
...
@@ -222,15 +222,15 @@ template <> void ImageData<uint8_t, float16, CUDAContext>(
...
@@ -222,15 +222,15 @@ template <> void ImageData<uint8_t, float16, CUDAContext>(
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_ImageDataHalfNCHW
_ImageDataHalfNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std,
nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y)
x, reinterpret_cast<half*>(y)
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_ImageDataHalfNHWC
_ImageDataHalfNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, mean, std,
nthreads, C, H, W, mean, std,
x, reinterpret_cast<half*>(y)
x, reinterpret_cast<half*>(y)
);
);
...
...
Dragon/src/kernels/norm/batch_norm_op_kernel.cu
View file @
d1f714e
...
@@ -190,27 +190,27 @@ __global__ void _BatchNormInferenceGrad(
...
@@ -190,27 +190,27 @@ __global__ void _BatchNormInferenceGrad(
auto nthreads = N * C * S; \
auto nthreads = N * C * S; \
if (data_format == "NCHW") { \
if (data_format == "NCHW") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, C, S, x, mu, rsig, gamma, \
N, C, S, x, mu, rsig, gamma, \
dy, ds, db, dgamma, dbeta \
dy, ds, db, dgamma, dbeta \
); \
); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW> \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, N, C, S, x, mu, \
nthreads, N, C, S, x, mu, \
rsig, gamma, ds, db, dy, dx \
rsig, gamma, ds, db, dy, dx \
); \
); \
} else if (data_format == "NHWC") { \
} else if (data_format == "NHWC") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, C, S, x, mu, rsig, gamma, \
N, C, S, x, mu, rsig, gamma, \
dy, ds, db, dgamma, dbeta \
dy, ds, db, dgamma, dbeta \
); \
); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC> \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, N, C, S, x, mu, \
nthreads, N, C, S, x, mu, \
rsig, gamma, ds, db, dy, dx \
rsig, gamma, ds, db, dy, dx \
); \
); \
...
@@ -234,24 +234,24 @@ __global__ void _BatchNormInferenceGrad(
...
@@ -234,24 +234,24 @@ __global__ void _BatchNormInferenceGrad(
if (data_format == "NCHW") { \
if (data_format == "NCHW") { \
if (dgamma != nullptr) { \
if (dgamma != nullptr) { \
_BatchNormWGrad<Tx, Tp, StorageOrder::NCHW> \
_BatchNormWGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >>
> \
0, ctx->cuda_stream() >>
> \
(N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
(N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
} \
} \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NCHW> \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
> \
0, ctx->cuda_stream() >>
> \
(nthreads, C, S, rsig, gamma, dy, dx); \
(nthreads, C, S, rsig, gamma, dy, dx); \
} else if (data_format == "NHWC") { \
} else if (data_format == "NHWC") { \
if (dgamma != nullptr) { \
if (dgamma != nullptr) { \
_BatchNormWGrad<Tx, Tp, StorageOrder::NHWC> \
_BatchNormWGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(C), CUDA_THREADS, \
0, ctx->cuda_stream() >>
> \
0, ctx->cuda_stream() >>
> \
(N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
(N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
} \
} \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NHWC> \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
> \
0, ctx->cuda_stream() >>
> \
(nthreads, C, S, rsig, gamma, dy, dx); \
(nthreads, C, S, rsig, gamma, dy, dx); \
} \
} \
}
}
...
...
Dragon/src/kernels/norm/group_norm_op_kernel.cu
View file @
d1f714e
...
@@ -408,20 +408,20 @@ __global__ void _GroupNormGradHalf(
...
@@ -408,20 +408,20 @@ __global__ void _GroupNormGradHalf(
CUDAContext* ctx) { \
CUDAContext* ctx) { \
const int C = G * D; \
const int C = G * D; \
_GroupNormFusedParams<Tp> \
_GroupNormFusedParams<Tp> \
<<
< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, G, D, mu, rsig, gamma, beta, scale, bias \
N, G, D, mu, rsig, gamma, beta, scale, bias \
); \
); \
if (data_format == "NCHW") { \
if (data_format == "NCHW") { \
_GroupNormForwardNCHW<Tx, Tp> \
_GroupNormForwardNCHW<Tx, Tp> \
<<
< CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, C, S, x, scale, bias, y \
N, C, S, x, scale, bias, y \
); \
); \
} else if (data_format == "NHWC") { \
} else if (data_format == "NHWC") { \
_GroupNormForwardNHWC<Tx, Tp> \
_GroupNormForwardNHWC<Tx, Tp> \
<<
< CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, C, S, x, scale, bias, y \
N, C, S, x, scale, bias, y \
); \
); \
} \
} \
...
@@ -448,35 +448,35 @@ __global__ void _GroupNormGradHalf(
...
@@ -448,35 +448,35 @@ __global__ void _GroupNormGradHalf(
auto nthreads = N * G * D * S; \
auto nthreads = N * G * D * S; \
if (data_format == "NCHW") { \
if (data_format == "NCHW") { \
_GroupNormWGrad<Tx, Tp, StorageOrder::NCHW> \
_GroupNormWGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \
); \
); \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, G, D, S, x, gamma, dy, ds, db \
N, G, D, S, x, gamma, dy, ds, db \
); \
); \
_GroupNormGrad<Tx, Tp, StorageOrder::NCHW> \
_GroupNormGrad<Tx, Tp, StorageOrder::NCHW> \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
nthreads, G, D, S, x, mu, rsig, \
nthreads, G, D, S, x, mu, rsig, \
gamma, ds, db, dy, dx \
gamma, ds, db, dy, dx \
); \
); \
} else if (data_format == "NHWC") { \
} else if (data_format == "NHWC") { \
_GroupNormWGrad<Tx, Tp, StorageOrder::NHWC> \
_GroupNormWGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta \
); \
); \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
<<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
N, G, D, S, x, gamma, dy, ds, db \
N, G, D, S, x, gamma, dy, ds, db \
); \
); \
_GroupNormGrad<Tx, Tp, StorageOrder::NHWC> \
_GroupNormGrad<Tx, Tp, StorageOrder::NHWC> \
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>
> ( \
0, ctx->cuda_stream() >>
> ( \
nthreads, G, D, S, x, mu, rsig, \
nthreads, G, D, S, x, mu, rsig, \
gamma, ds, db, dy, dx \
gamma, ds, db, dy, dx \
); \
); \
...
@@ -503,14 +503,14 @@ template <> void GroupNormForward<float16, float, CUDAContext>(
...
@@ -503,14 +503,14 @@ template <> void GroupNormForward<float16, float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
const int C = G * D;
const int C = G * D;
_GroupNormFusedParams<float>
_GroupNormFusedParams<float>
<<
< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, G, D, mu, rsig, gamma, beta, scale, bias
N, G, D, mu, rsig, gamma, beta, scale, bias
);
);
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_GroupNormForwardNCHW<half, float>
_GroupNormForwardNCHW<half, float>
<<
< CUDA_2D_BLOCKS(N * C), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, C, S,
N, C, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
scale, bias,
scale, bias,
...
@@ -518,8 +518,8 @@ template <> void GroupNormForward<float16, float, CUDAContext>(
...
@@ -518,8 +518,8 @@ template <> void GroupNormForward<float16, float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_GroupNormForwardNHWC<half, float>
_GroupNormForwardNHWC<half, float>
<<
< CUDA_2D_BLOCKS(N * C), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(N * C), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, C, S,
N, C, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
scale, bias,
scale, bias,
...
@@ -548,8 +548,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
...
@@ -548,8 +548,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
auto nthreads = N * G * D * S;
auto nthreads = N * G * D * S;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_GroupNormWGradHalf<StorageOrder::NCHW>
_GroupNormWGradHalf<StorageOrder::NCHW>
<<
< CUDA_2D_BLOCKS(G * D), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, G, D, S,
N, G, D, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
mu, rsig,
mu, rsig,
...
@@ -557,8 +557,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
...
@@ -557,8 +557,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
dgamma, dbeta
dgamma, dbeta
);
);
_GroupNormInternalGradHalf<StorageOrder::NCHW>
_GroupNormInternalGradHalf<StorageOrder::NCHW>
<<
< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, G, D, S,
N, G, D, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
gamma,
gamma,
...
@@ -566,8 +566,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
...
@@ -566,8 +566,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
ds, db
ds, db
);
);
_GroupNormGradHalf<StorageOrder::NCHW>
_GroupNormGradHalf<StorageOrder::NCHW>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, G, D, S,
nthreads, G, D, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
mu, rsig, gamma, ds, db,
mu, rsig, gamma, ds, db,
...
@@ -576,8 +576,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
...
@@ -576,8 +576,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
);
);
} else if (data_format == "NHWC") { \
} else if (data_format == "NHWC") { \
_GroupNormWGradHalf<StorageOrder::NHWC>
_GroupNormWGradHalf<StorageOrder::NHWC>
<<
< CUDA_2D_BLOCKS(G * D), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(G * D), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, G, D, S,
N, G, D, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
mu, rsig,
mu, rsig,
...
@@ -585,8 +585,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
...
@@ -585,8 +585,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
dgamma, dbeta
dgamma, dbeta
);
);
_GroupNormInternalGradHalf<StorageOrder::NHWC>
_GroupNormInternalGradHalf<StorageOrder::NHWC>
<<
< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
<<< CUDA_2D_BLOCKS(N * G), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, G, D, S,
N, G, D, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
gamma,
gamma,
...
@@ -594,8 +594,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
...
@@ -594,8 +594,8 @@ template <> void GroupNormBackward<float16, float, CUDAContext>(
ds, db
ds, db
);
);
_GroupNormGradHalf<StorageOrder::NHWC>
_GroupNormGradHalf<StorageOrder::NHWC>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, G, D, S,
nthreads, G, D, S,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
mu, rsig, gamma, ds, db,
mu, rsig, gamma, ds, db,
...
...
Dragon/src/kernels/recurrent/lstm_cell_op_kernel.cu
View file @
d1f714e
...
@@ -58,13 +58,13 @@ template <> void LSTMCell<float, CUDAContext>(
...
@@ -58,13 +58,13 @@ template <> void LSTMCell<float, CUDAContext>(
auto o_offset = 2 * C, c_offset = 3 * C,
auto o_offset = 2 * C, c_offset = 3 * C,
x_offset = 4 * C, NC = N * C;
x_offset = 4 * C, NC = N * C;
_LSTMCellAct
_LSTMCellAct
<<
< CUDA_BLOCKS(NC * 4), CUDA_THREADS,
<<< CUDA_BLOCKS(NC * 4), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
NC * 4, c_offset, x_offset, actx
NC * 4, c_offset, x_offset, actx
);
);
_LSTMCellGate
_LSTMCellGate
<<
< CUDA_BLOCKS(NC), CUDA_THREADS,
<<< CUDA_BLOCKS(NC), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
NC, C, o_offset, c_offset,
NC, C, o_offset, c_offset,
x_offset, cx, actx, c, h
x_offset, cx, actx, c, h
);
);
...
@@ -138,14 +138,14 @@ template <> void LSTMCellGrad<float, CUDAContext>(
...
@@ -138,14 +138,14 @@ template <> void LSTMCellGrad<float, CUDAContext>(
auto o_offset = 2 * C, c_offset = 3 * C,
auto o_offset = 2 * C, c_offset = 3 * C,
x_offset = 4 * C, NC = N * C;
x_offset = 4 * C, NC = N * C;
_LSTMCellGateGrad
_LSTMCellGateGrad
<<
< CUDA_BLOCKS(NC), CUDA_THREADS,
<<< CUDA_BLOCKS(NC), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
NC, C, o_offset, c_offset, x_offset,
NC, C, o_offset, c_offset, x_offset,
cx, actx, c, dc, dh, dcx, dx
cx, actx, c, dc, dh, dcx, dx
);
);
_LSTMCellActGrad
_LSTMCellActGrad
<<
< CUDA_BLOCKS(NC * 4), CUDA_THREADS,
<<< CUDA_BLOCKS(NC * 4), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
NC * 4, c_offset, x_offset, actx, dx
NC * 4, c_offset, x_offset, actx, dx
);
);
}
}
...
...
Dragon/src/kernels/update/adam_update_op_kernel.cu
View file @
d1f714e
...
@@ -39,8 +39,8 @@ template <> void AdamUpdate<float, CUDAContext>(
...
@@ -39,8 +39,8 @@ template <> void AdamUpdate<float, CUDAContext>(
float* v,
float* v,
CUDAContext* ctx) {
CUDAContext* ctx) {
_AdamUpdate
_AdamUpdate
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, lr, beta1, beta2, eps, g, m, v
count, lr, beta1, beta2, eps, g, m, v
);
);
}
}
...
...
Dragon/src/kernels/update/mprec_update_op_kerne.cu
View file @
d1f714e
...
@@ -29,8 +29,8 @@ template <> void MixedPrecL2Decay<float16, CUDAContext>(
...
@@ -29,8 +29,8 @@ template <> void MixedPrecL2Decay<float16, CUDAContext>(
float* dx,
float* dx,
CUDAContext* ctx) {
CUDAContext* ctx) {
_MixedPrecL2DecayHalf
_MixedPrecL2DecayHalf
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
alpha,
alpha,
reinterpret_cast<const half*>(w),
reinterpret_cast<const half*>(w),
...
@@ -58,8 +58,8 @@ template <> void MixedPrecUpdate<float16, CUDAContext>(
...
@@ -58,8 +58,8 @@ template <> void MixedPrecUpdate<float16, CUDAContext>(
float16* w,
float16* w,
CUDAContext* ctx) {
CUDAContext* ctx) {
_MixedPrecUpdateHalf
_MixedPrecUpdateHalf
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count,
count,
updates,
updates,
reinterpret_cast<half*>(w)
reinterpret_cast<half*>(w)
...
...
Dragon/src/kernels/update/nesterov_update_op_kernel.cu
View file @
d1f714e
...
@@ -32,8 +32,8 @@ template <> void NesterovUpdate<float, CUDAContext>(
...
@@ -32,8 +32,8 @@ template <> void NesterovUpdate<float, CUDAContext>(
float* h,
float* h,
CUDAContext* ctx) {
CUDAContext* ctx) {
_NesterovUpdate
_NesterovUpdate
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, lr, momentum, g, h
count, lr, momentum, g, h
);
);
}
}
...
...
Dragon/src/kernels/update/rmsprop_update_op_kernel.cu
View file @
d1f714e
...
@@ -34,8 +34,8 @@ template <> void RMSPropUpdate<float, CUDAContext>(
...
@@ -34,8 +34,8 @@ template <> void RMSPropUpdate<float, CUDAContext>(
float* h,
float* h,
CUDAContext* ctx) {
CUDAContext* ctx) {
_RMSPropUpdate
_RMSPropUpdate
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, lr, decay, eps, g, h
count, lr, decay, eps, g, h
);
);
}
}
...
...
Dragon/src/kernels/update/sgd_update_op_kernel.cu
View file @
d1f714e
...
@@ -31,8 +31,8 @@ template <> void SGDUpdate<float, CUDAContext>(
...
@@ -31,8 +31,8 @@ template <> void SGDUpdate<float, CUDAContext>(
float* h,
float* h,
CUDAContext* ctx) {
CUDAContext* ctx) {
_SGDUpdate
_SGDUpdate
<<
< CUDA_BLOCKS(count), CUDA_THREADS,
<<< CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
count, lr, momentum, g, h
count, lr, momentum, g, h
);
);
}
}
...
...
Dragon/src/kernels/vision/bias_add_op_kernel.cu
View file @
d1f714e
...
@@ -52,14 +52,14 @@ template<> void BiasAdd<float, CUDAContext>(
...
@@ -52,14 +52,14 @@ template<> void BiasAdd<float, CUDAContext>(
auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = outer_dim * axis_dim * inner_dim;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_BiasAddNCHW
_BiasAddNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, inner_dim, bias, y
nthreads, axis_dim, inner_dim, bias, y
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_BiasAddNHWC
_BiasAddNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, axis_dim, bias, y
nthreads, axis_dim, bias, y
);
);
} else {
} else {
...
...
Dragon/src/kernels/vision/bilinear_resize_op_kernel.cu
View file @
d1f714e
...
@@ -109,15 +109,15 @@ template <> void BilinearResize<float, CUDAContext>(
...
@@ -109,15 +109,15 @@ template <> void BilinearResize<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w;
auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_BilinearResizeNCHW
_BilinearResizeNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y
scale_h, scale_w, x, y
);
);
} else if(data_format == "NHWC") {
} else if(data_format == "NHWC") {
_BilinearResizeNHWC
_BilinearResizeNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y
scale_h, scale_w, x, y
);
);
...
@@ -224,15 +224,15 @@ template <> void BilinearResizeGrad<float, CUDAContext>(
...
@@ -224,15 +224,15 @@ template <> void BilinearResizeGrad<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w;
auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_BilinearResizeGradNCHW
_BilinearResizeGradNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx
scale_h, scale_w, dy, dx
);
);
} else if(data_format == "NHWC") {
} else if(data_format == "NHWC") {
_BilinearResizeGradNHWC
_BilinearResizeGradNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx
scale_h, scale_w, dy, dx
);
);
...
...
Dragon/src/kernels/vision/conv_op_kernel.cu
View file @
d1f714e
...
@@ -123,8 +123,8 @@ template <> void Im2Col2d<float, CUDAContext>(
...
@@ -123,8 +123,8 @@ template <> void Im2Col2d<float, CUDAContext>(
auto nthreads = C * out_h * out_w;
auto nthreads = C * out_h * out_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_Im2Col2dNCHW
_Im2Col2dNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
H, W,
H, W,
out_h, out_w,
out_h, out_w,
...
@@ -136,8 +136,8 @@ template <> void Im2Col2d<float, CUDAContext>(
...
@@ -136,8 +136,8 @@ template <> void Im2Col2d<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_Im2Col2dNHWC
_Im2Col2dNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -286,8 +286,8 @@ template <> void Col2Im2d<float, CUDAContext>(
...
@@ -286,8 +286,8 @@ template <> void Col2Im2d<float, CUDAContext>(
const int nthreads = C * H * W;
const int nthreads = C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_Col2Im2dNCHW
_Col2Im2dNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
H, W,
H, W,
out_h, out_w,
out_h, out_w,
...
@@ -299,8 +299,8 @@ template <> void Col2Im2d<float, CUDAContext>(
...
@@ -299,8 +299,8 @@ template <> void Col2Im2d<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_Col2Im2dNHWC
_Col2Im2dNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
...
Dragon/src/kernels/vision/depthwise_conv_op_kernel.cu
View file @
d1f714e
...
@@ -144,8 +144,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -144,8 +144,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
if (data_format == "NCHW") {
if (data_format == "NCHW") {
if (kernel_h == 3 && kernel_w == 3) {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dNCHW<float, 3, 3>
_DepthwiseConv2dNCHW<float, 3, 3>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -157,8 +157,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -157,8 +157,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
);
);
} else if (kernel_h == 5 && kernel_w == 5) {
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dNCHW<float, 5, 5>
_DepthwiseConv2dNCHW<float, 5, 5>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -170,8 +170,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -170,8 +170,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
);
);
} else if (kernel_h == 7 && kernel_w == 7) {
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dNCHW<float, 7, 7>
_DepthwiseConv2dNCHW<float, 7, 7>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -183,8 +183,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -183,8 +183,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
);
);
} else {
} else {
_DepthwiseConv2dNCHW<float, -1, -1>
_DepthwiseConv2dNCHW<float, -1, -1>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -198,8 +198,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -198,8 +198,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
if (kernel_h == 3 && kernel_w == 3) {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dNHWC<float, 3, 3>
_DepthwiseConv2dNHWC<float, 3, 3>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -211,8 +211,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -211,8 +211,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
);
);
} else if (kernel_h == 5 && kernel_w == 5) {
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dNHWC<float, 5, 5>
_DepthwiseConv2dNHWC<float, 5, 5>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -224,8 +224,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -224,8 +224,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
);
);
} else if (kernel_h == 7 && kernel_w == 7) {
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dNHWC<float, 7, 7>
_DepthwiseConv2dNHWC<float, 7, 7>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -237,8 +237,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
...
@@ -237,8 +237,8 @@ template <> void DepthwiseConv2d<float, CUDAContext>(
);
);
} else {
} else {
_DepthwiseConv2dNHWC<float, -1, -1>
_DepthwiseConv2dNHWC<float, -1, -1>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -394,8 +394,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -394,8 +394,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
if (data_format == "NCHW") {
if (data_format == "NCHW") {
if (kernel_h == 3 && kernel_w == 3) {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dGradNCHW<float, 3, 3>
_DepthwiseConv2dGradNCHW<float, 3, 3>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -407,8 +407,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -407,8 +407,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
);
);
} else if (kernel_h == 5 && kernel_w == 5) {
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dGradNCHW<float, 5, 5>
_DepthwiseConv2dGradNCHW<float, 5, 5>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -420,8 +420,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -420,8 +420,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
);
);
} else if (kernel_h == 7 && kernel_w == 7) {
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dGradNCHW<float, 7, 7>
_DepthwiseConv2dGradNCHW<float, 7, 7>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -433,8 +433,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -433,8 +433,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
);
);
} else {
} else {
_DepthwiseConv2dGradNCHW<float, -1, -1>
_DepthwiseConv2dGradNCHW<float, -1, -1>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -448,8 +448,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -448,8 +448,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
if (kernel_h == 3 && kernel_w == 3) {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dGradNHWC<float, 3, 3>
_DepthwiseConv2dGradNHWC<float, 3, 3>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -461,8 +461,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -461,8 +461,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
);
);
} else if (kernel_h == 5 && kernel_w == 5) {
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dGradNHWC<float, 5, 5>
_DepthwiseConv2dGradNHWC<float, 5, 5>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -474,8 +474,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -474,8 +474,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
);
);
} else if (kernel_h == 7 && kernel_w == 7) {
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dGradNHWC<float, 7, 7>
_DepthwiseConv2dGradNHWC<float, 7, 7>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -487,8 +487,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
...
@@ -487,8 +487,8 @@ template <> void DepthwiseConv2dGrad<float, CUDAContext>(
);
);
} else {
} else {
_DepthwiseConv2dGradNHWC<float, -1, -1>
_DepthwiseConv2dGradNHWC<float, -1, -1>
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
out_h, out_w,
out_h, out_w,
...
@@ -634,8 +634,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>(
...
@@ -634,8 +634,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>(
auto nblocks = C * kernel_h * kernel_w;
auto nblocks = C * kernel_h * kernel_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_DepthwiseConv2dWGradNCHW
_DepthwiseConv2dWGradNCHW
<<
< nblocks, nthreads,
<<< nblocks, nthreads,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, C, H, W,
N, C, H, W,
out_h, out_w,
out_h, out_w,
kernel_h, kernel_w,
kernel_h, kernel_w,
...
@@ -646,8 +646,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>(
...
@@ -646,8 +646,8 @@ template <> void DepthwiseConv2dWGrad<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_DepthwiseConv2dWGradNHWC
_DepthwiseConv2dWGradNHWC
<<
< nblocks, nthreads,
<<< nblocks, nthreads,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
N, C, H, W,
N, C, H, W,
out_h, out_w,
out_h, out_w,
kernel_h, kernel_w,
kernel_h, kernel_w,
...
...
Dragon/src/kernels/vision/drop_block_op_kernel.cu
View file @
d1f714e
...
@@ -77,16 +77,12 @@ template <> void DropBlock2d<CUDAContext>(
...
@@ -77,16 +77,12 @@ template <> void DropBlock2d<CUDAContext>(
int* mask,
int* mask,
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = N * C * seed_h * seed_w;
auto nthreads = N * C * seed_h * seed_w;
math::RandomUniform(
math::RandomUniform(nthreads, 0.f, 1.f, seed, ctx);
nthreads,
0.f, float(UINT_MAX),
seed, ctx
);
auto mask_thresh = (uint32_t)(UINT_MAX * gamma);
auto mask_thresh = (uint32_t)(UINT_MAX * gamma);
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_DropBlock2dNCHW
_DropBlock2dNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
seed_h, seed_w,
seed_h, seed_w,
...
@@ -96,8 +92,8 @@ template <> void DropBlock2d<CUDAContext>(
...
@@ -96,8 +92,8 @@ template <> void DropBlock2d<CUDAContext>(
);
);
} else if(data_format == "NHWC") {
} else if(data_format == "NHWC") {
_DropBlock2dNHWC
_DropBlock2dNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
seed_h, seed_w,
seed_h, seed_w,
...
...
Dragon/src/kernels/vision/nn_resize_op_kernel.cu
View file @
d1f714e
...
@@ -81,15 +81,15 @@ template <> void NNResize<float, CUDAContext>(
...
@@ -81,15 +81,15 @@ template <> void NNResize<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w;
auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_NNResizeNCHW
_NNResizeNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y
scale_h, scale_w, x, y
);
);
} else if(data_format == "NHWC") {
} else if(data_format == "NHWC") {
_NNResizeNHWC
_NNResizeNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, x, y
scale_h, scale_w, x, y
);
);
...
@@ -116,8 +116,8 @@ template <> void NNResize<float16, CUDAContext>(
...
@@ -116,8 +116,8 @@ template <> void NNResize<float16, CUDAContext>(
auto scale_w = (float)W / (float)out_w;
auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_NNResizeNCHW
_NNResizeNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W,
nthreads, C, H, W,
out_h, out_w, scale_h, scale_w,
out_h, out_w, scale_h, scale_w,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
...
@@ -125,8 +125,8 @@ template <> void NNResize<float16, CUDAContext>(
...
@@ -125,8 +125,8 @@ template <> void NNResize<float16, CUDAContext>(
);
);
} else if(data_format == "NHWC") {
} else if(data_format == "NHWC") {
_NNResizeNHWC
_NNResizeNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W,
nthreads, C, H, W,
out_h, out_w, scale_h, scale_w,
out_h, out_w, scale_h, scale_w,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
...
@@ -209,15 +209,15 @@ template <> void NNResizeGrad<float, CUDAContext>(
...
@@ -209,15 +209,15 @@ template <> void NNResizeGrad<float, CUDAContext>(
auto scale_w = (float)W / (float)out_w;
auto scale_w = (float)W / (float)out_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_NNResizeGradNCHW
_NNResizeGradNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx
scale_h, scale_w, dy, dx
);
);
} else if(data_format == "NHWC") {
} else if(data_format == "NHWC") {
_NNResizeGradNHWC
_NNResizeGradNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads, C, H, W, out_h, out_w,
nthreads, C, H, W, out_h, out_w,
scale_h, scale_w, dy, dx
scale_h, scale_w, dy, dx
);
);
...
...
Dragon/src/kernels/vision/pool_op_kernel.cu
View file @
d1f714e
...
@@ -120,8 +120,8 @@ template<> void MaxPool2d<float, CUDAContext>(
...
@@ -120,8 +120,8 @@ template<> void MaxPool2d<float, CUDAContext>(
auto nthreads = N * C * pool_h * pool_w;
auto nthreads = N * C * pool_h * pool_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_MaxPool2dNCHW
_MaxPool2dNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -132,8 +132,8 @@ template<> void MaxPool2d<float, CUDAContext>(
...
@@ -132,8 +132,8 @@ template<> void MaxPool2d<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_MaxPool2dNHWC
_MaxPool2dNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -256,8 +256,8 @@ template<> void AvgPool2d<float, CUDAContext>(
...
@@ -256,8 +256,8 @@ template<> void AvgPool2d<float, CUDAContext>(
auto nthreads = N * C * pool_h * pool_w;
auto nthreads = N * C * pool_h * pool_w;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_AvgPool2dNCHW
_AvgPool2dNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -268,8 +268,8 @@ template<> void AvgPool2d<float, CUDAContext>(
...
@@ -268,8 +268,8 @@ template<> void AvgPool2d<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_AvgPool2dNHWC
_AvgPool2dNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -392,8 +392,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>(
...
@@ -392,8 +392,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>(
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_MaxPool2dGrad_NCHW
_MaxPool2dGrad_NCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -404,8 +404,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>(
...
@@ -404,8 +404,8 @@ template<> void MaxPool2dGrad<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_MaxPool2dGradNHWC
_MaxPool2dGradNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -531,8 +531,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>(
...
@@ -531,8 +531,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>(
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (data_format == "NCHW") {
_AvgPool2dGradNCHW
_AvgPool2dGradNCHW
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -543,8 +543,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>(
...
@@ -543,8 +543,8 @@ template<> void AvgPool2dGrad<float, CUDAContext>(
);
);
} else if (data_format == "NHWC") {
} else if (data_format == "NHWC") {
_AvgPool2dGradNHWC
_AvgPool2dGradNHWC
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
...
Dragon/src/kernels/vision/roi_align_op_kernel.cu
View file @
d1f714e
...
@@ -132,8 +132,8 @@ template<> void ROIAlign<float, CUDAContext>(
...
@@ -132,8 +132,8 @@ template<> void ROIAlign<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w;
auto nthreads = num_rois * C * pool_h * pool_w;
_ROIAlign
_ROIAlign
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -283,8 +283,8 @@ template<> void ROIAlignGrad<float, CUDAContext>(
...
@@ -283,8 +283,8 @@ template<> void ROIAlignGrad<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w;
auto nthreads = num_rois * C * pool_h * pool_w;
_ROIAlignGrad
_ROIAlignGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
...
Dragon/src/kernels/vision/roi_align_op_kernel.fp16.cu
View file @
d1f714e
...
@@ -134,8 +134,8 @@ template<> void ROIAlign<float16, CUDAContext>(
...
@@ -134,8 +134,8 @@ template<> void ROIAlign<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w;
auto nthreads = num_rois * C * pool_h * pool_w;
_ROIAlignHalf
_ROIAlignHalf
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>
0, ctx->cuda_stream() >>
>
(nthreads, C, H, W, pool_h, pool_w,
(nthreads, C, H, W, pool_h, pool_w,
sampling_ratio, spatial_scale,
sampling_ratio, spatial_scale,
reinterpret_cast<const half*>(x), rois,
reinterpret_cast<const half*>(x), rois,
...
...
Dragon/src/kernels/vision/roi_pool_op_kernel.cu
View file @
d1f714e
...
@@ -92,8 +92,8 @@ template<> void ROIPool<float, CUDAContext>(
...
@@ -92,8 +92,8 @@ template<> void ROIPool<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w;
auto nthreads = num_rois * C * pool_h * pool_w;
_ROIPool
_ROIPool
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -185,8 +185,8 @@ template<> void ROIPool<float16, CUDAContext>(
...
@@ -185,8 +185,8 @@ template<> void ROIPool<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = num_rois * C * pool_h * pool_w;
auto nthreads = num_rois * C * pool_h * pool_w;
_ROIPoolHalf
_ROIPoolHalf
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
C, H, W,
C, H, W,
pool_h, pool_w,
pool_h, pool_w,
...
@@ -286,8 +286,8 @@ template<> void ROIPoolGrad<float, CUDAContext>(
...
@@ -286,8 +286,8 @@ template<> void ROIPoolGrad<float, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
auto nthreads = N * C * H * W;
auto nthreads = N * C * H * W;
_ROIPoolGrad
_ROIPoolGrad
<<
< CUDA_BLOCKS(nthreads), CUDA_THREADS,
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
nthreads,
nthreads,
num_rois,
num_rois,
C, H, W,
C, H, W,
...
...
Dragon/src/onnx/onnx_backend.cc
View file @
d1f714e
...
@@ -180,9 +180,9 @@ ONNXBackend::get_special_nodes() const {
...
@@ -180,9 +180,9 @@ ONNXBackend::get_special_nodes() const {
};
return
kSpecialNodes
;
};
return
kSpecialNodes
;
}
}
const
Map
<
string
,
Map
<
string
,
string
>
>&
const
Map
<
string
,
Map
<
string
,
string
>
>&
ONNXBackend
::
get_node_renamed_attrs
()
const
{
ONNXBackend
::
get_node_renamed_attrs
()
const
{
const
static
Map
<
string
,
Map
<
string
,
string
>
>
const
static
Map
<
string
,
Map
<
string
,
string
>
>
kPerNodeRenamedAttrs
=
{
kPerNodeRenamedAttrs
=
{
{
"Gemm"
,
{
{
"transB"
,
"transW"
}
}
},
{
"Gemm"
,
{
{
"transB"
,
"transW"
}
}
},
{
"BatchNormalization"
,
{
{
"epsilon"
,
"eps"
}
}
},
{
"BatchNormalization"
,
{
{
"epsilon"
,
"eps"
}
}
},
...
...
Dragon/src/onnx/onnx_backend.h
View file @
d1f714e
...
@@ -221,7 +221,7 @@ class ONNXBackend {
...
@@ -221,7 +221,7 @@ class ONNXBackend {
const
Map
<
string
,
SpecialNodeConverter
>&
get_special_nodes
()
const
;
const
Map
<
string
,
SpecialNodeConverter
>&
get_special_nodes
()
const
;
const
Map
<
string
,
string
>&
get_renamed_attrs
()
const
;
const
Map
<
string
,
string
>&
get_renamed_attrs
()
const
;
const
Map
<
string
,
Map
<
string
,
string
>
>&
get_node_renamed_attrs
()
const
;
const
Map
<
string
,
Map
<
string
,
string
>
>&
get_node_renamed_attrs
()
const
;
};
};
}
// namespace onnx
}
// namespace onnx
...
...
Dragon/src/operators/activation/cudnn_dropout_op.cc
View file @
d1f714e
...
@@ -77,15 +77,8 @@ template <class Context>
...
@@ -77,15 +77,8 @@ template <class Context>
void
CuDNNDropoutOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNDropoutOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -147,15 +140,8 @@ template <class Context>
...
@@ -147,15 +140,8 @@ template <class Context>
void
CuDNNDropoutGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNDropoutGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Dropout
);
DEPLOY_CUDNN
(
Dropout
);
...
...
Dragon/src/operators/activation/cudnn_elu_op.cc
View file @
d1f714e
...
@@ -26,15 +26,8 @@ template <class Context>
...
@@ -26,15 +26,8 @@ template <class Context>
void
CuDNNEluOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNEluOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -60,15 +53,8 @@ template <class Context>
...
@@ -60,15 +53,8 @@ template <class Context>
void
CuDNNEluGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNEluGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Elu
);
DEPLOY_CUDNN
(
Elu
);
...
...
Dragon/src/operators/activation/cudnn_relu_op.cc
View file @
d1f714e
...
@@ -40,15 +40,8 @@ void CuDNNReluOp<Context>::RunOnDevice() {
...
@@ -40,15 +40,8 @@ void CuDNNReluOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -92,15 +85,8 @@ void CuDNNReluGradientOp<Context>::RunOnDevice() {
...
@@ -92,15 +85,8 @@ void CuDNNReluGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Relu
);
DEPLOY_CUDNN
(
Relu
);
...
...
Dragon/src/operators/activation/cudnn_sigmoid_op.cc
View file @
d1f714e
...
@@ -35,15 +35,8 @@ template <class Context>
...
@@ -35,15 +35,8 @@ template <class Context>
void
CuDNNSigmoidOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNSigmoidOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -82,15 +75,8 @@ template <class Context>
...
@@ -82,15 +75,8 @@ template <class Context>
void
CuDNNSigmoidGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNSigmoidGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Sigmoid
);
DEPLOY_CUDNN
(
Sigmoid
);
...
...
Dragon/src/operators/activation/cudnn_softmax_op.cc
View file @
d1f714e
...
@@ -45,15 +45,8 @@ void CuDNNSoftmaxOp<Context>::RunOnDevice() {
...
@@ -45,15 +45,8 @@ void CuDNNSoftmaxOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -91,15 +84,8 @@ void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() {
...
@@ -91,15 +84,8 @@ void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Softmax
);
DEPLOY_CUDNN
(
Softmax
);
...
...
Dragon/src/operators/activation/cudnn_tanh_op.cc
View file @
d1f714e
...
@@ -35,15 +35,8 @@ template <class Context>
...
@@ -35,15 +35,8 @@ template <class Context>
void
CuDNNTanhOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNTanhOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -82,15 +75,8 @@ template <class Context>
...
@@ -82,15 +75,8 @@ template <class Context>
void
CuDNNTanhGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNTanhGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Tanh
);
DEPLOY_CUDNN
(
Tanh
);
...
...
Dragon/src/operators/activation/dropout_op.cc
View file @
d1f714e
...
@@ -44,15 +44,8 @@ template <class Context>
...
@@ -44,15 +44,8 @@ template <class Context>
void
DropoutOp
<
Context
>::
RunOnDevice
()
{
void
DropoutOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -83,15 +76,8 @@ template <class Context>
...
@@ -83,15 +76,8 @@ template <class Context>
void
DropoutGradientOp
<
Context
>::
RunOnDevice
()
{
void
DropoutGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
Dropout
);
DEPLOY_CPU
(
Dropout
);
...
...
Dragon/src/operators/activation/droppath_op.cc
View file @
d1f714e
...
@@ -52,15 +52,8 @@ void DropPathOp<Context>::RunOnDevice() {
...
@@ -52,15 +52,8 @@ void DropPathOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -97,15 +90,8 @@ void DropPathGradientOp<Context>::RunOnDevice() {
...
@@ -97,15 +90,8 @@ void DropPathGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
DropPath
);
DEPLOY_CPU
(
DropPath
);
...
...
Dragon/src/operators/activation/elu_op.cc
View file @
d1f714e
...
@@ -20,13 +20,8 @@ template <class Context>
...
@@ -20,13 +20,8 @@ template <class Context>
void
EluOp
<
Context
>::
RunOnDevice
()
{
void
EluOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -46,13 +41,8 @@ template <class Context>
...
@@ -46,13 +41,8 @@ template <class Context>
void
EluGradientOp
<
Context
>::
RunOnDevice
()
{
void
EluGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
Elu
);
DEPLOY_CPU
(
Elu
);
...
...
Dragon/src/operators/activation/prelu_op.cc
View file @
d1f714e
...
@@ -40,13 +40,8 @@ void PReluOp<Context>::RunOnDevice() {
...
@@ -40,13 +40,8 @@ void PReluOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -98,13 +93,8 @@ void PReluGradientOp<Context>::RunOnDevice() {
...
@@ -98,13 +93,8 @@ void PReluGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
PRelu
);
DEPLOY_CPU
(
PRelu
);
...
...
Dragon/src/operators/activation/relu_op.cc
View file @
d1f714e
...
@@ -20,15 +20,8 @@ template <class Context>
...
@@ -20,15 +20,8 @@ template <class Context>
void
ReluOp
<
Context
>::
RunOnDevice
()
{
void
ReluOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -48,15 +41,8 @@ template <class Context>
...
@@ -48,15 +41,8 @@ template <class Context>
void
ReluGradientOp
<
Context
>::
RunOnDevice
()
{
void
ReluGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
Relu
);
DEPLOY_CPU
(
Relu
);
...
...
Dragon/src/operators/activation/selu_op.cc
View file @
d1f714e
...
@@ -19,15 +19,8 @@ template <class Context>
...
@@ -19,15 +19,8 @@ template <class Context>
void
SEluOp
<
Context
>::
RunOnDevice
()
{
void
SEluOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -47,15 +40,8 @@ template <class Context>
...
@@ -47,15 +40,8 @@ template <class Context>
void
SEluGradientOp
<
Context
>::
RunOnDevice
()
{
void
SEluGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
SElu
);
DEPLOY_CPU
(
SElu
);
...
...
Dragon/src/operators/activation/sigmoid_op.cc
View file @
d1f714e
...
@@ -15,13 +15,8 @@ template <class Context>
...
@@ -15,13 +15,8 @@ template <class Context>
void
SigmoidOp
<
Context
>::
RunOnDevice
()
{
void
SigmoidOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -41,13 +36,8 @@ template <class Context>
...
@@ -41,13 +36,8 @@ template <class Context>
void
SigmoidGradientOp
<
Context
>::
RunOnDevice
()
{
void
SigmoidGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
Sigmoid
);
DEPLOY_CPU
(
Sigmoid
);
...
...
Dragon/src/operators/activation/softmax_op.cc
View file @
d1f714e
...
@@ -43,13 +43,8 @@ void SoftmaxOp<Context>::RunOnDevice() {
...
@@ -43,13 +43,8 @@ void SoftmaxOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -86,13 +81,8 @@ void SoftmaxGradientOp<Context>::RunOnDevice() {
...
@@ -86,13 +81,8 @@ void SoftmaxGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
Softmax
);
DEPLOY_CPU
(
Softmax
);
...
...
Dragon/src/operators/activation/tanh_op.cc
View file @
d1f714e
...
@@ -15,13 +15,8 @@ template <class Context>
...
@@ -15,13 +15,8 @@ template <class Context>
void
TanhOp
<
Context
>::
RunOnDevice
()
{
void
TanhOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -41,13 +36,8 @@ template <class Context>
...
@@ -41,13 +36,8 @@ template <class Context>
void
TanhGradientOp
<
Context
>::
RunOnDevice
()
{
void
TanhGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
Tanh
);
DEPLOY_CPU
(
Tanh
);
...
...
Dragon/src/operators/arithmetic/affine_op.cc
View file @
d1f714e
...
@@ -46,15 +46,8 @@ void AffineOp<Context>::RunOnDevice() {
...
@@ -46,15 +46,8 @@ void AffineOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -111,9 +104,7 @@ void AffineGradientOp<Context>::RunImpl() {
...
@@ -111,9 +104,7 @@ void AffineGradientOp<Context>::RunImpl() {
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
AffineGradientOp
<
Context
>::
Reduce
(
void
AffineGradientOp
<
Context
>::
Reduce
(
T
*
x
,
T
*
y
)
{
T
*
x
,
T
*
y
)
{
vec32_t
dims
=
{
vec32_t
dims
=
{
(
int
)
outer_dim_
,
(
int
)
outer_dim_
,
(
int
)
scale_dim_
,
(
int
)
scale_dim_
,
...
@@ -138,15 +129,8 @@ void AffineGradientOp<Context>::RunOnDevice() {
...
@@ -138,15 +129,8 @@ void AffineGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
-
1
));
Y
(
0
)
->
ReshapeLike
(
X
(
-
1
));
if
(
XIsType
(
X
(
-
1
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
-
1
));
}
else
if
(
XIsType
(
X
(
-
1
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
-
1
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
Affine
);
DEPLOY_CPU
(
Affine
);
...
...
Dragon/src/operators/arithmetic/cudnn_affine_op.cc
View file @
d1f714e
...
@@ -108,13 +108,6 @@ void CuDNNAffineOp<Context>::RunOnDevice() {
...
@@ -108,13 +108,6 @@ void CuDNNAffineOp<Context>::RunOnDevice() {
template
<
class
Context
>
template
<
typename
DT
,
typename
CT
>
template
<
class
Context
>
template
<
typename
DT
,
typename
CT
>
void
CuDNNAffineGradientOp
<
Context
>::
RunImpl
()
{
void
CuDNNAffineGradientOp
<
Context
>::
RunImpl
()
{
this
->
template
ResetDesc
<
DT
>
(
X
(
-
1
));
this
->
template
ResetDesc
<
DT
>
(
X
(
-
1
));
scale_dim_
=
X
(
1
).
count
();
outer_dim_
=
X
(
-
1
).
count
(
0
,
axis_
);
inner_dim_
=
X
(
-
1
).
count
(
axis_
+
num_axes_
);
dim_
=
scale_dim_
*
inner_dim_
;
reduce_dim_
=
std
::
max
(
outer_dim_
,
inner_dim_
);
Y
(
0
)
->
ReshapeLike
(
X
(
-
1
));
auto
*
alpha
=
X
(
1
).
template
data
<
DT
,
Context
>
();
auto
*
alpha
=
X
(
1
).
template
data
<
DT
,
Context
>
();
auto
*
dy
=
X
(
-
1
).
template
mutable_data
<
DT
,
Context
>
();
auto
*
dy
=
X
(
-
1
).
template
mutable_data
<
DT
,
Context
>
();
...
@@ -230,9 +223,7 @@ void CuDNNAffineGradientOp<Context>::CuDNNReduce(
...
@@ -230,9 +223,7 @@ void CuDNNAffineGradientOp<Context>::CuDNNReduce(
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
CuDNNAffineGradientOp
<
Context
>::
Reduce
(
void
CuDNNAffineGradientOp
<
Context
>::
Reduce
(
T
*
x
,
T
*
y
)
{
T
*
x
,
T
*
y
)
{
vec32_t
dims
=
{
vec32_t
dims
=
{
(
int
)
outer_dim_
,
(
int
)
outer_dim_
,
(
int
)
scale_dim_
,
(
int
)
scale_dim_
,
...
@@ -248,6 +239,14 @@ void CuDNNAffineGradientOp<Context>::Reduce(
...
@@ -248,6 +239,14 @@ void CuDNNAffineGradientOp<Context>::Reduce(
template
<
class
Context
>
template
<
class
Context
>
void
CuDNNAffineGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNAffineGradientOp
<
Context
>::
RunOnDevice
()
{
scale_dim_
=
X
(
1
).
count
();
outer_dim_
=
X
(
-
1
).
count
(
0
,
axis_
);
inner_dim_
=
X
(
-
1
).
count
(
axis_
+
num_axes_
);
dim_
=
scale_dim_
*
inner_dim_
;
reduce_dim_
=
std
::
max
(
outer_dim_
,
inner_dim_
);
Y
(
0
)
->
ReshapeLike
(
X
(
-
1
));
if
(
XIsType
(
X
(
-
1
),
float
))
{
if
(
XIsType
(
X
(
-
1
),
float
))
{
RunImpl
<
float
,
float
>
();
RunImpl
<
float
,
float
>
();
}
else
if
(
XIsType
(
X
(
-
1
),
float16
))
{
}
else
if
(
XIsType
(
X
(
-
1
),
float16
))
{
...
...
Dragon/src/operators/arithmetic/eltwise_op.cc
View file @
d1f714e
...
@@ -36,6 +36,13 @@ void EltwiseOp<Context>::ProdRunImpl() {
...
@@ -36,6 +36,13 @@ void EltwiseOp<Context>::ProdRunImpl() {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
EltwiseOp
<
Context
>::
RunImpl
()
{
void
EltwiseOp
<
Context
>::
RunImpl
()
{
if
(
operation_
==
"SUM"
)
SumRunImpl
<
T
>
();
else
if
(
operation_
==
"PROD"
)
ProdRunImpl
<
T
>
();
else
LOG
(
FATAL
)
<<
"Unknwon Operation: "
<<
operation_
;
}
template
<
class
Context
>
void
EltwiseOp
<
Context
>::
RunOnDevice
()
{
for
(
int
i
=
1
;
i
<
XSize
();
i
++
)
{
for
(
int
i
=
1
;
i
<
XSize
();
i
++
)
{
CHECK
(
X
(
i
).
dims
()
==
X
(
0
).
dims
())
CHECK
(
X
(
i
).
dims
()
==
X
(
0
).
dims
())
<<
"
\n
Excepted Input("
<<
i
<<
")'s dims as "
<<
"
\n
Excepted Input("
<<
i
<<
")'s dims as "
...
@@ -45,33 +52,10 @@ void EltwiseOp<Context>::RunImpl() {
...
@@ -45,33 +52,10 @@ void EltwiseOp<Context>::RunImpl() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
operation_
==
"SUM"
)
SumRunImpl
<
T
>
();
DispatchHelper
<
TensorTypes
else
if
(
operation_
==
"PROD"
)
ProdRunImpl
<
T
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
else
LOG
(
FATAL
)
<<
"Unknwon Operation: "
<<
operation_
;
float16
,
float
,
double
>
}
>::
Call
(
this
,
X
(
0
));
template
<
class
Context
>
void
EltwiseOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
RunImpl
<
int8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -133,26 +117,10 @@ void EltwiseGradientOp<Context>::RunImpl() {
...
@@ -133,26 +117,10 @@ void EltwiseGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
EltwiseGradientOp
<
Context
>::
RunOnDevice
()
{
void
EltwiseGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Eltwise
);
DEPLOY_CPU
(
Eltwise
);
...
...
Dragon/src/operators/arithmetic/exp_op.cc
View file @
d1f714e
...
@@ -15,17 +15,9 @@ template <class Context>
...
@@ -15,17 +15,9 @@ template <class Context>
void
ExpOp
<
Context
>::
RunOnDevice
()
{
void
ExpOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -40,17 +32,9 @@ template <class Context>
...
@@ -40,17 +32,9 @@ template <class Context>
void
ExpGradientOp
<
Context
>::
RunOnDevice
()
{
void
ExpGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
Exp
);
DEPLOY_CPU
(
Exp
);
...
...
Dragon/src/operators/arithmetic/fully_connected_op.cc
View file @
d1f714e
...
@@ -84,6 +84,12 @@ void FullyConnectedOp<Context>::NoTransRunImpl() {
...
@@ -84,6 +84,12 @@ void FullyConnectedOp<Context>::NoTransRunImpl() {
}
}
}
}
template
<
class
Context
>
template
<
typename
T
>
void
FullyConnectedOp
<
Context
>::
RunImpl
()
{
if
(
transW_
)
TransRunImpl
<
T
>
();
else
NoTransRunImpl
<
T
>
();
}
template
<
class
Context
>
template
<
class
Context
>
void
FullyConnectedOp
<
Context
>::
RunOnDevice
()
{
void
FullyConnectedOp
<
Context
>::
RunOnDevice
()
{
DETERMINE_RUNTIME_ARGS
(
X
(
0
));
DETERMINE_RUNTIME_ARGS
(
X
(
0
));
...
@@ -101,31 +107,12 @@ void FullyConnectedOp<Context>::RunOnDevice() {
...
@@ -101,31 +107,12 @@ void FullyConnectedOp<Context>::RunOnDevice() {
for
(
int
i
=
0
;
i
<
axis_
+
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
axis_
+
1
;
i
++
)
{
out_shape
[
i
]
=
i
<
axis_
?
X
(
0
).
dim
(
i
)
:
N_
;
out_shape
[
i
]
=
i
<
axis_
?
X
(
0
).
dim
(
i
)
:
N_
;
}
}
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
if
(
transW_
)
{
<
float
,
float16
,
double
>
TransRunImpl
<
float16
>
();
>::
Call
(
this
,
X
(
0
));
}
else
{
NoTransRunImpl
<
float16
>
();
}
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
if
(
transW_
)
{
TransRunImpl
<
float
>
();
}
else
{
NoTransRunImpl
<
float
>
();
}
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
if
(
transW_
)
{
TransRunImpl
<
double
>
();
}
else
{
NoTransRunImpl
<
double
>
();
}
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -209,17 +196,9 @@ void FullyConnectedGradientOp<Context>::RunOnDevice() {
...
@@ -209,17 +196,9 @@ void FullyConnectedGradientOp<Context>::RunOnDevice() {
<<
X
(
1
).
DimString
();
<<
X
(
1
).
DimString
();
}
}
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
FullyConnected
);
DEPLOY_CPU
(
FullyConnected
);
...
...
Dragon/src/operators/arithmetic/gram_matrix_op.cc
View file @
d1f714e
...
@@ -35,17 +35,9 @@ void GramMatrixOp<Context>::RunOnDevice() {
...
@@ -35,17 +35,9 @@ void GramMatrixOp<Context>::RunOnDevice() {
{
outer_dim_
,
axis_dim_
,
axis_dim_
}
{
outer_dim_
,
axis_dim_
,
axis_dim_
}
);
);
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -79,17 +71,9 @@ void GramMatrixGradientOp<Context>::RunOnDevice() {
...
@@ -79,17 +71,9 @@ void GramMatrixGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
GramMatrix
);
DEPLOY_CPU
(
GramMatrix
);
...
...
Dragon/src/operators/arithmetic/log_op.cc
View file @
d1f714e
...
@@ -14,17 +14,9 @@ template <class Context>
...
@@ -14,17 +14,9 @@ template <class Context>
void
LogOp
<
Context
>::
RunOnDevice
()
{
void
LogOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -39,17 +31,9 @@ template <class Context>
...
@@ -39,17 +31,9 @@ template <class Context>
void
LogGradientOp
<
Context
>::
RunOnDevice
()
{
void
LogGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
Log
);
DEPLOY_CPU
(
Log
);
...
...
Dragon/src/operators/arithmetic/matmul_op.cc
View file @
d1f714e
...
@@ -65,17 +65,9 @@ void MatmulOp<Context>::RunOnDevice() {
...
@@ -65,17 +65,9 @@ void MatmulOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -182,17 +174,9 @@ void MatmulGradientOp<Context>::RunOnDevice() {
...
@@ -182,17 +174,9 @@ void MatmulGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
Matmul
);
DEPLOY_CPU
(
Matmul
);
...
...
Dragon/src/operators/arithmetic/maximum_op.cc
View file @
d1f714e
...
@@ -58,26 +58,10 @@ void MaximumOp<Context>::RunImpl() {
...
@@ -58,26 +58,10 @@ void MaximumOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
MaximumOp
<
Context
>::
RunOnDevice
()
{
void
MaximumOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -135,9 +119,6 @@ void MaximumGradientOp<Context>::BroadcastRunImpl() {
...
@@ -135,9 +119,6 @@ void MaximumGradientOp<Context>::BroadcastRunImpl() {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
MaximumGradientOp
<
Context
>::
RunImpl
()
{
void
MaximumGradientOp
<
Context
>::
RunImpl
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
if
(
X
(
0
).
dims
()
==
X
(
1
).
dims
())
{
if
(
X
(
0
).
dims
()
==
X
(
1
).
dims
())
{
EltwiseRunImpl
<
T
>
();
EltwiseRunImpl
<
T
>
();
}
else
{
}
else
{
...
@@ -147,26 +128,13 @@ void MaximumGradientOp<Context>::RunImpl() {
...
@@ -147,26 +128,13 @@ void MaximumGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
MaximumGradientOp
<
Context
>::
RunOnDevice
()
{
void
MaximumGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
RunImpl
<
int8_t
>
();
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
DispatchHelper
<
TensorTypes
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
<
int8_t
,
uint8_t
,
int
,
int64_t
,
RunImpl
<
int
>
();
float16
,
float
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Maximum
);
DEPLOY_CPU
(
Maximum
);
...
...
Dragon/src/operators/arithmetic/minimum_op.cc
View file @
d1f714e
...
@@ -58,26 +58,10 @@ void MinimumOp<Context>::RunImpl() {
...
@@ -58,26 +58,10 @@ void MinimumOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
MinimumOp
<
Context
>::
RunOnDevice
()
{
void
MinimumOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -135,9 +119,6 @@ void MinimumGradientOp<Context>::BroadcastRunImpl() {
...
@@ -135,9 +119,6 @@ void MinimumGradientOp<Context>::BroadcastRunImpl() {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
MinimumGradientOp
<
Context
>::
RunImpl
()
{
void
MinimumGradientOp
<
Context
>::
RunImpl
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
if
(
X
(
0
).
dims
()
==
X
(
1
).
dims
())
{
if
(
X
(
0
).
dims
()
==
X
(
1
).
dims
())
{
EltwiseRunImpl
<
T
>
();
EltwiseRunImpl
<
T
>
();
}
else
{
}
else
{
...
@@ -147,26 +128,13 @@ void MinimumGradientOp<Context>::RunImpl() {
...
@@ -147,26 +128,13 @@ void MinimumGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
MinimumGradientOp
<
Context
>::
RunOnDevice
()
{
void
MinimumGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
RunImpl
<
int8_t
>
();
Y
(
1
)
->
ReshapeLike
(
X
(
1
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
DispatchHelper
<
TensorTypes
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
<
int8_t
,
uint8_t
,
int
,
int64_t
,
RunImpl
<
int
>
();
float16
,
float
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Minimum
);
DEPLOY_CPU
(
Minimum
);
...
...
Dragon/src/operators/arithmetic/moments_op.cc
View file @
d1f714e
...
@@ -7,6 +7,31 @@ namespace dragon {
...
@@ -7,6 +7,31 @@ namespace dragon {
template
<
class
Context
>
template
<
class
Context
>
template
<
typename
Tx
,
typename
Ty
>
template
<
typename
Tx
,
typename
Ty
>
void
MomentsOp
<
Context
>::
RunImpl
()
{
void
MomentsOp
<
Context
>::
RunImpl
()
{
auto
*
x
=
X
(
0
).
template
data
<
Tx
,
Context
>
();
auto
*
mean
=
Y
(
0
)
->
template
mutable_data
<
Ty
,
Context
>
();
auto
*
var
=
Y
(
1
)
->
template
mutable_data
<
Ty
,
Context
>
();
if
(
X
(
0
).
count
()
==
1
)
{
kernel
::
TypeA2B
(
Y
(
0
)
->
count
(),
x
,
mean
,
ctx
()
);
math
::
Set
(
Y
(
0
)
->
count
(),
cast
::
to
<
Ty
>
(
0.
f
),
var
,
ctx
()
);
}
else
{
kernel
::
Moments
(
(
int
)
dims32_
.
size
(),
dims32_
.
data
(),
(
int
)
axes32_
.
size
(),
axes32_
.
data
(),
x
,
mean
,
var
,
ctx
()
);
}
}
template
<
class
Context
>
void
MomentsOp
<
Context
>::
RunOnDevice
()
{
dims_
=
X
(
0
).
dims
();
axes32_
.
clear
();
dims_
=
X
(
0
).
dims
();
axes32_
.
clear
();
dims32_
.
assign
(
dims_
.
begin
(),
dims_
.
end
());
dims32_
.
assign
(
dims_
.
begin
(),
dims_
.
end
());
axes32_
.
assign
(
axes_
.
begin
(),
axes_
.
end
());
axes32_
.
assign
(
axes_
.
begin
(),
axes_
.
end
());
...
@@ -35,31 +60,6 @@ void MomentsOp<Context>::RunImpl() {
...
@@ -35,31 +60,6 @@ void MomentsOp<Context>::RunImpl() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
1
)
->
Reshape
(
out_shape
);
Y
(
1
)
->
Reshape
(
out_shape
);
auto
*
x
=
X
(
0
).
template
data
<
Tx
,
Context
>
();
auto
*
mean
=
Y
(
0
)
->
template
mutable_data
<
Ty
,
Context
>
();
auto
*
var
=
Y
(
1
)
->
template
mutable_data
<
Ty
,
Context
>
();
if
(
X
(
0
).
count
()
==
1
)
{
kernel
::
TypeA2B
(
Y
(
0
)
->
count
(),
x
,
mean
,
ctx
()
);
math
::
Set
(
Y
(
0
)
->
count
(),
cast
::
to
<
Ty
>
(
0.
f
),
var
,
ctx
()
);
}
else
{
kernel
::
Moments
(
(
int
)
dims32_
.
size
(),
dims32_
.
data
(),
(
int
)
axes32_
.
size
(),
axes32_
.
data
(),
x
,
mean
,
var
,
ctx
()
);
}
}
template
<
class
Context
>
void
MomentsOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
RunImpl
<
int8_t
,
float
>
();
RunImpl
<
int8_t
,
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
...
...
Dragon/src/operators/arithmetic/pow_op.cc
View file @
d1f714e
...
@@ -32,17 +32,9 @@ template <class Context>
...
@@ -32,17 +32,9 @@ template <class Context>
void
PowOp
<
Context
>::
RunOnDevice
()
{
void
PowOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -88,17 +80,9 @@ template <class Context>
...
@@ -88,17 +80,9 @@ template <class Context>
void
PowGradientOp
<
Context
>::
RunOnDevice
()
{
void
PowGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
Pow
);
DEPLOY_CPU
(
Pow
);
...
...
Dragon/src/operators/arithmetic/sqrt_op.cc
View file @
d1f714e
...
@@ -14,17 +14,9 @@ template <class Context>
...
@@ -14,17 +14,9 @@ template <class Context>
void
SqrtOp
<
Context
>::
RunOnDevice
()
{
void
SqrtOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -40,17 +32,9 @@ template <class Context>
...
@@ -40,17 +32,9 @@ template <class Context>
void
SqrtGradientOp
<
Context
>::
RunOnDevice
()
{
void
SqrtGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
Sqrt
);
DEPLOY_CPU
(
Sqrt
);
...
...
Dragon/src/operators/arithmetic/square_op.cc
View file @
d1f714e
...
@@ -14,26 +14,10 @@ template <class Context>
...
@@ -14,26 +14,10 @@ template <class Context>
void
SquareOp
<
Context
>::
RunOnDevice
()
{
void
SquareOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -49,26 +33,10 @@ template <class Context>
...
@@ -49,26 +33,10 @@ template <class Context>
void
SquareGradientOp
<
Context
>::
RunOnDevice
()
{
void
SquareGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Square
);
DEPLOY_CPU
(
Square
);
...
...
Dragon/src/operators/array/arange_op.cc
View file @
d1f714e
...
@@ -6,20 +6,24 @@ namespace dragon {
...
@@ -6,20 +6,24 @@ namespace dragon {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
ArangeOp
<
Context
>::
RunImpl
()
{
void
ArangeOp
<
Context
>::
RunImpl
()
{
astart_
=
start
(),
astop_
=
stop
(),
astep_
=
step
();
if
(
astop_
==
0
)
{
astop_
=
astart_
;
astart_
=
0
;
}
dim_
=
(
astop_
-
astart_
-
1
)
/
astep_
+
1
;
CHECK_GT
(
dim_
,
0
)
<<
"
\n
Invalid arguments:
\n
"
<<
"start = "
<<
start
()
<<
", "
<<
"stop = "
<<
stop
()
<<
", "
<<
"step = "
<<
step
()
<<
"."
;
Y
(
0
)
->
Reshape
({
dim_
});
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
kernel
::
Arange
(
dim_
,
astart_
,
astep_
,
y
,
ctx
());
kernel
::
Arange
(
dim_
,
astart_
,
astep_
,
y
,
ctx
());
}
}
template
<
class
Context
>
template
<
class
Context
>
void
ArangeOp
<
Context
>::
RunOnDevice
()
{
void
ArangeOp
<
Context
>::
RunOnDevice
()
{
astart_
=
start
(),
astop_
=
stop
(),
astep_
=
step
();
if
(
astop_
==
0
)
{
astop_
=
astart_
;
astart_
=
0
;
}
dim_
=
(
astop_
-
astart_
-
1
)
/
astep_
+
1
;
CHECK_GT
(
dim_
,
0
)
<<
"
\n
Invalid arguments:
\n
"
<<
"start = "
<<
start
()
<<
", "
<<
"stop = "
<<
stop
()
<<
", "
<<
"step = "
<<
step
()
<<
"."
;
Y
(
0
)
->
Reshape
({
dim_
});
if
(
dtype
()
==
"int8"
)
{
if
(
dtype
()
==
"int8"
)
{
RunImpl
<
int8_t
>
();
RunImpl
<
int8_t
>
();
}
else
if
(
dtype
()
==
"uint8"
)
{
}
else
if
(
dtype
()
==
"uint8"
)
{
...
...
Dragon/src/operators/array/argreduce_op.cc
View file @
d1f714e
...
@@ -101,28 +101,10 @@ void ArgReduceOp<Context>::RunOnDevice() {
...
@@ -101,28 +101,10 @@ void ArgReduceOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
1
)
->
Reshape
(
out_shape
);
Y
(
1
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
ArgReduce
);
DEPLOY_CPU
(
ArgReduce
);
...
...
Dragon/src/operators/array/concat_op.cc
View file @
d1f714e
...
@@ -56,28 +56,10 @@ void ConcatOp<Context>::RunOnDevice() {
...
@@ -56,28 +56,10 @@ void ConcatOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -116,28 +98,10 @@ void ConcatGradientOp<Context>::RunOnDevice() {
...
@@ -116,28 +98,10 @@ void ConcatGradientOp<Context>::RunOnDevice() {
for
(
int
i
=
0
;
i
<
YSize
();
i
++
)
for
(
int
i
=
0
;
i
<
YSize
();
i
++
)
Y
(
i
)
->
ReshapeLike
(
X
(
i
));
Y
(
i
)
->
ReshapeLike
(
X
(
i
));
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Concat
);
DEPLOY_CPU
(
Concat
);
...
...
Dragon/src/operators/array/crop_op.cc
View file @
d1f714e
...
@@ -145,28 +145,10 @@ void CropOp<Context>::RunOnDevice() {
...
@@ -145,28 +145,10 @@ void CropOp<Context>::RunOnDevice() {
TENSOR_FROM_VEC
(
X_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
X_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -209,28 +191,10 @@ void CropGradientOp<Context>::RunOnDevice() {
...
@@ -209,28 +191,10 @@ void CropGradientOp<Context>::RunOnDevice() {
TENSOR_FROM_VEC
(
X_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
X_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Crop
);
DEPLOY_CPU
(
Crop
);
...
...
Dragon/src/operators/array/index_select_op.cc
View file @
d1f714e
...
@@ -55,28 +55,10 @@ void IndexSelectOp<Context>::RunOnDevice() {
...
@@ -55,28 +55,10 @@ void IndexSelectOp<Context>::RunOnDevice() {
CHECK
(
X
(
1
).
template
IsType
<
int64_t
>
())
CHECK
(
X
(
1
).
template
IsType
<
int64_t
>
())
<<
"
\n
The type of indices should be int64."
;
<<
"
\n
The type of indices should be int64."
;
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -115,26 +97,10 @@ void IndexSelectGradientOp<Context>::RunOnDevice() {
...
@@ -115,26 +97,10 @@ void IndexSelectGradientOp<Context>::RunOnDevice() {
CHECK
(
X
(
1
).
template
IsType
<
int64_t
>
())
CHECK
(
X
(
1
).
template
IsType
<
int64_t
>
())
<<
"
\n
The type of indices should be int64."
;
<<
"
\n
The type of indices should be int64."
;
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
IndexSelect
);
DEPLOY_CPU
(
IndexSelect
);
...
...
Dragon/src/operators/array/multinomial_op.cc
View file @
d1f714e
...
@@ -35,17 +35,26 @@ void MultinomialOp<Context>::RunImpl() {
...
@@ -35,17 +35,26 @@ void MultinomialOp<Context>::RunImpl() {
double
running_total
,
r
;
double
running_total
,
r
;
int
yi
=
0
,
num_classes
=
X
(
0
).
dim
(
axis_
);
int
yi
=
0
,
num_classes
=
X
(
0
).
dim
(
axis_
);
double
uniform_p
=
1.
/
(
double
)
num_classes
;
auto
*
rng
=
ctx
()
->
rand_generator
();
auto
*
rng
=
ctx
()
->
rand_generator
();
std
::
uniform_real_distribution
<
float
>
eps_dist
;
for
(
int
i
=
0
;
i
<
outer_dim_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
outer_dim_
;
++
i
)
{
running_total
=
0.
;
running_total
=
0.
;
for
(
int
j
=
0
;
j
<
num_classes
;
++
j
)
{
if
(
eps_
>
0.
f
&&
eps_dist
(
*
rng
)
<
eps_
)
{
running_total
+=
(
double
)
x
[
j
];
for
(
int
j
=
0
;
j
<
num_classes
;
++
j
)
{
cdf
[
j
]
=
running_total
;
running_total
+=
uniform_p
;
cdf
[
j
]
=
running_total
;
}
}
else
{
for
(
int
j
=
0
;
j
<
num_classes
;
++
j
)
{
running_total
+=
(
double
)
x
[
j
];
cdf
[
j
]
=
running_total
;
}
}
}
std
::
uniform_real_distribution
<
double
>
std
::
uniform_real_distribution
<
double
>
dist
(
0.
f
,
running_total
);
dist
(
0.
,
running_total
);
for
(
int
j
=
0
;
j
<
(
int
)
num_samples_
;
++
j
)
{
for
(
int
j
=
0
;
j
<
(
int
)
num_samples_
;
++
j
)
{
r
=
dist
(
*
rng
);
r
=
dist
(
*
rng
);
auto
found_iter
=
std
::
upper_bound
(
auto
found_iter
=
std
::
upper_bound
(
...
@@ -75,24 +84,10 @@ void MultinomialOp<Context>::RunOnDevice() {
...
@@ -75,24 +84,10 @@ void MultinomialOp<Context>::RunOnDevice() {
// Normalize the logits if necessary
// Normalize the logits if necessary
if
(
normalize_
)
SoftmaxRun
();
if
(
normalize_
)
SoftmaxRun
();
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
int64_t
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Multinomial
);
DEPLOY_CPU
(
Multinomial
);
...
...
Dragon/src/operators/array/one_hot_op.cc
View file @
d1f714e
...
@@ -29,17 +29,9 @@ void OneHotOp<Context>::RunOnDevice() {
...
@@ -29,17 +29,9 @@ void OneHotOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
int
,
int64_t
>
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"int32"
,
"int64"
}
);
}
}
}
DEPLOY_CPU
(
OneHot
);
DEPLOY_CPU
(
OneHot
);
...
...
Dragon/src/operators/array/pad_op.cc
View file @
d1f714e
...
@@ -112,28 +112,10 @@ void PadOp<Context>::RunOnDevice() {
...
@@ -112,28 +112,10 @@ void PadOp<Context>::RunOnDevice() {
TENSOR_FROM_VEC
(
X_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
X_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -213,28 +195,10 @@ void PadGradientOp<Context>::RunOnDevice() {
...
@@ -213,28 +195,10 @@ void PadGradientOp<Context>::RunOnDevice() {
TENSOR_FROM_VEC
(
Y_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
Y_strides_
,
X
(
0
).
strides
(),
int
);
TENSOR_FROM_VEC
(
X_dims_
,
X_dims
,
int
);
TENSOR_FROM_VEC
(
X_dims_
,
X_dims
,
int
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Pad
);
DEPLOY_CPU
(
Pad
);
...
...
Dragon/src/operators/array/reduce_op.cc
View file @
d1f714e
...
@@ -14,33 +14,6 @@ namespace dragon {
...
@@ -14,33 +14,6 @@ namespace dragon {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
ReduceOp
<
Context
>::
RunImpl
()
{
void
ReduceOp
<
Context
>::
RunImpl
()
{
dims_
=
X
(
0
).
dims
();
dims32_
.
assign
(
dims_
.
begin
(),
dims_
.
end
());
axes32_
.
assign
(
axes_
.
begin
(),
axes_
.
end
());
if
(
axes32_
.
empty
())
{
// Reduce to a Scalar if missing axes
for
(
int
i
=
0
;
i
<
X
(
0
).
ndim
();
++
i
)
axes32_
.
push_back
(
i
);
}
for
(
int
i
=
0
;
i
<
axes32_
.
size
();
i
++
)
{
int
axis
=
axes32_
[
i
];
axes32_
[
i
]
=
axis
<
0
?
axis
+
X
(
0
).
ndim
()
:
axis
;
CHECK
(
axes32_
[
i
]
>=
0
&&
axes32_
[
i
]
<
X
(
0
).
ndim
())
\
<<
"
\n
Excepted the axis in [-"
<<
X
(
0
).
ndim
()
<<
", "
<<
X
(
0
).
ndim
()
<<
"), got "
<<
axis
<<
"."
;
dims_
[
axes32_
[
i
]]
=
1
;
}
vec64_t
out_shape
;
for
(
const
auto
&
dim
:
dims_
)
{
if
(
dim
!=
1
||
keep_dims_
)
out_shape
.
emplace_back
(
dim
);
}
Y
(
0
)
->
Reshape
(
out_shape
);
auto
*
x
=
X
(
0
).
template
data
<
T
,
Context
>
();
auto
*
x
=
X
(
0
).
template
data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
...
@@ -64,31 +37,8 @@ void ReduceOp<Context>::RunImpl() {
...
@@ -64,31 +37,8 @@ void ReduceOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
ReduceOp
<
Context
>::
RunOnDevice
()
{
void
ReduceOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
dims_
=
X
(
0
).
dims
();
RunImpl
<
int8_t
>
();
dims32_
.
assign
(
dims_
.
begin
(),
dims_
.
end
());
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
template
<
class
Context
>
template
<
typename
T
>
void
ReduceGradientOp
<
Context
>::
RunImpl
()
{
y_dims_
=
X
(
0
).
dims
();
axes32_
.
assign
(
axes_
.
begin
(),
axes_
.
end
());
axes32_
.
assign
(
axes_
.
begin
(),
axes_
.
end
());
if
(
axes32_
.
empty
())
{
if
(
axes32_
.
empty
())
{
...
@@ -103,11 +53,25 @@ void ReduceGradientOp<Context>::RunImpl() {
...
@@ -103,11 +53,25 @@ void ReduceGradientOp<Context>::RunImpl() {
CHECK
(
axes32_
[
i
]
>=
0
&&
axes32_
[
i
]
<
X
(
0
).
ndim
())
\
CHECK
(
axes32_
[
i
]
>=
0
&&
axes32_
[
i
]
<
X
(
0
).
ndim
())
\
<<
"
\n
Excepted the axis in [-"
<<
X
(
0
).
ndim
()
<<
"
\n
Excepted the axis in [-"
<<
X
(
0
).
ndim
()
<<
", "
<<
X
(
0
).
ndim
()
<<
"), got "
<<
axis
<<
"."
;
<<
", "
<<
X
(
0
).
ndim
()
<<
"), got "
<<
axis
<<
"."
;
y_
dims_
[
axes32_
[
i
]]
=
1
;
dims_
[
axes32_
[
i
]]
=
1
;
}
}
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
vec64_t
out_shape
;
for
(
const
auto
&
dim
:
dims_
)
{
if
(
dim
!=
1
||
keep_dims_
)
out_shape
.
emplace_back
(
dim
);
}
Y
(
0
)
->
Reshape
(
out_shape
);
DispatchHelper
<
TensorTypes
<
int8_t
,
uint8_t
,
int
,
int64_t
,
float16
,
float
,
double
>
>::
Call
(
this
,
X
(
0
));
}
template
<
class
Context
>
template
<
typename
T
>
void
ReduceGradientOp
<
Context
>::
RunImpl
()
{
auto
*
dy
=
X
(
1
).
template
data
<
T
,
Context
>
();
auto
*
dy
=
X
(
1
).
template
data
<
T
,
Context
>
();
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
...
@@ -153,26 +117,30 @@ void ReduceGradientOp<Context>::RunImpl() {
...
@@ -153,26 +117,30 @@ void ReduceGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
ReduceGradientOp
<
Context
>::
RunOnDevice
()
{
void
ReduceGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
int8_t
))
{
y_dims_
=
X
(
0
).
dims
();
RunImpl
<
int8_t
>
();
axes32_
.
assign
(
axes_
.
begin
(),
axes_
.
end
());
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
if
(
axes32_
.
empty
())
{
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
// Reduce to a Scalar if missing axes
RunImpl
<
int
>
();
for
(
int
i
=
0
;
i
<
X
(
0
).
ndim
();
++
i
)
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
axes32_
.
push_back
(
i
);
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
for
(
int
i
=
0
;
i
<
axes32_
.
size
();
i
++
)
{
int
axis
=
axes32_
[
i
];
axes32_
[
i
]
=
axis
<
0
?
axis
+
X
(
0
).
ndim
()
:
axis
;
CHECK
(
axes32_
[
i
]
>=
0
&&
axes32_
[
i
]
<
X
(
0
).
ndim
())
\
<<
"
\n
Excepted the axis in [-"
<<
X
(
0
).
ndim
()
<<
", "
<<
X
(
0
).
ndim
()
<<
"), got "
<<
axis
<<
"."
;
y_dims_
[
axes32_
[
i
]]
=
1
;
}
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
DispatchHelper
<
TensorTypes
<
int8_t
,
uint8_t
,
int
,
int64_t
,
float16
,
float
,
double
>
>::
Call
(
this
,
X
(
0
));
}
}
DEPLOY_CPU
(
Reduce
);
DEPLOY_CPU
(
Reduce
);
...
...
Dragon/src/operators/array/repeat_op.cc
View file @
d1f714e
...
@@ -45,28 +45,10 @@ void RepeatOp<Context>::RunOnDevice() {
...
@@ -45,28 +45,10 @@ void RepeatOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
}
}
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -98,26 +80,10 @@ void RepeatGradientOp<Context>::RunOnDevice() {
...
@@ -98,26 +80,10 @@ void RepeatGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Repeat
);
DEPLOY_CPU
(
Repeat
);
...
...
Dragon/src/operators/array/slice_op.cc
View file @
d1f714e
...
@@ -65,28 +65,10 @@ void SliceOp<Context>::RunOnDevice() {
...
@@ -65,28 +65,10 @@ void SliceOp<Context>::RunOnDevice() {
outer_dim_
=
X
(
0
).
count
(
0
,
axis_
);
outer_dim_
=
X
(
0
).
count
(
0
,
axis_
);
inner_dim_
=
X
(
0
).
count
(
axis_
+
1
);
inner_dim_
=
X
(
0
).
count
(
axis_
+
1
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -136,28 +118,10 @@ void SliceGradientOp<Context>::RunOnDevice() {
...
@@ -136,28 +118,10 @@ void SliceGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Slice
);
DEPLOY_CPU
(
Slice
);
...
...
Dragon/src/operators/array/stack_op.cc
View file @
d1f714e
...
@@ -51,28 +51,10 @@ void StackOp<Context>::RunOnDevice() {
...
@@ -51,28 +51,10 @@ void StackOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -104,28 +86,10 @@ void StackGradientOp<Context>::RunOnDevice() {
...
@@ -104,28 +86,10 @@ void StackGradientOp<Context>::RunOnDevice() {
for
(
int
i
=
0
;
i
<
YSize
();
i
++
)
for
(
int
i
=
0
;
i
<
YSize
();
i
++
)
Y
(
i
)
->
ReshapeLike
(
X
(
i
));
Y
(
i
)
->
ReshapeLike
(
X
(
i
));
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Stack
);
DEPLOY_CPU
(
Stack
);
...
...
Dragon/src/operators/array/tile_op.cc
View file @
d1f714e
...
@@ -60,28 +60,10 @@ void TileOp<Context>::RunOnDevice() {
...
@@ -60,28 +60,10 @@ void TileOp<Context>::RunOnDevice() {
TENSOR_FROM_VEC
(
X_dims_
,
X
(
0
).
dims
(),
int
);
TENSOR_FROM_VEC
(
X_dims_
,
X
(
0
).
dims
(),
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
TENSOR_FROM_VEC
(
Y_dims_
,
Y_dims
,
int
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -99,7 +81,7 @@ void TileGradientOp<Context>::RunImpl() {
...
@@ -99,7 +81,7 @@ void TileGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
TileGradientOp
<
Context
>::
RunOnDevice
()
{
void
TileGradientOp
<
Context
>::
RunOnDevice
()
{
// Add the axes
// Add the axes
vector
<
pair
<
int
,
int
>
>
dispatch_axes
;
vector
<
pair
<
int
,
int
>
>
dispatch_axes
;
for
(
int
i
=
0
;
i
<
X
(
0
).
ndim
();
i
++
)
{
for
(
int
i
=
0
;
i
<
X
(
0
).
ndim
();
i
++
)
{
auto
m
=
multiples
(
i
);
auto
m
=
multiples
(
i
);
if
(
m
>
1
)
{
dispatch_axes
.
push_back
({
m
,
i
});
}
if
(
m
>
1
)
{
dispatch_axes
.
push_back
({
m
,
i
});
}
...
@@ -128,26 +110,11 @@ void TileGradientOp<Context>::RunOnDevice() {
...
@@ -128,26 +110,11 @@ void TileGradientOp<Context>::RunOnDevice() {
rows_
=
dst_
->
count
(
0
,
axis_
);
rows_
=
dst_
->
count
(
0
,
axis_
);
cols_
=
dst_
->
count
(
axis_
);
cols_
=
dst_
->
count
(
axis_
);
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
ctx
()
->
FinishDeviceCompution
();
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
ctx
()
->
FinishDeviceCompution
();
// Protect X if num_axes >= 2
// Protect X if num_axes >= 2
std
::
swap
(
src_
,
dst_
);
std
::
swap
(
src_
,
dst_
);
...
...
Dragon/src/operators/array/transpose_op.cc
View file @
d1f714e
...
@@ -54,28 +54,10 @@ void TransposeOp<Context>::RunOnDevice() {
...
@@ -54,28 +54,10 @@ void TransposeOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -125,28 +107,10 @@ void TransposeGradientOp<Context>::RunOnDevice() {
...
@@ -125,28 +107,10 @@ void TransposeGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Transpose
);
DEPLOY_CPU
(
Transpose
);
...
...
Dragon/src/operators/control_flow/assign_op.cc
View file @
d1f714e
...
@@ -116,28 +116,10 @@ void AssignOp<Context>::RunOnDevice() {
...
@@ -116,28 +116,10 @@ void AssignOp<Context>::RunOnDevice() {
TENSOR_FROM_VECTOR
(
X_dims_
,
X_dims
,
int
);
TENSOR_FROM_VECTOR
(
X_dims_
,
X_dims
,
int
);
TENSOR_FROM_VECTOR
(
Y_strides_
,
Y
(
0
)
->
strides
(),
int
);
TENSOR_FROM_VECTOR
(
Y_strides_
,
Y
(
0
)
->
strides
(),
int
);
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Assign
);
DEPLOY_CPU
(
Assign
);
...
...
Dragon/src/operators/control_flow/copy_op.cc
View file @
d1f714e
...
@@ -14,28 +14,10 @@ template <class Context>
...
@@ -14,28 +14,10 @@ template <class Context>
void
CopyOp
<
Context
>::
RunOnDevice
()
{
void
CopyOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
Copy
);
DEPLOY_CPU
(
Copy
);
...
...
Dragon/src/operators/control_flow/masked_assign_op.cc
View file @
d1f714e
...
@@ -65,28 +65,10 @@ void MaskedAssignOp<Context>::RunOnDevice() {
...
@@ -65,28 +65,10 @@ void MaskedAssignOp<Context>::RunOnDevice() {
CHECK
(
XIsType
(
X
(
1
),
bool
)
||
XIsType
(
X
(
1
),
uint8_t
))
CHECK
(
XIsType
(
X
(
1
),
bool
)
||
XIsType
(
X
(
1
),
uint8_t
))
<<
"
\n
Excepted bool or uint8 mask."
;
<<
"
\n
Excepted bool or uint8 mask."
;
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
MaskedAssign
);
DEPLOY_CPU
(
MaskedAssign
);
...
...
Dragon/src/operators/loss/ctc_loss_op.cc
View file @
d1f714e
...
@@ -23,13 +23,8 @@ void CTCLossGradientOp<Context>::RunImpl() {
...
@@ -23,13 +23,8 @@ void CTCLossGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
CTCLossGradientOp
<
Context
>::
RunOnDevice
()
{
void
CTCLossGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
CTCLoss
);
DEPLOY_CPU
(
CTCLoss
);
...
...
Dragon/src/operators/loss/l1_loss_op.cc
View file @
d1f714e
...
@@ -15,9 +15,7 @@ void L1LossOp<Context>::RunImpl() {
...
@@ -15,9 +15,7 @@ void L1LossOp<Context>::RunImpl() {
->
ReshapeLike
(
X
(
0
))
->
ReshapeLike
(
X
(
0
))
->
template
mutable_data
<
T
,
Context
>
();
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
->
Reshape
({})
->
template
mutable_data
<
T
,
Context
>
();
if
(
XSize
()
>
1
)
{
if
(
XSize
()
>
1
)
{
auto
*
target
=
X
(
1
).
template
data
<
T
,
Context
>
();
auto
*
target
=
X
(
1
).
template
data
<
T
,
Context
>
();
...
@@ -53,13 +51,10 @@ void L1LossOp<Context>::RunOnDevice() {
...
@@ -53,13 +51,10 @@ void L1LossOp<Context>::RunOnDevice() {
<<
"while "
<<
X
(
0
).
DimString
()
<<
" is required."
;
<<
"while "
<<
X
(
0
).
DimString
()
<<
" is required."
;
}
}
if
(
XIsType
(
X
(
0
),
float
))
{
Y
(
0
)
->
Reshape
({});
RunImpl
<
float
>
();
}
else
{
DispatchHelper
<
TensorTypes
LOG
(
FATAL
)
<<
DTypeString
(
<
float
>>::
Call
(
this
,
X
(
0
));
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -104,13 +99,8 @@ void L1LossGradientOp<Context>::RunImpl() {
...
@@ -104,13 +99,8 @@ void L1LossGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
L1LossGradientOp
<
Context
>::
RunOnDevice
()
{
void
L1LossGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
L1Loss
);
DEPLOY_CPU
(
L1Loss
);
...
...
Dragon/src/operators/loss/l2_loss_op.cc
View file @
d1f714e
...
@@ -14,9 +14,7 @@ void L2LossOp<Context>::RunImpl() {
...
@@ -14,9 +14,7 @@ void L2LossOp<Context>::RunImpl() {
->
ReshapeLike
(
X
(
0
))
->
ReshapeLike
(
X
(
0
))
->
template
mutable_data
<
T
,
Context
>
();
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
float
,
Context
>
();
->
Reshape
({})
->
template
mutable_data
<
float
,
Context
>
();
if
(
XSize
()
>
1
)
{
if
(
XSize
()
>
1
)
{
auto
*
target
=
X
(
1
).
template
data
<
T
,
Context
>
();
auto
*
target
=
X
(
1
).
template
data
<
T
,
Context
>
();
...
@@ -56,15 +54,10 @@ void L2LossOp<Context>::RunOnDevice() {
...
@@ -56,15 +54,10 @@ void L2LossOp<Context>::RunOnDevice() {
<<
"while "
<<
X
(
0
).
DimString
()
<<
" is required."
;
<<
"while "
<<
X
(
0
).
DimString
()
<<
" is required."
;
}
}
if
(
XIsType
(
X
(
0
),
float
))
{
Y
(
0
)
->
Reshape
({});
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -107,15 +100,8 @@ void L2LossGradientOp<Context>::RunImpl() {
...
@@ -107,15 +100,8 @@ void L2LossGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
L2LossGradientOp
<
Context
>::
RunOnDevice
()
{
void
L2LossGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
L2Loss
);
DEPLOY_CPU
(
L2Loss
);
...
...
Dragon/src/operators/loss/sigmoid_ce_loss_op.cc
View file @
d1f714e
...
@@ -50,13 +50,8 @@ void SigmoidCrossEntropyOp<Context>::RunOnDevice() {
...
@@ -50,13 +50,8 @@ void SigmoidCrossEntropyOp<Context>::RunOnDevice() {
loss_
.
ReshapeLike
(
X
(
0
));
loss_
.
ReshapeLike
(
X
(
0
));
flag_
.
ReshapeLike
(
X
(
0
));
flag_
.
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -113,13 +108,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() {
...
@@ -113,13 +108,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
flag_
.
ReshapeLike
(
X
(
0
));
flag_
.
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
SigmoidCrossEntropy
);
DEPLOY_CPU
(
SigmoidCrossEntropy
);
...
...
Dragon/src/operators/loss/smooth_l1_loss_op.cc
View file @
d1f714e
...
@@ -44,7 +44,6 @@ void SmoothL1LossOp<Context>::RunImpl() {
...
@@ -44,7 +44,6 @@ void SmoothL1LossOp<Context>::RunImpl() {
normalizer
=
X
(
0
).
count
();
normalizer
=
X
(
0
).
count
();
}
}
Y
(
0
)
->
Reshape
({});
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
math
::
Sum
(
nelements
,
1.
/
normalizer
,
err
,
y
,
ctx
());
math
::
Sum
(
nelements
,
1.
/
normalizer
,
err
,
y
,
ctx
());
}
}
...
@@ -53,13 +52,10 @@ template <class Context>
...
@@ -53,13 +52,10 @@ template <class Context>
void
SmoothL1LossOp
<
Context
>::
RunOnDevice
()
{
void
SmoothL1LossOp
<
Context
>::
RunOnDevice
()
{
CHECK
(
X
(
0
).
count
()
==
X
(
1
).
count
());
CHECK
(
X
(
0
).
count
()
==
X
(
1
).
count
());
if
(
XIsType
(
X
(
0
),
float
))
{
Y
(
0
)
->
Reshape
({});
RunImpl
<
float
>
();
}
else
{
DispatchHelper
<
TensorTypes
LOG
(
FATAL
)
<<
DTypeString
(
<
float
>>::
Call
(
this
,
X
(
0
));
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -110,13 +106,8 @@ void SmoothL1LossGradientOp<Context>::RunImpl() {
...
@@ -110,13 +106,8 @@ void SmoothL1LossGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
SmoothL1LossGradientOp
<
Context
>::
RunOnDevice
()
{
void
SmoothL1LossGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
SmoothL1Loss
);
DEPLOY_CPU
(
SmoothL1Loss
);
...
...
Dragon/src/operators/loss/softmax_ce_loss_op.cc
View file @
d1f714e
...
@@ -89,13 +89,8 @@ void SoftmaxCrossEntropyOp<Context>::RunOnDevice() {
...
@@ -89,13 +89,8 @@ void SoftmaxCrossEntropyOp<Context>::RunOnDevice() {
SoftmaxRun
();
SoftmaxRun
();
loss_
.
ReshapeLike
(
X
(
0
));
loss_
.
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -161,13 +156,8 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
...
@@ -161,13 +156,8 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
SoftmaxCrossEntropy
);
DEPLOY_CPU
(
SoftmaxCrossEntropy
);
...
...
Dragon/src/operators/misc/accuracy_op.cc
View file @
d1f714e
...
@@ -24,7 +24,7 @@ void AccuracyOp<Context>::RunImpl() {
...
@@ -24,7 +24,7 @@ void AccuracyOp<Context>::RunImpl() {
const
int
label
=
target
[
i
*
inner_dim_
+
j
];
const
int
label
=
target
[
i
*
inner_dim_
+
j
];
for
(
int
k
=
0
;
k
<
ignore_
.
count
();
k
++
)
for
(
int
k
=
0
;
k
<
ignore_
.
count
();
k
++
)
if
(
label
==
ignore
[
k
])
continue
;
if
(
label
==
ignore
[
k
])
continue
;
vector
<
pair
<
Tx
,
int
>
>
vec
;
vector
<
pair
<
Tx
,
int
>
>
vec
;
for
(
int
k
=
0
;
k
<
axis_dim_
;
k
++
)
for
(
int
k
=
0
;
k
<
axis_dim_
;
k
++
)
vec
.
push_back
(
vec
.
push_back
(
std
::
make_pair
(
std
::
make_pair
(
...
@@ -35,7 +35,7 @@ void AccuracyOp<Context>::RunImpl() {
...
@@ -35,7 +35,7 @@ void AccuracyOp<Context>::RunImpl() {
vec
.
begin
(),
vec
.
begin
(),
vec
.
begin
()
+
top_k_
,
vec
.
begin
()
+
top_k_
,
vec
.
end
(),
vec
.
end
(),
std
::
greater
<
pair
<
Tx
,
int
>
>
()
std
::
greater
<
pair
<
Tx
,
int
>
>
()
);
);
for
(
int
k
=
0
;
k
<
top_k_
;
k
++
)
{
for
(
int
k
=
0
;
k
<
top_k_
;
k
++
)
{
if
(
vec
[
k
].
second
==
label
)
{
acc
++
;
break
;
}
if
(
vec
[
k
].
second
==
label
)
{
acc
++
;
break
;
}
...
...
Dragon/src/operators/misc/gradient_op.cc
View file @
d1f714e
...
@@ -18,43 +18,19 @@ void GradientGenerateOp<Context>::RunImpl() {
...
@@ -18,43 +18,19 @@ void GradientGenerateOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
GradientGenerateOp
<
Context
>::
RunOnDevice
()
{
void
GradientGenerateOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
bool
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
bool
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
float16
,
float
,
double
>
RunImpl
<
int8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
GradientGenerate
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
GradientGenerate
);
#endif
OPERATOR_SCHEMA
(
GradientGenerate
);
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
GradientGatherOp
<
Context
>::
RunImpl
()
{
void
GradientGatherOp
<
Context
>::
RunImpl
()
{
int64_t
count
=
Y
(
0
)
->
count
();
int64_t
count
=
Y
(
0
)
->
count
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
if
(
indices
.
size
()
==
1
)
{
if
(
indices
.
size
()
==
1
)
{
auto
*
x
=
X
(
indices
[
0
]).
template
data
<
T
,
Context
>
();
auto
*
x
=
X
(
indices
[
0
]).
template
data
<
T
,
Context
>
();
ctx
()
->
template
Copy
<
T
,
Context
,
Context
>
(
count
,
y
,
x
);
math
::
Copy
(
count
,
x
,
y
,
ctx
()
);
}
else
if
(
indices
.
size
()
==
2
)
{
}
else
if
(
indices
.
size
()
==
2
)
{
CHECK_EQ
(
count
,
X
(
indices
[
1
]).
count
());
CHECK_EQ
(
count
,
X
(
indices
[
1
]).
count
());
auto
*
a
=
X
(
indices
[
0
]).
template
data
<
T
,
Context
>
();
auto
*
a
=
X
(
indices
[
0
]).
template
data
<
T
,
Context
>
();
...
@@ -63,7 +39,7 @@ void GradientGatherOp<Context>::RunImpl() {
...
@@ -63,7 +39,7 @@ void GradientGatherOp<Context>::RunImpl() {
}
else
{
}
else
{
size_t
i
=
1
;
size_t
i
=
1
;
auto
*
x
=
X
(
indices
[
0
]).
template
data
<
T
,
Context
>
();
auto
*
x
=
X
(
indices
[
0
]).
template
data
<
T
,
Context
>
();
ctx
()
->
template
Copy
<
T
,
Context
,
Context
>
(
count
,
y
,
x
);
math
::
Copy
(
count
,
x
,
y
,
ctx
()
);
while
(
i
<
indices
.
size
())
{
while
(
i
<
indices
.
size
())
{
if
(
indices
.
size
()
-
i
>=
2
)
{
if
(
indices
.
size
()
-
i
>=
2
)
{
auto
*
a
=
X
(
indices
[
i
]).
template
data
<
T
,
Context
>
();
auto
*
a
=
X
(
indices
[
i
]).
template
data
<
T
,
Context
>
();
...
@@ -84,34 +60,12 @@ void GradientGatherOp<Context>::RunOnDevice() {
...
@@ -84,34 +60,12 @@ void GradientGatherOp<Context>::RunOnDevice() {
auto
&
Xi
=
X
(
indices
[
0
]);
auto
&
Xi
=
X
(
indices
[
0
]);
Y
(
0
)
->
ReshapeLike
(
Xi
);
Y
(
0
)
->
ReshapeLike
(
Xi
);
if
(
XIsType
(
Xi
,
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
Xi
,
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
Xi
);
}
else
if
(
XIsType
(
Xi
,
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
Xi
,
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
Xi
,
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
Xi
,
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
Xi
,
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
Xi
,
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
GradientGather
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
GradientGather
);
#endif
OPERATOR_SCHEMA
(
GradientGather
).
NumOutputs
(
1
);
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
GradientAddOp
<
Context
>::
RunImpl
()
{
void
GradientAddOp
<
Context
>::
RunImpl
()
{
auto
*
x
=
X
(
1
).
template
data
<
T
,
Context
>
();
auto
*
x
=
X
(
1
).
template
data
<
T
,
Context
>
();
...
@@ -124,37 +78,12 @@ void GradientAddOp<Context>::RunOnDevice() {
...
@@ -124,37 +78,12 @@ void GradientAddOp<Context>::RunOnDevice() {
CHECK_EQ
(
X
(
0
).
name
(),
Y
(
0
)
->
name
())
CHECK_EQ
(
X
(
0
).
name
(),
Y
(
0
)
->
name
())
<<
"
\n
Requires X(0) == Y(0)."
;
<<
"
\n
Requires X(0) == Y(0)."
;
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
float16
,
float
,
double
>
RunImpl
<
uint8_t
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
}
DEPLOY_CPU
(
GradientAdd
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
GradientAdd
);
#endif
OPERATOR_SCHEMA
(
GradientAdd
)
.
NumInputs
(
2
).
NumOutputs
(
1
)
.
Inplace
({
{
0
,
0
}
});
template
<
class
Context
>
template
<
class
Context
>
void
StopGradientOp
<
Context
>::
RunOnDevice
()
{
void
StopGradientOp
<
Context
>::
RunOnDevice
()
{
if
(
Y
(
0
)
->
name
()
!=
X
(
0
).
name
())
{
if
(
Y
(
0
)
->
name
()
!=
X
(
0
).
name
())
{
...
@@ -163,14 +92,53 @@ void StopGradientOp<Context>::RunOnDevice() {
...
@@ -163,14 +92,53 @@ void StopGradientOp<Context>::RunOnDevice() {
}
}
}
}
DEPLOY_CPU
(
GradientGenerate
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
GradientGenerate
);
#endif
DEPLOY_CPU
(
GradientGather
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
GradientGather
);
#endif
DEPLOY_CPU
(
GradientAdd
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
GradientAdd
);
#endif
DEPLOY_CPU
(
StopGradient
);
DEPLOY_CPU
(
StopGradient
);
#ifdef WITH_CUDA
#ifdef WITH_CUDA
DEPLOY_CUDA
(
StopGradient
);
DEPLOY_CUDA
(
StopGradient
);
#endif
#endif
OPERATOR_SCHEMA
(
GradientGenerate
)
/* X(0), ... */
.
NumInputs
(
1
,
INT_MAX
)
/* Y(0), ... */
.
NumOutputs
(
1
,
INT_MAX
);
OPERATOR_SCHEMA
(
GradientGather
)
/* X(0), ... */
.
NumInputs
(
1
,
INT_MAX
)
/* Y */
.
NumOutputs
(
1
);
OPERATOR_SCHEMA
(
GradientAdd
)
/* X(0), X(1) */
.
NumInputs
(
2
)
/* Y */
.
NumOutputs
(
1
)
/* X(0) => Y */
.
Inplace
({
{
0
,
0
}
});
OPERATOR_SCHEMA
(
StopGradient
)
OPERATOR_SCHEMA
(
StopGradient
)
.
NumInputs
(
1
).
NumOutputs
(
1
)
/* X */
.
Inplace
({
{
0
,
0
}
});;
.
NumInputs
(
1
)
/* Y */
.
NumOutputs
(
1
)
/* X => Y */
.
Inplace
({
{
0
,
0
}
});
NO_GRADIENT
(
StopGradient
);
NO_GRADIENT
(
StopGradient
);
...
...
Dragon/src/operators/misc/initialize_op.cc
View file @
d1f714e
...
@@ -5,7 +5,7 @@ namespace dragon {
...
@@ -5,7 +5,7 @@ namespace dragon {
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
void
InitializeOp
<
Context
>::
RunImpl
()
{
void
InitializeOp
<
Context
>::
RunImpl
()
{
unique_ptr
<
Filler
<
T
,
Context
>
>
f
;
unique_ptr
<
Filler
<
T
,
Context
>
>
f
;
f
.
reset
(
CreateFiller
<
T
,
Context
>
(
proto_
));
f
.
reset
(
CreateFiller
<
T
,
Context
>
(
proto_
));
f
->
Fill
(
Y
(
0
),
ctx
());
f
->
Fill
(
Y
(
0
),
ctx
());
}
}
...
...
Dragon/src/operators/misc/python_op.cc
View file @
d1f714e
...
@@ -152,9 +152,11 @@ DEPLOY_CUDA(TemplateGradient);
...
@@ -152,9 +152,11 @@ DEPLOY_CUDA(TemplateGradient);
#endif
#endif
OPERATOR_SCHEMA
(
TemplateGradient
);
OPERATOR_SCHEMA
(
TemplateGradient
);
class
GetTemplateGradient
final
:
public
GradientMakerBase
{
namespace
{
class
GradientMaker
final
:
public
GradientMakerBase
{
public
:
public
:
GRADIENT_MAKER_CTOR
(
G
etTemplateGradient
);
GRADIENT_MAKER_CTOR
(
G
radientMaker
);
vector
<
OperatorDef
>
MakeDef
()
override
{
vector
<
OperatorDef
>
MakeDef
()
override
{
vector
<
string
>
inputs
,
outputs
;
vector
<
string
>
inputs
,
outputs
;
for
(
auto
input
:
def
.
input
())
inputs
.
push_back
(
input
);
for
(
auto
input
:
def
.
input
())
inputs
.
push_back
(
input
);
...
@@ -164,7 +166,9 @@ class GetTemplateGradient final : public GradientMakerBase {
...
@@ -164,7 +166,9 @@ class GetTemplateGradient final : public GradientMakerBase {
}
}
};
};
REGISTER_GRADIENT
(
Template
,
GetTemplateGradient
);
}
// namespace
REGISTER_GRADIENT
(
Template
,
GradientMaker
);
}
// namespace dragon
}
// namespace dragon
...
...
Dragon/src/operators/mpi/mpi_broadcast_op.cc
View file @
d1f714e
...
@@ -35,22 +35,10 @@ void MPIBroadcastOp<Context>::RunOnDevice() {
...
@@ -35,22 +35,10 @@ void MPIBroadcastOp<Context>::RunOnDevice() {
BCast
(
dims
.
data
(),
ndim
);
BCast
(
dims
.
data
(),
ndim
);
Y
(
0
)
->
Reshape
(
dims
);
Y
(
0
)
->
Reshape
(
dims
);
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
float16
,
float
,
double
>
RunImpl
<
int
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -74,22 +62,10 @@ template <class Context>
...
@@ -74,22 +62,10 @@ template <class Context>
void
MPIBroadcastGradientOp
<
Context
>::
RunOnDevice
()
{
void
MPIBroadcastGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
-
1
));
Y
(
0
)
->
ReshapeLike
(
X
(
-
1
));
if
(
XIsType
(
X
(
-
1
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
-
1
),
int
))
{
float16
,
float
,
double
>
RunImpl
<
int
>
();
>::
Call
(
this
,
X
(
-
1
));
}
else
if
(
XIsType
(
X
(
-
1
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
-
1
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
-
1
),
float
))
{
RunImpl
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
-
1
),
{
"int8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
});
}
}
}
DEPLOY_CPU
(
MPIBroadcast
);
DEPLOY_CPU
(
MPIBroadcast
);
...
...
Dragon/src/operators/mpi/mpi_gather_op.cc
View file @
d1f714e
...
@@ -50,22 +50,10 @@ void MPIGatherOp<Context>::RunOnDevice() {
...
@@ -50,22 +50,10 @@ void MPIGatherOp<Context>::RunOnDevice() {
}
}
}
}
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
float16
,
float
,
double
>
RunImpl
<
int
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
});
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -88,22 +76,10 @@ template <class Context>
...
@@ -88,22 +76,10 @@ template <class Context>
void
MPIGatherGradientOp
<
Context
>::
RunOnDevice
()
{
void
MPIGatherGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
int8_t
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
int8_t
>
();
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
float16
,
float
,
double
>
RunImpl
<
int
>
();
>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"int8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
});
}
}
}
DEPLOY_CPU
(
MPIGather
);
DEPLOY_CPU
(
MPIGather
);
...
...
Dragon/src/operators/norm/l2_norm_op.cc
View file @
d1f714e
...
@@ -73,17 +73,9 @@ void L2NormOp<Context>::RunOnDevice() {
...
@@ -73,17 +73,9 @@ void L2NormOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -177,17 +169,9 @@ void L2NormGradientOp<Context>::RunOnDevice() {
...
@@ -177,17 +169,9 @@ void L2NormGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float16
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float16
>
();
<
float
,
float16
,
double
>
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
>::
Call
(
this
,
X
(
0
));
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float16"
,
"float32"
,
"float64"
}
);
}
}
}
DEPLOY_CPU
(
L2Norm
);
DEPLOY_CPU
(
L2Norm
);
...
...
Dragon/src/operators/recurrent/cudnn_recurrent_op.cc
View file @
d1f714e
...
@@ -211,15 +211,8 @@ void CuDNNRecurrentOp<Context>::RunImpl() {
...
@@ -211,15 +211,8 @@ void CuDNNRecurrentOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
CuDNNRecurrentOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNRecurrentOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -313,15 +306,8 @@ void CuDNNRecurrentGradientOp<Context>::RunOnDevice() {
...
@@ -313,15 +306,8 @@ void CuDNNRecurrentGradientOp<Context>::RunOnDevice() {
Y
(
2
)
->
ReshapeLike
(
X
(
2
));
// dHx
Y
(
2
)
->
ReshapeLike
(
X
(
2
));
// dHx
Y
(
3
)
->
ReshapeLike
(
X
(
3
));
// dCx
Y
(
3
)
->
ReshapeLike
(
X
(
3
));
// dCx
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Recurrent
);
DEPLOY_CUDNN
(
Recurrent
);
...
...
Dragon/src/operators/recurrent/rnn_param_op.cc
View file @
d1f714e
...
@@ -52,15 +52,8 @@ void RNNParamSetOp<Context>::RunImpl() {
...
@@ -52,15 +52,8 @@ void RNNParamSetOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
RNNParamSetOp
<
Context
>::
RunOnDevice
()
{
void
RNNParamSetOp
<
Context
>::
RunOnDevice
()
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
RNNParamSet
);
DEPLOY_CPU
(
RNNParamSet
);
...
...
Dragon/src/operators/vision/bias_add_op.cc
View file @
d1f714e
...
@@ -43,13 +43,8 @@ void BiasAddOp<Context>::RunOnDevice() {
...
@@ -43,13 +43,8 @@ void BiasAddOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -102,13 +97,8 @@ void BiasAddGradientOp<Context>::RunOnDevice() {
...
@@ -102,13 +97,8 @@ void BiasAddGradientOp<Context>::RunOnDevice() {
Y
(
1
)
->
ReshapeLike
(
X
(
0
));
Y
(
1
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
-
1
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
-
1
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
-
1
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
BiasAdd
);
DEPLOY_CPU
(
BiasAdd
);
...
...
Dragon/src/operators/vision/bilinear_resize_op.cc
View file @
d1f714e
...
@@ -48,14 +48,9 @@ void BilinearResizeOp<Context>::RunOnDevice() {
...
@@ -48,14 +48,9 @@ void BilinearResizeOp<Context>::RunOnDevice() {
}
}
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -90,14 +85,9 @@ void BilinearResizeGradientOp<Context>::RunImpl() {
...
@@ -90,14 +85,9 @@ void BilinearResizeGradientOp<Context>::RunImpl() {
template
<
class
Context
>
template
<
class
Context
>
void
BilinearResizeGradientOp
<
Context
>::
RunOnDevice
()
{
void
BilinearResizeGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
BilinearResize
);
DEPLOY_CPU
(
BilinearResize
);
...
...
Dragon/src/operators/vision/conv2d_op.cc
View file @
d1f714e
...
@@ -27,13 +27,8 @@ void Conv2dOp<Context>::RunOnDevice() {
...
@@ -27,13 +27,8 @@ void Conv2dOp<Context>::RunOnDevice() {
if
(
data_format
()
==
"NHWC"
&&
group_
!=
1
)
if
(
data_format
()
==
"NHWC"
&&
group_
!=
1
)
LOG
(
FATAL
)
<<
"GroupConv(NHWC) is not supported."
;
LOG
(
FATAL
)
<<
"GroupConv(NHWC) is not supported."
;
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -66,13 +61,8 @@ void Conv2dGradientOp<Context>::RunOnDevice() {
...
@@ -66,13 +61,8 @@ void Conv2dGradientOp<Context>::RunOnDevice() {
if
(
data_format
()
==
"NHWC"
&&
group_
!=
1
)
if
(
data_format
()
==
"NHWC"
&&
group_
!=
1
)
LOG
(
FATAL
)
<<
"GroupConv(NHWC) is not supported."
;
LOG
(
FATAL
)
<<
"GroupConv(NHWC) is not supported."
;
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
Conv2d
);
DEPLOY_CPU
(
Conv2d
);
...
...
Dragon/src/operators/vision/conv2d_transpose_op.cc
View file @
d1f714e
...
@@ -32,13 +32,8 @@ void ConvTranspose2dOp<Context>::RunOnDevice() {
...
@@ -32,13 +32,8 @@ void ConvTranspose2dOp<Context>::RunOnDevice() {
for
(
int
i
=
0
;
i
<
num_axes_
;
i
++
)
for
(
int
i
=
0
;
i
<
num_axes_
;
i
++
)
out_shape_
[
i
]
=
X
(
0
).
dim
(
axis_
+
i
);
out_shape_
[
i
]
=
X
(
0
).
dim
(
axis_
+
i
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -76,13 +71,8 @@ void ConvTranspose2dGradientOp<Context>::RunOnDevice() {
...
@@ -76,13 +71,8 @@ void ConvTranspose2dGradientOp<Context>::RunOnDevice() {
for
(
int
i
=
0
;
i
<
num_axes_
;
i
++
)
for
(
int
i
=
0
;
i
<
num_axes_
;
i
++
)
out_shape_
[
i
]
=
X
(
0
).
dim
(
axis_
+
i
);
out_shape_
[
i
]
=
X
(
0
).
dim
(
axis_
+
i
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
ConvTranspose2d
);
DEPLOY_CPU
(
ConvTranspose2d
);
...
...
Dragon/src/operators/vision/cudnn_bias_add_op.cc
View file @
d1f714e
...
@@ -62,15 +62,8 @@ void CuDNNBiasAddOp<Context>::RunOnDevice() {
...
@@ -62,15 +62,8 @@ void CuDNNBiasAddOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -129,15 +122,8 @@ void CuDNNBiasAddGradientOp<Context>::RunOnDevice() {
...
@@ -129,15 +122,8 @@ void CuDNNBiasAddGradientOp<Context>::RunOnDevice() {
Y
(
1
)
->
ReshapeLike
(
X
(
0
));
Y
(
1
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
-
1
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
-
1
));
}
else
if
(
XIsType
(
X
(
-
1
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
-
1
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
BiasAdd
);
DEPLOY_CUDNN
(
BiasAdd
);
...
...
Dragon/src/operators/vision/cudnn_conv2d_op.cc
View file @
d1f714e
...
@@ -216,15 +216,8 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
...
@@ -216,15 +216,8 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
#endif
#endif
ConvOpBase
<
Context
>::
Reshape
();
ConvOpBase
<
Context
>::
Reshape
();
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
class
Context
>
...
@@ -474,15 +467,8 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
...
@@ -474,15 +467,8 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
#endif
#endif
ConvOpBase
<
Context
>::
Reshape
(
true
);
ConvOpBase
<
Context
>::
Reshape
(
true
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Conv2d
);
DEPLOY_CUDNN
(
Conv2d
);
...
...
Dragon/src/operators/vision/cudnn_conv2d_transpose_op.cc
View file @
d1f714e
...
@@ -214,15 +214,8 @@ void CuDNNConvTranspose2dOp<Context>::RunOnDevice() {
...
@@ -214,15 +214,8 @@ void CuDNNConvTranspose2dOp<Context>::RunOnDevice() {
#endif
#endif
ConvOpBase
<
Context
>::
Reshape
();
ConvOpBase
<
Context
>::
Reshape
();
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
ConvTranspose2d
);
DEPLOY_CUDNN
(
ConvTranspose2d
);
...
@@ -471,15 +464,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::RunOnDevice() {
...
@@ -471,15 +464,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::RunOnDevice() {
#endif
#endif
ConvOpBase
<
Context
>::
Reshape
(
true
);
ConvOpBase
<
Context
>::
Reshape
(
true
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
ConvTranspose2dGradient
);
DEPLOY_CUDNN
(
ConvTranspose2dGradient
);
...
...
Dragon/src/operators/vision/cudnn_depthwise_conv2d_op.cc
View file @
d1f714e
...
@@ -68,13 +68,8 @@ void CuDNNDepthwiseConv2dOp<Context>::RunOnDevice() {
...
@@ -68,13 +68,8 @@ void CuDNNDepthwiseConv2dOp<Context>::RunOnDevice() {
<<
"
\n
Excepted in/out channels unchanged."
;
<<
"
\n
Excepted in/out channels unchanged."
;
ConvOpBase
<
Context
>::
Reshape
();
ConvOpBase
<
Context
>::
Reshape
();
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -152,13 +147,8 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::RunOnDevice() {
...
@@ -152,13 +147,8 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::RunOnDevice() {
==
"NCHW"
?
X
(
0
).
dim
(
1
)
:
X
(
0
).
dim
(
-
1
);
==
"NCHW"
?
X
(
0
).
dim
(
1
)
:
X
(
0
).
dim
(
-
1
);
ConvOpBase
<
Context
>::
Reshape
(
true
);
ConvOpBase
<
Context
>::
Reshape
(
true
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CUDNN
(
DepthwiseConv2d
);
DEPLOY_CUDNN
(
DepthwiseConv2d
);
...
...
Dragon/src/operators/vision/cudnn_lrn_op.cc
View file @
d1f714e
...
@@ -14,10 +14,13 @@ void CuDNNLRNOp<Context>::RunImpl() {
...
@@ -14,10 +14,13 @@ void CuDNNLRNOp<Context>::RunImpl() {
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
CUDNN_CHECK
(
cudnnLRNCrossChannelForward
(
CUDNN_CHECK
(
cudnnLRNCrossChannelForward
(
ctx
()
->
cudnn_handle
(),
lrn_desc_
,
ctx
()
->
cudnn_handle
(),
lrn_desc_
,
CUDNN_LRN_CROSS_CHANNEL_DIM1
,
CUDNN_LRN_CROSS_CHANNEL_DIM1
,
CuDNNType
<
T
>::
one
,
input_desc_
,
x
,
CuDNNType
<
T
>::
one
,
CuDNNType
<
T
>::
zero
,
output_desc_
,
y
input_desc_
,
x
,
CuDNNType
<
T
>::
zero
,
output_desc_
,
y
));
));
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown DataFormat: "
<<
data_format
();
LOG
(
FATAL
)
<<
"Unknown DataFormat: "
<<
data_format
();
...
@@ -29,15 +32,8 @@ void CuDNNLRNOp<Context>::RunOnDevice() {
...
@@ -29,15 +32,8 @@ void CuDNNLRNOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
this
->
mode_
==
"ACROSS_CHANNELS"
)
{
if
(
this
->
mode_
==
"ACROSS_CHANNELS"
)
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
else
if
(
this
->
mode_
==
"WITHIN_CHANNEL"
)
{
}
else
if
(
this
->
mode_
==
"WITHIN_CHANNEL"
)
{
LRNOp
<
Context
>::
RunOnDevice
();
LRNOp
<
Context
>::
RunOnDevice
();
}
else
{
}
else
{
...
@@ -57,7 +53,8 @@ void CuDNNLRNGradientOp<Context>::RunImpl() {
...
@@ -57,7 +53,8 @@ void CuDNNLRNGradientOp<Context>::RunImpl() {
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
CUDNN_CHECK
(
cudnnLRNCrossChannelBackward
(
CUDNN_CHECK
(
cudnnLRNCrossChannelBackward
(
ctx
()
->
cudnn_handle
(),
lrn_desc_
,
ctx
()
->
cudnn_handle
(),
lrn_desc_
,
CUDNN_LRN_CROSS_CHANNEL_DIM1
,
CUDNN_LRN_CROSS_CHANNEL_DIM1
,
CuDNNType
<
T
>::
one
,
CuDNNType
<
T
>::
one
,
input_desc_
,
y
,
input_desc_
,
y
,
...
@@ -76,15 +73,8 @@ void CuDNNLRNGradientOp<Context>::RunOnDevice() {
...
@@ -76,15 +73,8 @@ void CuDNNLRNGradientOp<Context>::RunOnDevice() {
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
this
->
mode_
==
"ACROSS_CHANNELS"
)
{
if
(
this
->
mode_
==
"ACROSS_CHANNELS"
)
{
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
else
if
(
this
->
mode_
==
"WITHIN_CHANNEL"
)
{
}
else
if
(
this
->
mode_
==
"WITHIN_CHANNEL"
)
{
LRNGradientOp
<
Context
>::
RunOnDevice
();
LRNGradientOp
<
Context
>::
RunOnDevice
();
}
else
{
}
else
{
...
...
Dragon/src/operators/vision/cudnn_pool2d_op.cc
View file @
d1f714e
...
@@ -39,9 +39,12 @@ void CuDNNPool2dOp<Context>::RunImpl() {
...
@@ -39,9 +39,12 @@ void CuDNNPool2dOp<Context>::RunImpl() {
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
CUDNN_CHECK
(
cudnnPoolingForward
(
CUDNN_CHECK
(
cudnnPoolingForward
(
ctx
()
->
cudnn_handle
(),
pool_desc_
,
ctx
()
->
cudnn_handle
(),
CuDNNType
<
T
>::
one
,
input_desc_
,
x
,
pool_desc_
,
CuDNNType
<
T
>::
zero
,
output_desc_
,
y
CuDNNType
<
T
>::
one
,
input_desc_
,
x
,
CuDNNType
<
T
>::
zero
,
output_desc_
,
y
));
));
}
}
...
@@ -49,15 +52,8 @@ template <class Context>
...
@@ -49,15 +52,8 @@ template <class Context>
void
CuDNNPool2dOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNPool2dOp
<
Context
>::
RunOnDevice
()
{
Pool2dOp
<
Context
>::
Reshape
();
Pool2dOp
<
Context
>::
Reshape
();
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -97,7 +93,8 @@ void CuDNNPool2dGradientOp<Context>::RunImpl() {
...
@@ -97,7 +93,8 @@ void CuDNNPool2dGradientOp<Context>::RunImpl() {
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
CUDNN_CHECK
(
cudnnPoolingBackward
(
CUDNN_CHECK
(
cudnnPoolingBackward
(
ctx
()
->
cudnn_handle
(),
pool_desc_
,
ctx
()
->
cudnn_handle
(),
pool_desc_
,
CuDNNType
<
T
>::
one
,
CuDNNType
<
T
>::
one
,
input_desc_
,
y
,
input_desc_
,
y
,
input_desc_
,
dy
,
input_desc_
,
dy
,
...
@@ -111,15 +108,8 @@ template <class Context>
...
@@ -111,15 +108,8 @@ template <class Context>
void
CuDNNPool2dGradientOp
<
Context
>::
RunOnDevice
()
{
void
CuDNNPool2dGradientOp
<
Context
>::
RunOnDevice
()
{
Pool2dGradientOp
<
Context
>::
Reshape
();
Pool2dGradientOp
<
Context
>::
Reshape
();
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CUDNN
(
Pool2d
);
DEPLOY_CUDNN
(
Pool2d
);
...
...
Dragon/src/operators/vision/depthwise_conv2d_op.cc
View file @
d1f714e
...
@@ -40,13 +40,8 @@ void DepthwiseConv2dOp<Context>::RunOnDevice() {
...
@@ -40,13 +40,8 @@ void DepthwiseConv2dOp<Context>::RunOnDevice() {
<<
"
\n
Excepted in/out channels unchanged."
;
<<
"
\n
Excepted in/out channels unchanged."
;
ConvOpBase
<
Context
>::
Reshape
();
ConvOpBase
<
Context
>::
Reshape
();
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -98,13 +93,8 @@ void DepthwiseConv2dGradientOp<Context>::RunOnDevice() {
...
@@ -98,13 +93,8 @@ void DepthwiseConv2dGradientOp<Context>::RunOnDevice() {
X
(
0
).
dim
(
1
)
:
X
(
0
).
dim
(
-
1
);
X
(
0
).
dim
(
1
)
:
X
(
0
).
dim
(
-
1
);
ConvOpBase
<
Context
>::
Reshape
(
true
);
ConvOpBase
<
Context
>::
Reshape
(
true
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
>>::
Call
(
this
,
X
(
0
));
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
}
);
}
}
}
DEPLOY_CPU
(
DepthwiseConv2d
);
DEPLOY_CPU
(
DepthwiseConv2d
);
...
...
Dragon/src/operators/vision/drop_block2d_op.cc
View file @
d1f714e
...
@@ -102,15 +102,8 @@ template <class Context>
...
@@ -102,15 +102,8 @@ template <class Context>
void
DropBlock2dOp
<
Context
>::
RunOnDevice
()
{
void
DropBlock2dOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -144,15 +137,8 @@ template <class Context>
...
@@ -144,15 +137,8 @@ template <class Context>
void
DropBlock2dGradientOp
<
Context
>::
RunOnDevice
()
{
void
DropBlock2dGradientOp
<
Context
>::
RunOnDevice
()
{
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
DEPLOY_CPU
(
DropBlock2d
);
DEPLOY_CPU
(
DropBlock2d
);
...
...
Dragon/src/operators/vision/nn_resize_op.cc
View file @
d1f714e
...
@@ -49,15 +49,8 @@ void NNResizeOp<Context>::RunOnDevice() {
...
@@ -49,15 +49,8 @@ void NNResizeOp<Context>::RunOnDevice() {
Y
(
0
)
->
Reshape
(
out_shape
);
Y
(
0
)
->
Reshape
(
out_shape
);
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
}
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
...
Dragon/src/operators/vision/roi_align_op.cc
View file @
d1f714e
...
@@ -33,15 +33,8 @@ void ROIAlignOp<Context>::RunOnDevice() {
...
@@ -33,15 +33,8 @@ void ROIAlignOp<Context>::RunOnDevice() {
pool_w_
/* feature_w */
pool_w_
/* feature_w */
});
});
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
};
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
...
Dragon/src/operators/vision/roi_pool_op.cc
View file @
d1f714e
...
@@ -37,15 +37,8 @@ void ROIPoolOp<Context>::RunOnDevice() {
...
@@ -37,15 +37,8 @@ void ROIPoolOp<Context>::RunOnDevice() {
pool_w_
/* feature_w */
pool_w_
/* feature_w */
});
});
if
(
XIsType
(
X
(
0
),
float
))
{
DispatchHelper
<
TensorTypes
RunImpl
<
float
>
();
<
float
,
float16
>>::
Call
(
this
,
X
(
0
));
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"float32"
,
"float16"
}
);
};
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
...
Dragon/src/utils/math_functions.cu
View file @
d1f714e
...
@@ -111,8 +111,8 @@ DEFINE_COPY_FUNC(double);
...
@@ -111,8 +111,8 @@ DEFINE_COPY_FUNC(double);
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, x, y \
n, x, y \
); \
); \
}
}
...
@@ -245,8 +245,8 @@ __global__ void _AddScalar(
...
@@ -245,8 +245,8 @@ __global__ void _AddScalar(
sizeof(T) * n, ctx->cuda_stream())); \
sizeof(T) * n, ctx->cuda_stream())); \
} else { \
} else { \
_Set \
_Set \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, alpha, y \
n, alpha, y \
); \
); \
} \
} \
...
@@ -273,15 +273,15 @@ DEFINE_SET_FUNC(double);
...
@@ -273,15 +273,15 @@ DEFINE_SET_FUNC(double);
if (type == 0) { \
if (type == 0) { \
/*! Row - BroadcastX */ \
/*! Row - BroadcastX */ \
_RowBroadcastSet \
_RowBroadcastSet \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, x, y \
n, cols, x, y \
); \
); \
} else if (type == 1) { \
} else if (type == 1) { \
/*! Col - BroadcastX */ \
/*! Col - BroadcastX */ \
_ColBroadcastSet \
_ColBroadcastSet \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, x, y \
n, cols, x, y \
); \
); \
} \
} \
...
@@ -307,8 +307,8 @@ DEFINE_BROADCAST_SET_FUNC(double);
...
@@ -307,8 +307,8 @@ DEFINE_BROADCAST_SET_FUNC(double);
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Pow \
_Pow \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cast::to<T>(exp), x, y \
n, cast::to<T>(exp), x, y \
); \
); \
}
}
...
@@ -337,8 +337,8 @@ DEFINE_POWX_FUNC(double);
...
@@ -337,8 +337,8 @@ DEFINE_POWX_FUNC(double);
} return; \
} return; \
} \
} \
_Scale \
_Scale \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, _alpha_, x, y \
n, _alpha_, x, y \
); \
); \
}
}
...
@@ -386,8 +386,8 @@ DEFINE_CUBLAS_SCALE_FUNC(double, cublasDscal_v2);
...
@@ -386,8 +386,8 @@ DEFINE_CUBLAS_SCALE_FUNC(double, cublasDscal_v2);
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Axpy \
_Axpy \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cast::to<T>(alpha), x, y \
n, cast::to<T>(alpha), x, y \
); \
); \
}
}
...
@@ -434,8 +434,8 @@ DEFINE_AXPY_FUNC(int64_t);
...
@@ -434,8 +434,8 @@ DEFINE_AXPY_FUNC(int64_t);
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Axpby \
_Axpby \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, \
n, \
cast::to<T>(alpha), x, \
cast::to<T>(alpha), x, \
cast::to<T>(beta), y \
cast::to<T>(beta), y \
...
@@ -461,8 +461,8 @@ DEFINE_AXPBY_FUNC(double);
...
@@ -461,8 +461,8 @@ DEFINE_AXPBY_FUNC(double);
T _alpha_ = (T)alpha; \
T _alpha_ = (T)alpha; \
if (_alpha_ == T(0)) return; \
if (_alpha_ == T(0)) return; \
_AddScalar \
_AddScalar \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, _alpha_, y \
n, _alpha_, y \
); \
); \
}
}
...
@@ -506,8 +506,8 @@ __global__ void _InvStd(
...
@@ -506,8 +506,8 @@ __global__ void _InvStd(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_InvStd \
_InvStd \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cast::to<T>(eps), x, y \
n, cast::to<T>(eps), x, y \
); \
); \
}
}
...
@@ -598,8 +598,8 @@ template <> float ASum<float, CUDAContext>(
...
@@ -598,8 +598,8 @@ template <> float ASum<float, CUDAContext>(
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_##name \
_##name \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, a, b, y \
n, a, b, y \
); \
); \
}
}
...
@@ -829,29 +829,29 @@ __global__ void _ColBroadcastDiv(
...
@@ -829,29 +829,29 @@ __global__ void _ColBroadcastDiv(
if (type == 0) { \
if (type == 0) { \
/*! Row - BroadcastB */ \
/*! Row - BroadcastB */ \
_RowBroadcast##name<T, false> \
_RowBroadcast##name<T, false> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, a, b, y \
n, cols, a, b, y \
); \
); \
} else if (type == 1) { \
} else if (type == 1) { \
/*! Col - BroadcastB */ \
/*! Col - BroadcastB */ \
_ColBroadcast##name<T, false> \
_ColBroadcast##name<T, false> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, a, b, y \
n, cols, a, b, y \
); \
); \
} else if (type == 2) { \
} else if (type == 2) { \
/*! Row - BroadcastA */ \
/*! Row - BroadcastA */ \
_RowBroadcast##name<T, true> \
_RowBroadcast##name<T, true> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, a, b, y \
n, cols, a, b, y \
); \
); \
} else if (type == 3) { \
} else if (type == 3) { \
/*! Col - BroadcastA */ \
/*! Col - BroadcastA */ \
_ColBroadcast##name<T, true> \
_ColBroadcast##name<T, true> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, a, b, y \
n, cols, a, b, y \
); \
); \
} else { \
} else { \
...
...
Dragon/src/utils/math_functions.fp16.cu
View file @
d1f714e
...
@@ -48,16 +48,16 @@ template <> void Exp<float16, CUDAContext>(
...
@@ -48,16 +48,16 @@ template <> void Exp<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_ExpHalf2
_ExpHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_ExpHalf
_ExpHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -94,16 +94,16 @@ template <> void Log<float16, CUDAContext>(
...
@@ -94,16 +94,16 @@ template <> void Log<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_LogHalf2
_LogHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_LogHalf
_LogHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -140,16 +140,16 @@ template <> void Inv<float16, CUDAContext>(
...
@@ -140,16 +140,16 @@ template <> void Inv<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_InvHalf2
_InvHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_InvHalf
_InvHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -186,16 +186,16 @@ template <> void Sqrt<float16, CUDAContext>(
...
@@ -186,16 +186,16 @@ template <> void Sqrt<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_SqrtHalf2
_SqrtHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_SqrtHalf
_SqrtHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -232,16 +232,16 @@ template <> void RSqrt<float16, CUDAContext>(
...
@@ -232,16 +232,16 @@ template <> void RSqrt<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_RSqrtHalf2
_RSqrtHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_RSqrtHalf
_RSqrtHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -278,16 +278,16 @@ template <> void Square<float16, CUDAContext>(
...
@@ -278,16 +278,16 @@ template <> void Square<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_SquareHalf2
_SquareHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_SquareHalf
_SquareHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -330,16 +330,16 @@ template <> void Set<float16, CUDAContext>(
...
@@ -330,16 +330,16 @@ template <> void Set<float16, CUDAContext>(
}
}
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_SetHalf<half2>
_SetHalf<half2>
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
cast::to<half2>(alpha),
cast::to<half2>(alpha),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_SetHalf<float16>
_SetHalf<float16>
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n, alpha, y
n, alpha, y
);
);
}
}
...
@@ -380,8 +380,8 @@ template <> void Pow<float16, CUDAContext>(
...
@@ -380,8 +380,8 @@ template <> void Pow<float16, CUDAContext>(
CHECK(alpha == 2.f) << "\nRequired power = 2";
CHECK(alpha == 2.f) << "\nRequired power = 2";
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_PowHalf2
_PowHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
alpha,
alpha,
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
...
@@ -389,8 +389,8 @@ template <> void Pow<float16, CUDAContext>(
...
@@ -389,8 +389,8 @@ template <> void Pow<float16, CUDAContext>(
);
);
} else {
} else {
_PowHalf
_PowHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
alpha,
alpha,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
...
@@ -487,16 +487,16 @@ template <> void AddScalar<float16, CUDAContext>(
...
@@ -487,16 +487,16 @@ template <> void AddScalar<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_AddScalarHalf2
_AddScalarHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
cast::to<half2>(alpha),
cast::to<half2>(alpha),
reinterpret_cast<half2*>(y)
reinterpret_cast<half2*>(y)
);
);
} else {
} else {
_AddScalarHalf
_AddScalarHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
cast::to<half>(alpha),
cast::to<half>(alpha),
reinterpret_cast<half*>(y)
reinterpret_cast<half*>(y)
...
@@ -546,8 +546,8 @@ template <> void InvStd<float16, CUDAContext>(
...
@@ -546,8 +546,8 @@ template <> void InvStd<float16, CUDAContext>(
CUDAContext* ctx) {
CUDAContext* ctx) {
if ((n & 1) == 0) {
if ((n & 1) == 0) {
_InvStdHalf2
_InvStdHalf2
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n >> 1,
n >> 1,
cast::to<half2>(eps),
cast::to<half2>(eps),
reinterpret_cast<const half2*>(x),
reinterpret_cast<const half2*>(x),
...
@@ -555,8 +555,8 @@ template <> void InvStd<float16, CUDAContext>(
...
@@ -555,8 +555,8 @@ template <> void InvStd<float16, CUDAContext>(
);
);
} else {
} else {
_InvStdHalf
_InvStdHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
cast::to<half>(eps),
cast::to<half>(eps),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(x),
...
@@ -668,8 +668,8 @@ __global__ void _DivHalf(
...
@@ -668,8 +668,8 @@ __global__ void _DivHalf(
CUDAContext* ctx) { \
CUDAContext* ctx) { \
if ((n & 1) == 0) { \
if ((n & 1) == 0) { \
_##name##Half2 \
_##name##Half2 \
<<
< CUDA_BLOCKS(n >> 1), CUDA_THREADS, \
<<< CUDA_BLOCKS(n >> 1), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n >> 1, \
n >> 1, \
reinterpret_cast<const half2*>(a), \
reinterpret_cast<const half2*>(a), \
reinterpret_cast<const half2*>(b), \
reinterpret_cast<const half2*>(b), \
...
@@ -677,8 +677,8 @@ __global__ void _DivHalf(
...
@@ -677,8 +677,8 @@ __global__ void _DivHalf(
); \
); \
} else { \
} else { \
_##name##Half \
_##name##Half \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, \
n, \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<const half*>(b), \
...
@@ -699,8 +699,8 @@ template <> void Div<float16, CUDAContext>(
...
@@ -699,8 +699,8 @@ template <> void Div<float16, CUDAContext>(
float16* y,
float16* y,
CUDAContext* ctx) {
CUDAContext* ctx) {
_DivHalf
_DivHalf
<<
< CUDA_BLOCKS(n), CUDA_THREADS,
<<< CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >>
>(
0, ctx->cuda_stream() >>
>(
n,
n,
reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(b),
reinterpret_cast<const half*>(b),
...
@@ -884,8 +884,8 @@ __global__ void _ColBroadcastDivHalf(
...
@@ -884,8 +884,8 @@ __global__ void _ColBroadcastDivHalf(
if (type == 0) { \
if (type == 0) { \
/*! Row - BroadcastB */ \
/*! Row - BroadcastB */ \
_RowBroadcast##name##Half<false> \
_RowBroadcast##name##Half<false> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, \
n, cols, \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<const half*>(b), \
...
@@ -894,8 +894,8 @@ __global__ void _ColBroadcastDivHalf(
...
@@ -894,8 +894,8 @@ __global__ void _ColBroadcastDivHalf(
} else if (type == 1) { \
} else if (type == 1) { \
/*! Col - BroadcastB */ \
/*! Col - BroadcastB */ \
_ColBroadcast##name##Half<false> \
_ColBroadcast##name##Half<false> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, \
n, cols, \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<const half*>(b), \
...
@@ -904,8 +904,8 @@ __global__ void _ColBroadcastDivHalf(
...
@@ -904,8 +904,8 @@ __global__ void _ColBroadcastDivHalf(
} else if (type == 2) { \
} else if (type == 2) { \
/*! Row - BroadcastA */ \
/*! Row - BroadcastA */ \
_RowBroadcast##name##Half<true> \
_RowBroadcast##name##Half<true> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, \
n, cols, \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<const half*>(b), \
...
@@ -914,8 +914,8 @@ __global__ void _ColBroadcastDivHalf(
...
@@ -914,8 +914,8 @@ __global__ void _ColBroadcastDivHalf(
} else if (type == 3) { \
} else if (type == 3) { \
/*! Col - BroadcastA */ \
/*! Col - BroadcastA */ \
_ColBroadcast##name##Half<true> \
_ColBroadcast##name##Half<true> \
<<
< CUDA_BLOCKS(n), CUDA_THREADS, \
<<< CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >>
>( \
0, ctx->cuda_stream() >>
>( \
n, cols, \
n, cols, \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<const half*>(b), \
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment