Commit 5cd0761b by Ting PAN

Unlock CUDA Async Streams

1 parent 3b990761
Showing with 3330 additions and 2487 deletions
...@@ -52,9 +52,9 @@ using Set = std::unordered_set<Value> ; ...@@ -52,9 +52,9 @@ using Set = std::unordered_set<Value> ;
/* /*
* Define the Kernel version. * Define the Kernel version.
* *
* | Major(2) | Minor(2) | Patch(10) | * | Major(2) | Minor(2) | Patch(11) |
*/ */
#define DRAGON_VERSION 2210 #define DRAGON_VERSION 2211
/* /*
* Define the default random seed. * Define the default random seed.
......
...@@ -34,6 +34,8 @@ class CPUContext { ...@@ -34,6 +34,8 @@ class CPUContext {
virtual ~CPUContext() {} virtual ~CPUContext() {}
inline void SwitchToDevice() {} inline void SwitchToDevice() {}
inline void SwitchToDevice(int stream_id) {}
inline void FinishDeviceCompution() {} inline void FinishDeviceCompution() {}
inline static void* New(size_t nbytes) { inline static void* New(size_t nbytes) {
...@@ -47,7 +49,15 @@ class CPUContext { ...@@ -47,7 +49,15 @@ class CPUContext {
return data; return data;
} }
inline static void Memset(size_t nbytes, void* ptr) { inline static void Memset(
size_t nbytes,
void* ptr) {
memset(ptr, 0, nbytes);
}
inline void MemsetAsync(
size_t nbytes,
void* ptr) {
memset(ptr, 0, nbytes); memset(ptr, 0, nbytes);
} }
...@@ -59,18 +69,16 @@ class CPUContext { ...@@ -59,18 +69,16 @@ class CPUContext {
memcpy(dst, src, nbytes); memcpy(dst, src, nbytes);
} }
inline static void Delete(void* data) { free(data); }
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
inline static void MemcpyAsync( inline void MemcpyAsync(
size_t nbytes, size_t nbytes,
void* dst, void* dst,
const void* src) { const void* src) {
NOT_IMPLEMENTED; memcpy(dst, src, nbytes);
} }
template<typename T, class DstContext, class SrcContext> template<typename T, class DstContext, class SrcContext>
inline static void Copy( inline void Copy(
int n, int n,
T* dst, T* dst,
const T* src) { const T* src) {
...@@ -82,7 +90,10 @@ class CPUContext { ...@@ -82,7 +90,10 @@ class CPUContext {
else for (int i = 0; i < n; i++) dst[i] = src[i]; else for (int i = 0; i < n; i++) dst[i] = src[i];
} }
inline static void Delete(void* data) { free(data); }
inline int device_id() const { return 0; } inline int device_id() const { return 0; }
inline void set_stream_id(int stream_id) {}
inline std::mt19937* rand_generator() { inline std::mt19937* rand_generator() {
if (!rand_generator_.get()) if (!rand_generator_.get())
......
...@@ -23,8 +23,7 @@ namespace dragon { ...@@ -23,8 +23,7 @@ namespace dragon {
class CUDAObject { class CUDAObject {
public: public:
CUDAObject(int default_stream = 1) CUDAObject() {
: default_stream(default_stream) {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
cuda_streams[i] = vector<cudaStream_t>(); cuda_streams[i] = vector<cudaStream_t>();
cublas_handles[i] = vector<cublasHandle_t>(); cublas_handles[i] = vector<cublasHandle_t>();
...@@ -38,7 +37,7 @@ class CUDAObject { ...@@ -38,7 +37,7 @@ class CUDAObject {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
for (int j = 0; j < cuda_streams[i].size(); j++) { for (int j = 0; j < cuda_streams[i].size(); j++) {
auto& stream = cuda_streams[i][j]; auto& stream = cuda_streams[i][j];
// follow caffe2, do not check the stream destroying // follow the caffe2, do not check the stream destroying
// Error code 29 (driver shutting down) is inevitable // Error code 29 (driver shutting down) is inevitable
// TODO(PhyscalX): Can someone solve this issue? // TODO(PhyscalX): Can someone solve this issue?
if (stream) cudaStreamDestroy(stream); if (stream) cudaStreamDestroy(stream);
...@@ -52,19 +51,21 @@ class CUDAObject { ...@@ -52,19 +51,21 @@ class CUDAObject {
} }
} }
/** // follow the caffe2,
* Each device takes a group of streams. // each device takes a group of non-bl0cking streams
* // the stream 0 is reserved for default stream,
* The stream 0 is reserved for default stream, // as some computations really require it,
* stream 1 or higher is created as ``cudaStreamNonBlocking``. // e.g. cublas.asum() and mixed cpu/cuda operations
*/ // besides, somes calls, such as cudnn.conv() and cudnn.rnn(),
// produce wrong results if running them on non-blocking streams
// note that caffe2 also use default streams (within CuDNNState)
cudaStream_t GetStream(int device_id, int stream_id) { cudaStream_t GetStream(int device_id, int stream_id) {
vector<cudaStream_t>& dev_streams = cuda_streams[device_id]; vector<cudaStream_t>& dev_streams = cuda_streams[device_id];
if (dev_streams.size() <= (unsigned)stream_id) if (dev_streams.size() <= (unsigned)stream_id)
dev_streams.resize(stream_id + 1, nullptr); dev_streams.resize(stream_id + 1, nullptr);
if (!dev_streams[stream_id]) { if (!dev_streams[stream_id]) {
DeviceGuard guard(device_id); DeviceGuard guard(device_id);
unsigned int flags = !stream_id && default_stream ? unsigned int flags = !stream_id ?
cudaStreamDefault : cudaStreamNonBlocking; cudaStreamDefault : cudaStreamNonBlocking;
CUDA_CHECK(cudaStreamCreateWithFlags( CUDA_CHECK(cudaStreamCreateWithFlags(
&dev_streams[stream_id], flags)); &dev_streams[stream_id], flags));
...@@ -102,8 +103,6 @@ class CUDAObject { ...@@ -102,8 +103,6 @@ class CUDAObject {
} }
#endif #endif
int default_stream;
vector<cudaStream_t> cuda_streams[CUDA_MAX_DEVICES]; vector<cudaStream_t> cuda_streams[CUDA_MAX_DEVICES];
vector<cublasHandle_t> cublas_handles[CUDA_MAX_DEVICES]; vector<cublasHandle_t> cublas_handles[CUDA_MAX_DEVICES];
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -129,11 +128,10 @@ class CUDAContext { ...@@ -129,11 +128,10 @@ class CUDAContext {
stream_id_ = stream_id; stream_id_ = stream_id;
} }
inline void SwitchToDevice() { SwitchToDevice(0); } inline void SwitchToDevice() { SwitchToDevice(1); }
inline void FinishDeviceCompution() { inline void FinishDeviceCompution() {
cudaStreamSynchronize(cuda_object_ cudaStreamSynchronize(cuda_stream());
.GetStream(device_id_, stream_id_));
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
CHECK_EQ(error, cudaSuccess) CHECK_EQ(error, cudaSuccess)
<< "\nCUDA Error: " << cudaGetErrorString(error); << "\nCUDA Error: " << cudaGetErrorString(error);
...@@ -147,8 +145,17 @@ class CUDAContext { ...@@ -147,8 +145,17 @@ class CUDAContext {
return data; return data;
} }
inline static void Memset(size_t nbytes, void* ptr) { inline static void Memset(
cudaMemset(ptr, 0, nbytes); size_t nbytes,
void* ptr) {
CUDA_CHECK(cudaMemset(ptr, 0, nbytes));
}
inline void MemsetAsync(
size_t nbytes,
void* ptr) {
CUDA_CHECK(cudaMemsetAsync(ptr, 0,
nbytes, cuda_stream()));
} }
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
...@@ -169,20 +176,22 @@ class CUDAContext { ...@@ -169,20 +176,22 @@ class CUDAContext {
cudaMemcpyDefault, cuda_stream())); cudaMemcpyDefault, cuda_stream()));
} }
inline static void Delete(void* data) { cudaFree(data); }
template<typename T, class DstContext, class SrcContext> template<typename T, class DstContext, class SrcContext>
static void Copy( inline void Copy(
int n, int n,
T* dst, T* dst,
const T* src) { const T* src) {
if (dst == src) return; if (dst == src) return;
Memcpy<SrcContext, DstContext>( MemcpyAsync<SrcContext, DstContext>(
n * sizeof(T), (void*)dst, (const void*)src); n * sizeof(T), (void*)dst, (const void*)src);
} }
inline static void Delete(void* data) { cudaFree(data); }
inline int device_id() const { return device_id_; } inline int device_id() const { return device_id_; }
inline void set_stream_id(int stream_id) { stream_id_ = stream_id; }
inline cudaStream_t cuda_stream() { inline cudaStream_t cuda_stream() {
return cuda_stream(device_id_, stream_id_); return cuda_stream(device_id_, stream_id_);
} }
...@@ -227,7 +236,7 @@ class CUDAContext { ...@@ -227,7 +236,7 @@ class CUDAContext {
static thread_local CUDAObject cuda_object_; static thread_local CUDAObject cuda_object_;
private: private:
int device_id_, stream_id_ = 0, random_seed_; int device_id_, stream_id_ = 1, random_seed_;
unique_ptr<std::mt19937> rand_generator_; unique_ptr<std::mt19937> rand_generator_;
curandGenerator_t curand_generator_ = nullptr; curandGenerator_t curand_generator_ = nullptr;
}; };
...@@ -271,7 +280,7 @@ class CUDAClosure { ...@@ -271,7 +280,7 @@ class CUDAClosure {
protected: protected:
Context* ctx_; Context* ctx_;
CUDAObject cuda_object_ = 0; CUDAObject cuda_object_;
vector<int> active_streams_; vector<int> active_streams_;
}; };
...@@ -283,8 +292,22 @@ class CUDAContext { ...@@ -283,8 +292,22 @@ class CUDAContext {
CUDAContext(const int device_id = 0) { CUDA_NOT_COMPILED; } CUDAContext(const int device_id = 0) { CUDA_NOT_COMPILED; }
inline void SwitchToDevice() { CUDA_NOT_COMPILED; } inline void SwitchToDevice() { CUDA_NOT_COMPILED; }
inline void SwitchToDevice(int stream_id) { CUDA_NOT_COMPILED; }
inline void FinishDeviceCompution() { CUDA_NOT_COMPILED; } inline void FinishDeviceCompution() { CUDA_NOT_COMPILED; }
inline static void Memset(
size_t nbytes,
void* ptr) {
CUDA_NOT_COMPILED;
}
inline void MemsetAsync(
size_t nbytes,
void* ptr) {
CUDA_NOT_COMPILED;
}
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
inline static void Memcpy( inline static void Memcpy(
size_t nbytes, size_t nbytes,
...@@ -302,6 +325,7 @@ class CUDAContext { ...@@ -302,6 +325,7 @@ class CUDAContext {
} }
inline int device_id() const { return 0; } inline int device_id() const { return 0; }
inline void set_stream_id(int stream_id) {}
}; };
#endif // WITH_CUDA #endif // WITH_CUDA
......
...@@ -37,7 +37,8 @@ class GraphBase { ...@@ -37,7 +37,8 @@ class GraphBase {
virtual bool Run( virtual bool Run(
const string& include, const string& include,
const string& exclude) = 0; const string& exclude,
const int stream_id = 1) = 0;
inline string name() const { return name_; } inline string name() const { return name_; }
...@@ -58,7 +59,8 @@ class Graph final : public GraphBase { ...@@ -58,7 +59,8 @@ class Graph final : public GraphBase {
bool Run( bool Run(
const string& include, const string& include,
const string& exclude) override; const string& exclude,
const int stream_id = 1) override;
GraphDef Prune(const GraphDef& meta_graph); GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph); GraphDef MakeUpdate(const GraphDef& meta_graph);
......
...@@ -44,7 +44,7 @@ class OperatorBase { ...@@ -44,7 +44,7 @@ class OperatorBase {
const string& anchor); const string& anchor);
inline void SwitchToPhase(const string& phase) { phase_ = phase; } inline void SwitchToPhase(const string& phase) { phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; } virtual void Run(int stream_id = 1) { NOT_IMPLEMENTED; }
inline const string& name() const { return def_.name(); } inline const string& name() const { return def_.name(); }
inline const string& type() const { return def_.type(); } inline const string& type() const { return def_.type(); }
...@@ -100,13 +100,13 @@ class Operator : public OperatorBase { ...@@ -100,13 +100,13 @@ class Operator : public OperatorBase {
Output(0)->name() == "ignore")); Output(0)->name() == "ignore"));
} }
virtual void Run() final { void Run(int stream_id = 1) final {
if (!allow_run_) return; if (!allow_run_) return;
if (allow_recompute_) MakeResource(); if (allow_recompute_) MakeResource();
ctx().SwitchToDevice(); ctx()->SwitchToDevice(stream_id);
MemorySwitch(); MemorySwitch();
RunOnDevice(); RunOnDevice();
if (do_sync_) ctx().FinishDeviceCompution(); if (do_sync_) ctx()->FinishDeviceCompution();
if (allow_recompute_) CleanResource(); if (allow_recompute_) CleanResource();
} }
...@@ -123,7 +123,7 @@ class Operator : public OperatorBase { ...@@ -123,7 +123,7 @@ class Operator : public OperatorBase {
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
inline Context& ctx() { return ctx_; } inline Context* ctx() { return &ctx_; }
inline bool AllowRun() { return allow_run_; } inline bool AllowRun() { return allow_run_; }
protected: protected:
...@@ -192,6 +192,27 @@ DECLARE_REGISTRY( ...@@ -192,6 +192,27 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
Workspace*); Workspace*);
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it ?"; \
tensor.Reshape(shape); \
unique_ptr< Filler<type, Context> > filler( \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \
ctx()->FinishDeviceCompution(); \
} else { \
TIndex count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \
<< "\nModel request " << "Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect Tensor before ?"; \
tensor.Reshape(shape); \
}
#define TENSOR_FILL(tensor, shape) \ #define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \ CHECK(ws()->GetFiller(tensor.name())) \
...@@ -200,7 +221,8 @@ DECLARE_REGISTRY( ...@@ -200,7 +221,8 @@ DECLARE_REGISTRY(
tensor.Reshape(shape); \ tensor.Reshape(shape); \
unique_ptr< Filler<T, Context> > filler( \ unique_ptr< Filler<T, Context> > filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, &ctx()); \ filler->Fill(&tensor, ctx()); \
ctx()->FinishDeviceCompution(); \
} else { \ } else { \
TIndex count = 1; \ TIndex count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \ for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
...@@ -217,7 +239,7 @@ DECLARE_REGISTRY( ...@@ -217,7 +239,7 @@ DECLARE_REGISTRY(
if (size > ptr_tensor->count()) { \ if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape({ size }); \ ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \ math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
ptr_tensor->template mutable_data<T, Context>()); \ ptr_tensor->template mutable_data<T, Context>(), ctx()); \
} \ } \
} }
......
...@@ -74,7 +74,9 @@ class Tensor { ...@@ -74,7 +74,9 @@ class Tensor {
for (TIndex i = start; i < end; i++) ret *= dim(i); for (TIndex i = start; i < end; i++) ret *= dim(i);
return ret; return ret;
} }
inline TIndex count() const { return size_; } inline TIndex count() const { return size_; }
inline TIndex count(const TIndex start) const { inline TIndex count(const TIndex start) const {
return count(start, ndim()); return count(start, ndim());
} }
...@@ -197,7 +199,7 @@ class Tensor { ...@@ -197,7 +199,7 @@ class Tensor {
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
// call the constructors // call the constructors
if (meta.ctor()) meta_.ctor()(data_ptr, size_); if (meta.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize(); capacity_ = size_ * meta.itemsize(), require_init_ = true;
return data_ptr; return data_ptr;
} }
...@@ -225,6 +227,15 @@ class Tensor { ...@@ -225,6 +227,15 @@ class Tensor {
} }
template <typename T, class Context> template <typename T, class Context>
T* mutable_data(Context* ctx) {
auto* data = mutable_data<T, Context>();
if (!require_init_) return data;
ctx->MemsetAsync(nbytes(), (void*)data);
require_init_ = false;
return data;
}
template <typename T, class Context>
const T* data() const { const T* data() const {
CHECK(meta_ == TypeMeta::Make<T>()) CHECK(meta_ == TypeMeta::Make<T>())
<< "\nThe DType of Tensor(" << name() << ") is " << "\nThe DType of Tensor(" << name() << ") is "
...@@ -234,27 +245,31 @@ class Tensor { ...@@ -234,27 +245,31 @@ class Tensor {
} }
template <class Context> template <class Context>
inline void CopyFrom(const Tensor& other) { inline void CopyFrom(const Tensor& other, Context* ctx) {
if ((void*)&other == (void*)this) return;
CHECK_EQ(size_, other.size_); CHECK_EQ(size_, other.size_);
auto* src = other.template raw_data<Context>(); auto* src = other.template raw_data<Context>();
auto* dst = raw_mutable_data<Context>(other.meta_); auto* dst = raw_mutable_data<Context>(other.meta_);
if (dst == src) return; ctx->template MemcpyAsync<Context, Context>(
if (TypeMeta::Id<Context>() == nbytes(), dst, src);
TypeMeta::Id<CPUContext>()) { require_init_ = false;
CPUContext::Memcpy<Context, Context>(nbytes(), dst, src);
} else if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) {
CUDAContext::Memcpy<Context, Context>(nbytes(), dst, src);
}
} }
inline void Move(MixedMemory* mem) { inline void Move(MixedMemory* mem) {
if (mem != nullptr) ex_memory_ = mem; if (mem != nullptr) {
else ex_memory_ = new MixedMemory(TypeMeta::Make<float>(), 4); ex_memory_ = mem;
own_mem_ = false; require_init_ = false;
} else {
ex_memory_ = new MixedMemory(
TypeMeta::Make<float>(), 4);
require_init_ = true;
} own_mem_ = false;
} }
inline void Share(MixedMemory* mem) { Move(mem); is_shared_ = true; } inline void Share(MixedMemory* mem) {
Move(mem); is_shared_ = true;
require_init_ = false;
}
inline void Reset() { inline void Reset() {
size_ = capacity_ = 0; size_ = capacity_ = 0;
...@@ -275,7 +290,7 @@ class Tensor { ...@@ -275,7 +290,7 @@ class Tensor {
shared_ptr<MixedMemory> memory_; shared_ptr<MixedMemory> memory_;
MixedMemory* ex_memory_ = nullptr; MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, is_shared_ = false; bool is_corrupted_ = false, is_shared_ = false;
bool own_mem_ = true; bool own_mem_ = true, require_init_ = true;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -179,29 +179,28 @@ class Workspace { ...@@ -179,29 +179,28 @@ class Workspace {
template <class Context> template <class Context>
inline vector<void*> caches( inline vector<void*> caches(
const vector<size_t>& segments) { const vector<size_t>& segments) {
TIndex total_size = 0; TIndex nbytes = 0;
for (auto& segment : segments) total_size += (TIndex)segment; for (auto& segment : segments) nbytes += (TIndex)segment;
Tensor* cacheT = CreateTensor("/share/cache"); Tensor* cache_t = CreateTensor("/share/cache");
cacheT->Reshape({ total_size }); cache_t->Reshape({ nbytes });
vector<void*> caches(segments.size()); vector<void*> Bcaches(segments.size());
caches[0] = cacheT->template mutable_data<uint8_t, Context>(); Bcaches[0] = cache_t->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); i++) for (int i = 1; i < segments.size(); i++)
caches[i] = (uint8_t*)caches[i - 1] + segments[i - 1]; Bcaches[i] = (uint8_t*)Bcaches[i - 1] + segments[i - 1];
return caches; return Bcaches;
} }
template <typename T, class Context> template <typename T, class Context>
inline vector<T*> caches( inline vector<T*> caches(
const vector<TIndex>& segments) { const vector<TIndex>& segments) {
TIndex total_count = 0; vector<size_t> Tsegments;
for (auto& segment : segments) total_count += segment; for (auto& segment : segments)
Tensor* cacheT = CreateTensor("/share/cache"); Tsegments.emplace_back(segment * sizeof(T));
cacheT->Reshape({ total_count }); vector<void*> Bcaches = caches<Context>(Tsegments);
vector<T*> caches(segments.size()); vector<T*> Tcaches(segments.size());
caches[0] = cacheT->template mutable_data<T, Context>(); for (int i = 0; i < segments.size(); i++)
for (int i = 1; i < segments.size(); i++) Tcaches[i] = (T*)Bcaches[i];
caches[i] = caches[i - 1] + segments[i - 1]; return Tcaches;
return caches;
} }
/******************** Operator ********************/ /******************** Operator ********************/
...@@ -259,11 +258,12 @@ class Workspace { ...@@ -259,11 +258,12 @@ class Workspace {
void RunGraph( void RunGraph(
const string& graph_name, const string& graph_name,
const string& include, const string& include,
const string& exclude) { const string& exclude,
const int stream_id = 1) {
if (!graph_map_.count(graph_name)) if (!graph_map_.count(graph_name))
LOG(FATAL) << "Graph(" << graph_name LOG(FATAL) << "Graph(" << graph_name
<< ") does not exist."; << ") does not exist.";
graph_map_[graph_name]->Run(include, exclude); graph_map_[graph_name]->Run(include, exclude, stream_id);
} }
vector<string> GetGraphs() { vector<string> GetGraphs() {
......
...@@ -36,7 +36,6 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -36,7 +36,6 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void SoftmaxRun(); void SoftmaxRun();
void SoftmaxRunFP16();
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
......
...@@ -42,7 +42,7 @@ public: ...@@ -42,7 +42,7 @@ public:
// simply copy the dY to dX // simply copy the dY to dX
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (Output(0)->name() != Input(-1).name()) if (Output(0)->name() != Input(-1).name())
Output(0)->template CopyFrom<Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1), ctx());
} }
}; };
......
...@@ -34,7 +34,6 @@ class L2NormOp final : public Operator<Context> { ...@@ -34,7 +34,6 @@ class L2NormOp final : public Operator<Context> {
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
float eps; float eps;
string mode; string mode;
bool across_inner;
Tensor* norm, buffer; Tensor* norm, buffer;
TIndex outer_dim, dim, inner_dim, spatial_dim; TIndex outer_dim, dim, inner_dim, spatial_dim;
}; };
...@@ -55,7 +54,6 @@ class L2NormGradientOp final : public Operator<Context> { ...@@ -55,7 +54,6 @@ class L2NormGradientOp final : public Operator<Context> {
protected: protected:
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
string mode; string mode;
bool across_inner;
Tensor* norm, buffer, buffer_inner; Tensor* norm, buffer, buffer_inner;
TIndex outer_dim, dim, inner_dim; TIndex outer_dim, dim, inner_dim;
}; };
......
...@@ -24,7 +24,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> { ...@@ -24,7 +24,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat32() override;
void ComputeRunWithFloat16() override; void ComputeRunWithFloat16() override;
protected: protected:
......
...@@ -43,10 +43,26 @@ class CollectiveUpdateOp final : public Operator<Context> { ...@@ -43,10 +43,26 @@ class CollectiveUpdateOp final : public Operator<Context> {
void InitNCCL(); void InitNCCL();
void RunOnDevice() override; void RunOnDevice() override;
void MPIAllReduceWithFloat();
void NCCLAllReduceWithFloat(); template <typename T> void MPIAllReduce(
void MPIBcastWithFloat(); Tensor* tensor,
void NCCLBcastWithFloat(); MPI_Datatype dtype);
template <typename T> void MPIBcast(
Tensor* tensor,
MPI_Datatype dtype);
#ifdef WITH_MPI_NCCL
template <typename T> void NCCLAllReduce(
Tensor* tensor,
ncclDataType_t dtype,
cudaStream_t& stream);
template <typename T> void NCCLBcast(
Tensor* tensor,
ncclDataType_t dtype,
cudaStream_t& stream);
#endif
protected: protected:
int comm_size, comm_rank, comm_root; int comm_size, comm_rank, comm_root;
......
...@@ -24,7 +24,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> { ...@@ -24,7 +24,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat32() override;
void ComputeRunWithFloat16() override; void ComputeRunWithFloat16() override;
protected: protected:
......
...@@ -24,7 +24,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> { ...@@ -24,7 +24,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat32() override;
void ComputeRunWithFloat16() override; void ComputeRunWithFloat16() override;
protected: protected:
......
...@@ -25,7 +25,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> { ...@@ -25,7 +25,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat32() override;
void ComputeRunWithFloat16() override; void ComputeRunWithFloat16() override;
protected: protected:
......
...@@ -35,13 +35,11 @@ class UpdateOpBase : public Operator<Context> { ...@@ -35,13 +35,11 @@ class UpdateOpBase : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void PreprocessRunWithType(); template <typename T> void PreprocessRunWithType();
virtual void ComputeRunWithFloat() = 0; virtual void ComputeRunWithFloat32() = 0;
virtual void ComputeRunWithFloat16() = 0;
virtual void ComputeRunWithFloat16() { void UpdateRunWithFloat32();
LOG(FATAL) << "This Updater does not support FP16."; void UpdateRunWithFloat16();
}
template <typename T> void UpdateRunWithType();
protected: protected:
float lr_mult, decay_mult; float lr_mult, decay_mult;
......
...@@ -80,7 +80,8 @@ class ConvOpBase : public Operator<Context> { ...@@ -80,7 +80,8 @@ class ConvOpBase : public Operator<Context> {
dilation[0], dilation[1], dilation[0], dilation[1],
data_format, data_format,
im, im,
col); col,
ctx());
} else LOG(FATAL) << "ConvNd has not been implemented yet"; } else LOG(FATAL) << "ConvNd has not been implemented yet";
} }
template <typename T> void Col2Im(const T* col, T* im) { template <typename T> void Col2Im(const T* col, T* im) {
...@@ -94,7 +95,8 @@ class ConvOpBase : public Operator<Context> { ...@@ -94,7 +95,8 @@ class ConvOpBase : public Operator<Context> {
dilation[0], dilation[1], dilation[0], dilation[1],
data_format, data_format,
col, col,
im); im,
ctx());
} else LOG(FATAL) << "ConvNd has not been implemented yet"; } else LOG(FATAL) << "ConvNd has not been implemented yet";
} }
}; };
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
namespace dragon { namespace dragon {
#define HFLT_MIN 6.10e-5F
template <typename DestType, typename SrcType> template <typename DestType, typename SrcType>
DestType dragon_cast(SrcType val); DestType dragon_cast(SrcType val);
......
...@@ -29,9 +29,17 @@ namespace dragon { ...@@ -29,9 +29,17 @@ namespace dragon {
#ifdef WITH_CUDA #ifdef WITH_CUDA
static const int CUDA_THREADS = 1024; // The number of cuda threads to use. We set it to
// We do have a server with 10 GPUs :-) // 1024 which would work for compute capability 2.x
#define CUDA_MAX_DEVICES 10 // Set it to 512 if using compute capability 1.x
const int CUDA_THREADS = 1024;
// The maximum number of blocks to use in the default kernel call. We set it to
// 65535 which would work for compute capability 2.x (where 65536 is the limit)
const int CUDA_MAX_BLOCKS = 65535;
// You really need a NVIDIA DGX-2 !!! :-)
#define CUDA_MAX_DEVICES 16
#define CUDA_VERSION_MIN(major, minor, patch) \ #define CUDA_VERSION_MIN(major, minor, patch) \
(CUDA_VERSION >= (major * 1000 + minor * 100 + patch)) (CUDA_VERSION >= (major * 1000 + minor * 100 + patch))
...@@ -67,12 +75,16 @@ static const int CUDA_THREADS = 1024; ...@@ -67,12 +75,16 @@ static const int CUDA_THREADS = 1024;
} while (0) } while (0)
#endif // WITH_MPI_NCCL #endif // WITH_MPI_NCCL
#define CUDA_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; \
i < n; i += blockDim.x * gridDim.x) i < n; i += blockDim.x * gridDim.x)
inline int CUDA_BLOCKS(const int N) { inline int CUDA_BLOCKS(const int N) {
return (N + CUDA_THREADS - 1) / CUDA_THREADS; return std::max(
std::min(
(N + CUDA_THREADS - 1) / CUDA_THREADS,
CUDA_MAX_BLOCKS
), 1);
} }
#if CUDA_VERSION_MAX(9, 0, 0) #if CUDA_VERSION_MAX(9, 0, 0)
......
...@@ -44,6 +44,7 @@ template<> class CUDNNType<float> { ...@@ -44,6 +44,7 @@ template<> class CUDNNType<float> {
static const cudnnDataType_t type = CUDNN_DATA_FLOAT; static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static float oneval, zeroval; static float oneval, zeroval;
static const void *one, *zero; static const void *one, *zero;
typedef float BNParamType;
}; };
template<> class CUDNNType<double> { template<> class CUDNNType<double> {
...@@ -51,6 +52,7 @@ template<> class CUDNNType<double> { ...@@ -51,6 +52,7 @@ template<> class CUDNNType<double> {
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static double oneval, zeroval; static double oneval, zeroval;
static const void *one, *zero; static const void *one, *zero;
typedef double BNParamType;
}; };
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -59,6 +61,7 @@ template<> class CUDNNType<float16> { ...@@ -59,6 +61,7 @@ template<> class CUDNNType<float16> {
static const cudnnDataType_t type = CUDNN_DATA_HALF; static const cudnnDataType_t type = CUDNN_DATA_HALF;
static float oneval, zeroval; static float oneval, zeroval;
static const void *one, *zero; static const void *one, *zero;
typedef float BNParamType;
}; };
#endif #endif
......
...@@ -40,7 +40,7 @@ class ConstantFiller final : public Filler<T, Context> { ...@@ -40,7 +40,7 @@ class ConstantFiller final : public Filler<T, Context> {
void Fill(Tensor* tensor, Context* ctx) override { void Fill(Tensor* tensor, Context* ctx) override {
math::Set<T, Context>(tensor->count(), math::Set<T, Context>(tensor->count(),
dragon_cast<T, float>(filler().value()), dragon_cast<T, float>(filler().value()),
tensor->mutable_data<T, Context>()); tensor->mutable_data<T, Context>(), ctx);
} }
protected: protected:
...@@ -71,11 +71,11 @@ class TruncatedNormalFiller final : public Filler<T, Context> { ...@@ -71,11 +71,11 @@ class TruncatedNormalFiller final : public Filler<T, Context> {
void Fill(Tensor* tensor, Context* ctx) override { void Fill(Tensor* tensor, Context* ctx) override {
// implement it on gpu is difficult // implement it on gpu is difficult
static CPUContext cpu_ctx; static CPUContext cctx;
math::RandomTruncatedNormal<T, CPUContext>(tensor->count(), math::RandomTruncatedNormal<T, CPUContext>(tensor->count(),
filler().mean(), filler().std(), filler().mean(), filler().std(),
filler().low(), filler().high(), filler().low(), filler().high(),
tensor->mutable_data<T, CPUContext>(), &cpu_ctx); tensor->mutable_data<T, CPUContext>(), &cctx);
} }
protected: protected:
......
...@@ -36,7 +36,8 @@ template <typename T, class Context> ...@@ -36,7 +36,8 @@ template <typename T, class Context>
void Set( void Set(
const int n, const int n,
const T alpha, const T alpha,
T* x); T* x,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void RandomUniform( void RandomUniform(
...@@ -78,73 +79,84 @@ void Add( ...@@ -78,73 +79,84 @@ void Add(
const int n, const int n,
const T* a, const T* a,
const T* b, const T* b,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Sub( void Sub(
const int n, const int n,
const T* a, const T* a,
const T* b, const T* b,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Mul( void Mul(
const int n, const int n,
const T* a, const T* a,
const T* b, const T* b,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Div( void Div(
const int n, const int n,
const T* a, const T* a,
const T* b, const T* b,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Clip( void Clip(
const int n, const int n,
const float low, const float low,
const float high, const float high,
T* x); T* x,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Exp( void Exp(
const int n, const int n,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Log( void Log(
const int n, const int n,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Square( void Square(
const int n, const int n,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Sqrt( void Sqrt(
const int n, const int n,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Pow( void Pow(
const int n, const int n,
const float alpha, const float alpha,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Inv( void Inv(
const int n, const int n,
const float numerator, const float numerator,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
/******************** Level-2 ********************/ /******************** Level-2 ********************/
...@@ -164,19 +176,21 @@ void Scale( ...@@ -164,19 +176,21 @@ void Scale(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
T StridedDot( void StridedDot(
const int n, const int n,
const T* a, const T* a,
const int incx, const int incx,
const T* b, const T* b,
const int incy, const int incy,
T* y,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
float Dot( void Dot(
const int n, const int n,
const T* a, const T* a,
const T* b, const T* b,
T* y,
Context* ctx); Context* ctx);
template<typename T, class Context> template<typename T, class Context>
...@@ -188,13 +202,15 @@ template<typename T, class Context> ...@@ -188,13 +202,15 @@ template<typename T, class Context>
void AddScalar( void AddScalar(
const int n, const int n,
const float alpha, const float alpha,
T* y); T* y,
Context* ctx);
template<typename T, class Context> template<typename T, class Context>
void MulScalar( void MulScalar(
const int n, const int n,
const float alpha, const float alpha,
T* y); T* y,
Context* ctx);
template<typename T, class Context> template<typename T, class Context>
void Axpy( void Axpy(
......
...@@ -49,7 +49,8 @@ void Elu( ...@@ -49,7 +49,8 @@ void Elu(
const int count, const int count,
const float alpha, const float alpha,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void EluGrad( void EluGrad(
...@@ -57,7 +58,8 @@ void EluGrad( ...@@ -57,7 +58,8 @@ void EluGrad(
const float alpha, const float alpha,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx); T* dx,
Context* ctx);
/******************** activation.prelu ********************/ /******************** activation.prelu ********************/
...@@ -70,7 +72,8 @@ void PRelu( ...@@ -70,7 +72,8 @@ void PRelu(
const string& data_format, const string& data_format,
const T* x, const T* x,
const T* w, const T* w,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void PReluGrad( void PReluGrad(
...@@ -82,7 +85,8 @@ void PReluGrad( ...@@ -82,7 +85,8 @@ void PReluGrad(
const T* dy, const T* dy,
const T* x, const T* x,
const T* w, const T* w,
T* dx); T* dx,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void PReluWGrad( void PReluWGrad(
...@@ -106,7 +110,8 @@ void Relu( ...@@ -106,7 +110,8 @@ void Relu(
const int count, const int count,
const float slope, const float slope,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ReluGrad( void ReluGrad(
...@@ -114,7 +119,8 @@ void ReluGrad( ...@@ -114,7 +119,8 @@ void ReluGrad(
const float slope, const float slope,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx); T* dx,
Context* ctx);
/******************** activation.selu ********************/ /******************** activation.selu ********************/
...@@ -122,14 +128,16 @@ template <typename T, class Context> ...@@ -122,14 +128,16 @@ template <typename T, class Context>
void SElu( void SElu(
const int count, const int count,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SEluGrad( void SEluGrad(
const int count, const int count,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx); T* dx,
Context* ctx);
/******************** activation.sigmoid ********************/ /******************** activation.sigmoid ********************/
...@@ -137,14 +145,16 @@ template <typename T, class Context> ...@@ -137,14 +145,16 @@ template <typename T, class Context>
void Sigmoid( void Sigmoid(
const int count, const int count,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SigmoidGrad( void SigmoidGrad(
const int count, const int count,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx); T* dx,
Context* ctx);
/******************** activation.softmax ********************/ /******************** activation.softmax ********************/
...@@ -179,14 +189,16 @@ template <typename T, class Context> ...@@ -179,14 +189,16 @@ template <typename T, class Context>
void Tanh( void Tanh(
const int count, const int count,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void TanhGrad( void TanhGrad(
const int count, const int count,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx); T* dx,
Context* ctx);
/******************** arithmetic.affine ********************/ /******************** arithmetic.affine ********************/
...@@ -223,7 +235,8 @@ void Clip( ...@@ -223,7 +235,8 @@ void Clip(
const float high, const float high,
const T* x, const T* x,
T* mask, T* mask,
T* y); T* y,
Context* ctx);
/******************** control_flow.compare ********************/ /******************** control_flow.compare ********************/
...@@ -232,7 +245,8 @@ void Equal( ...@@ -232,7 +245,8 @@ void Equal(
const int count, const int count,
const T* a, const T* a,
const T* b, const T* b,
T* y); T* y,
Context* ctx);
/******************** loss.l1_loss ********************/ /******************** loss.l1_loss ********************/
...@@ -240,7 +254,8 @@ template <typename T, class Context> ...@@ -240,7 +254,8 @@ template <typename T, class Context>
void AbsGrad( void AbsGrad(
const int count, const int count,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** loss.sigmoid_cross_entropy ********************/ /******************** loss.sigmoid_cross_entropy ********************/
...@@ -301,14 +316,16 @@ void SmoothL1( ...@@ -301,14 +316,16 @@ void SmoothL1(
const int count, const int count,
const float beta, const float beta,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SmoothL1Grad( void SmoothL1Grad(
const int count, const int count,
const float beta, const float beta,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** loss.softmax_cross_entropy ********************/ /******************** loss.softmax_cross_entropy ********************/
...@@ -317,7 +334,8 @@ void SoftmaxCrossEntropy( ...@@ -317,7 +334,8 @@ void SoftmaxCrossEntropy(
const int count, const int count,
const T* prob, const T* prob,
const T* target, const T* target,
T* loss); T* loss,
Context* ctx);
/******************** loss.softmax_focal_loss ********************/ /******************** loss.softmax_focal_loss ********************/
...@@ -366,8 +384,8 @@ void SparseSoftmaxCrossEntropy( ...@@ -366,8 +384,8 @@ void SparseSoftmaxCrossEntropy(
const Ty* labels, const Ty* labels,
const int* ignores, const int* ignores,
const int num_ignores, const int num_ignores,
Tx* losses, float* losses,
Tx* flags, float* flags,
Context* ctx); Context* ctx);
template <typename Tx, typename Ty, class Context> template <typename Tx, typename Ty, class Context>
...@@ -380,7 +398,7 @@ void SparseSoftmaxCrossEntropyGrad( ...@@ -380,7 +398,7 @@ void SparseSoftmaxCrossEntropyGrad(
const int* ignores, const int* ignores,
const int num_ignores, const int num_ignores,
Tx* dx, Tx* dx,
Tx* flags, float* flags,
Context* ctx); Context* ctx);
/******************** misc.astype ********************/ /******************** misc.astype ********************/
...@@ -389,7 +407,8 @@ template <typename Ta, typename Tb, class Context> ...@@ -389,7 +407,8 @@ template <typename Ta, typename Tb, class Context>
void TypeA2B( void TypeA2B(
const int count, const int count,
const Ta* a, const Ta* a,
Tb* b); Tb* b,
Context* ctx);
/******************** misc.image_data ********************/ /******************** misc.image_data ********************/
...@@ -404,7 +423,8 @@ void ImageData( ...@@ -404,7 +423,8 @@ void ImageData(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const Tx* x, const Tx* x,
Ty* y); Ty* y,
Context* ctx);
/******************** ndarray.arange ********************/ /******************** ndarray.arange ********************/
...@@ -413,7 +433,8 @@ void Arange( ...@@ -413,7 +433,8 @@ void Arange(
const int count, const int count,
const int start, const int start,
const int step, const int step,
T* y); T* y,
Context* ctx);
/******************** ndarray.argreduce ********************/ /******************** ndarray.argreduce ********************/
...@@ -425,7 +446,8 @@ void Argmax( ...@@ -425,7 +446,8 @@ void Argmax(
const int top_k, const int top_k,
const T* x, const T* x,
int64_t* indices, int64_t* indices,
T* values); T* values,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Argmin( void Argmin(
...@@ -435,7 +457,8 @@ void Argmin( ...@@ -435,7 +457,8 @@ void Argmin(
const int top_k, const int top_k,
const T* x, const T* x,
int64_t* indices, int64_t* indices,
T* values); T* values,
Context* ctx);
/******************** ndarray.gather ********************/ /******************** ndarray.gather ********************/
...@@ -443,7 +466,8 @@ template <typename T, class Context> ...@@ -443,7 +466,8 @@ template <typename T, class Context>
void CanonicalAxis( void CanonicalAxis(
const int count, const int count,
const int dim, const int dim,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Gather( void Gather(
...@@ -454,7 +478,8 @@ void Gather( ...@@ -454,7 +478,8 @@ void Gather(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void GatherGrad( void GatherGrad(
...@@ -465,7 +490,8 @@ void GatherGrad( ...@@ -465,7 +490,8 @@ void GatherGrad(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.concat ********************/ /******************** ndarray.concat ********************/
...@@ -478,7 +504,8 @@ void Concat( ...@@ -478,7 +504,8 @@ void Concat(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ConcatGrad( void ConcatGrad(
...@@ -489,7 +516,8 @@ void ConcatGrad( ...@@ -489,7 +516,8 @@ void ConcatGrad(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.crop ********************/ /******************** ndarray.crop ********************/
...@@ -501,7 +529,8 @@ void Crop1D( ...@@ -501,7 +529,8 @@ void Crop1D(
const int inner_dim, const int inner_dim,
const int start, const int start,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Crop1DGrad( void Crop1DGrad(
...@@ -512,7 +541,8 @@ void Crop1DGrad( ...@@ -512,7 +541,8 @@ void Crop1DGrad(
const int start, const int start,
const int end, const int end,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.pad ********************/ /******************** ndarray.pad ********************/
...@@ -525,7 +555,8 @@ void ConstPad1D( ...@@ -525,7 +555,8 @@ void ConstPad1D(
const int pad_l, const int pad_l,
const float value, const float value,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ReflectPad1D( void ReflectPad1D(
...@@ -535,7 +566,8 @@ void ReflectPad1D( ...@@ -535,7 +566,8 @@ void ReflectPad1D(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void EdgePad1D( void EdgePad1D(
...@@ -545,7 +577,8 @@ void EdgePad1D( ...@@ -545,7 +577,8 @@ void EdgePad1D(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ConstPad1DGrad( void ConstPad1DGrad(
...@@ -555,7 +588,8 @@ void ConstPad1DGrad( ...@@ -555,7 +588,8 @@ void ConstPad1DGrad(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ReflectPad1DGrad( void ReflectPad1DGrad(
...@@ -565,7 +599,8 @@ void ReflectPad1DGrad( ...@@ -565,7 +599,8 @@ void ReflectPad1DGrad(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void EdgePad1DGrad( void EdgePad1DGrad(
...@@ -575,7 +610,8 @@ void EdgePad1DGrad( ...@@ -575,7 +610,8 @@ void EdgePad1DGrad(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.one_hot ********************/ /******************** ndarray.one_hot ********************/
...@@ -585,7 +621,8 @@ void OneHot( ...@@ -585,7 +621,8 @@ void OneHot(
const int depth, const int depth,
const int on_value, const int on_value,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
/******************** ndarray.reduce ********************/ /******************** ndarray.reduce ********************/
...@@ -595,7 +632,8 @@ void Sum( ...@@ -595,7 +632,8 @@ void Sum(
const int axis_dim, const int axis_dim,
const int inner_dim, const int inner_dim,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SumGrad( void SumGrad(
...@@ -604,7 +642,8 @@ void SumGrad( ...@@ -604,7 +642,8 @@ void SumGrad(
const int inner_dim, const int inner_dim,
const T coeff, const T coeff,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.repeat ********************/ /******************** ndarray.repeat ********************/
...@@ -616,7 +655,8 @@ void Repeat( ...@@ -616,7 +655,8 @@ void Repeat(
const int inner_dim, const int inner_dim,
const int repeats, const int repeats,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void RepeatGrad( void RepeatGrad(
...@@ -640,7 +680,8 @@ void Slice( ...@@ -640,7 +680,8 @@ void Slice(
const int y_slice_dim, const int y_slice_dim,
const int slice_offset, const int slice_offset,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SliceGrad( void SliceGrad(
...@@ -651,7 +692,8 @@ void SliceGrad( ...@@ -651,7 +692,8 @@ void SliceGrad(
const int y_slice_dim, const int y_slice_dim,
const int slice_offset, const int slice_offset,
const T* dy, const T* dy,
T* x); T* x,
Context* ctx);
/******************** ndarray.tile ********************/ /******************** ndarray.tile ********************/
...@@ -662,7 +704,8 @@ void Tile( ...@@ -662,7 +704,8 @@ void Tile(
const int ex_inner_dim, const int ex_inner_dim,
const int multiple, const int multiple,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void TileGrad( void TileGrad(
...@@ -684,7 +727,8 @@ void Transpose( ...@@ -684,7 +727,8 @@ void Transpose(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void TransposeGrad( void TransposeGrad(
...@@ -694,7 +738,8 @@ void TransposeGrad( ...@@ -694,7 +738,8 @@ void TransposeGrad(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** recurrent.lstm_cell ********************/ /******************** recurrent.lstm_cell ********************/
...@@ -706,7 +751,8 @@ void LSTMCell( ...@@ -706,7 +751,8 @@ void LSTMCell(
const T* cx, const T* cx,
T* xact, T* xact,
T* c, T* c,
T* h); T* h,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void LSTMCellGrad( void LSTMCellGrad(
...@@ -719,7 +765,8 @@ void LSTMCellGrad( ...@@ -719,7 +765,8 @@ void LSTMCellGrad(
const T* dc, const T* dc,
const T* dh, const T* dh,
T* dcx, T* dcx,
T* dx); T* dx,
Context* ctx);
/******************** update.adam_update ********************/ /******************** update.adam_update ********************/
...@@ -732,7 +779,8 @@ void AdamUpdate( ...@@ -732,7 +779,8 @@ void AdamUpdate(
const float eps, const float eps,
T* g, T* g,
T* m, T* m,
T* v); T* v,
Context* ctx);
/******************** update.nesterov_update ********************/ /******************** update.nesterov_update ********************/
...@@ -742,7 +790,8 @@ void NesterovUpdate( ...@@ -742,7 +790,8 @@ void NesterovUpdate(
const float lr, const float lr,
const float momentum, const float momentum,
T* g, T* g,
T* h); T* h,
Context* ctx);
/******************** update.rmsprop_update ********************/ /******************** update.rmsprop_update ********************/
...@@ -753,7 +802,8 @@ void RMSPropUpdate( ...@@ -753,7 +802,8 @@ void RMSPropUpdate(
const float decay, const float decay,
const float eps, const float eps,
T* g, T* g,
T* h); T* h,
Context* ctx);
/******************** update.sgd_update ********************/ /******************** update.sgd_update ********************/
...@@ -763,7 +813,8 @@ void SGDUpdate( ...@@ -763,7 +813,8 @@ void SGDUpdate(
const float lr, const float lr,
const float momentum, const float momentum,
T* g, T* g,
T* h); T* h,
Context* ctx);
/******************** vision.bias_add ********************/ /******************** vision.bias_add ********************/
...@@ -792,7 +843,8 @@ void BilinearResize( ...@@ -792,7 +843,8 @@ void BilinearResize(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void BilinearResizeGrad( void BilinearResizeGrad(
...@@ -805,7 +857,8 @@ void BilinearResizeGrad( ...@@ -805,7 +857,8 @@ void BilinearResizeGrad(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** vision.conv ********************/ /******************** vision.conv ********************/
...@@ -826,7 +879,8 @@ void Im2Col2d( ...@@ -826,7 +879,8 @@ void Im2Col2d(
const int dilation_w, const int dilation_w,
const string& data_format, const string& data_format,
const T* im, const T* im,
T* col); T* col,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Col2Im2d( void Col2Im2d(
...@@ -845,7 +899,8 @@ void Col2Im2d( ...@@ -845,7 +899,8 @@ void Col2Im2d(
const int dilation_w, const int dilation_w,
const string& data_format, const string& data_format,
const T* col, const T* col,
T* im); T* im,
Context* ctx);
/******************** vision.nn_resize ********************/ /******************** vision.nn_resize ********************/
...@@ -860,7 +915,8 @@ void NNResize( ...@@ -860,7 +915,8 @@ void NNResize(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void NNResizeGrad( void NNResizeGrad(
...@@ -873,7 +929,8 @@ void NNResizeGrad( ...@@ -873,7 +929,8 @@ void NNResizeGrad(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** vision.pooling ********************/ /******************** vision.pooling ********************/
...@@ -895,7 +952,8 @@ void MAXPooling2d( ...@@ -895,7 +952,8 @@ void MAXPooling2d(
const string& data_format, const string& data_format,
const T* x, const T* x,
int* mask, int* mask,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void AVGPooling2d( void AVGPooling2d(
...@@ -914,7 +972,8 @@ void AVGPooling2d( ...@@ -914,7 +972,8 @@ void AVGPooling2d(
const int pad_w, const int pad_w,
const string& data_format, const string& data_format,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void MAXPooling2dGrad( void MAXPooling2dGrad(
...@@ -934,7 +993,8 @@ void MAXPooling2dGrad( ...@@ -934,7 +993,8 @@ void MAXPooling2dGrad(
const string& data_format, const string& data_format,
const T* dy, const T* dy,
const int* mask, const int* mask,
T* dx); T* dx,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void AVGPooling2dGrad( void AVGPooling2dGrad(
...@@ -953,7 +1013,8 @@ void AVGPooling2dGrad( ...@@ -953,7 +1013,8 @@ void AVGPooling2dGrad(
const int pad_w, const int pad_w,
const string& data_format, const string& data_format,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** vision.roi_pooling ********************/ /******************** vision.roi_pooling ********************/
...@@ -971,7 +1032,8 @@ void ROIPooling( ...@@ -971,7 +1032,8 @@ void ROIPooling(
const T* x, const T* x,
const T* rois, const T* rois,
int* mask, int* mask,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ROIPoolingGrad( void ROIPoolingGrad(
...@@ -987,7 +1049,8 @@ void ROIPoolingGrad( ...@@ -987,7 +1049,8 @@ void ROIPoolingGrad(
const T* dy, const T* dy,
const T* rois, const T* rois,
const int* mask, const int* mask,
T* dx); T* dx,
Context* ctx);
/******************** vision.roi_align ********************/ /******************** vision.roi_align ********************/
...@@ -1005,7 +1068,8 @@ void ROIAlign( ...@@ -1005,7 +1068,8 @@ void ROIAlign(
const int sampling_ratio, const int sampling_ratio,
const T* x, const T* x,
const T* rois, const T* rois,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ROIAlignGrad( void ROIAlignGrad(
...@@ -1021,7 +1085,8 @@ void ROIAlignGrad( ...@@ -1021,7 +1085,8 @@ void ROIAlignGrad(
const int sampling_ratio, const int sampling_ratio,
const float* dy, const float* dy,
const float* rois, const float* rois,
float* dx); float* dx,
Context* ctx);
} // namespace kernel } // namespace kernel
......
...@@ -80,7 +80,7 @@ T Dot( ...@@ -80,7 +80,7 @@ T Dot(
const T* b); const T* b);
template<typename T> template<typename T>
T ASum( T Sum(
const int n, const int n,
const T* x); const T* x);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#ifdef WITH_SSE #ifdef WITH_SSE
#include <immintrin.h> #include <immintrin.h>
#include <tmmintrin.h>
#include <cstdint> #include <cstdint>
namespace dragon { namespace dragon {
......
...@@ -250,8 +250,9 @@ void LoadCaffemodel( ...@@ -250,8 +250,9 @@ void LoadCaffemodel(
void RunGraph( void RunGraph(
const std::string& graph_name, const std::string& graph_name,
Workspace* ws) { Workspace* ws,
ws->RunGraph(graph_name, "", ""); const int stream_id) {
ws->RunGraph(graph_name, "", "", stream_id);
} }
template <typename T> template <typename T>
......
...@@ -38,8 +38,7 @@ class Device { ...@@ -38,8 +38,7 @@ class Device {
EXPORT const int device_id() const { return device_id_; } EXPORT const int device_id() const { return device_id_; }
private: private:
int device_type_; int device_type_, device_id_;
int device_id_;
}; };
EXPORT Workspace* CreateWorkspace(const std::string& name); EXPORT Workspace* CreateWorkspace(const std::string& name);
...@@ -61,7 +60,8 @@ EXPORT std::string CreateGraph( ...@@ -61,7 +60,8 @@ EXPORT std::string CreateGraph(
EXPORT void RunGraph( EXPORT void RunGraph(
const std::string& graph_name, const std::string& graph_name,
Workspace* ws); Workspace* ws,
const int stream_id = 1);
EXPORT void CreateTensor( EXPORT void CreateTensor(
const std::string& name, const std::string& name,
......
...@@ -116,7 +116,7 @@ class NumpyFeeder : public TensorFeederBase { ...@@ -116,7 +116,7 @@ class NumpyFeeder : public TensorFeederBase {
#else #else
LOG(FATAL) << "CUDA was not compiled."; LOG(FATAL) << "CUDA was not compiled.";
#endif #endif
} else{ } else {
CPUContext::Memcpy<CPUContext, CPUContext>(tensor->nbytes(), CPUContext::Memcpy<CPUContext, CPUContext>(tensor->nbytes(),
tensor->raw_mutable_data<CPUContext>(), tensor->raw_mutable_data<CPUContext>(),
static_cast<void*>(PyArray_DATA(array))); static_cast<void*>(PyArray_DATA(array)));
......
...@@ -18,18 +18,22 @@ ...@@ -18,18 +18,22 @@
PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) { PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) {
PyObject* def_string = nullptr; PyObject* def_string = nullptr;
PyObject* py_g_outputs = nullptr; PyObject* py_g_outputs = nullptr;
if (!PyArg_ParseTuple(args, "SO!", &def_string, &PyList_Type, &py_g_outputs)) { if (!PyArg_ParseTuple(args, "SO!",
PyErr_SetString(PyExc_ValueError, "Excepted a serialized string of OperatorDef " &def_string, &PyList_Type, &py_g_outputs)) {
PyErr_SetString(PyExc_ValueError,
"Excepted a serialized string of OperatorDef "
"and a list containing outputs of this GradientOp."); "and a list containing outputs of this GradientOp.");
return nullptr; return nullptr;
} }
OperatorDef def; OperatorDef def;
if (!def.ParseFromString(PyBytes_AsStringEx(def_string))) { if (!def.ParseFromString(PyBytes_AsStringEx(def_string))) {
PyErr_SetString(PyExc_ValueError, "Failed to parse the OperatorDef."); PyErr_SetString(PyExc_ValueError,
"Failed to parse the OperatorDef.");
return nullptr; return nullptr;
} }
if (!GradientRegistry()->Has(def.type())) { if (!GradientRegistry()->Has(def.type())) {
PyErr_SetString(PyExc_KeyError, "This Operator does not register GradientOp."); PyErr_SetString(PyExc_KeyError,
"This Operator does not register GradientOp.");
return nullptr; return nullptr;
} }
vector<string> g_outputs; vector<string> g_outputs;
...@@ -61,7 +65,8 @@ PyObject* RunGradientFlowCC(PyObject* self, PyObject* args) { ...@@ -61,7 +65,8 @@ PyObject* RunGradientFlowCC(PyObject* self, PyObject* args) {
PyObject* py_fp_ops, *py_targets; PyObject* py_fp_ops, *py_targets;
PyObject* py_input_grads, *py_ignore_grads; PyObject* py_input_grads, *py_ignore_grads;
PyObject* py_share_grads, *py_export_graph; PyObject* py_share_grads, *py_export_graph;
if (!PyArg_ParseTuple(args, "OOOOOO", &py_fp_ops, &py_targets, if (!PyArg_ParseTuple(args, "OOOOOO",
&py_fp_ops, &py_targets,
&py_input_grads, &py_ignore_grads, &py_input_grads, &py_ignore_grads,
&py_share_grads, &py_export_graph)) { &py_share_grads, &py_export_graph)) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
...@@ -84,8 +89,8 @@ PyObject* RunGradientFlowCC(PyObject* self, PyObject* args) { ...@@ -84,8 +89,8 @@ PyObject* RunGradientFlowCC(PyObject* self, PyObject* args) {
for (auto& grad : input_grads) maker.AddExternalGrad(grad); for (auto& grad : input_grads) maker.AddExternalGrad(grad);
for (auto& grad : ignore_grads) maker.AddIgnoreGrad(grad); for (auto& grad : ignore_grads) maker.AddIgnoreGrad(grad);
maker.Make(fp_ops, targets, bp_ops); maker.Make(fp_ops, targets, bp_ops);
bool share_grads = (bool)PyObject_IsTrue(py_share_grads); bool share_grads = PyObject_IsTrue(py_share_grads) ? true : false;
bool export_graph = (bool)PyObject_IsTrue(py_export_graph); bool export_graph = PyObject_IsTrue(py_export_graph) ? true : false;
if (share_grads) maker.Share("/share/buffer/grads", bp_ops); if (share_grads) maker.Share("/share/buffer/grads", bp_ops);
if (export_graph) { if (export_graph) {
Tensor* t = ws()->CreateTensor("/export/dynamic_graph/gradient_flow"); Tensor* t = ws()->CreateTensor("/export/dynamic_graph/gradient_flow");
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
inline PyObject* SetLogLevelCC(PyObject* self, PyObject* args) { inline PyObject* SetLogLevelCC(PyObject* self, PyObject* args) {
char* cname; char* cname;
if (!PyArg_ParseTuple(args, "s", &cname)) { if (!PyArg_ParseTuple(args, "s", &cname)) {
PyErr_SetString(PyExc_ValueError, "Excepted the logging level."); PyErr_SetString(PyExc_ValueError,
"Excepted the logging level.");
return nullptr; return nullptr;
} }
SetLogDestination(StrToLogSeverity(string(cname))); SetLogDestination(StrToLogSeverity(string(cname)));
......
...@@ -17,16 +17,19 @@ ...@@ -17,16 +17,19 @@
inline PyObject* CreateGraphCC(PyObject* self, PyObject* args) { inline PyObject* CreateGraphCC(PyObject* self, PyObject* args) {
PyObject* graph_str; PyObject* graph_str;
if (!PyArg_ParseTuple(args, "S", &graph_str)) { if (!PyArg_ParseTuple(args, "S", &graph_str)) {
PyErr_SetString(PyExc_ValueError, "Excepted a serialized string of GraphDef."); PyErr_SetString(PyExc_ValueError,
"Excepted a serialized string of GraphDef.");
return nullptr; return nullptr;
} }
GraphDef graph_def; GraphDef graph_def;
if (!graph_def.ParseFromString(PyBytes_AsStringEx(graph_str))) { if (!graph_def.ParseFromString(PyBytes_AsStringEx(graph_str))) {
PyErr_SetString(PyExc_RuntimeError, "Failed to parse the GraphDef."); PyErr_SetString(PyExc_RuntimeError,
"Failed to parse the GraphDef.");
return nullptr; return nullptr;
} }
if (!ws()->CreateGraph(graph_def)) { if (!ws()->CreateGraph(graph_def)) {
PyErr_SetString(PyExc_RuntimeError, "Failed to create the Graph."); PyErr_SetString(PyExc_RuntimeError,
"Failed to create the Graph.");
return nullptr; return nullptr;
} }
Py_RETURN_TRUE; Py_RETURN_TRUE;
...@@ -34,11 +37,17 @@ inline PyObject* CreateGraphCC(PyObject* self, PyObject* args) { ...@@ -34,11 +37,17 @@ inline PyObject* CreateGraphCC(PyObject* self, PyObject* args) {
inline PyObject* RunGraphCC(PyObject* self, PyObject* args) { inline PyObject* RunGraphCC(PyObject* self, PyObject* args) {
char* cname, *include, *exclude; char* cname, *include, *exclude;
if (!PyArg_ParseTuple(args, "sss", &cname, &include, &exclude)) { if (!PyArg_ParseTuple(args, "sss",
PyErr_SetString(PyExc_ValueError, "Excepted the graph name, include and exclude rules."); &cname, &include, &exclude)) {
PyErr_SetString(PyExc_ValueError,
"Excepted the graph name, include and exclude rules.");
return nullptr; return nullptr;
} }
ws()->RunGraph(string(cname), string(include), string(exclude)); ws()->RunGraph(
string(cname),
string(include),
string(exclude)
);
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
......
...@@ -46,7 +46,8 @@ inline PyObject* MPICreateGroupCC(PyObject* self, PyObject* args) { ...@@ -46,7 +46,8 @@ inline PyObject* MPICreateGroupCC(PyObject* self, PyObject* args) {
PyObject *incl, *excl, *ret; PyObject *incl, *excl, *ret;
int local_root, world_size; int local_root, world_size;
if (!PyArg_ParseTuple(args, "iOO", &local_root, &incl, &excl)) { if (!PyArg_ParseTuple(args, "iOO", &local_root, &incl, &excl)) {
PyErr_SetString(PyExc_ValueError, "Excepted the local root, include and exclued list."); PyErr_SetString(PyExc_ValueError,
"Excepted the local root, include and exclued list.");
return nullptr; return nullptr;
} }
MPI_Group world_group, local_group; MPI_Group world_group, local_group;
......
...@@ -37,12 +37,14 @@ inline PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) { ...@@ -37,12 +37,14 @@ inline PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) {
inline PyObject* RunOperatorCC(PyObject* self, PyObject* args) { inline PyObject* RunOperatorCC(PyObject* self, PyObject* args) {
PyObject* op_str; PyObject* op_str;
if (!PyArg_ParseTuple(args, "S", &op_str)) { if (!PyArg_ParseTuple(args, "S", &op_str)) {
PyErr_SetString(PyExc_ValueError, "Excepted a serialized string of OperatorDef."); PyErr_SetString(PyExc_ValueError,
"Excepted a serialized string of OperatorDef.");
return nullptr; return nullptr;
} }
OperatorDef op_def; OperatorDef op_def;
if (!op_def.ParseFromString(PyBytes_AsStringEx(op_str))) { if (!op_def.ParseFromString(PyBytes_AsStringEx(op_str))) {
PyErr_SetString(PyExc_RuntimeError, "Failed to parse the OperatorDef."); PyErr_SetString(PyExc_RuntimeError,
"Failed to parse the OperatorDef.");
return nullptr; return nullptr;
} }
ws()->RunOperator(op_def); ws()->RunOperator(op_def);
...@@ -52,7 +54,8 @@ inline PyObject* RunOperatorCC(PyObject* self, PyObject* args) { ...@@ -52,7 +54,8 @@ inline PyObject* RunOperatorCC(PyObject* self, PyObject* args) {
inline PyObject* RunOperatorsCC(PyObject* self, PyObject* args) { inline PyObject* RunOperatorsCC(PyObject* self, PyObject* args) {
PyObject* py_ops; PyObject* py_ops;
if (!PyArg_ParseTuple(args, "O", &py_ops)) { if (!PyArg_ParseTuple(args, "O", &py_ops)) {
PyErr_SetString(PyExc_ValueError, "Excepted a list of serialized string of OperatorDef."); PyErr_SetString(PyExc_ValueError,
"Excepted a list of serialized string of OperatorDef.");
return nullptr; return nullptr;
} }
OperatorDef op_def; OperatorDef op_def;
...@@ -67,12 +70,14 @@ inline PyObject* RunOperatorsCC(PyObject* self, PyObject* args) { ...@@ -67,12 +70,14 @@ inline PyObject* RunOperatorsCC(PyObject* self, PyObject* args) {
inline PyObject* CreatePersistentOpCC(PyObject* self, PyObject* args) { inline PyObject* CreatePersistentOpCC(PyObject* self, PyObject* args) {
PyObject* op_str; PyObject* op_str;
if (!PyArg_ParseTuple(args, "S", &op_str)) { if (!PyArg_ParseTuple(args, "S", &op_str)) {
PyErr_SetString(PyExc_ValueError, "Excepted a serialized string of OperatorDef."); PyErr_SetString(PyExc_ValueError,
"Excepted a serialized string of OperatorDef.");
return nullptr; return nullptr;
} }
OperatorDef op_def; OperatorDef op_def;
if (!op_def.ParseFromString(PyBytes_AsStringEx(op_str))) { if (!op_def.ParseFromString(PyBytes_AsStringEx(op_str))) {
PyErr_SetString(PyExc_RuntimeError, "Failed to parse the OperatorDef."); PyErr_SetString(PyExc_RuntimeError,
"Failed to parse the OperatorDef.");
return nullptr; return nullptr;
} }
ws()->CreatePersistentOp(op_def); ws()->CreatePersistentOp(op_def);
...@@ -82,8 +87,10 @@ inline PyObject* CreatePersistentOpCC(PyObject* self, PyObject* args) { ...@@ -82,8 +87,10 @@ inline PyObject* CreatePersistentOpCC(PyObject* self, PyObject* args) {
inline PyObject* RunPersistentOpCC(PyObject* self, PyObject* args) { inline PyObject* RunPersistentOpCC(PyObject* self, PyObject* args) {
char* key, *anchor; char* key, *anchor;
PyObject* py_inputs, *py_outputs; PyObject* py_inputs, *py_outputs;
if (!PyArg_ParseTuple(args, "ssOO", &key, &anchor, &py_inputs, &py_outputs)) { if (!PyArg_ParseTuple(args, "ssOO",
PyErr_SetString(PyExc_ValueError, "Excepted a persistent key, anchor, " &key, &anchor, &py_inputs, &py_outputs)) {
PyErr_SetString(PyExc_ValueError,
"Excepted a persistent key, anchor, "
"list of inputs and outputs."); "list of inputs and outputs.");
return nullptr; return nullptr;
} }
......
...@@ -39,12 +39,14 @@ inline PyObject* CreateTensorCC(PyObject* self, PyObject* args) { ...@@ -39,12 +39,14 @@ inline PyObject* CreateTensorCC(PyObject* self, PyObject* args) {
inline PyObject* CreateFillerCC(PyObject* self, PyObject* args) { inline PyObject* CreateFillerCC(PyObject* self, PyObject* args) {
PyObject* filler_string; PyObject* filler_string;
if (!PyArg_ParseTuple(args, "S", &filler_string)) { if (!PyArg_ParseTuple(args, "S", &filler_string)) {
PyErr_SetString(PyExc_ValueError, "Excepted a serialized string of TensorFiller."); PyErr_SetString(PyExc_ValueError,
"Excepted a serialized string of TensorFiller.");
return nullptr; return nullptr;
} }
TensorFiller filler_def; TensorFiller filler_def;
if (!filler_def.ParseFromString(PyBytes_AsStringEx(filler_string))) { if (!filler_def.ParseFromString(PyBytes_AsStringEx(filler_string))) {
PyErr_SetString(PyExc_RuntimeError, "Failed to parse the TensorFiller."); PyErr_SetString(PyExc_RuntimeError,
"Failed to parse the TensorFiller.");
return nullptr; return nullptr;
} }
ws()->CreateFiller(filler_def); ws()->CreateFiller(filler_def);
...@@ -60,7 +62,8 @@ inline PyObject* GetFillerTypeCC(PyObject* self, PyObject* args) { ...@@ -60,7 +62,8 @@ inline PyObject* GetFillerTypeCC(PyObject* self, PyObject* args) {
inline PyObject* RenameTensorCC(PyObject* self, PyObject* args) { inline PyObject* RenameTensorCC(PyObject* self, PyObject* args) {
char* ori_name, *tar_name; char* ori_name, *tar_name;
if (!PyArg_ParseTuple(args, "ss", &ori_name, &tar_name)) { if (!PyArg_ParseTuple(args, "ss", &ori_name, &tar_name)) {
PyErr_SetString(PyExc_ValueError, "Excepted the original and target name."); PyErr_SetString(PyExc_ValueError,
"Excepted the original and target name.");
return nullptr; return nullptr;
} }
if (!ws()->HasTensor(tar_name)) { if (!ws()->HasTensor(tar_name)) {
...@@ -77,7 +80,8 @@ PyObject* TensorFromShapeCC(PyObject* self, PyObject* args) { ...@@ -77,7 +80,8 @@ PyObject* TensorFromShapeCC(PyObject* self, PyObject* args) {
char* cname, *dtype; char* cname, *dtype;
PyObject* shape, *device_option = nullptr; PyObject* shape, *device_option = nullptr;
if (!PyArg_ParseTuple(args, "sOs|O", &cname, &shape, &dtype, &device_option)) { if (!PyArg_ParseTuple(args, "sOs|O", &cname, &shape, &dtype, &device_option)) {
PyErr_SetString(PyExc_ValueError, "Excepted the name, shape, dtype and optional device option."); PyErr_SetString(PyExc_ValueError,
"Excepted the name, shape, dtype and optional device option.");
return nullptr; return nullptr;
} }
const TypeMeta& meta = TypeStringToMeta(dtype); const TypeMeta& meta = TypeStringToMeta(dtype);
...@@ -119,7 +123,8 @@ PyObject* TensorFromPyArrayCC(PyObject* self, PyObject* args) { ...@@ -119,7 +123,8 @@ PyObject* TensorFromPyArrayCC(PyObject* self, PyObject* args) {
char* cname; char* cname;
PyArrayObject* original_array = nullptr; PyArrayObject* original_array = nullptr;
if (!PyArg_ParseTuple(args, "sO", &cname, &original_array)) { if (!PyArg_ParseTuple(args, "sO", &cname, &original_array)) {
PyErr_SetString(PyExc_ValueError, "Failed to create tensor from numpy.ndarray.\n" PyErr_SetString(PyExc_ValueError,
"Failed to create tensor from numpy.ndarray.\n"
"Excepted the name and numpy.ndarray both."); "Excepted the name and numpy.ndarray both.");
return nullptr; return nullptr;
} }
...@@ -214,7 +219,8 @@ inline PyObject* TensorToPyArrayCC(PyObject* self, PyObject* args) { ...@@ -214,7 +219,8 @@ inline PyObject* TensorToPyArrayCC(PyObject* self, PyObject* args) {
return nullptr; return nullptr;
} }
auto* data = tensor->raw_mutable_data<CPUContext>(); auto* data = tensor->raw_mutable_data<CPUContext>();
PyObject* array = PyArray_SimpleNewFromData(tensor->ndim(), dims.data(), npy_type, data); PyObject* array = PyArray_SimpleNewFromData(
(int)tensor->ndim(), dims.data(), npy_type, data);
Py_XINCREF(array); Py_XINCREF(array);
return array; return array;
} }
......
...@@ -30,6 +30,8 @@ class BlobFetcher(Process): ...@@ -30,6 +30,8 @@ class BlobFetcher(Process):
---------- ----------
batch_size : int batch_size : int
The size of a training batch. The size of a training batch.
dtype : str
The data type of batch. Default is ``float32``.
partition : boolean partition : boolean
Whether to partition batch. Default is ``False``. Whether to partition batch. Default is ``False``.
prefetch : int prefetch : int
...@@ -42,6 +44,7 @@ class BlobFetcher(Process): ...@@ -42,6 +44,7 @@ class BlobFetcher(Process):
""" """
super(BlobFetcher, self).__init__() super(BlobFetcher, self).__init__()
self._batch_size = kwargs.get('batch_size', 100) self._batch_size = kwargs.get('batch_size', 100)
self._dtype = kwargs.get('dtype', 'float32')
self._partition = kwargs.get('partition', False) self._partition = kwargs.get('partition', False)
self._mean_values = kwargs.get('mean_values', []) self._mean_values = kwargs.get('mean_values', [])
self._scale = kwargs.get('scale', 1.0) self._scale = kwargs.get('scale', 1.0)
...@@ -68,7 +71,7 @@ class BlobFetcher(Process): ...@@ -68,7 +71,7 @@ class BlobFetcher(Process):
if ix != self._batch_size - 1: im, labels = self.Q_in.get() if ix != self._batch_size - 1: im, labels = self.Q_in.get()
# mean subtraction & numerical scale # mean subtraction & numerical scale
im_blob = im_blob.astype(np.float32) im_blob = im_blob.astype(self._dtype)
if len(self._mean_values) > 0: if len(self._mean_values) > 0:
im_blob -= self._mean_values im_blob -= self._mean_values
if self._scale != 1.0: if self._scale != 1.0:
......
...@@ -70,6 +70,8 @@ class DataBatch(object): ...@@ -70,6 +70,8 @@ class DataBatch(object):
The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``. The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``.
batch_size : int batch_size : int
The size of a training batch. The size of a training batch.
dtype : str
The data type of batch. Default is ``float32``.
partition : boolean partition : boolean
Whether to partition batch. Default is ``False``. Whether to partition batch. Default is ``False``.
prefetch : int prefetch : int
......
...@@ -49,16 +49,14 @@ class DataReader(Process): ...@@ -49,16 +49,14 @@ class DataReader(Process):
self._source = kwargs.get('source', '') self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False) self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False) self._use_shuffle = kwargs.get('shuffle', False)
self._use_instance_chunk = kwargs.get('instance_chunk', False)
self._num_chunks = kwargs.get('num_chunks', 2048) self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1) self._chunk_size = kwargs.get('chunk_size', -1)
self._num_parts = 1 self._part_idx, self._num_parts = 0, 1
self._part_idx = 0 self._cur_idx, self._cur_chunk_idx = 0, 0
self._random_seed = config.GetRandomSeed() self._random_seed = config.GetRandomSeed()
self._cur_idx = 0
self._cur_chunk_idx = 0
self.Q_out = None self.Q_out = None
self.daemon = True self.daemon = True
...@@ -167,12 +165,13 @@ class DataReader(Process): ...@@ -167,12 +165,13 @@ class DataReader(Process):
self._db.open(self._source) self._db.open(self._source)
self._zfill = self._db.zfill() self._zfill = self._db.zfill()
self._num_entries = self._db.num_entries() self._num_entries = self._db.num_entries()
self._epoch_size = int(self._num_entries / self._num_parts + 1) self._epoch_size = int(self._num_entries/ self._num_parts + 1)
if self._use_shuffle: if self._use_shuffle:
if self._chunk_size == 1: if self._chunk_size == 1:
# each chunk has at most 1 record [For Fully Shuffle] # each chunk has at most 1 record [For Fully Shuffle]
self._num_shuffle_parts = int(self._num_entries / self._chunk_size / self._num_parts) + 1 self._chunk_size, self._num_shuffle_parts = \
1, int(self._num_entries / self._num_parts) + 1
else: else:
if self._use_shuffle and self._chunk_size == -1: if self._use_shuffle and self._chunk_size == -1:
# search a optimal chunk size by chunks [For Chunk Shuffle] # search a optimal chunk size by chunks [For Chunk Shuffle]
...@@ -183,6 +182,11 @@ class DataReader(Process): ...@@ -183,6 +182,11 @@ class DataReader(Process):
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 / self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20))) (self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1) self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1)
limit = (self._num_parts - 0.5) * self._num_shuffle_parts * self._chunk_size
if self._num_entries <= limit:
# roll back to fully shuffle
self._chunk_size, self._num_shuffle_parts = \
1, int(self._num_entries / self._num_parts) + 1
else: else:
# each chunk has at most K records [For Multiple Nodes] # each chunk has at most K records [For Multiple Nodes]
# note that if ``shuffle`` and ``multiple_nodes`` are all ``False``, # note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
......
...@@ -14,7 +14,7 @@ from __future__ import division ...@@ -14,7 +14,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
version = '0.2.2' version = '0.2.2'
full_version = '0.2.2.10' full_version = '0.2.2.11'
release = False release = False
if not release: if not release:
......
...@@ -364,7 +364,7 @@ class BatchNormLayer(Layer): ...@@ -364,7 +364,7 @@ class BatchNormLayer(Layer):
var = Tensor(scope + '/param:1').Constant(value=0.0) var = Tensor(scope + '/param:1').Constant(value=0.0)
factor = Tensor(scope + '/param:2').Constant(value=0.0) factor = Tensor(scope + '/param:2').Constant(value=0.0)
# in dragon, set diff as None will ignore computing grad automatically # in dragon, set diff as None will ignore computing grad automatically
# but in bvlc-caffe1, you must set lr_mult = 0 manually # but in bvlc-caffe, you must set lr_mult = 0 manually
self._blobs.append({'data': mean, 'diff': None}) self._blobs.append({'data': mean, 'diff': None})
self._blobs.append({'data': var, 'diff': None}) self._blobs.append({'data': var, 'diff': None})
self._blobs.append({'data': factor, 'diff': None}) self._blobs.append({'data': factor, 'diff': None})
......
...@@ -20,7 +20,7 @@ from .arithmetic import ( ...@@ -20,7 +20,7 @@ from .arithmetic import (
from .ndarray import ( from .ndarray import (
squeeze, unsqueeze, squeeze, unsqueeze,
sum, mean, argmin, argmax, max, topk, sum, mean, argmin, argmax, max, min, topk,
cat, gather, cat, gather,
) )
......
...@@ -13,7 +13,6 @@ from __future__ import absolute_import ...@@ -13,7 +13,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.torch.tensor import Tensor from dragon.vm.torch.tensor import Tensor
from dragon.vm.torch.ops.primitive import MakeContext, WrapScalar from dragon.vm.torch.ops.primitive import MakeContext, WrapScalar
from dragon.vm.torch.ops.factory import get_module from dragon.vm.torch.ops.factory import get_module
...@@ -26,7 +25,6 @@ def _fundamental(input, value, op='Add', out=None): ...@@ -26,7 +25,6 @@ def _fundamental(input, value, op='Add', out=None):
raise TypeError('Type of value should be numerical, got {}.' raise TypeError('Type of value should be numerical, got {}.'
.format(type(value))) .format(type(value)))
value = WrapScalar(value, input._dtype, input._ctx) value = WrapScalar(value, input._dtype, input._ctx)
ctx = MakeContext(inputs=[input, value]) ctx = MakeContext(inputs=[input, value])
key = 'torch/ops/{}/{}:{}'.format(op.lower(), ctx[0].lower(), ctx[1]) key = 'torch/ops/{}/{}:{}'.format(op.lower(), ctx[0].lower(), ctx[1])
module = get_module(Fundamental, key, ctx, op_type=op) module = get_module(Fundamental, key, ctx, op_type=op)
......
...@@ -13,7 +13,7 @@ from __future__ import absolute_import ...@@ -13,7 +13,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.torch.utils.data.io.data_reader import DataReader from dragon.io.data_reader import DataReader
from dragon.vm.torch.utils.data.io.data_transformer import DataTransformer from dragon.vm.torch.utils.data.io.data_transformer import DataTransformer
......
...@@ -19,7 +19,7 @@ from multiprocessing import Queue ...@@ -19,7 +19,7 @@ from multiprocessing import Queue
import dragon.core.mpi as mpi import dragon.core.mpi as mpi
from .data_reader import DataReader from dragon.io.data_reader import DataReader
from .data_transformer import DataTransformer from .data_transformer import DataTransformer
from .blob_fetcher import BlobFetcher from .blob_fetcher import BlobFetcher
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import numpy.random as npr
from multiprocessing import Process
import dragon.config as config
from dragon.tools.db import LMDB
class DataReader(Process):
"""DataReader is deployed to queue encoded str from `LMDB`_.
It is supported to adaptively partition and shuffle records over all distributed nodes.
"""
def __init__(self, **kwargs):
"""Construct a ``DataReader``.
Parameters
----------
source : str
The path of database.
multiple_nodes: boolean
Whether to split data for multiple parallel nodes. Default is ``False``.
shuffle : boolean
Whether to shuffle the data. Default is ``False``.
num_chunks : int
The number of chunks to split. Default is ``2048``.
chunk_size : int
The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
"""
super(DataReader, self).__init__()
self._source = kwargs.get('source', '')
self._multiple_nodes = kwargs.get('multiple_nodes', False)
self._use_shuffle = kwargs.get('shuffle', False)
self._num_chunks = kwargs.get('num_chunks', 2048)
self._chunk_size = kwargs.get('chunk_size', -1)
self._num_parts = 1
self._part_idx = 0
self._random_seed = config.GetRandomSeed()
self._cur_idx = 0
self._cur_chunk_idx = 0
self.Q_out = None
self.daemon = True
def element(self):
"""Get the value of current record.
Returns
-------
str
The encoded str.
"""
return self._db.value()
def redirect(self, target_idx):
"""Redirect to the target position.
Parameters
----------
target_idx : int
The key of instance in ``LMDB``.
Returns
-------
None
Notes
-----
The redirection reopens the ``LMDB``.
You can drop caches by ``echo 3 > /proc/sys/vm/drop_caches``.
This will disturb getting stuck when ``Database Size`` >> ``RAM Size``.
"""
self._db.close()
self._db.open(self._source)
self._cur_idx = target_idx
self._db.set(str(self._cur_idx).zfill(self._zfill))
def reset(self):
"""Reset the cursor and environment.
Returns
-------
None
"""
if self._multiple_nodes or self._use_shuffle:
if self._use_shuffle: self._perm = npr.permutation(self._num_shuffle_parts)
self._cur_chunk_idx = 0
self._start_idx = int(self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx])
self._start_idx = int(self._start_idx * self._chunk_size)
if self._start_idx >= self._num_entries: self.next_chunk()
self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._num_entries, self._end_idx)
else:
self._start_idx = 0
self._end_idx = self._num_entries
self.redirect(self._start_idx)
def next_record(self):
"""Step the cursor of records.
Returns
-------
None
"""
self._cur_idx += 1
self._db.next()
def next_chunk(self):
"""Step the cursor of shuffling chunks.
Returns
-------
None
"""
self._cur_chunk_idx += 1
if self._cur_chunk_idx >= self._num_shuffle_parts: self.reset()
else:
self._start_idx = self._part_idx * self._num_shuffle_parts + self._perm[self._cur_chunk_idx]
self._start_idx = self._start_idx * self._chunk_size
if self._start_idx >= self._num_entries: self.next_chunk()
else:
self._end_idx = self._start_idx + self._chunk_size
self._end_idx = min(self._num_entries, self._end_idx)
self.redirect(self._start_idx)
def run(self):
"""Start the process.
Returns
-------
None
"""
# fix seed
npr.seed(self._random_seed)
# init db
self._db = LMDB()
self._db.open(self._source)
self._zfill = self._db.zfill()
self._num_entries = self._db.num_entries()
self._epoch_size = int(self._num_entries / self._num_parts + 1)
if self._use_shuffle:
if self._chunk_size == 1:
# each chunk has at most 1 record [For Fully Shuffle]
self._num_shuffle_parts = int(self._num_entries / self._chunk_size / self._num_parts) + 1
else:
if self._use_shuffle and self._chunk_size == -1:
# search a optimal chunk size by chunks [For Chunk Shuffle]
max_chunk_size = self._db._total_size / ((self._num_chunks * (1 << 20)))
min_chunk_size = 1
while min_chunk_size * 2 < max_chunk_size: min_chunk_size *= 2
self._chunk_size = min_chunk_size
self._num_shuffle_parts = int(math.ceil(self._db._total_size * 1.1 /
(self._num_parts * self._chunk_size << 20)))
self._chunk_size = int(self._num_entries / self._num_shuffle_parts / self._num_parts + 1)
else:
# each chunk has at most K records [For Multiple Nodes]
# note that if ``shuffle`` and ``multiple_nodes`` are all ``False``,
# ``chunk_size`` and ``num_shuffle_parts`` are meaningless
self._chunk_size = int(self._num_entries / self._num_parts) + 1
self._num_shuffle_parts = 1
self._perm = np.arange(self._num_shuffle_parts)
# init env
self.reset()
# run
while True:
self.Q_out.put(self.element())
self.next_record()
if self._cur_idx >= self._end_idx:
if self._multiple_nodes or \
self._use_shuffle: self.next_chunk()
else: self.reset()
\ No newline at end of file
...@@ -42,7 +42,7 @@ find_modules() ...@@ -42,7 +42,7 @@ find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.2.10', version='0.2.2.11',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/seetaresearch/Dragon', url='https://github.com/seetaresearch/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -19,7 +19,8 @@ template <> void GenerateProposals<float, CPUContext>( ...@@ -19,7 +19,8 @@ template <> void GenerateProposals<float, CPUContext>(
const float* scores, const float* scores,
const float* bbox_deltas, const float* bbox_deltas,
const float* anchors, const float* anchors,
float* proposals) { float* proposals,
CPUContext* ctx) {
float* proposal = proposals; float* proposal = proposals;
const int K = feat_h * feat_w; const int K = feat_h * feat_w;
for (int h = 0; h < feat_h; ++h) { for (int h = 0; h < feat_h; ++h) {
...@@ -57,7 +58,8 @@ template <> void GenerateProposals_v2<float, CPUContext>( ...@@ -57,7 +58,8 @@ template <> void GenerateProposals_v2<float, CPUContext>(
const float min_box_w, const float min_box_w,
const float* scores, const float* scores,
const float* bbox_deltas, const float* bbox_deltas,
float* proposals) { float* proposals,
CPUContext* ctx) {
float* proposal = proposals; float* proposal = proposals;
for (int i = 0; i < total_anchors; ++i) { for (int i = 0; i < total_anchors; ++i) {
// bbox_deltas: [1, 4, total_anchors] // bbox_deltas: [1, 4, total_anchors]
...@@ -98,7 +100,8 @@ template <> void ApplyNMS<float, CPUContext>( ...@@ -98,7 +100,8 @@ template <> void ApplyNMS<float, CPUContext>(
const float thresh, const float thresh,
const float* boxes, const float* boxes,
int* keep_indices, int* keep_indices,
int& num_keep) { int& num_keep,
CPUContext* ctx) {
int count = 0; int count = 0;
std::vector<char> is_dead(num_boxes); std::vector<char> is_dead(num_boxes);
for (int i = 0; i < num_boxes; ++i) is_dead[i] = 0; for (int i = 0; i < num_boxes; ++i) is_dead[i] = 0;
......
...@@ -62,7 +62,7 @@ __global__ void _GenerateProposals( ...@@ -62,7 +62,7 @@ __global__ void _GenerateProposals(
const T* bbox_deltas, const T* bbox_deltas,
const T* anchors, const T* anchors,
T* proposals) { T* proposals) {
CUDA_KERNEL_LOOP(idx, nthreads) { CUDA_1D_KERNEL_LOOP(idx, nthreads) {
const int h = idx / A / feat_w; const int h = idx / A / feat_w;
const int w = (idx / A) % feat_w; const int w = (idx / A) % feat_w;
const int a = idx % A; const int a = idx % A;
...@@ -99,11 +99,13 @@ template <> void GenerateProposals<float, CUDAContext>( ...@@ -99,11 +99,13 @@ template <> void GenerateProposals<float, CUDAContext>(
const float* scores, const float* scores,
const float* bbox_deltas, const float* bbox_deltas,
const float* anchors, const float* anchors,
float* proposals) { float* proposals,
CUDAContext* ctx) {
const int num_proposals = A * feat_h * feat_w; const int num_proposals = A * feat_h * feat_w;
_GenerateProposals<float> _GenerateProposals<float>
<< <CUDA_BLOCKS(num_proposals), CUDA_THREADS >> >( << < CUDA_BLOCKS(num_proposals), CUDA_THREADS,
num_proposals, A, feat_h, feat_w, stride, 0, ctx->cuda_stream() >> >(num_proposals,
A, feat_h, feat_w, stride,
im_h, im_w, min_box_h, min_box_w, im_h, im_w, min_box_h, min_box_w,
scores, bbox_deltas, anchors, proposals); scores, bbox_deltas, anchors, proposals);
} }
...@@ -118,7 +120,7 @@ __global__ void _GenerateProposals_v2( ...@@ -118,7 +120,7 @@ __global__ void _GenerateProposals_v2(
const T* scores, const T* scores,
const T* bbox_deltas, const T* bbox_deltas,
T* proposals) { T* proposals) {
CUDA_KERNEL_LOOP(idx, nthreads) { CUDA_1D_KERNEL_LOOP(idx, nthreads) {
const float dx = bbox_deltas[idx]; const float dx = bbox_deltas[idx];
const float dy = bbox_deltas[nthreads + idx]; const float dy = bbox_deltas[nthreads + idx];
const float d_log_w = bbox_deltas[2 * nthreads + idx]; const float d_log_w = bbox_deltas[2 * nthreads + idx];
...@@ -139,10 +141,12 @@ template <> void GenerateProposals_v2<float, CUDAContext>( ...@@ -139,10 +141,12 @@ template <> void GenerateProposals_v2<float, CUDAContext>(
const float min_box_w, const float min_box_w,
const float* scores, const float* scores,
const float* bbox_deltas, const float* bbox_deltas,
float* proposals) { float* proposals,
CUDAContext* ctx) {
_GenerateProposals_v2<float> _GenerateProposals_v2<float>
<< <CUDA_BLOCKS(total_anchors), CUDA_THREADS >> >( << < CUDA_BLOCKS(total_anchors), CUDA_THREADS,
total_anchors, im_h, im_w, min_box_h, min_box_w, 0, ctx->cuda_stream() >> >(total_anchors,
im_h, im_w, min_box_h, min_box_w,
scores, bbox_deltas, proposals); scores, bbox_deltas, proposals);
} }
...@@ -170,7 +174,7 @@ __global__ void nms_mask( ...@@ -170,7 +174,7 @@ __global__ void nms_mask(
const int num_boxes, const int num_boxes,
const T nms_thresh, const T nms_thresh,
const T* boxes, const T* boxes,
unsigned long long* mask) { uint64_t* mask) {
const int i_start = blockIdx.x * NMS_BLOCK_SIZE; const int i_start = blockIdx.x * NMS_BLOCK_SIZE;
const int di_end = min(num_boxes - i_start, NMS_BLOCK_SIZE); const int di_end = min(num_boxes - i_start, NMS_BLOCK_SIZE);
const int j_start = blockIdx.y * NMS_BLOCK_SIZE; const int j_start = blockIdx.y * NMS_BLOCK_SIZE;
...@@ -209,25 +213,30 @@ void _ApplyNMS( ...@@ -209,25 +213,30 @@ void _ApplyNMS(
const float thresh, const float thresh,
const T* boxes, const T* boxes,
int* keep_indices, int* keep_indices,
int& num_keep) { int& num_keep,
CUDAContext* ctx) {
const int num_blocks = DIV_UP(num_boxes, NMS_BLOCK_SIZE); const int num_blocks = DIV_UP(num_boxes, NMS_BLOCK_SIZE);
const dim3 blocks(num_blocks, num_blocks); const dim3 blocks(num_blocks, num_blocks);
size_t mask_nbytes = num_boxes * num_blocks * sizeof(unsigned long long); size_t mask_nbytes = num_boxes * num_blocks * sizeof(uint64_t);
size_t boxes_nbytes = num_boxes * 5 * sizeof(T); size_t boxes_nbytes = num_boxes * 5 * sizeof(T);
void* boxes_dev, *mask_dev; void* boxes_dev, *mask_dev;
CUDA_CHECK(cudaMalloc(&boxes_dev, boxes_nbytes)); CUDA_CHECK(cudaMalloc(&boxes_dev, boxes_nbytes));
CUDA_CHECK(cudaMalloc(&mask_dev, mask_nbytes)); CUDA_CHECK(cudaMalloc(&mask_dev, mask_nbytes));
CUDA_CHECK(cudaMemcpy(boxes_dev, boxes, boxes_nbytes, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMemcpy(boxes_dev, boxes,
nms_mask<T> << <blocks, NMS_BLOCK_SIZE >> > ( boxes_nbytes, cudaMemcpyHostToDevice));
num_boxes, thresh, (T*)boxes_dev, (unsigned long long*)mask_dev); nms_mask<T>
<< < blocks, NMS_BLOCK_SIZE,
0, ctx->cuda_stream() >> > (num_boxes,
thresh, (T*)boxes_dev, (uint64_t*)mask_dev);
CUDA_CHECK(cudaPeekAtLastError()); CUDA_CHECK(cudaPeekAtLastError());
std::vector<unsigned long long> mask_host(num_boxes * num_blocks); std::vector<uint64_t> mask_host(num_boxes * num_blocks);
CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev, mask_nbytes, cudaMemcpyDeviceToHost)); CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev,
mask_nbytes, cudaMemcpyDeviceToHost));
std::vector<unsigned long long> dead_bit(num_blocks); std::vector<uint64_t> dead_bit(num_blocks);
memset(&dead_bit[0], 0, sizeof(unsigned long long) * num_blocks); memset(&dead_bit[0], 0, sizeof(uint64_t) * num_blocks);
int num_selected = 0; int num_selected = 0;
for (int i = 0; i < num_boxes; ++i) { for (int i = 0; i < num_boxes; ++i) {
...@@ -235,7 +244,7 @@ void _ApplyNMS( ...@@ -235,7 +244,7 @@ void _ApplyNMS(
const int inblock = i % NMS_BLOCK_SIZE; const int inblock = i % NMS_BLOCK_SIZE;
if (!(dead_bit[nblock] & (1ULL << inblock))) { if (!(dead_bit[nblock] & (1ULL << inblock))) {
keep_indices[num_selected++] = i; keep_indices[num_selected++] = i;
unsigned long long* mask_i = &mask_host[0] + i * num_blocks; uint64_t* mask_i = &mask_host[0] + i * num_blocks;
for (int j = nblock; j < num_blocks; ++j) dead_bit[j] |= mask_i[j]; for (int j = nblock; j < num_blocks; ++j) dead_bit[j] |= mask_i[j];
if (num_selected == max_keeps) break; if (num_selected == max_keeps) break;
} }
...@@ -251,9 +260,10 @@ template <> void ApplyNMS<float, CUDAContext>( ...@@ -251,9 +260,10 @@ template <> void ApplyNMS<float, CUDAContext>(
const float thresh, const float thresh,
const float* boxes, const float* boxes,
int* keep_indices, int* keep_indices,
int& num_keep) { int& num_keep,
CUDAContext* ctx) {
_ApplyNMS<float>(num_boxes, max_keeps, thresh, _ApplyNMS<float>(num_boxes, max_keeps, thresh,
boxes, keep_indices, num_keep); boxes, keep_indices, num_keep, ctx);
} }
} // namespace rcnn } // namespace rcnn
......
...@@ -126,7 +126,8 @@ void GenerateProposals( ...@@ -126,7 +126,8 @@ void GenerateProposals(
const T* scores, const T* scores,
const T* bbox_deltas, const T* bbox_deltas,
const T* anchors, const T* anchors,
T* proposals); T* proposals,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void GenerateProposals_v2( void GenerateProposals_v2(
...@@ -137,7 +138,8 @@ void GenerateProposals_v2( ...@@ -137,7 +138,8 @@ void GenerateProposals_v2(
const float min_box_w, const float min_box_w,
const T* scores, const T* scores,
const T* bbox_deltas, const T* bbox_deltas,
T* proposals); T* proposals,
Context* ctx);
template <typename T> template <typename T>
inline void SortProposals( inline void SortProposals(
...@@ -246,7 +248,8 @@ void ApplyNMS( ...@@ -246,7 +248,8 @@ void ApplyNMS(
const T thresh, const T thresh,
const T* boxes, const T* boxes,
int* keep_indices, int* keep_indices,
int& num_keep); int& num_keep,
Context* ctx);
} // namespace rcnn } // namespace rcnn
......
...@@ -37,7 +37,7 @@ void ProposalOp<Context>::RunWithType() { ...@@ -37,7 +37,7 @@ void ProposalOp<Context>::RunWithType() {
Input(0).template data<T, Context>(), Input(0).template data<T, Context>(),
Input(1).template data<T, Context>(), Input(1).template data<T, Context>(),
anchors_.template mutable_data<T, Context>(), anchors_.template mutable_data<T, Context>(),
proposals_.template mutable_data<T, Context>()); proposals_.template mutable_data<T, Context>(), ctx());
rcnn::SortProposals(0, num_proposals - 1, pre_nms_top_n, rcnn::SortProposals(0, num_proposals - 1, pre_nms_top_n,
proposals_.template mutable_data<T, CPUContext>()); proposals_.template mutable_data<T, CPUContext>());
...@@ -45,7 +45,8 @@ void ProposalOp<Context>::RunWithType() { ...@@ -45,7 +45,8 @@ void ProposalOp<Context>::RunWithType() {
rcnn::ApplyNMS<T, Context>( rcnn::ApplyNMS<T, Context>(
pre_nms_topn, post_nms_top_n, nms_thresh, pre_nms_topn, post_nms_top_n, nms_thresh,
proposals_.template mutable_data<T, Context>(), proposals_.template mutable_data<T, Context>(),
roi_indices_.template mutable_data<int, CPUContext>(), num_rois); roi_indices_.template mutable_data<int, CPUContext>(),
num_rois, ctx());
rcnn::RetrieveRoIs<T>(num_rois, n, rcnn::RetrieveRoIs<T>(num_rois, n,
proposals_.template mutable_data<T, CPUContext>(), proposals_.template mutable_data<T, CPUContext>(),
...@@ -95,14 +96,15 @@ void ProposalOp<Context>::RunWithType() { ...@@ -95,14 +96,15 @@ void ProposalOp<Context>::RunWithType() {
im_height, im_width, min_box_h, min_box_w, im_height, im_width, min_box_h, min_box_w,
Input(-3).template data<T, Context>(), Input(-3).template data<T, Context>(),
Input(-2).template data<T, Context>(), Input(-2).template data<T, Context>(),
proposals_.template mutable_data<T, Context>()); proposals_.template mutable_data<T, Context>(), ctx());
rcnn::SortProposals(0, total_proposals - 1, pre_nms_top_n, rcnn::SortProposals(0, total_proposals - 1, pre_nms_top_n,
proposals_.template mutable_data<T, CPUContext>()); proposals_.template mutable_data<T, CPUContext>());
rcnn::ApplyNMS<T, Context>(pre_nms_topn, post_nms_top_n, nms_thresh, rcnn::ApplyNMS<T, Context>(pre_nms_topn, post_nms_top_n, nms_thresh,
proposals_.template mutable_data<T, Context>(), proposals_.template mutable_data<T, Context>(),
roi_indices_.template mutable_data<int, CPUContext>(), num_rois); roi_indices_.template mutable_data<int, CPUContext>(),
num_rois, ctx());
rcnn::RetrieveRoIs<T>(num_rois, n, rcnn::RetrieveRoIs<T>(num_rois, n,
proposals_.template mutable_data<T, CPUContext>(), proposals_.template mutable_data<T, CPUContext>(),
...@@ -128,7 +130,7 @@ void ProposalOp<Context>::RunWithType() { ...@@ -128,7 +130,7 @@ void ProposalOp<Context>::RunWithType() {
collective_rois.ReshapeLike(*Output(0)); collective_rois.ReshapeLike(*Output(0));
auto* rois = collective_rois.template mutable_data<T, CPUContext>(); auto* rois = collective_rois.template mutable_data<T, CPUContext>();
CPUContext::template Copy<T, CPUContext, CPUContext>( ctx()->template Copy<T, CPUContext, CPUContext>(
collective_rois.count(), rois, collective_rois.count(), rois,
Output(0)->template data<T, CPUContext>()); Output(0)->template data<T, CPUContext>());
...@@ -147,6 +149,8 @@ void ProposalOp<Context>::RunWithType() { ...@@ -147,6 +149,8 @@ void ProposalOp<Context>::RunWithType() {
template <class Context> template <class Context>
void ProposalOp<Context>::RunOnDevice() { void ProposalOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
num_images = Input(0).dim(0); num_images = Input(0).dim(0);
CHECK_EQ(Input(-1).dim(0), num_images) CHECK_EQ(Input(-1).dim(0), num_images)
<< "\nExcepted " << num_images << " groups image info, " << "\nExcepted " << num_images << " groups image info, "
......
...@@ -455,7 +455,10 @@ Graph::Graph(const GraphDef& meta_graph, Workspace* ws) ...@@ -455,7 +455,10 @@ Graph::Graph(const GraphDef& meta_graph, Workspace* ws)
RecomputingAware(optimized_graph, ws); RecomputingAware(optimized_graph, ws);
} }
bool Graph::Run(const string& include, const string& exclude) { bool Graph::Run(
const string& include,
const string& exclude,
const int stream_id) {
LOG(DEBUG) << "Run Graph: " << name(); LOG(DEBUG) << "Run Graph: " << name();
for (auto op : ops_) { for (auto op : ops_) {
if (!include.empty()) if (!include.empty())
...@@ -464,7 +467,7 @@ bool Graph::Run(const string& include, const string& exclude) { ...@@ -464,7 +467,7 @@ bool Graph::Run(const string& include, const string& exclude) {
if (op->type().find(exclude) != string::npos) continue; if (op->type().find(exclude) != string::npos) continue;
op->SwitchToPhase(this->args_["phase"].s()); op->SwitchToPhase(this->args_["phase"].s());
LOG(DEBUG) << "$ Before Operator: " << op->name(); LOG(DEBUG) << "$ Before Operator: " << op->name();
op->Run(); op->Run(stream_id);
LOG(DEBUG) << "$ After Operator: " << op->name(); LOG(DEBUG) << "$ After Operator: " << op->name();
} }
return true; return true;
......
...@@ -8,7 +8,6 @@ void MixedMemory::ToCPU() { ...@@ -8,7 +8,6 @@ void MixedMemory::ToCPU() {
switch (state_) { switch (state_) {
case UNINITIALIZED: case UNINITIALIZED:
cpu_ptr_ = CPUContext::New(nbytes_); cpu_ptr_ = CPUContext::New(nbytes_);
CPUContext::Memset(nbytes_, cpu_ptr_);
state_ = STATE_AT_CPU; state_ = STATE_AT_CPU;
break; break;
case STATE_AT_CUDA: case STATE_AT_CUDA:
...@@ -32,7 +31,6 @@ void MixedMemory::ToCUDA() { ...@@ -32,7 +31,6 @@ void MixedMemory::ToCUDA() {
switch (state_) { switch (state_) {
case UNINITIALIZED: case UNINITIALIZED:
cuda_ptr_ = CUDAContext::New(nbytes_); cuda_ptr_ = CUDAContext::New(nbytes_);
CUDAContext::Memset(nbytes_, cuda_ptr_);
state_ = STATE_AT_CUDA; state_ = STATE_AT_CUDA;
break; break;
case STATE_AT_CPU: case STATE_AT_CPU:
......
...@@ -15,33 +15,35 @@ void CuDNNDropoutOp<Context>::RunWithType() { ...@@ -15,33 +15,35 @@ void CuDNNDropoutOp<Context>::RunWithType() {
float scale = use_scale ? 1.0 / (1.0 - prob()) : 1.0; float scale = use_scale ? 1.0 / (1.0 - prob()) : 1.0;
if (phase() == "TEST") { if (phase() == "TEST") {
if (Output(0) != &Input(0)) { if (Output(0) != &Input(0)) {
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), Ydata, Xdata); Output(0)->count(), Ydata, Xdata);
if (scale == 1.0) if (scale == 1.0)
math::Scal<T, Context>(Output(0)->count(), math::Scal<T, Context>(Output(0)->count(),
1.0 - prob(), Ydata, &ctx()); 1.0 - prob(), Ydata, ctx());
} }
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
CHECK(use_scale) << "\nCuDNN only supports scale-dropout"; CHECK(use_scale) << "\nCuDNN only supports scale-dropout";
Tensor* mask = ws()->CreateTensor("/mnt/" + anchor() + "/dropout/mask"); Tensor* mask = ws()->CreateTensor(
"/mnt/" + anchor() + "/dropout/mask");
// determine the dropout states // determine the dropout states
if (!states_initialized) { if (!states_initialized) {
states_initialized = true; states_initialized = true;
CUDNN_CHECK(cudnnDropoutGetStatesSize( CUDNN_CHECK(cudnnDropoutGetStatesSize(
ctx().cudnn_handle(), &states_size)); ctx()->cudnn_handle(), &states_size));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
Tensor* states = ws()->CreateTensor("/share/cudnn/dropout:" + Tensor* states = ws()->CreateTensor(
dragon_cast<string, unsigned long long>(random_seed) + "/states"); "/share/cudnn/dropout:" + dragon_cast<string,
unsigned long long>(random_seed) + "/states");
if (states->count() > 0) { if (states->count() > 0) {
auto* Sdata = states->template mutable_data<uint8_t, Context>(); auto* Sdata = states->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnRestoreDropoutDescriptor( CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc, ctx().cudnn_handle(), prob(), dropout_desc, ctx()->cudnn_handle(), prob(),
Sdata, states_size, random_seed)); Sdata, states_size, random_seed));
} else { } else {
states->Reshape({ (TIndex)states_size }); states->Reshape({ (TIndex)states_size });
auto* Sdata = states->template mutable_data<uint8_t, Context>(); auto* Sdata = states->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnSetDropoutDescriptor( CUDNN_CHECK(cudnnSetDropoutDescriptor(
dropout_desc, ctx().cudnn_handle(), prob(), dropout_desc, ctx()->cudnn_handle(), prob(),
Sdata, states_size, random_seed)); Sdata, states_size, random_seed));
} }
} }
...@@ -53,7 +55,7 @@ void CuDNNDropoutOp<Context>::RunWithType() { ...@@ -53,7 +55,7 @@ void CuDNNDropoutOp<Context>::RunWithType() {
mask->Reshape({ (TIndex)reserve_space_size }); mask->Reshape({ (TIndex)reserve_space_size });
auto* Rdata = mask->template mutable_data<uint8_t, Context>(); auto* Rdata = mask->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnDropoutForward( CUDNN_CHECK(cudnnDropoutForward(
ctx().cudnn_handle(), dropout_desc, ctx()->cudnn_handle(), dropout_desc,
input_desc, Xdata, input_desc, Xdata,
input_desc, Ydata, input_desc, Ydata,
Rdata, reserve_space_size)); Rdata, reserve_space_size));
...@@ -65,7 +67,9 @@ void CuDNNDropoutOp<Context>::RunOnDevice() { ...@@ -65,7 +67,9 @@ void CuDNNDropoutOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
#ifdef WITH_CUDA_FP16
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else if (XIsType(Input(0), float16)) RunWithType<float16>();
#endif
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
...@@ -76,19 +80,21 @@ void CuDNNDropoutGradientOp<Context>::RunWithType() { ...@@ -76,19 +80,21 @@ void CuDNNDropoutGradientOp<Context>::RunWithType() {
if (phase() == "TEST") { NOT_IMPLEMENTED; } if (phase() == "TEST") { NOT_IMPLEMENTED; }
else if (phase() == "TRAIN") { else if (phase() == "TRAIN") {
CHECK(use_scale) << "\nCuDNN only supports scale-dropout"; CHECK(use_scale) << "\nCuDNN only supports scale-dropout";
Tensor* mask = ws()->GetTensor("/mnt/" + anchor() + "/dropout/mask"); Tensor* mask = ws()->GetTensor(
"/mnt/" + anchor() + "/dropout/mask");
// determine the dropout states // determine the dropout states
if (!states_initialized) { if (!states_initialized) {
states_initialized = true; states_initialized = true;
CUDNN_CHECK(cudnnDropoutGetStatesSize( CUDNN_CHECK(cudnnDropoutGetStatesSize(
ctx().cudnn_handle(), &states_size)); ctx()->cudnn_handle(), &states_size));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
Tensor* states = ws()->CreateTensor("/share/cudnn/dropout:" + Tensor* states = ws()->CreateTensor(
dragon_cast<string, unsigned long long>(random_seed) + "/states"); "/share/cudnn/dropout:" + dragon_cast<string,
unsigned long long>(random_seed) + "/states");
if (states->count() > 0) { if (states->count() > 0) {
auto* Sdata = states->template mutable_data<uint8_t, Context>(); auto* Sdata = states->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnRestoreDropoutDescriptor( CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc, ctx().cudnn_handle(), prob(), dropout_desc, ctx()->cudnn_handle(), prob(),
Sdata, states_size, random_seed)); Sdata, states_size, random_seed));
} else { LOG(FATAL) << "Missing states with seed: " << random_seed; } } else { LOG(FATAL) << "Missing states with seed: " << random_seed; }
} }
...@@ -101,7 +107,7 @@ void CuDNNDropoutGradientOp<Context>::RunWithType() { ...@@ -101,7 +107,7 @@ void CuDNNDropoutGradientOp<Context>::RunWithType() {
input_desc, &reserve_space_size)); input_desc, &reserve_space_size));
auto* Rdata = mask->template mutable_data<uint8_t, Context>(); auto* Rdata = mask->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnDropoutBackward( CUDNN_CHECK(cudnnDropoutBackward(
ctx().cudnn_handle(), dropout_desc, ctx()->cudnn_handle(), dropout_desc,
input_desc, dYdata, input_desc, dYdata,
input_desc, dXdata, input_desc, dXdata,
Rdata, reserve_space_size)); Rdata, reserve_space_size));
...@@ -113,7 +119,9 @@ void CuDNNDropoutGradientOp<Context>::RunOnDevice() { ...@@ -113,7 +119,9 @@ void CuDNNDropoutGradientOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
#ifdef WITH_CUDA_FP16
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else if (XIsType(Input(0), float16)) RunWithType<float16>();
#endif
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
......
...@@ -14,7 +14,7 @@ void CuDNNEluOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void CuDNNEluOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
} }
...@@ -41,7 +41,7 @@ void CuDNNEluGradientOp<Context>::RunWithType() { ...@@ -41,7 +41,7 @@ void CuDNNEluGradientOp<Context>::RunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Ydata, input_desc, dYdata, output_desc, Ydata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
......
...@@ -13,7 +13,7 @@ void CuDNNReluOp<Context>::RunWithType() { ...@@ -13,7 +13,7 @@ void CuDNNReluOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
#else #else
...@@ -49,7 +49,7 @@ void CuDNNReluGradientOp<Context>::RunWithType() { ...@@ -49,7 +49,7 @@ void CuDNNReluGradientOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Ydata, input_desc, dYdata, output_desc, Ydata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
......
...@@ -13,12 +13,12 @@ void CuDNNSigmoidOp<Context>::RunWithType() { ...@@ -13,12 +13,12 @@ void CuDNNSigmoidOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
#else #else
CUDNN_CHECK(cudnnActivationForward_v4( CUDNN_CHECK(cudnnActivationForward_v4(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<Dtype>::one, input_desc, Xdata, CUDNNType<Dtype>::one, input_desc, Xdata,
CUDNNType<Dtype>::zero, output_desc, Ydata)); CUDNNType<Dtype>::zero, output_desc, Ydata));
#endif #endif
...@@ -47,13 +47,13 @@ void CuDNNSigmoidGradientOp<Context>::RunWithType() { ...@@ -47,13 +47,13 @@ void CuDNNSigmoidGradientOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Ydata, input_desc, dYdata, output_desc, Ydata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
#else #else
CUDNN_CHECK(cudnnActivationBackward_v4( CUDNN_CHECK(cudnnActivationBackward_v4(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Ydata, input_desc, dYdata, output_desc, Ydata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
......
...@@ -7,8 +7,7 @@ namespace dragon { ...@@ -7,8 +7,7 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNSoftmaxOp<Context>::RunWithType() { void CuDNNSoftmaxOp<Context>::RunWithType() {
Tensor fake_tensor(vector<TIndex>( Tensor fake_tensor(vector<TIndex>(
{ outer_dim, Input(0).dim(axis), inner_dim }) { outer_dim, Input(0).dim(axis), inner_dim }));
);
cudnnSetTensorDesc<T>(&input_desc, &fake_tensor); cudnnSetTensorDesc<T>(&input_desc, &fake_tensor);
cudnnSetTensorDesc<T>(&output_desc, &fake_tensor); cudnnSetTensorDesc<T>(&output_desc, &fake_tensor);
...@@ -16,7 +15,7 @@ void CuDNNSoftmaxOp<Context>::RunWithType() { ...@@ -16,7 +15,7 @@ void CuDNNSoftmaxOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnSoftmaxForward( CUDNN_CHECK(cudnnSoftmaxForward(
ctx().cudnn_handle(), ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
...@@ -41,8 +40,7 @@ DEPLOY_CUDNN(Softmax); ...@@ -41,8 +40,7 @@ DEPLOY_CUDNN(Softmax);
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNSoftmaxGradientOp<Context>::RunWithType() { void CuDNNSoftmaxGradientOp<Context>::RunWithType() {
Tensor fake_tensor(vector<TIndex>( Tensor fake_tensor(vector<TIndex>(
{ outer_dim, Input(0).dim(axis), inner_dim }) { outer_dim, Input(0).dim(axis), inner_dim }));
);
cudnnSetTensorDesc<T>(&input_desc, &fake_tensor); cudnnSetTensorDesc<T>(&input_desc, &fake_tensor);
cudnnSetTensorDesc<T>(&output_desc, &fake_tensor); cudnnSetTensorDesc<T>(&output_desc, &fake_tensor);
...@@ -50,7 +48,7 @@ void CuDNNSoftmaxGradientOp<Context>::RunWithType() { ...@@ -50,7 +48,7 @@ void CuDNNSoftmaxGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnSoftmaxBackward( CUDNN_CHECK(cudnnSoftmaxBackward(
ctx().cudnn_handle(), ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL,
CUDNNType<T>::one, input_desc, Ydata, input_desc, dYdata, CUDNNType<T>::one, input_desc, Ydata, input_desc, dYdata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
......
...@@ -13,12 +13,12 @@ void CuDNNTanhOp<Context>::RunWithType() { ...@@ -13,12 +13,12 @@ void CuDNNTanhOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
#else #else
CUDNN_CHECK(cudnnActivationForward_v4( CUDNN_CHECK(cudnnActivationForward_v4(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<Dtype>::one, input_desc, Xdata, CUDNNType<Dtype>::one, input_desc, Xdata,
CUDNNType<Dtype>::zero, output_desc, Ydata)); CUDNNType<Dtype>::zero, output_desc, Ydata));
#endif #endif
...@@ -47,13 +47,13 @@ void CuDNNTanhGradientOp<Context>::RunWithType() { ...@@ -47,13 +47,13 @@ void CuDNNTanhGradientOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Ydata, input_desc, dYdata, output_desc, Ydata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
#else #else
CUDNN_CHECK(cudnnActivationBackward_v4( CUDNN_CHECK(cudnnActivationBackward_v4(
ctx().cudnn_handle(), act_desc, ctx()->cudnn_handle(), act_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Ydata, input_desc, dYdata, output_desc, Ydata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
......
...@@ -11,10 +11,10 @@ void DropoutOp<Context>::RunWithType() { ...@@ -11,10 +11,10 @@ void DropoutOp<Context>::RunWithType() {
float scale = use_scale ? 1.0 / (1.0 - prob()) : 1.0; float scale = use_scale ? 1.0 / (1.0 - prob()) : 1.0;
if (phase() == "TEST") { if (phase() == "TEST") {
if (Output(0) != &Input(0)) { if (Output(0) != &Input(0)) {
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), Ydata, Xdata); Output(0)->count(), Ydata, Xdata);
if (scale == 1.0) math::Scal<T, Context>( if (scale == 1.0) math::Scal<T, Context>(
Output(0)->count(), 1.0 - prob(), Ydata, &ctx()); Output(0)->count(), 1.0 - prob(), Ydata, ctx());
} }
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
Tensor* mask = ws()->CreateTensor( Tensor* mask = ws()->CreateTensor(
...@@ -23,7 +23,7 @@ void DropoutOp<Context>::RunWithType() { ...@@ -23,7 +23,7 @@ void DropoutOp<Context>::RunWithType() {
uint32_t* Mdata = mask->template mutable_data<uint32_t, Context>(); uint32_t* Mdata = mask->template mutable_data<uint32_t, Context>();
kernel::Dropout<T, Context>( kernel::Dropout<T, Context>(
Output(0)->count(), prob(), scale, Output(0)->count(), prob(), scale,
Xdata, Mdata, Ydata, &ctx()); Xdata, Mdata, Ydata, ctx());
} else LOG(FATAL) << "Incorrect Op phase: " << phase(); } else LOG(FATAL) << "Incorrect Op phase: " << phase();
} }
...@@ -52,7 +52,8 @@ void DropoutGradientOp<Context>::RunWithType() { ...@@ -52,7 +52,8 @@ void DropoutGradientOp<Context>::RunWithType() {
else if (phase() == "TRAIN") { else if (phase() == "TRAIN") {
kernel::DropoutGrad<T, Context>( kernel::DropoutGrad<T, Context>(
Output(0)->count(), prob(), scale, Output(0)->count(), prob(), scale,
dYdata, Mdata, dXdata, &ctx()); dYdata, Mdata, dXdata, ctx());
ctx()->FinishDeviceCompution();
mask->Reset(); mask->Reset();
} else LOG(FATAL) << "Incorrect Op phase: " << phase(); } else LOG(FATAL) << "Incorrect Op phase: " << phase();
} }
......
...@@ -8,7 +8,8 @@ template <class Context> template <typename T> ...@@ -8,7 +8,8 @@ template <class Context> template <typename T>
void EluOp<Context>::RunWithType() { void EluOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Elu<T, Context>(Output(0)->count(), alpha, Xdata, Ydata); kernel::Elu<T, Context>(Output(0)->count(),
alpha, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -30,8 +31,8 @@ void EluGradientOp<Context>::RunWithType() { ...@@ -30,8 +31,8 @@ void EluGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dYdata = Input(1).template data<T, Context>(); auto* dYdata = Input(1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::EluGrad<T, Context>( kernel::EluGrad<T, Context>(Output(0)->count(),
Output(0)->count(), alpha, dYdata, Ydata, dXdata); alpha, dYdata, Ydata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -18,7 +18,7 @@ void PReluOp<Context>::RunWithType() { ...@@ -18,7 +18,7 @@ void PReluOp<Context>::RunWithType() {
kernel::PRelu<T, Context>( kernel::PRelu<T, Context>(
Output(0)->count(), channels, dim, Output(0)->count(), channels, dim,
channel_shared ? true : false, data_format, channel_shared ? true : false, data_format,
Xdata, Wdata, Ydata); Xdata, Wdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -49,12 +49,12 @@ void PReluGradientOp<Context>::RunWithType() { ...@@ -49,12 +49,12 @@ void PReluGradientOp<Context>::RunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
DECLARE_MULTIPLIER(multiplier, channels * dim); DECLARE_MULTIPLIER(multiplier, channels * dim);
auto* dWdata = Output(1)->template mutable_data<T, Context>(); auto* dWdata = Output(1)->template mutable_data<T, Context>(ctx());
auto* dWBdata = ws()->template caches<T, Context>({ channels * dim })[0]; auto* dWBdata = ws()->template caches<T, Context>({ channels * dim })[0];
kernel::PReluWGrad<T, Context>( kernel::PReluWGrad<T, Context>(
Input(0).dim(0), Input(0).count(1), channels, dim, Input(0).dim(0), Input(0).count(1), channels, dim,
channel_shared ? true : false, data_format, channel_shared ? true : false, data_format,
dYdata, Xdata, multiplier, dWBdata, dWdata, &ctx()); dYdata, Xdata, multiplier, dWBdata, dWdata, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
...@@ -63,7 +63,7 @@ void PReluGradientOp<Context>::RunWithType() { ...@@ -63,7 +63,7 @@ void PReluGradientOp<Context>::RunWithType() {
kernel::PReluGrad<T, Context>( kernel::PReluGrad<T, Context>(
Output(0)->count(), channels, dim, Output(0)->count(), channels, dim,
channel_shared ? true : false, data_format, channel_shared ? true : false, data_format,
dYdata, Xdata, Wdata, dXdata); dYdata, Xdata, Wdata, dXdata, ctx());
} }
} }
......
...@@ -8,7 +8,8 @@ template <class Context> template <typename T> ...@@ -8,7 +8,8 @@ template <class Context> template <typename T>
void ReluOp<Context>::RunWithType() { void ReluOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Relu<T, Context>(Output(0)->count(), slope, Xdata, Ydata); kernel::Relu<T, Context>(Output(0)->count(),
slope, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -24,15 +25,17 @@ DEPLOY_CPU(Relu); ...@@ -24,15 +25,17 @@ DEPLOY_CPU(Relu);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(Relu); DEPLOY_CUDA(Relu);
#endif #endif
OPERATOR_SCHEMA(Relu).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } }); OPERATOR_SCHEMA(Relu)
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context> template <typename T> template <class Context> template <typename T>
void ReluGradientOp<Context>::RunWithType() { void ReluGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dYdata = Input(1).template data<T, Context>(); auto* dYdata = Input(1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::ReluGrad<T, Context>( kernel::ReluGrad<T, Context>(Output(0)->count(),
Output(0)->count(), slope, dYdata, Ydata, dXdata); slope, dYdata, Ydata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -47,7 +50,9 @@ DEPLOY_CPU(ReluGradient); ...@@ -47,7 +50,9 @@ DEPLOY_CPU(ReluGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ReluGradient); DEPLOY_CUDA(ReluGradient);
#endif #endif
OPERATOR_SCHEMA(ReluGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 }}); OPERATOR_SCHEMA(ReluGradient)
.NumInputs(2).NumOutputs(1)
.Inplace({ { 1, 0 }});
class GetReluGradient final : public GradientMakerBase { class GetReluGradient final : public GradientMakerBase {
public: public:
......
...@@ -8,7 +8,7 @@ template <class Context> template <typename T> ...@@ -8,7 +8,7 @@ template <class Context> template <typename T>
void SEluOp<Context>::RunWithType() { void SEluOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::SElu<T, Context>(Output(0)->count(), Xdata, Ydata); kernel::SElu<T, Context>(Output(0)->count(), Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -23,15 +23,17 @@ DEPLOY_CPU(SElu); ...@@ -23,15 +23,17 @@ DEPLOY_CPU(SElu);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(SElu); DEPLOY_CUDA(SElu);
#endif #endif
OPERATOR_SCHEMA(SElu).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } }); OPERATOR_SCHEMA(SElu)
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context> template <typename T> template <class Context> template <typename T>
void SEluGradientOp<Context>::RunWithType() { void SEluGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dYdata = Input(1).template data<T, Context>(); auto* dYdata = Input(1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::SEluGrad<T, Context>( kernel::SEluGrad<T, Context>(Output(0)->count(),
Output(0)->count(), dYdata, Ydata, dXdata); dYdata, Ydata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -46,7 +48,9 @@ DEPLOY_CPU(SEluGradient); ...@@ -46,7 +48,9 @@ DEPLOY_CPU(SEluGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(SEluGradient); DEPLOY_CUDA(SEluGradient);
#endif #endif
OPERATOR_SCHEMA(SEluGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 }}); OPERATOR_SCHEMA(SEluGradient)
.NumInputs(2).NumOutputs(1)
.Inplace({ { 1, 0 }});
class GetSEluGradient final : public GradientMakerBase { class GetSEluGradient final : public GradientMakerBase {
public: public:
......
...@@ -8,7 +8,7 @@ template <class Context> template <typename T> ...@@ -8,7 +8,7 @@ template <class Context> template <typename T>
void SigmoidOp<Context>::RunWithType() { void SigmoidOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Sigmoid<T, Context>(Output(0)->count(), Xdata, Ydata); kernel::Sigmoid<T, Context>(Output(0)->count(), Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -30,8 +30,8 @@ void SigmoidGradientOp<Context>::RunWithType() { ...@@ -30,8 +30,8 @@ void SigmoidGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dYdata = Input(1).template data<T, Context>(); auto* dYdata = Input(1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::SigmoidGrad<T, Context>( kernel::SigmoidGrad<T, Context>(Output(0)->count(),
Output(0)->count(), dYdata, Ydata, dXdata); dYdata, Ydata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -12,13 +12,13 @@ void SoftmaxOp<Context>::RunWithType() { ...@@ -12,13 +12,13 @@ void SoftmaxOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Input(0).count(), Ydata, Xdata); Input(0).count(), Ydata, Xdata);
kernel::Softmax<T, Context>( kernel::Softmax<T, Context>(
Output(0)->count(), Input(0).dim(axis), Output(0)->count(), Input(0).dim(axis),
outer_dim, inner_dim, multiplier, outer_dim, inner_dim, multiplier,
Xdata, WSdata, Ydata, &ctx()); Xdata, WSdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -36,7 +36,9 @@ DEPLOY_CPU(Softmax); ...@@ -36,7 +36,9 @@ DEPLOY_CPU(Softmax);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(Softmax); DEPLOY_CUDA(Softmax);
#endif #endif
OPERATOR_SCHEMA(Softmax).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } }); OPERATOR_SCHEMA(Softmax)
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context> template <typename T> template <class Context> template <typename T>
void SoftmaxGradientOp<Context>::RunWithType() { void SoftmaxGradientOp<Context>::RunWithType() {
...@@ -44,15 +46,16 @@ void SoftmaxGradientOp<Context>::RunWithType() { ...@@ -44,15 +46,16 @@ void SoftmaxGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>(
{ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Input(0).count(), dXdata, dYdata); Input(0).count(), dXdata, dYdata);
kernel::SoftmaxGrad<T, Context>( kernel::SoftmaxGrad<T, Context>(
Output(0)->count(), Input(0).dim(axis), Output(0)->count(), Input(0).dim(axis),
outer_dim, inner_dim, multiplier, outer_dim, inner_dim, multiplier,
dYdata, Ydata, WSdata, dXdata, &ctx()); dYdata, Ydata, WSdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -70,7 +73,9 @@ DEPLOY_CPU(SoftmaxGradient); ...@@ -70,7 +73,9 @@ DEPLOY_CPU(SoftmaxGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(SoftmaxGradient); DEPLOY_CUDA(SoftmaxGradient);
#endif #endif
OPERATOR_SCHEMA(SoftmaxGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } }); OPERATOR_SCHEMA(SoftmaxGradient)
.NumInputs(2).NumOutputs(1)
.Inplace({ { 1, 0 } });
class GetSoftmaxGradient final : public GradientMakerBase { class GetSoftmaxGradient final : public GradientMakerBase {
public: public:
......
...@@ -8,7 +8,7 @@ template <class Context> template <typename T> ...@@ -8,7 +8,7 @@ template <class Context> template <typename T>
void TanhOp<Context>::RunWithType() { void TanhOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Tanh<T, Context>(Output(0)->count(), Xdata, Ydata); kernel::Tanh<T, Context>(Output(0)->count(), Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -30,8 +30,8 @@ void TanhGradientOp<Context>::RunWithType() { ...@@ -30,8 +30,8 @@ void TanhGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context>(); auto* Ydata = Input(0).template data<T, Context>();
auto* dYdata = Input(1).template data<T, Context>(); auto* dYdata = Input(1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::TanhGrad<T, Context>( kernel::TanhGrad<T, Context>(Output(0)->count(),
Output(0)->count(), dYdata, Ydata, dXdata); dYdata, Ydata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -9,7 +9,7 @@ void AddOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void AddOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Add<T, Context>(Output(0)->count(), x1, x2, y); math::Add<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -19,23 +19,24 @@ void AddOp<Context>::BroadcastRunWithType(int type) { ...@@ -19,23 +19,24 @@ void AddOp<Context>::BroadcastRunWithType(int type) {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), y, x1); Output(0)->count(), y, x1);
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
if (type == 0) { if (type == 0) {
outer_dim = Input(0).count(); x2 = Input(1).template data<T, CPUContext>();
inner_dim = 1; math::AddScalar<T, Context>(Output(0)->count(),
dragon_cast<float, T>(x2[0]), y, ctx());
} else { } else {
outer_dim = Input(0).count(0, Input(0).axis(-1)); outer_dim = Input(0).count(0, Input(0).axis(-1));
inner_dim = Input(0).dim(-1); inner_dim = Input(0).dim(-1);
}
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x2, 1.0, multiplier, x2,
1.0, y, &ctx()); 1.0, y, ctx());
}
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(0).dim(0); outer_dim = Input(0).dim(0);
inner_dim = Input(0).count(1); inner_dim = Input(0).count(1);
...@@ -44,7 +45,7 @@ void AddOp<Context>::BroadcastRunWithType(int type) { ...@@ -44,7 +45,7 @@ void AddOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x2, multiplier, 1.0, x2, multiplier,
1.0, y, &ctx()); 1.0, y, ctx());
} }
} }
...@@ -77,13 +78,13 @@ void AddGradientOp<Context>::EltwiseRunWithType() { ...@@ -77,13 +78,13 @@ void AddGradientOp<Context>::EltwiseRunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(1)->count(), dx2, dy); Output(1)->count(), dx2, dy);
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), dx1, dy); Output(0)->count(), dx1, dy);
} }
} }
...@@ -108,7 +109,7 @@ void AddGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -108,7 +109,7 @@ void AddGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = X1->dim(0); outer_dim = X1->dim(0);
inner_dim = X1->count(1); inner_dim = X1->count(1);
...@@ -116,13 +117,13 @@ void AddGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -116,13 +117,13 @@ void AddGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} }
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
X1->count(), dx1, dy); X1->count(), dx1, dy);
} }
} }
......
...@@ -34,7 +34,7 @@ void AffineOp<Context>::RunWithType() { ...@@ -34,7 +34,7 @@ void AffineOp<Context>::RunWithType() {
kernel::Affine<T, Context>( kernel::Affine<T, Context>(
Output(0)->count(), outer_dim, scale_dim, inner_dim, Output(0)->count(), outer_dim, scale_dim, inner_dim,
Xdata, Adata, Bdata, bias_multiplier, Ydata, &ctx()); Xdata, Adata, Bdata, bias_multiplier, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -58,13 +58,13 @@ void AffineGradientOp<Context>::BiasRunWithType() { ...@@ -58,13 +58,13 @@ void AffineGradientOp<Context>::BiasRunWithType() {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dBias = Output(2)->template mutable_data<T, Context>(); auto* dBias = Output(2)->template mutable_data<T, Context>(ctx());
for (int n = 0; n < outer_dim; n++) { for (int n = 0; n < outer_dim; n++) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, scale_dim, inner_dim, CblasNoTrans, scale_dim, inner_dim,
1.0, dYdata, multiplier, 1.0, dYdata, multiplier,
1.0, dBias, &ctx()); 1.0, dBias, ctx());
dYdata += dim; dYdata += dim;
} }
} }
...@@ -79,45 +79,36 @@ void AffineGradientOp<Context>::ScaleRunWithType() { ...@@ -79,45 +79,36 @@ void AffineGradientOp<Context>::ScaleRunWithType() {
bool is_eltwise = (Input(-1).count() == Input(1).count()); bool is_eltwise = (Input(-1).count() == Input(1).count());
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dScale = Output(1)->template mutable_data<T, Context>(); auto* dScale = Output(1)->template mutable_data<T, Context>(ctx());
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
auto* dYxX = dXdata; auto* dYxX = dXdata;
math::Mul<T, Context>(Output(0)->count(), dYdata, Xdata, dYxX); math::Mul<T, Context>(Output(0)->count(), dYdata, Xdata, dYxX, ctx());
if (!is_eltwise) { if (!is_eltwise) {
T* SRes_data = nullptr; T* SRes_data = nullptr;
// reduce inner dimensions
if (inner_dim == 1) { if (inner_dim == 1) {
SRes_data = dYxX; SRes_data = dYxX;
} else if (sum_result.count() == 1) { // handle inner only
dScale = Output(1)->template mutable_data<T, CPUContext>();
T result = math::Dot<T, Context>(
inner_dim, dYxX, multiplier, &ctx());
*dScale += result;
} else { } else {
SRes_data = (outer_dim == 1) ? // handle scale only SRes_data = (outer_dim == 1) ?
dScale : sum_result.template mutable_data<T, Context>(); dScale : sum_result.template mutable_data<T, Context>();
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, sum_result.count(), inner_dim, CblasNoTrans, sum_result.count(), inner_dim,
1.0, dYxX, multiplier, 1.0, dYxX, multiplier,
SRes_data == dScale ? 1.0 : 0.0, SRes_data, &ctx()); SRes_data == dScale ? 1.0 : 0.0,
SRes_data, ctx());
} }
// reduce outer dimensions
if (outer_dim != 1) { if (outer_dim != 1) {
if (scale_dim == 1) { // handle outer only
dScale = Output(1)->template mutable_data<T, CPUContext>();
T result = math::Dot<T, Context>(
outer_dim, multiplier, SRes_data, &ctx());
*dScale += result;
} else {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, scale_dim, CblasTrans, outer_dim, scale_dim,
1.0, SRes_data, multiplier, 1.0, SRes_data, multiplier,
1.0, dScale, &ctx()); 1.0, dScale, ctx());
}
} }
} else { } else {
math::Axpy<T, Context>(Output(1)->count(), math::Axpy<T, Context>(Output(1)->count(),
1.f, dYxX, dScale, &ctx()); 1.f, dYxX, dScale, ctx());
} }
} }
...@@ -131,7 +122,7 @@ void AffineGradientOp<Context>::RunWithType() { ...@@ -131,7 +122,7 @@ void AffineGradientOp<Context>::RunWithType() {
kernel::AffineGrad<T, Context>( kernel::AffineGrad<T, Context>(
Output(0)->count(), outer_dim, scale_dim, inner_dim, Output(0)->count(), outer_dim, scale_dim, inner_dim,
dYdata, Adata, dXdata, &ctx()); dYdata, Adata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -15,7 +15,7 @@ void ClipOp<Context>::RunWithType() { ...@@ -15,7 +15,7 @@ void ClipOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template mutable_data<T, Context>(); auto* Mdata = mask->template mutable_data<T, Context>();
kernel::Clip<T, Context>(Output(0)->count(), kernel::Clip<T, Context>(Output(0)->count(),
low, high, Xdata, Mdata, Ydata); low, high, Xdata, Mdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -30,7 +30,9 @@ DEPLOY_CPU(Clip); ...@@ -30,7 +30,9 @@ DEPLOY_CPU(Clip);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(Clip); DEPLOY_CUDA(Clip);
#endif #endif
OPERATOR_SCHEMA(Clip).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } }); OPERATOR_SCHEMA(Clip)
.NumInputs(1).NumOutputs(1)
.Inplace({ { 0, 0 } });
template <class Context> template <typename T> template <class Context> template <typename T>
void ClipGradientOp<Context>::RunWithType() { void ClipGradientOp<Context>::RunWithType() {
...@@ -39,7 +41,8 @@ void ClipGradientOp<Context>::RunWithType() { ...@@ -39,7 +41,8 @@ void ClipGradientOp<Context>::RunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<T, Context>(); auto* Mdata = mask->template data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), dXdata, Mdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
dXdata, Mdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -54,7 +57,9 @@ DEPLOY_CPU(ClipGradient); ...@@ -54,7 +57,9 @@ DEPLOY_CPU(ClipGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ClipGradient); DEPLOY_CUDA(ClipGradient);
#endif #endif
OPERATOR_SCHEMA(ClipGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } }); OPERATOR_SCHEMA(ClipGradient)
.NumInputs(2).NumOutputs(1)
.Inplace({ { 1, 0 } });
class GetClipGradient final : public GradientMakerBase { class GetClipGradient final : public GradientMakerBase {
public: public:
......
...@@ -23,7 +23,7 @@ void CuDNNAffineOp<Context>::RunWithType() { ...@@ -23,7 +23,7 @@ void CuDNNAffineOp<Context>::RunWithType() {
mul_desc, CUDNN_OP_TENSOR_MUL, mul_desc, CUDNN_OP_TENSOR_MUL,
CUDNNType<T>::type, CUDNN_PROPAGATE_NAN)); CUDNNType<T>::type, CUDNN_PROPAGATE_NAN));
CUDNN_CHECK(cudnnOpTensor( CUDNN_CHECK(cudnnOpTensor(
ctx().cudnn_handle(), mul_desc, ctx()->cudnn_handle(), mul_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::one, param_desc, Adata, CUDNNType<T>::one, param_desc, Adata,
CUDNNType<T>::zero, input_desc, Ydata)); CUDNNType<T>::zero, input_desc, Ydata));
...@@ -36,7 +36,7 @@ void CuDNNAffineOp<Context>::RunWithType() { ...@@ -36,7 +36,7 @@ void CuDNNAffineOp<Context>::RunWithType() {
add_desc, CUDNN_OP_TENSOR_ADD, add_desc, CUDNN_OP_TENSOR_ADD,
CUDNNType<T>::type, CUDNN_PROPAGATE_NAN)); CUDNNType<T>::type, CUDNN_PROPAGATE_NAN));
CUDNN_CHECK(cudnnOpTensor( CUDNN_CHECK(cudnnOpTensor(
ctx().cudnn_handle(), add_desc, ctx()->cudnn_handle(), add_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
CUDNNType<T>::one, param_desc, Bdata, CUDNNType<T>::one, param_desc, Bdata,
CUDNNType<T>::zero, input_desc, Ydata)); CUDNNType<T>::zero, input_desc, Ydata));
...@@ -48,7 +48,9 @@ void CuDNNAffineOp<Context>::RunOnDevice() { ...@@ -48,7 +48,9 @@ void CuDNNAffineOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
#ifdef WITH_CUDA_FP16
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else if (XIsType(Input(0), float16)) RunWithType<float16>();
#endif
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
...@@ -76,17 +78,17 @@ void CuDNNAffineGradientOp<Context>::RunWithType() { ...@@ -76,17 +78,17 @@ void CuDNNAffineGradientOp<Context>::RunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
Output(1)->ReshapeLike(Input(1)); Output(1)->ReshapeLike(Input(1));
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dAdata = Output(1)->template mutable_data<T, Context>(); auto* dAdata = Output(1)->template mutable_data<T, Context>(ctx());
// eltwise // eltwise
if (Input(0).count() == Input(1).count()) { if (Input(0).count() == Input(1).count()) {
CUDNN_CHECK(cudnnOpTensor( CUDNN_CHECK(cudnnOpTensor(
ctx().cudnn_handle(), mul_desc, ctx()->cudnn_handle(), mul_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::one, input_desc, dYdata, CUDNNType<T>::one, input_desc, dYdata,
CUDNNType<T>::one, param_desc, dAdata)); CUDNNType<T>::one, param_desc, dAdata));
} else { } else {
CUDNN_CHECK(cudnnOpTensor( CUDNN_CHECK(cudnnOpTensor(
ctx().cudnn_handle(), mul_desc, ctx()->cudnn_handle(), mul_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::one, input_desc, dYdata, CUDNNType<T>::one, input_desc, dYdata,
CUDNNType<T>::zero, input_desc, dXdata)); CUDNNType<T>::zero, input_desc, dXdata));
...@@ -97,11 +99,11 @@ void CuDNNAffineGradientOp<Context>::RunWithType() { ...@@ -97,11 +99,11 @@ void CuDNNAffineGradientOp<Context>::RunWithType() {
// db = dy // db = dy
if (Output(2)->name() != "ignore") { if (Output(2)->name() != "ignore") {
Output(2)->ReshapeLike(Input(1)); Output(2)->ReshapeLike(Input(1));
auto* dBdata = Output(2)->template mutable_data<T, Context>(); auto* dBdata = Output(2)->template mutable_data<T, Context>(ctx());
// eltwise // eltwise
if (Input(-1).count() == Input(1).count()) { if (Input(-1).count() == Input(1).count()) {
math::Axpy<T, Context>(Output(2)->count(), math::Axpy<T, Context>(Output(2)->count(),
1.f, dYdata, dBdata, &ctx()); 1.f, dYdata, dBdata, ctx());
} else { } else {
ComputeBiasGradient_v2<T>(dYdata, dBdata); ComputeBiasGradient_v2<T>(dYdata, dBdata);
} }
...@@ -109,7 +111,7 @@ void CuDNNAffineGradientOp<Context>::RunWithType() { ...@@ -109,7 +111,7 @@ void CuDNNAffineGradientOp<Context>::RunWithType() {
// dx = alpha * dy // dx = alpha * dy
CUDNN_CHECK(cudnnOpTensor( CUDNN_CHECK(cudnnOpTensor(
ctx().cudnn_handle(), mul_desc, ctx()->cudnn_handle(), mul_desc,
CUDNNType<T>::one, input_desc, dYdata, CUDNNType<T>::one, input_desc, dYdata,
CUDNNType<T>::one, param_desc, Adata, CUDNNType<T>::one, param_desc, Adata,
CUDNNType<T>::zero, input_desc, dXdata)); CUDNNType<T>::zero, input_desc, dXdata));
...@@ -126,11 +128,11 @@ void CuDNNAffineGradientOp<Context>::ComputeScaleGradient( ...@@ -126,11 +128,11 @@ void CuDNNAffineGradientOp<Context>::ComputeScaleGradient(
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES)); CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
size_t workspace_size = 0; size_t workspace_size = 0;
CUDNN_CHECK(cudnnGetReductionWorkspaceSize( CUDNN_CHECK(cudnnGetReductionWorkspaceSize(
ctx().cudnn_handle(), reduce_desc, ctx()->cudnn_handle(), reduce_desc,
input_desc, param_desc, &workspace_size)); input_desc, param_desc, &workspace_size));
auto* WSdata = ws()->template caches<Context>({ workspace_size })[0];; auto* WSdata = ws()->template caches<Context>({ workspace_size })[0];;
CUDNN_CHECK(cudnnReduceTensor( CUDNN_CHECK(cudnnReduceTensor(
ctx().cudnn_handle(), reduce_desc, ctx()->cudnn_handle(), reduce_desc,
nullptr, 0, WSdata, workspace_size, nullptr, 0, WSdata, workspace_size,
CUDNNType<T>::one, input_desc, dYxX, CUDNNType<T>::one, input_desc, dYxX,
CUDNNType<T>::one, param_desc, dA)); CUDNNType<T>::one, param_desc, dA));
...@@ -145,32 +147,23 @@ void CuDNNAffineGradientOp<Context>::ComputeScaleGradient_v2( ...@@ -145,32 +147,23 @@ void CuDNNAffineGradientOp<Context>::ComputeScaleGradient_v2(
sum_result.Reshape({ outer_dim * scale_dim }); sum_result.Reshape({ outer_dim * scale_dim });
T* SRes_data = nullptr; T* SRes_data = nullptr;
if (inner_dim == 1) SRes_data = dYxX; // reduce inner dimensions
else if (sum_result.count() == 1) { if (inner_dim == 1) {
auto* dAC = Output(1)->template mutable_data<T, CPUContext>(); SRes_data = dYxX;
T result = math::Dot<T, Context>(
inner_dim, dYxX, multiplier, &ctx());
*dAC += result;
} else { } else {
SRes_data = (outer_dim == 1) ? SRes_data = (outer_dim == 1) ?
dA : sum_result.template mutable_data<T, Context>(); dA : sum_result.template mutable_data<T, Context>();
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, sum_result.count(), inner_dim, CblasNoTrans, sum_result.count(), inner_dim,
1.0, dYxX, multiplier, 1.0, dYxX, multiplier,
SRes_data == dA ? 1.0 : 0.0, SRes_data, &ctx()); SRes_data == dA ? 1.0 : 0.0, SRes_data, ctx());
} }
// reduce outer dimensions
if (outer_dim != 1) { if (outer_dim != 1) {
if (scale_dim == 1) {
auto* dAC = Output(1)->template mutable_data<T, CPUContext>();
T result = math::Dot<T, Context>(
outer_dim, multiplier, SRes_data, &ctx());
*dAC += result;
} else {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, scale_dim, CblasTrans, outer_dim, scale_dim,
1.0, SRes_data, multiplier, 1.0, SRes_data, multiplier,
1.0, dA, &ctx()); 1.0, dA, ctx());
}
} }
} }
...@@ -185,11 +178,11 @@ void CuDNNAffineGradientOp<Context>::ComputeBiasGradient( ...@@ -185,11 +178,11 @@ void CuDNNAffineGradientOp<Context>::ComputeBiasGradient(
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES)); CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
size_t workspace_size = 0; size_t workspace_size = 0;
CUDNN_CHECK(cudnnGetReductionWorkspaceSize( CUDNN_CHECK(cudnnGetReductionWorkspaceSize(
ctx().cudnn_handle(), reduce_desc, ctx()->cudnn_handle(), reduce_desc,
input_desc, param_desc, &workspace_size)); input_desc, param_desc, &workspace_size));
auto* WSdata = ws()->template caches<Context>({ workspace_size })[0]; auto* WSdata = ws()->template caches<Context>({ workspace_size })[0];
CUDNN_CHECK(cudnnReduceTensor( CUDNN_CHECK(cudnnReduceTensor(
ctx().cudnn_handle(), reduce_desc, ctx()->cudnn_handle(), reduce_desc,
nullptr, 0, WSdata, workspace_size, nullptr, 0, WSdata, workspace_size,
CUDNNType<T>::one, input_desc, dY, CUDNNType<T>::one, input_desc, dY,
CUDNNType<T>::one, param_desc, dB)); CUDNNType<T>::one, param_desc, dB));
...@@ -205,7 +198,7 @@ void CuDNNAffineGradientOp<Context>::ComputeBiasGradient_v2( ...@@ -205,7 +198,7 @@ void CuDNNAffineGradientOp<Context>::ComputeBiasGradient_v2(
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, scale_dim, inner_dim, CblasNoTrans, scale_dim, inner_dim,
1.0, dY, multiplier, 1.0, dY, multiplier,
1.0, dB, &ctx()); 1.0, dB, ctx());
dY += dim; dY += dim;
} }
} }
......
...@@ -9,7 +9,7 @@ void DivOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void DivOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Div<T, Context>(Output(0)->count(), x1, x2, y); math::Div<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -18,34 +18,40 @@ void DivOp<Context>::BroadcastRunWithType(int type) { ...@@ -18,34 +18,40 @@ void DivOp<Context>::BroadcastRunWithType(int type) {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({
Output(0)->count() })[0];
if (type == 0 || type == 1) {
if (type == 0) { if (type == 0) {
outer_dim = Input(0).count(); x2 = Input(1).template data<T, CPUContext>();
inner_dim = 1; float inverse_x2 = 1.f / dragon_cast<float, T>(x2[0]);
} else { ctx()->template Copy<T, Context, Context>(
Output(0)->count(), y, x1);
math::MulScalar<T, Context>(
Output(0)->count(), inverse_x2, y, ctx());
} else if (type == 1) {
outer_dim = Input(0).count(0, Input(0).axis(-1)); outer_dim = Input(0).count(0, Input(0).axis(-1));
inner_dim = Input(0).dim(-1); inner_dim = Input(0).dim(-1);
}
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
auto* c = ws()->template caches<T, Context>(
{ Output(0)->count() })[0];
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x2, 1.0, multiplier, x2,
0.0, c, &ctx()); 0.0, c, ctx());
math::Div<T, Context>(Output(0)->count(), x1, c, y); math::Div<T, Context>(
Output(0)->count(), x1, c, y, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(0).dim(0); outer_dim = Input(0).dim(0);
inner_dim = Input(0).count(1); inner_dim = Input(0).count(1);
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
auto* c = ws()->template caches<T, Context>(
{ Output(0)->count() })[0];
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x2, multiplier, 1.0, x2, multiplier,
0.0, c, &ctx()); 0.0, c, ctx());
math::Div<T, Context>(Output(0)->count(), x1, c, y); math::Div<T, Context>(
Output(0)->count(), x1, c, y, ctx());
} }
} }
...@@ -82,16 +88,16 @@ void DivGradientOp<Context>::EltwiseRunWithType() { ...@@ -82,16 +88,16 @@ void DivGradientOp<Context>::EltwiseRunWithType() {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({ X1->count() })[0]; auto* c = ws()->template caches<T, Context>({ X1->count() })[0];
math::Mul<T,Context>(X1->count(), dy, x1, c); // dY * X1 math::Mul<T,Context>(X1->count(), dy, x1, c, ctx()); // dY * X1
math::Square<T, Context>(X2->count(), x2, dx2); // X2^{2} math::Square<T, Context>(X2->count(), x2, dx2, ctx()); // X2^{2}
math::Inv<T, Context>(X2->count(), -1, dx2, dx2); // -1 / X2^{2} math::Inv<T, Context>(X2->count(), -1, dx2, dx2, ctx()); // -1 / X2^{2}
math::Mul<T, Context>(X2->count(), c, dx2, dx2); math::Mul<T, Context>(X2->count(), c, dx2, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
math::Div<T, Context>(X1->count(), dy, x2, dx1); math::Div<T, Context>(X1->count(), dy, x2, dx1, ctx());
} }
} }
...@@ -118,23 +124,23 @@ void DivGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -118,23 +124,23 @@ void DivGradientOp<Context>::BroadcastRunWithType(int type) {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
auto cs = ws()->template caches<T, Context>( auto cs = ws()->template caches<T, Context>(
{ X1->count(), X2->count() }); { X1->count(), X2->count() });
math::Mul<T, Context>(X1->count(), dy, x1, cs[0]); // dY * X1 math::Mul<T, Context>(X1->count(), dy, x1, cs[0], ctx()); // dY * X1
math::Square<T, Context>(X2->count(), x2, dx2); // X2^{2} math::Square<T, Context>(X2->count(), x2, dx2, ctx()); // X2^{2}
math::Inv<T, Context>(X2->count(), -1.0, dx2, dx2); // -1 / X2^{2} math::Inv<T, Context>(X2->count(), -1, dx2, dx2, ctx()); // -1 / X2^{2}
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, cs[0], multiplier, 1.0, cs[0], multiplier,
0.0, cs[1], &ctx()); 0.0, cs[1], ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, cs[0], multiplier, 1.0, cs[0], multiplier,
0.0, cs[1], &ctx()); 0.0, cs[1], ctx());
} }
math::Mul<T, Context>(X2->count(), cs[1], dx2, dx2); math::Mul<T, Context>(X2->count(), cs[1], dx2, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
...@@ -146,16 +152,16 @@ void DivGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -146,16 +152,16 @@ void DivGradientOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x2, 1.0, multiplier, x2,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x2, multiplier, 1.0, x2, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} }
math::Div<T, Context>(X1->count(), dy, dx1, dx1); math::Div<T, Context>(X1->count(), dy, dx1, dx1, ctx());
} }
} }
......
...@@ -7,9 +7,13 @@ template <class Context> template <typename T> ...@@ -7,9 +7,13 @@ template <class Context> template <typename T>
void DotOp<Context>::DotRunWithType() { void DotOp<Context>::DotRunWithType() {
auto* X1data = Input(0).template data<T, Context>(); auto* X1data = Input(0).template data<T, Context>();
auto* X2data = Input(1).template data<T, Context>(); auto* X2data = Input(1).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, CPUContext>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
Ydata[0] = math::Dot<T, Context>(
Input(0).count(), X1data, X2data, &ctx()); T result_host;
math::Dot<T, Context>(Input(0).count(),
X1data, X2data, &result_host, ctx());
ctx()->template Copy<T, Context, CPUContext>(
1, Ydata, &result_host);
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -22,7 +26,7 @@ void DotOp<Context>::GemmRunWithType() { ...@@ -22,7 +26,7 @@ void DotOp<Context>::GemmRunWithType() {
TransB ? CblasTrans : CblasNoTrans, TransB ? CblasTrans : CblasNoTrans,
M, N1, K1, M, N1, K1,
1.0, X1data, X2data, 1.0, X1data, X2data,
0.0, Ydata, &ctx()); 0.0, Ydata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -33,7 +37,7 @@ void DotOp<Context>::GemvRunWithType() { ...@@ -33,7 +37,7 @@ void DotOp<Context>::GemvRunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
TransA ? CblasTrans : CblasNoTrans, M, N1, TransA ? CblasTrans : CblasNoTrans, M, N1,
1.0, X1data, X2data, 1.0, X1data, X2data,
0.0, Ydata, &ctx()); 0.0, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -98,12 +102,14 @@ void DotGradientOp<Context>::DotRunWithType() { ...@@ -98,12 +102,14 @@ void DotGradientOp<Context>::DotRunWithType() {
auto* dYdata = Input(2).template data<T, CPUContext>(); auto* dYdata = Input(2).template data<T, CPUContext>();
auto* dX1data = Output(0)->template mutable_data<T, Context>(); auto* dX1data = Output(0)->template mutable_data<T, Context>();
auto* dX2data = Output(1)->template mutable_data<T, Context>(); auto* dX2data = Output(1)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), dX1data, X2data); Output(0)->count(), dX1data, X2data);
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(1)->count(), dX2data, X1data); Output(1)->count(), dX2data, X1data);
math::MulScalar<T, Context>(Output(0)->count(), dYdata[0], dX1data); math::MulScalar<T, Context>(
math::MulScalar<T, Context>(Output(1)->count(), dYdata[0], dX2data); Output(0)->count(), dYdata[0], dX1data, ctx());
math::MulScalar<T, Context>(
Output(1)->count(), dYdata[0], dX2data, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -118,13 +124,13 @@ void DotGradientOp<Context>::GemmRunWithType() { ...@@ -118,13 +124,13 @@ void DotGradientOp<Context>::GemmRunWithType() {
TransB ? CblasNoTrans : CblasTrans, TransB ? CblasNoTrans : CblasTrans,
M, K1, N1, M, K1, N1,
1.0, dYdata, X2data, 1.0, dYdata, X2data,
0.0, dX1data, &ctx()); 0.0, dX1data, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
TransA ? CblasNoTrans : CblasTrans, TransA ? CblasNoTrans : CblasTrans,
CblasNoTrans, CblasNoTrans,
K1, N1, M, K1, N1, M,
1.0, X1data, dYdata, 1.0, X1data, dYdata,
0.0, dX2data, &ctx()); 0.0, dX2data, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -138,11 +144,11 @@ void DotGradientOp<Context>::GemvRunWithType() { ...@@ -138,11 +144,11 @@ void DotGradientOp<Context>::GemvRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
M, N1, 1, M, N1, 1,
1.0, dYdata, X2data, 1.0, dYdata, X2data,
0.0, dX1data, &ctx()); 0.0, dX1data, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
TransA ? CblasNoTrans : CblasTrans, M, N1, TransA ? CblasNoTrans : CblasTrans, M, N1,
1.0, X1data, dYdata, 1.0, X1data, dYdata,
0.0, dX2data, &ctx()); 0.0, dX2data, ctx());
} }
template <class Context> template <class Context>
......
...@@ -7,10 +7,11 @@ template <class Context> template <typename T> ...@@ -7,10 +7,11 @@ template <class Context> template <typename T>
void EltwiseOp<Context>::SumRunWithType() { void EltwiseOp<Context>::SumRunWithType() {
TIndex count = Output(0)->count(); TIndex count = Output(0)->count();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(count, dragon_cast<T, float>(0), Ydata); math::Set<T, Context>(count,
dragon_cast<T, float>(0), Ydata, ctx());
for (int i = 0; i < InputSize(); ++i) { for (int i = 0; i < InputSize(); ++i) {
math::Axpy<T, Context>(count, coeffs[i], math::Axpy<T, Context>(count, coeffs[i],
Input(i).template data<T, Context>(), Ydata, &ctx()); Input(i).template data<T, Context>(), Ydata, ctx());
} }
} }
...@@ -21,19 +22,24 @@ void EltwiseOp<Context>::ProdRunWithType() { ...@@ -21,19 +22,24 @@ void EltwiseOp<Context>::ProdRunWithType() {
math::Mul<T, Context>(count, math::Mul<T, Context>(count,
Input(0).template data<T, Context>(), Input(0).template data<T, Context>(),
Input(1).template data<T, Context>(), Input(1).template data<T, Context>(),
Ydata); Ydata, ctx());
for (int i = 2; i < InputSize(); i++) { for (int i = 2; i < InputSize(); i++) {
math::Mul<T, Context>(count, math::Mul<T, Context>(count,
Ydata, Ydata,
Input(i).template data<T, Context>(), Input(i).template data<T, Context>(),
Ydata); Ydata, ctx());
} }
} }
template <class Context> template <class Context>
void EltwiseOp<Context>::RunOnDevice() { void EltwiseOp<Context>::RunOnDevice() {
for (int i = 1; i < InputSize(); i++) for (int i = 1; i < InputSize(); i++) {
CHECK(Input(i).dims() == Input(0).dims()); CHECK(Input(i).dims() == Input(0).dims())
<< "\nExcepted Input(" << i << ")'s dims as "
<< Input(0).DimString() << ",\n but got "
<< Input(1).DimString() << ".";
}
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (operation == "SUM") { if (operation == "SUM") {
...@@ -65,12 +71,12 @@ void EltwiseGradientOp<Context>::SumRunWithType() { ...@@ -65,12 +71,12 @@ void EltwiseGradientOp<Context>::SumRunWithType() {
for (int i = 0; i < OutputSize(); i++) { for (int i = 0; i < OutputSize(); i++) {
if (Output(i)->name() == "ignore") continue; if (Output(i)->name() == "ignore") continue;
auto* dXdata = Output(i)->template mutable_data<T, Context>(); auto* dXdata = Output(i)->template mutable_data<T, Context>();
if (coeffs[i] == float(1)) { if (coeffs[i] == 1.f) {
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
count, dXdata, dYdata); count, dXdata, dYdata);
} else { } else {
math::Scale<T, Context>(count, math::Scale<T, Context>(count,
coeffs[i], dYdata, dXdata, &ctx()); coeffs[i], dYdata, dXdata, ctx());
} }
} }
} }
...@@ -88,11 +94,11 @@ void EltwiseGradientOp<Context>::ProdRunWithType() { ...@@ -88,11 +94,11 @@ void EltwiseGradientOp<Context>::ProdRunWithType() {
if (i == j) continue; if (i == j) continue;
auto* Xdata = Input(j).template data<T, Context>(); auto* Xdata = Input(j).template data<T, Context>();
if (!initialized) { if (!initialized) {
ctx().template Copy<T, Context, Context>(count, dXdata, Xdata); ctx()->template Copy<T, Context, Context>(count, dXdata, Xdata);
initialized = true; initialized = true;
} else math::Mul<T, Context>(count, Xdata, dXdata, dXdata); } else math::Mul<T, Context>(count, Xdata, dXdata, dXdata, ctx());
} }
math::Mul<T, Context>(count, dYdata, dXdata, dXdata); math::Mul<T, Context>(count, dYdata, dXdata, dXdata, ctx());
} }
} }
......
...@@ -8,7 +8,7 @@ template <class Context> template <typename T> ...@@ -8,7 +8,7 @@ template <class Context> template <typename T>
void ExpOp<Context>::RunWithType() { void ExpOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Exp<T, Context>(Output(0)->count(), Xdata, Ydata); math::Exp<T, Context>(Output(0)->count(), Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -30,7 +30,8 @@ void ExpGradientOp<Context>::RunWithType() { ...@@ -30,7 +30,8 @@ void ExpGradientOp<Context>::RunWithType() {
auto* Ydata = Input(0).template data<T, Context >(); auto* Ydata = Input(0).template data<T, Context >();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), dYdata, Ydata, dXdata); math::Mul<T, Context>(Output(0)->count(),
dYdata, Ydata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -12,7 +12,7 @@ void GramMatrixOp<Context>::RunWithType() { ...@@ -12,7 +12,7 @@ void GramMatrixOp<Context>::RunWithType() {
CblasNoTrans, CblasTrans, CblasNoTrans, CblasTrans,
dim, dim, inner_dim, dim, dim, inner_dim,
1.0, Xdata, Xdata, 1.0, Xdata, Xdata,
0.0, Ydata, &ctx()); 0.0, Ydata, ctx());
Xdata += x_offset; Xdata += x_offset;
Ydata += y_offset; Ydata += y_offset;
} }
...@@ -47,7 +47,7 @@ void GramMatrixGradientOp<Context>::RunWithType() { ...@@ -47,7 +47,7 @@ void GramMatrixGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
dim, inner_dim, dim, dim, inner_dim, dim,
2.0, dYdata, Xdata, 2.0, dYdata, Xdata,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
dYdata += y_offset; dYdata += y_offset;
dXdata += x_offset; dXdata += x_offset;
} }
......
...@@ -23,7 +23,7 @@ void InnerProductOp<Context>::TransRunWithType() { ...@@ -23,7 +23,7 @@ void InnerProductOp<Context>::TransRunWithType() {
CblasNoTrans, CblasTrans, CblasNoTrans, CblasTrans,
M, num_output, K, M, num_output, K,
1.0, Xdata, Wdata, 1.0, Xdata, Wdata,
0.0, Ydata, &ctx()); 0.0, Ydata, ctx());
if (InputSize() > 2) { if (InputSize() > 2) {
DECLARE_MULTIPLIER(multiplier, M); DECLARE_MULTIPLIER(multiplier, M);
...@@ -32,7 +32,7 @@ void InnerProductOp<Context>::TransRunWithType() { ...@@ -32,7 +32,7 @@ void InnerProductOp<Context>::TransRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
M, num_output, 1, M, num_output, 1,
1.0, multiplier, Bdata, 1.0, multiplier, Bdata,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
} }
...@@ -55,7 +55,7 @@ void InnerProductOp<Context>::NoTransRunWithType() { ...@@ -55,7 +55,7 @@ void InnerProductOp<Context>::NoTransRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
M, num_output, K, M, num_output, K,
1.0, Xdata, Wdata, 1.0, Xdata, Wdata,
0.0, Ydata, &ctx()); 0.0, Ydata, ctx());
if (InputSize() > 2) { if (InputSize() > 2) {
DECLARE_MULTIPLIER(multiplier, M); DECLARE_MULTIPLIER(multiplier, M);
...@@ -64,7 +64,7 @@ void InnerProductOp<Context>::NoTransRunWithType() { ...@@ -64,7 +64,7 @@ void InnerProductOp<Context>::NoTransRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
M, num_output, 1, M, num_output, 1,
1.0, multiplier, Bdata, 1.0, multiplier, Bdata,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
} }
...@@ -102,30 +102,30 @@ void InnerProductGradientOp<Context>::RunWithType() { ...@@ -102,30 +102,30 @@ void InnerProductGradientOp<Context>::RunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
Output(1)->ReshapeLike(Input(1)); Output(1)->ReshapeLike(Input(1));
auto* dWdata = Output(1)->template mutable_data<T, Context>(); auto* dWdata = Output(1)->template mutable_data<T, Context>(ctx());
if (TransW) { if (TransW) {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasTrans, CblasNoTrans, CblasTrans, CblasNoTrans,
num_output, K, M, num_output, K, M,
1.0, dYdata, Xdata, 1.0, dYdata, Xdata,
1.0, dWdata, &ctx()); 1.0, dWdata, ctx());
} else { } else {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasTrans, CblasNoTrans, CblasTrans, CblasNoTrans,
K, num_output, M, K, num_output, M,
1.0, Xdata, dYdata, 1.0, Xdata, dYdata,
1.0, dWdata, &ctx()); 1.0, dWdata, ctx());
} }
} }
if (Output(2)->name() != "ignore") { if (Output(2)->name() != "ignore") {
DECLARE_MULTIPLIER(multiplier, M); DECLARE_MULTIPLIER(multiplier, M);
Output(2)->Reshape({ num_output }); Output(2)->Reshape({ num_output });
auto* dBdata = Output(2)->template mutable_data<T, Context>(); auto* dBdata = Output(2)->template mutable_data<T, Context>(ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, M, num_output, CblasTrans, M, num_output,
1.0, dYdata, multiplier, 1.0, dYdata, multiplier,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
...@@ -136,13 +136,13 @@ void InnerProductGradientOp<Context>::RunWithType() { ...@@ -136,13 +136,13 @@ void InnerProductGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
M, K, num_output, M, K, num_output,
1.0, dYdata, Wdata, 1.0, dYdata, Wdata,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else { } else {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasTrans, CblasNoTrans, CblasTrans,
M, K, num_output, M, K, num_output,
1.0, dYdata, Wdata, 1.0, dYdata, Wdata,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} }
} }
} }
......
...@@ -7,7 +7,7 @@ template <class Context> template <typename T> ...@@ -7,7 +7,7 @@ template <class Context> template <typename T>
void LogOp<Context>::RunWithType() { void LogOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Log<T, Context>(Output(0)->count(), Xdata, Ydata); math::Log<T, Context>(Output(0)->count(), Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -29,7 +29,7 @@ void LogGradientOp<Context>::RunWithType() { ...@@ -29,7 +29,7 @@ void LogGradientOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Div<T, Context>(Output(0)->count(), dYdata, Xdata, dXdata); math::Div<T, Context>(Output(0)->count(), dYdata, Xdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -16,7 +16,7 @@ void MatmulOp<Context>::RunWithType() { ...@@ -16,7 +16,7 @@ void MatmulOp<Context>::RunWithType() {
TransB ? CblasTrans : CblasNoTrans, TransB ? CblasTrans : CblasNoTrans,
M, N, K1, M, N, K1,
1.0, X1data, X2data, 1.0, X1data, X2data,
0.0, Ydata, &ctx()); 0.0, Ydata, ctx());
X1data += x1_offset; X1data += x1_offset;
X2data += x2_offset; X2data += x2_offset;
Ydata += y_offset; Ydata += y_offset;
...@@ -76,13 +76,13 @@ void MatmulGradientOp<Context>::RunWithType() { ...@@ -76,13 +76,13 @@ void MatmulGradientOp<Context>::RunWithType() {
TransB ? CblasNoTrans : CblasTrans, TransB ? CblasNoTrans : CblasTrans,
M, K1, N, M, K1, N,
1.0, dYdata, X2data, 1.0, dYdata, X2data,
0.0, dX1data, &ctx()); 0.0, dX1data, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
TransA ? CblasNoTrans : CblasTrans, TransA ? CblasNoTrans : CblasTrans,
CblasNoTrans, CblasNoTrans,
K1, N, M, K1, N, M,
1.0, X1data, dYdata, 1.0, X1data, dYdata,
0.0, dX2data, &ctx()); 0.0, dX2data, ctx());
X1data += x1_offset; X1data += x1_offset;
X2data += x2_offset; X2data += x2_offset;
dX1data += x1_offset; dX1data += x1_offset;
......
...@@ -9,7 +9,7 @@ void MulOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void MulOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), x1, x2, y); math::Mul<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -18,34 +18,39 @@ void MulOp<Context>::BroadcastRunWithType(int type) { ...@@ -18,34 +18,39 @@ void MulOp<Context>::BroadcastRunWithType(int type) {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({
Output(0)->count() })[0];
if (type == 0 || type == 1) {
if (type == 0) { if (type == 0) {
outer_dim = Input(0).count(); x2 = Input(1).template data<T, CPUContext>();
inner_dim = 1; ctx()->template Copy<T, Context, Context>(
} else { Output(0)->count(), y, x1);
math::MulScalar<T, Context>(Output(0)->count(),
dragon_cast<float, T>(x2[0]), y, ctx());
} else if (type == 1) {
outer_dim = Input(0).count(0, Input(0).axis(-1)); outer_dim = Input(0).count(0, Input(0).axis(-1));
inner_dim = Input(0).dim(-1); inner_dim = Input(0).dim(-1);
}
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
auto* c = ws()->template caches<T, Context>(
{ Output(0)->count() })[0];
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x2, 1.0, multiplier, x2,
0.0, c, &ctx()); 0.0, c, ctx());
math::Mul<T, Context>(Output(0)->count(), x1, c, y); math::Mul<T, Context>(
Output(0)->count(), x1, c, y, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(0).dim(0); outer_dim = Input(0).dim(0);
inner_dim = Input(0).count(1); inner_dim = Input(0).count(1);
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
auto* c = ws()->template caches<T, Context>(
{ Output(0)->count() })[0];
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x2, multiplier, 1.0, x2, multiplier,
0.0, c, &ctx()); 0.0, c, ctx());
math::Mul<T, Context>(Output(0)->count(), x1, c, y); math::Mul<T, Context>(
Output(0)->count(), x1, c, y, ctx());
} }
} }
...@@ -79,13 +84,13 @@ void MulGradientOp<Context>::EltwiseRunWithType() { ...@@ -79,13 +84,13 @@ void MulGradientOp<Context>::EltwiseRunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(1)->count(), dy, x1, dx2); math::Mul<T, Context>(Output(1)->count(), dy, x1, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), dy, x2, dx1); math::Mul<T, Context>(Output(0)->count(), dy, x2, dx1, ctx());
} }
} }
...@@ -110,19 +115,19 @@ void MulGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -110,19 +115,19 @@ void MulGradientOp<Context>::BroadcastRunWithType(int type) {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({ X1->count() })[0]; auto* c = ws()->template caches<T, Context>({ X1->count() })[0];
math::Mul<T, Context>(X1->count(), dy, x1, c); math::Mul<T, Context>(X1->count(), dy, x1, c, ctx());
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, c, multiplier, 1.0, c, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, c, multiplier, 1.0, c, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} }
} }
...@@ -135,16 +140,16 @@ void MulGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -135,16 +140,16 @@ void MulGradientOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x2, 1.0, multiplier, x2,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x2, multiplier, 1.0, x2, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} }
math::Mul<T, Context>(X1->count(), dy, dx1, dx1); math::Mul<T, Context>(X1->count(), dy, dx1, dx1, ctx());
} }
} }
......
...@@ -9,16 +9,17 @@ void PowOp<Context>::RunWithType() { ...@@ -9,16 +9,17 @@ void PowOp<Context>::RunWithType() {
TIndex count = Input(0).count(); TIndex count = Input(0).count();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
if (power_scale == float(0)) { if (power_scale == 0.f) {
float value = (power == float(0)) ? float(1) : pow(shift, power); float value = (power == 0.f) ? 1.f : pow(shift, power);
math::Set<T, Context>(count, dragon_cast<T, float>(value), Ydata); math::Set<T, Context>(count,
dragon_cast<T, float>(value), Ydata, ctx());
return; return;
} }
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
ctx().template Copy<T, Context, Context>(count, Ydata, Xdata); ctx()->template Copy<T, Context, Context>(count, Ydata, Xdata);
if (scale != float(1)) math::Scal<T, Context>(count, scale, Ydata, &ctx()); if (scale != 1.f) math::Scal<T, Context>(count, scale, Ydata, ctx());
if (shift != float(0)) math::AddScalar<T, Context>(count, shift, Ydata); if (shift != 0.f) math::AddScalar<T, Context>(count, shift, Ydata, ctx());
if (power != float(1)) math::Pow<T, Context>(count, power, Ydata, Ydata); if (power != 1.f) math::Pow<T, Context>(count, power, Ydata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -42,35 +43,36 @@ void PowGradientOp<Context>::RunWithType() { ...@@ -42,35 +43,36 @@ void PowGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
if (power_scale == float(0) || power == float(1)) { if (power_scale == 0.f || power == 1.f) {
const T value = dragon_cast<T, float>(power_scale); const T value = dragon_cast<T, float>(power_scale);
math::Set<T, Context>(count, value, dXdata); math::Set<T, Context>(count, value, dXdata, ctx());
} else { } else {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
if (power == float(2)) { if (power == 2.f) {
math::Axpby<T, Context>(count, math::Axpby<T, Context>(count,
power_scale * scale, Xdata, power_scale * scale, Xdata,
0, dXdata, &ctx()); 0, dXdata, ctx());
if (shift != float(0)) if (shift != 0.f)
math::AddScalar<T, Context>(count, power_scale * shift, dXdata); math::AddScalar<T, Context>(count,
} else if (shift == float(0)) { power_scale * shift, dXdata, ctx());
} else if (shift == 0.f) {
auto* Ydata = Input(1).template data<T, Context>(); auto* Ydata = Input(1).template data<T, Context>();
math::Div<T, Context>(count, Ydata, Xdata, dXdata); math::Div<T, Context>(count, Ydata, Xdata, dXdata, ctx());
math::Scal<T, Context>(count, power, dXdata, &ctx()); math::Scal<T, Context>(count, power, dXdata, ctx());
} else { } else {
auto* Ydata = Input(1).template data<T, Context>(); auto* Ydata = Input(1).template data<T, Context>();
ctx().template Copy<T, Context, Context>(count, dXdata, Xdata); ctx()->template Copy<T, Context, Context>(count, dXdata, Xdata);
if (scale != float(1)) if (scale != 1.f)
math::Scal<T, Context>(count, scale, dXdata, &ctx()); math::Scal<T, Context>(count, scale, dXdata, ctx());
if (shift != float(0)) if (shift != 0.f)
math::AddScalar<T, Context>(count, shift, dXdata); math::AddScalar<T, Context>(count, shift, dXdata, ctx());
math::Div<T, Context>(count, Ydata, dXdata, dXdata); math::Div<T, Context>(count, Ydata, dXdata, dXdata, ctx());
if (power_scale != float(1)) if (power_scale != 1.f)
math::Scal<T, Context>(count, power_scale, dXdata, &ctx()); math::Scal<T, Context>(count, power_scale, dXdata, ctx());
} }
} }
if (power_scale != float(0)) if (power_scale != 0.f)
math::Mul<T, Context>(count, dYdata, dXdata, dXdata); math::Mul<T, Context>(count, dYdata, dXdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -9,7 +9,7 @@ void RAddOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void RAddOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Add<T, Context>(Output(0)->count(), x1, x2, y); math::Add<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -19,23 +19,24 @@ void RAddOp<Context>::BroadcastRunWithType(int type) { ...@@ -19,23 +19,24 @@ void RAddOp<Context>::BroadcastRunWithType(int type) {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), y, x2); Output(0)->count(), y, x2);
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
if (type == 0) { if (type == 0) {
outer_dim = Input(1).count(); x1 = Input(0).template data<T, CPUContext>();
inner_dim = 1; math::AddScalar<T, Context>(Output(0)->count(),
dragon_cast<float, T>(x1[0]), y, ctx());
} else { } else {
outer_dim = Input(1).count(0, Input(1).axis(-1)); outer_dim = Input(1).count(0, Input(1).axis(-1));
inner_dim = Input(1).dim(-1); inner_dim = Input(1).dim(-1);
}
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x1, 1.0, multiplier, x1,
1.0, y, &ctx()); 1.0, y, ctx());
}
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(1).dim(0); outer_dim = Input(1).dim(0);
inner_dim = Input(1).count(1); inner_dim = Input(1).count(1);
...@@ -44,7 +45,7 @@ void RAddOp<Context>::BroadcastRunWithType(int type) { ...@@ -44,7 +45,7 @@ void RAddOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x1, multiplier, 1.0, x1, multiplier,
1.0, y, &ctx()); 1.0, y, ctx());
} }
} }
...@@ -77,13 +78,13 @@ void RAddGradientOp<Context>::EltwiseRunWithType() { ...@@ -77,13 +78,13 @@ void RAddGradientOp<Context>::EltwiseRunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(1)->count(), dx2, dy); Output(1)->count(), dx2, dy);
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), dx1, dy); Output(0)->count(), dx1, dy);
} }
} }
...@@ -108,7 +109,7 @@ void RAddGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -108,7 +109,7 @@ void RAddGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = X2->dim(0); outer_dim = X2->dim(0);
inner_dim = X2->count(1); inner_dim = X2->count(1);
...@@ -116,13 +117,13 @@ void RAddGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -116,13 +117,13 @@ void RAddGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} }
} }
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
X2->count(), dx2, dy); X2->count(), dx2, dy);
} }
} }
......
...@@ -9,7 +9,7 @@ void RDivOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void RDivOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Div<T, Context>(Output(0)->count(), x1, x2, y); math::Div<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -34,8 +34,8 @@ void RDivOp<Context>::BroadcastRunWithType(int type) { ...@@ -34,8 +34,8 @@ void RDivOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x1, 1.0, multiplier, x1,
0.0, c, &ctx()); 0.0, c, ctx());
math::Div<T, Context>(Output(0)->count(), c, x2, y); math::Div<T, Context>(Output(0)->count(), c, x2, y, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(1).dim(0); outer_dim = Input(1).dim(0);
inner_dim = Input(1).count(1); inner_dim = Input(1).count(1);
...@@ -44,8 +44,8 @@ void RDivOp<Context>::BroadcastRunWithType(int type) { ...@@ -44,8 +44,8 @@ void RDivOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x1, multiplier, 1.0, x1, multiplier,
0.0, c, &ctx()); 0.0, c, ctx());
math::Div<T, Context>(Output(0)->count(), c, x2, y); math::Div<T, Context>(Output(0)->count(), c, x2, y, ctx());
} }
} }
...@@ -82,16 +82,16 @@ void RDivGradientOp<Context>::EltwiseRunWithType() { ...@@ -82,16 +82,16 @@ void RDivGradientOp<Context>::EltwiseRunWithType() {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({ X1->count() })[0]; auto* c = ws()->template caches<T, Context>({ X1->count() })[0];
math::Mul<T, Context>(X1->count(), dy, x1, c); // dY * X1 math::Mul<T, Context>(X1->count(), dy, x1, c, ctx()); // dY * X1
math::Square<T, Context>(X2->count(), x2, dx2); // X2^{2} math::Square<T, Context>(X2->count(), x2, dx2, ctx()); // X2^{2}
math::Inv<T, Context>(X2->count(), -1, dx2, dx2); // -1 / X2^{2} math::Inv<T, Context>(X2->count(), -1, dx2, dx2, ctx()); // -1 / X2^{2}
math::Mul<T, Context>(X2->count(), c, dx2, dx2); math::Mul<T, Context>(X2->count(), c, dx2, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
math::Div<T, Context>(X1->count(), dy, x2, dx1); math::Div<T, Context>(X1->count(), dy, x2, dx1, ctx());
} }
} }
...@@ -116,19 +116,19 @@ void RDivGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -116,19 +116,19 @@ void RDivGradientOp<Context>::BroadcastRunWithType(int type) {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({ X2->count() })[0]; auto* c = ws()->template caches<T, Context>({ X2->count() })[0];
math::Div<T, Context>(X2->count(), dy, x2, c); math::Div<T, Context>(X2->count(), dy, x2, c, ctx());
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, c, multiplier, 1.0, c, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, c, multiplier, 1.0, c, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} }
} }
...@@ -142,18 +142,18 @@ void RDivGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -142,18 +142,18 @@ void RDivGradientOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
-1.0, multiplier, x1, -1.0, multiplier, x1,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
-1.0, x1, multiplier, -1.0, x1, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} }
math::Mul<T, Context>(X2->count(), dy, dx2, dx2); math::Mul<T, Context>(X2->count(), dy, dx2, dx2, ctx());
math::Div<T, Context>(X2->count(), dx2, x2, dx2); math::Div<T, Context>(X2->count(), dx2, x2, dx2, ctx());
math::Div<T, Context>(X2->count(), dx2, x2, dx2); math::Div<T, Context>(X2->count(), dx2, x2, dx2, ctx());
} }
} }
......
...@@ -9,7 +9,7 @@ void RMulOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void RMulOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), x1, x2, y); math::Mul<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -18,34 +18,39 @@ void RMulOp<Context>::BroadcastRunWithType(int type) { ...@@ -18,34 +18,39 @@ void RMulOp<Context>::BroadcastRunWithType(int type) {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({
Output(0)->count() })[0];
if (type == 0 || type == 1) {
if (type == 0) { if (type == 0) {
outer_dim = Input(1).count(); x1 = Input(0).template data<T, CPUContext>();
inner_dim = 1; ctx()->template Copy<T, Context, Context>(
} else { Output(0)->count(), y, x2);
math::MulScalar<T, Context>(Output(0)->count(),
dragon_cast<float, T>(x1[0]), y, ctx());
} else if (type == 1) {
outer_dim = Input(1).count(0, Input(1).axis(-1)); outer_dim = Input(1).count(0, Input(1).axis(-1));
inner_dim = Input(1).dim(-1); inner_dim = Input(1).dim(-1);
}
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
auto* c = ws()->template caches<T, Context>(
{ Output(0)->count() })[0];
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x1, 1.0, multiplier, x1,
0.0, c, &ctx()); 0.0, c, ctx());
math::Mul<T, Context>(Output(0)->count(), c, x2, y); math::Mul<T, Context>(
Output(0)->count(), c, x2, y, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(1).dim(0); outer_dim = Input(1).dim(0);
inner_dim = Input(1).count(1); inner_dim = Input(1).count(1);
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
auto* c = ws()->template caches<T, Context>(
{ Output(0)->count() })[0];
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x1, multiplier, 1.0, x1, multiplier,
0.0, c, &ctx()); 0.0, c, ctx());
math::Mul<T, Context>(Output(0)->count(), c, x2, y); math::Mul<T, Context>(
Output(0)->count(), c, x2, y, ctx());
} }
} }
...@@ -79,13 +84,13 @@ void RMulGradientOp<Context>::EltwiseRunWithType() { ...@@ -79,13 +84,13 @@ void RMulGradientOp<Context>::EltwiseRunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(1)->count(), dy, x1, dx2); math::Mul<T, Context>(Output(1)->count(), dy, x1, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), dy, x2, dx1); math::Mul<T, Context>(Output(0)->count(), dy, x2, dx1, ctx());
} }
} }
...@@ -110,19 +115,19 @@ void RMulGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -110,19 +115,19 @@ void RMulGradientOp<Context>::BroadcastRunWithType(int type) {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
auto* c = ws()->template caches<T, Context>({ X2->count() })[0]; auto* c = ws()->template caches<T, Context>({ X2->count() })[0];
math::Mul<T, Context>(X2->count(), dy, x2, c); math::Mul<T, Context>(X2->count(), dy, x2, c, ctx());
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, c, multiplier, 1.0, c, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, c, multiplier, 1.0, c, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} }
} }
...@@ -135,16 +140,16 @@ void RMulGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -135,16 +140,16 @@ void RMulGradientOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x1, 1.0, multiplier, x1,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} else if (type == 2) { } else if (type == 2) {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x1, multiplier, 1.0, x1, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} }
math::Mul<T, Context>(X2->count(), dy, dx2, dx2); math::Mul<T, Context>(X2->count(), dy, dx2, dx2, ctx());
} }
} }
......
...@@ -9,7 +9,7 @@ void RSubOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,7 @@ void RSubOp<Context>::EltwiseRunWithType() {
auto* x1 = Input(0).template data<T, Context>(); auto* x1 = Input(0).template data<T, Context>();
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
math::Sub<T, Context>(Output(0)->count(), x1, x2, y); math::Sub<T, Context>(Output(0)->count(), x1, x2, y, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -19,7 +19,7 @@ void RSubOp<Context>::BroadcastRunWithType(int type) { ...@@ -19,7 +19,7 @@ void RSubOp<Context>::BroadcastRunWithType(int type) {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), y, x2); Output(0)->count(), y, x2);
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
...@@ -35,7 +35,7 @@ void RSubOp<Context>::BroadcastRunWithType(int type) { ...@@ -35,7 +35,7 @@ void RSubOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, multiplier, x1, 1.0, multiplier, x1,
-1.0, y, &ctx()); -1.0, y, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = Input(1).dim(0); outer_dim = Input(1).dim(0);
inner_dim = Input(1).count(1); inner_dim = Input(1).count(1);
...@@ -44,7 +44,7 @@ void RSubOp<Context>::BroadcastRunWithType(int type) { ...@@ -44,7 +44,7 @@ void RSubOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
1.0, x1, multiplier, 1.0, x1, multiplier,
-1.0, y, &ctx()); -1.0, y, ctx());
} }
} }
...@@ -78,12 +78,12 @@ void RSubGradientOp<Context>::EltwiseRunWithType() { ...@@ -78,12 +78,12 @@ void RSubGradientOp<Context>::EltwiseRunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
math::Scale<T, Context>( math::Scale<T, Context>(
Output(1)->count(), -1, dy, dx2, &ctx()); Output(1)->count(), -1, dy, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), dx1, dy); Output(0)->count(), dx1, dy);
} }
} }
...@@ -108,7 +108,7 @@ void RSubGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -108,7 +108,7 @@ void RSubGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = X2->dim(0); outer_dim = X2->dim(0);
inner_dim = X2->count(1); inner_dim = X2->count(1);
...@@ -116,14 +116,14 @@ void RSubGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -116,14 +116,14 @@ void RSubGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
0.0, dx1, &ctx()); 0.0, dx1, ctx());
} }
} }
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
math::Scale<T, Context>( math::Scale<T, Context>(
X2->count(), -1, dy, dx2, &ctx()); X2->count(), -1, dy, dx2, ctx());
} }
} }
......
...@@ -7,7 +7,7 @@ template <class Context> template <typename T> ...@@ -7,7 +7,7 @@ template <class Context> template <typename T>
void SquareOp<Context>::RunWithType() { void SquareOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Pow<T, Context>(Output(0)->count(), 2.0, Xdata, Ydata); math::Pow<T, Context>(Output(0)->count(), 2.0, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -29,8 +29,8 @@ void SquareGradientOp<Context>::RunWithType() { ...@@ -29,8 +29,8 @@ void SquareGradientOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), dYdata, Xdata, dXdata); math::Mul<T, Context>(Output(0)->count(), dYdata, Xdata, dXdata, ctx());
math::Scal<T, Context>(Output(0)->count(), 2.0, dXdata, &ctx()); math::Scal<T, Context>(Output(0)->count(), 2.0, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -9,7 +9,8 @@ void SubOp<Context>::EltwiseRunWithType() { ...@@ -9,7 +9,8 @@ void SubOp<Context>::EltwiseRunWithType() {
auto* X1data = Input(0).template data<T, Context>(); auto* X1data = Input(0).template data<T, Context>();
auto* X2data = Input(1).template data<T, Context>(); auto* X2data = Input(1).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Sub<T, Context>(Output(0)->count(), X1data, X2data, Ydata); math::Sub<T, Context>(Output(0)->count(),
X1data, X2data, Ydata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -19,23 +20,24 @@ void SubOp<Context>::BroadcastRunWithType(int type) { ...@@ -19,23 +20,24 @@ void SubOp<Context>::BroadcastRunWithType(int type) {
auto* x2 = Input(1).template data<T, Context>(); auto* x2 = Input(1).template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), y, x1); Output(0)->count(), y, x1);
if (type == 0 || type == 1) { if (type == 0 || type == 1) {
if (type == 0) { if (type == 0) {
outer_dim = Input(0).count(); x2 = Input(1).template data<T, CPUContext>();
inner_dim = 1; math::AddScalar<T, Context>(Output(0)->count(),
-dragon_cast<float, T>(x2[0]), y, ctx());
} else { } else {
outer_dim = Input(0).count(0, Input(0).axis(-1)); outer_dim = Input(0).count(0, Input(0).axis(-1));
inner_dim = Input(0).dim(-1); inner_dim = Input(0).dim(-1);
}
DECLARE_MULTIPLIER(multiplier, outer_dim); DECLARE_MULTIPLIER(multiplier, outer_dim);
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
-1.0, multiplier, x2, -1.0, multiplier, x2,
1.0, y, &ctx()); 1.0, y, ctx());
}
} }
else if (type == 2) { else if (type == 2) {
outer_dim = Input(0).dim(0); outer_dim = Input(0).dim(0);
...@@ -45,7 +47,7 @@ void SubOp<Context>::BroadcastRunWithType(int type) { ...@@ -45,7 +47,7 @@ void SubOp<Context>::BroadcastRunWithType(int type) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
outer_dim, inner_dim, 1, outer_dim, inner_dim, 1,
-1.0, x2, multiplier, -1.0, x2, multiplier,
1.0, y, &ctx()); 1.0, y, ctx());
} }
} }
...@@ -79,12 +81,12 @@ void SubGradientOp<Context>::EltwiseRunWithType() { ...@@ -79,12 +81,12 @@ void SubGradientOp<Context>::EltwiseRunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dx2 = Output(1)->template mutable_data<T, Context>(); auto* dx2 = Output(1)->template mutable_data<T, Context>();
math::Scale<T, Context>(Output(1)->count(), math::Scale<T, Context>(Output(1)->count(),
-1.0, dy, dx2, &ctx()); -1.0, dy, dx2, ctx());
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), dx1, dy); Output(0)->count(), dx1, dy);
} }
} }
...@@ -109,7 +111,7 @@ void SubGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -109,7 +111,7 @@ void SubGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, outer_dim, inner_dim, CblasTrans, outer_dim, inner_dim,
-1.0, dy, multiplier, -1.0, dy, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} else if (type == 2) { } else if (type == 2) {
outer_dim = X1->dim(0); outer_dim = X1->dim(0);
inner_dim = X1->count(1); inner_dim = X1->count(1);
...@@ -117,13 +119,13 @@ void SubGradientOp<Context>::BroadcastRunWithType(int type) { ...@@ -117,13 +119,13 @@ void SubGradientOp<Context>::BroadcastRunWithType(int type) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, outer_dim, inner_dim, CblasNoTrans, outer_dim, inner_dim,
-1.0, dy, multiplier, -1.0, dy, multiplier,
0.0, dx2, &ctx()); 0.0, dx2, ctx());
} }
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
auto* dx1 = Output(0)->template mutable_data<T, Context>(); auto* dx1 = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
X1->count(), dx1, dy); X1->count(), dx1, dy);
} }
} }
......
...@@ -8,7 +8,8 @@ void CompareOp<Context>::EqualRunWithType() { ...@@ -8,7 +8,8 @@ void CompareOp<Context>::EqualRunWithType() {
auto* X1data = Input(0).template data<T, Context>(); auto* X1data = Input(0).template data<T, Context>();
auto* X2data = Input(1).template data<T, Context>(); auto* X2data = Input(1).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Equal<T, Context>(Output(0)->count(), X1data, X2data, Ydata); kernel::Equal<T, Context>(Output(0)->count(),
X1data, X2data, Ydata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -7,7 +7,7 @@ void CopyOp<Context>::RunWithType() { ...@@ -7,7 +7,7 @@ void CopyOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), Ydata, Xdata); Output(0)->count(), Ydata, Xdata);
} }
......
...@@ -20,10 +20,10 @@ void CTCLossGradientOp<Context>::RunWithType() { ...@@ -20,10 +20,10 @@ void CTCLossGradientOp<Context>::RunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
math::Scale<T, Context>(Output(0)->count(), math::Scale<T, Context>(Output(0)->count(),
dYdata_host, Gdata, dXdata, &ctx()); dYdata_host, Gdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -45,7 +45,7 @@ void CuDNNCTCLossOp<Context>::RunWithType() { ...@@ -45,7 +45,7 @@ void CuDNNCTCLossOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&grad_desc, Input(0).dims()); cudnnSetTensorDesc<T>(&grad_desc, Input(0).dims());
CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize( CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(
ctx().cudnn_handle(), prob_desc, grad_desc, ctx()->cudnn_handle(), prob_desc, grad_desc,
packed_labels.data(), label_lengths.data(), packed_labels.data(), label_lengths.data(),
input_lengths.data(), input_lengths.data(),
ctc_algo, ctc_desc, &workspace_size)); ctc_algo, ctc_desc, &workspace_size));
...@@ -58,7 +58,7 @@ void CuDNNCTCLossOp<Context>::RunWithType() { ...@@ -58,7 +58,7 @@ void CuDNNCTCLossOp<Context>::RunWithType() {
auto* WSdata = (uint8_t*)ws()->template caches<Context>({ auto* WSdata = (uint8_t*)ws()->template caches<Context>({
workspace_size })[0]; workspace_size })[0];
CUDNN_CHECK(cudnnCTCLoss(ctx().cudnn_handle(), CUDNN_CHECK(cudnnCTCLoss(ctx()->cudnn_handle(),
prob_desc, Pdata, packed_labels.data(), prob_desc, Pdata, packed_labels.data(),
label_lengths.data(), input_lengths.data(), label_lengths.data(), input_lengths.data(),
Ydata, grad_desc, Gdata, Ydata, grad_desc, Gdata,
......
...@@ -12,11 +12,13 @@ void L1LossOp<Context>::RunWithType() { ...@@ -12,11 +12,13 @@ void L1LossOp<Context>::RunWithType() {
auto* diff_data = diff->template mutable_data<T, Context>(); auto* diff_data = diff->template mutable_data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Sub<T, Context>(Input(0).count(), X0data, X1data, diff_data); math::Sub<T, Context>(Input(0).count(),
X0data, X1data, diff_data, ctx());
if (InputSize() > 2) { if (InputSize() > 2) {
CHECK_EQ(Input(0).count(), Input(2).count()); CHECK_EQ(Input(0).count(), Input(2).count());
auto* Wdata = Input(2).template data<T, Context>(); auto* Wdata = Input(2).template data<T, Context>();
math::Mul<T, Context>(diff->count(), Wdata, diff_data, diff_data); math::Mul<T, Context>(diff->count(),
Wdata, diff_data, diff_data, ctx());
} }
T normalizer = 1; T normalizer = 1;
...@@ -27,11 +29,13 @@ void L1LossOp<Context>::RunWithType() { ...@@ -27,11 +29,13 @@ void L1LossOp<Context>::RunWithType() {
} }
T loss = math::ASum<T, Context>(diff->count(), diff_data); T loss = math::ASum<T, Context>(diff->count(), diff_data);
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<T, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void L1LossOp<Context>::RunOnDevice() { void L1LossOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
CHECK_EQ(Input(0).count(), Input(1).count()); CHECK_EQ(Input(0).count(), Input(1).count());
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
diff = ws()->CreateTensor("/mnt/" + anchor() + "/l1_loss/diff"); diff = ws()->CreateTensor("/mnt/" + anchor() + "/l1_loss/diff");
...@@ -51,9 +55,11 @@ template <class Context> template <typename T> ...@@ -51,9 +55,11 @@ template <class Context> template <typename T>
void L1LossGradientOp<Context>::RunWithType() { void L1LossGradientOp<Context>::RunWithType() {
auto* diff_data = diff->template mutable_data<T, Context>(); auto* diff_data = diff->template mutable_data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
kernel::AbsGrad<T, Context>(diff->count(), diff_data, diff_data); ctx()->FinishDeviceCompution();
kernel::AbsGrad<T, Context>(diff->count(),
diff_data, diff_data, ctx());
T alpha = dYdata_host, normalizer = 1; T alpha = dYdata_host, normalizer = 1;
if (normalization == "BATCH_SIZE") { if (normalization == "BATCH_SIZE") {
...@@ -69,7 +75,7 @@ void L1LossGradientOp<Context>::RunWithType() { ...@@ -69,7 +75,7 @@ void L1LossGradientOp<Context>::RunWithType() {
const T sign = (i == 0) ? 1 : -1; const T sign = (i == 0) ? 1 : -1;
alpha *= sign; alpha *= sign;
math::Axpby<T, Context>(Output(i)->count(), math::Axpby<T, Context>(Output(i)->count(),
alpha, diff_data, 0, dXdata, &ctx()); alpha, diff_data, 0, dXdata, ctx());
} }
} }
......
...@@ -9,12 +9,14 @@ void L2LossOp<Context>::RunWithType() { ...@@ -9,12 +9,14 @@ void L2LossOp<Context>::RunWithType() {
auto* X0data = Input(0).template data<T, Context>(); auto* X0data = Input(0).template data<T, Context>();
auto* X1data = Input(1).template data<T, Context>(); auto* X1data = Input(1).template data<T, Context>();
auto* diff_data = diff->template mutable_data<T, Context>(); auto* diff_data = diff->template mutable_data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<float, Context>();
math::Sub<T, Context>(diff->count(), X0data, X1data, diff_data); math::Sub<T, Context>(diff->count(),
X0data, X1data, diff_data, ctx());
if (InputSize() > 2) { if (InputSize() > 2) {
CHECK_EQ(Input(0).count(), Input(2).count()); CHECK_EQ(Input(0).count(), Input(2).count());
auto* Wdata = Input(2).template data<T, Context>(); auto* Wdata = Input(2).template data<T, Context>();
math::Mul<T, Context>(diff->count(), Wdata, diff_data, diff_data); math::Mul<T, Context>(diff->count(),
Wdata, diff_data, diff_data, ctx());
} }
T normalizer = 1; T normalizer = 1;
...@@ -23,10 +25,12 @@ void L2LossOp<Context>::RunWithType() { ...@@ -23,10 +25,12 @@ void L2LossOp<Context>::RunWithType() {
} else if (normalization == "FULL") { } else if (normalization == "FULL") {
normalizer = Input(0).count(); normalizer = Input(0).count();
} }
normalizer *= 2;
T loss = T(0.5) * math::Dot<T, Context>(diff->count(), T loss;
diff_data, diff_data, &ctx()); math::Dot<T, Context>(diff->count(),
math::Set<T, Context>(1, loss / normalizer, Ydata); diff_data, diff_data, &loss, ctx());
math::Set<T, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -48,10 +52,11 @@ OPERATOR_SCHEMA(L2Loss).NumInputs(2, 3).NumOutputs(1); ...@@ -48,10 +52,11 @@ OPERATOR_SCHEMA(L2Loss).NumInputs(2, 3).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void L2LossGradientOp<Context>::RunWithType() { void L2LossGradientOp<Context>::RunWithType() {
auto* diff_data = diff->template mutable_data<T, Context>(); auto* diff_data = diff->template data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
ctx()->FinishDeviceCompution();
T alpha = dYdata_host, normalizer = 1; T alpha = dYdata_host, normalizer = 1;
if (normalization == "BATCH_SIZE") { if (normalization == "BATCH_SIZE") {
...@@ -67,7 +72,7 @@ void L2LossGradientOp<Context>::RunWithType() { ...@@ -67,7 +72,7 @@ void L2LossGradientOp<Context>::RunWithType() {
const T sign = (i == 0) ? 1 : -1; const T sign = (i == 0) ? 1 : -1;
alpha *= sign; alpha *= sign;
math::Axpby<T, Context>(Output(i)->count(), math::Axpby<T, Context>(Output(i)->count(),
alpha, diff_data, 0, dXdata, &ctx()); alpha, diff_data, 0, dXdata, ctx());
} }
} }
......
...@@ -13,11 +13,11 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() { ...@@ -13,11 +13,11 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() {
auto* Fdata = flags.template mutable_data<T, Context>(); auto* Fdata = flags.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>( kernel::SigmoidCrossEntropy<T, Context>(
Input(0).count(), Xdata, Tdata, Ldata, Fdata, &ctx()); Input(0).count(), Xdata, Tdata, Ldata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template CopyFrom<Context>(losses); Output(0)->template CopyFrom<Context>(losses, ctx());
return; return;
} }
...@@ -35,11 +35,13 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() { ...@@ -35,11 +35,13 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() {
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<T, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void SigmoidCrossEntropyOp<Context>::RunOnDevice() { void SigmoidCrossEntropyOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
CHECK_EQ(Input(0).count(), Input(1).count()) CHECK_EQ(Input(0).count(), Input(1).count())
<< "\nNumber of predictions must match the number of labels."; << "\nNumber of predictions must match the number of labels.";
losses.ReshapeLike(Input(0)); losses.ReshapeLike(Input(0));
...@@ -63,12 +65,12 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -63,12 +65,12 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
auto* Fdata = flags.template mutable_data<T, Context>(); auto* Fdata = flags.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropyGrad<T, Context>( kernel::SigmoidCrossEntropyGrad<T, Context>(
Input(0).count(), Xdata, Tdata, dXdata, Fdata, &ctx()); Input(0).count(), Xdata, Tdata, dXdata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), math::Mul<T, Context>(Output(0)->count(),
dYdata, dXdata, dXdata); return; dYdata, dXdata, dXdata, ctx()); return;
} }
T normalizer = 1; T normalizer = 1;
...@@ -83,14 +85,16 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -83,14 +85,16 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
} }
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
math::Scal<T, Context>(Output(0)->count(), math::Scal<T, Context>(Output(0)->count(),
dYdata_host / normalizer, dXdata, &ctx()); dYdata_host / normalizer, dXdata, ctx());
} }
template <class Context> template <class Context>
void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() { void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
flags.ReshapeLike(Input(0)); flags.ReshapeLike(Input(0));
......
...@@ -15,11 +15,11 @@ void SigmoidFocalLossOp<Context>::RunWithType() { ...@@ -15,11 +15,11 @@ void SigmoidFocalLossOp<Context>::RunWithType() {
kernel::SigmoidFocalLoss<T, Context>( kernel::SigmoidFocalLoss<T, Context>(
outer_dim, axis_dim, inner_dim, outer_dim, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
Xdata, Tdata, Ldata, Fdata, &ctx()); Xdata, Tdata, Ldata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template CopyFrom<Context>(losses); Output(0)->template CopyFrom<Context>(losses, ctx());
return; return;
} }
...@@ -37,11 +37,13 @@ void SigmoidFocalLossOp<Context>::RunWithType() { ...@@ -37,11 +37,13 @@ void SigmoidFocalLossOp<Context>::RunWithType() {
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<T, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void SigmoidFocalLossOp<Context>::RunOnDevice() { void SigmoidFocalLossOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
axis_dim = Input(0).dim(axis); axis_dim = Input(0).dim(axis);
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
...@@ -71,12 +73,12 @@ void SigmoidFocalLossGradientOp<Context>::RunWithType() { ...@@ -71,12 +73,12 @@ void SigmoidFocalLossGradientOp<Context>::RunWithType() {
kernel::SigmoidFocalLossGradient<T, Context>( kernel::SigmoidFocalLossGradient<T, Context>(
outer_dim, axis_dim, inner_dim, outer_dim, axis_dim, inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
Xdata, Tdata, dXdata, Fdata, &ctx()); Xdata, Tdata, dXdata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), math::Mul<T, Context>(Output(0)->count(),
dYdata, dXdata, dXdata); return; dYdata, dXdata, dXdata, ctx()); return;
} }
T normalizer = 1; T normalizer = 1;
...@@ -91,14 +93,16 @@ void SigmoidFocalLossGradientOp<Context>::RunWithType() { ...@@ -91,14 +93,16 @@ void SigmoidFocalLossGradientOp<Context>::RunWithType() {
} }
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
math::Scal<T, Context>(Output(0)->count(), math::Scal<T, Context>(Output(0)->count(),
dYdata_host / normalizer, dXdata, &ctx()); dYdata_host / normalizer, dXdata, ctx());
} }
template <class Context> template <class Context>
void SigmoidFocalLossGradientOp<Context>::RunOnDevice() { void SigmoidFocalLossGradientOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
axis_dim = Input(0).dim(axis); axis_dim = Input(0).dim(axis);
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
......
...@@ -11,20 +11,21 @@ void SmoothL1LossOp<Context>::RunWithType() { ...@@ -11,20 +11,21 @@ void SmoothL1LossOp<Context>::RunWithType() {
auto* X1data = Input(1).template data<T, Context>(); auto* X1data = Input(1).template data<T, Context>();
auto* diff_data = diff->template mutable_data<T, Context>(); auto* diff_data = diff->template mutable_data<T, Context>();
auto* error_data = error->template mutable_data<T, Context>(); auto* error_data = error->template mutable_data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<float, Context>();
math::Sub<T, Context>(diff->count(), X0data, X1data, diff_data); math::Sub<T, Context>(diff->count(),
X0data, X1data, diff_data, ctx());
if (InputSize() > 2) { if (InputSize() > 2) {
auto* inside_w_data = Input(2).template data<T, Context>(); auto* inside_w_data = Input(2).template data<T, Context>();
math::Mul<T, Context>(diff->count(), math::Mul<T, Context>(diff->count(),
inside_w_data, diff_data, diff_data); inside_w_data, diff_data, diff_data, ctx());
} }
kernel::SmoothL1<T, Context>( kernel::SmoothL1<T, Context>(diff->count(),
diff->count(), beta, diff_data, error_data); beta, diff_data, error_data, ctx());
if (InputSize() > 3) { if (InputSize() > 3) {
auto* outside_w_data = Input(3).template data<T, Context>(); auto* outside_w_data = Input(3).template data<T, Context>();
math::Mul<T, Context>(diff->count(), math::Mul<T, Context>(diff->count(),
outside_w_data, error_data, error_data); outside_w_data, error_data, error_data, ctx());
} }
T normalizer = 1; T normalizer = 1;
...@@ -34,12 +35,14 @@ void SmoothL1LossOp<Context>::RunWithType() { ...@@ -34,12 +35,14 @@ void SmoothL1LossOp<Context>::RunWithType() {
normalizer = Input(0).count(); normalizer = Input(0).count();
} }
T loss = math::ASum<T, Context>(error->count(), error_data); float loss = math::ASum<float, Context>(error->count(), error_data);
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<float, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void SmoothL1LossOp<Context>::RunOnDevice() { void SmoothL1LossOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
CHECK(Input(0).dims() == Input(1).dims()); CHECK(Input(0).dims() == Input(1).dims());
if (InputSize() > 2) CHECK(Input(0).dims() == Input(2).dims()); if (InputSize() > 2) CHECK(Input(0).dims() == Input(2).dims());
if (InputSize() > 3) CHECK(Input(0).dims() == Input(3).dims()); if (InputSize() > 3) CHECK(Input(0).dims() == Input(3).dims());
...@@ -64,10 +67,12 @@ template <class Context> template <typename T> ...@@ -64,10 +67,12 @@ template <class Context> template <typename T>
void SmoothL1LossGradientOp<Context>::RunWithType() { void SmoothL1LossGradientOp<Context>::RunWithType() {
auto* diff_data = diff->template mutable_data<T, Context>(); auto* diff_data = diff->template mutable_data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
kernel::SmoothL1Grad<T, Context>( ctx()->FinishDeviceCompution();
diff->count(), beta, diff_data, diff_data);
kernel::SmoothL1Grad<T, Context>(diff->count(),
beta, diff_data, diff_data, ctx());
T alpha = dYdata_host, normalizer = 1; T alpha = dYdata_host, normalizer = 1;
if (normalization == "BATCH_SIZE") { if (normalization == "BATCH_SIZE") {
...@@ -83,16 +88,16 @@ void SmoothL1LossGradientOp<Context>::RunWithType() { ...@@ -83,16 +88,16 @@ void SmoothL1LossGradientOp<Context>::RunWithType() {
const T sign = (i == 0) ? 1 : -1; const T sign = (i == 0) ? 1 : -1;
alpha *= sign; alpha *= sign;
math::Axpby<T, Context>(Output(i)->count(), math::Axpby<T, Context>(Output(i)->count(),
alpha, diff_data, 0, dXdata, &ctx()); alpha, diff_data, 0, dXdata, ctx());
if (InputSize() > 3) { if (InputSize() > 3) {
auto* inside_w_data = Input(2).template data<T, Context>(); auto* inside_w_data = Input(2).template data<T, Context>();
math::Mul<T, Context>(Output(i)->count(), math::Mul<T, Context>(Output(i)->count(),
inside_w_data, dXdata, dXdata); inside_w_data, dXdata, dXdata, ctx());
} }
if (InputSize() > 4) { if (InputSize() > 4) {
auto* outside_w_data = Input(3).template data<T, Context>(); auto* outside_w_data = Input(3).template data<T, Context>();
math::Mul<T, Context>(Output(i)->count(), math::Mul<T, Context>(Output(i)->count(),
outside_w_data, dXdata, dXdata); outside_w_data, dXdata, dXdata, ctx());
} }
} }
} }
......
...@@ -26,15 +26,15 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -26,15 +26,15 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
auto* Pdata = prob->template data<T, Context>(); auto* Pdata = prob->template data<T, Context>();
auto* Tdata = Input(1).template data<T, Context>(); auto* Tdata = Input(1).template data<T, Context>();
auto* Ldata = losses.template mutable_data<T, Context>(); auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SoftmaxCrossEntropy<T, Context>( kernel::SoftmaxCrossEntropy<T, Context>(Input(0).count(),
Input(0).count(), Pdata, Tdata, Ldata); Pdata, Tdata, Ldata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->Reshape({ outer_dim * inner_dim }); Output(0)->Reshape({ outer_dim * inner_dim });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Sum<T, Context>(outer_dim * inner_dim, kernel::Sum<T, Context>(outer_dim * inner_dim,
Input(0).dim(axis), inner_dim, Input(0).dim(axis), inner_dim,
Ldata, Ydata); return; Ldata, Ydata, ctx()); return;
} }
T normalizer = 1; T normalizer = 1;
...@@ -47,11 +47,13 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -47,11 +47,13 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<T, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void SoftmaxCrossEntropyOp<Context>::RunOnDevice() { void SoftmaxCrossEntropyOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
CHECK_EQ(Input(0).count(), Input(1).count()) CHECK_EQ(Input(0).count(), Input(1).count())
...@@ -76,16 +78,16 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -76,16 +78,16 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
auto* Tdata = Input(1).template data<T, Context>(); auto* Tdata = Input(1).template data<T, Context>();
auto* Pdata = prob->template mutable_data<T, Context>(); auto* Pdata = prob->template mutable_data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata); ctx()->template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
math::Axpy<T, Context>(Output(0)->count(), math::Axpy<T, Context>(Output(0)->count(),
-1.0, Tdata, dXdata, &ctx()); -1.0, Tdata, dXdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(outer_dim * inner_dim, kernel::SumGrad<T, Context>(outer_dim * inner_dim,
Input(0).dim(axis), inner_dim, 1.0, dYdata, Pdata); Input(0).dim(axis), inner_dim, 1.0, dYdata, Pdata, ctx());
math::Mul<T, Context>(Output(0)->count(), math::Mul<T, Context>(Output(0)->count(),
Pdata, dXdata, dXdata); return; Pdata, dXdata, dXdata, ctx()); return;
} }
T normalizer = 1; T normalizer = 1;
...@@ -96,10 +98,10 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -96,10 +98,10 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
} }
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
math::Scal<T, Context>(Output(0)->count(), math::Scal<T, Context>(Output(0)->count(),
dYdata_host / normalizer, dXdata, &ctx()); dYdata_host / normalizer, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -20,11 +20,11 @@ void SoftmaxFocalLossOp<Context>::RunWithType() { ...@@ -20,11 +20,11 @@ void SoftmaxFocalLossOp<Context>::RunWithType() {
outer_dim, Input(0).dim(axis), inner_dim, outer_dim, Input(0).dim(axis), inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
Pdata, Tdata, Idata, this->ignores.count(), Pdata, Tdata, Idata, this->ignores.count(),
Ldata, Fdata, &ctx()); Ldata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template CopyFrom<Context>(losses); Output(0)->template CopyFrom<Context>(losses, ctx());
return; return;
} }
...@@ -42,11 +42,13 @@ void SoftmaxFocalLossOp<Context>::RunWithType() { ...@@ -42,11 +42,13 @@ void SoftmaxFocalLossOp<Context>::RunWithType() {
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(1, loss / normalizer, Ydata); math::Set<T, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void SoftmaxFocalLossOp<Context>::RunOnDevice() { void SoftmaxFocalLossOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
CHECK_EQ(outer_dim * inner_dim, Input(1).count()) CHECK_EQ(outer_dim * inner_dim, Input(1).count())
...@@ -80,16 +82,16 @@ void SoftmaxFocalLossGradientOp<Context>::RunWithType() { ...@@ -80,16 +82,16 @@ void SoftmaxFocalLossGradientOp<Context>::RunWithType() {
outer_dim, Output(0)->dim(axis), inner_dim, outer_dim, Output(0)->dim(axis), inner_dim,
pos_alpha, neg_alpha, gamma, neg_id, pos_alpha, neg_alpha, gamma, neg_id,
Pdata, Tdata, Idata, this->ignores.count(), Pdata, Tdata, Idata, this->ignores.count(),
dXdata, Fdata, &ctx()); dXdata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>( kernel::SumGrad<T, Context>(
Input(0).count() / Input(0).dim(axis), Input(0).count() / Input(0).dim(axis),
Input(0).dim(axis), inner_dim, Input(0).dim(axis), inner_dim,
1.0, dYdata, Pdata); 1.0, dYdata, Pdata, ctx());
math::Mul<T, Context>(Output(0)->count(), math::Mul<T, Context>(Output(0)->count(),
Pdata, dXdata, dXdata); return; Pdata, dXdata, dXdata, ctx()); return;
} }
T normalizer = 1; T normalizer = 1;
...@@ -104,14 +106,16 @@ void SoftmaxFocalLossGradientOp<Context>::RunWithType() { ...@@ -104,14 +106,16 @@ void SoftmaxFocalLossGradientOp<Context>::RunWithType() {
} }
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
T dYdata_host; ctx().template Copy<T, CPUContext, Context>( T dYdata_host; ctx()->template Copy<T, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
math::Scal<T, Context>(Output(0)->count(), math::Scal<T, Context>(Output(0)->count(),
dYdata_host / normalizer, dXdata, &ctx()); dYdata_host / normalizer, dXdata, ctx());
} }
template <class Context> template <class Context>
void SoftmaxFocalLossGradientOp<Context>::RunOnDevice() { void SoftmaxFocalLossGradientOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob"); this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob");
outer_dim = this->prob->count(0, axis); outer_dim = this->prob->count(0, axis);
inner_dim = this->prob->count(axis + 1); inner_dim = this->prob->count(axis + 1);
......
...@@ -21,83 +21,66 @@ void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRun() { ...@@ -21,83 +21,66 @@ void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRun() {
softmax_op->Run(); softmax_op->Run();
} }
template <class Context>
void SparseSoftmaxCrossEntropyOp<Context>::SoftmaxRunFP16() {
Tensor* XF32 = ws()->CreateTensor(
"/mnt/" + anchor() + "/softmax/xf32");
XF32->ReshapeLike(Input(0));
auto* XdataF16 = Input(0).template data<float16, Context>();
auto* XdataF32 = XF32->template mutable_data<float, Context>();
kernel::TypeA2B<float16, float, Context>(
Input(0).count(), XdataF16, XdataF32);
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ XF32->name() }),
vector<string>({ "/mnt/" + anchor() + "/softmax/prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (def().has_device_option())
softmax_def.mutable_device_option()
->CopyFrom(def().device_option());
if (!softmax_op) softmax_op.reset(
CreateOperator(softmax_def, ws()));
else softmax_op->MutableOp(softmax_def);
softmax_op->Run();
}
template <class Context> template <typename Tx, typename Ty> template <class Context> template <typename Tx, typename Ty>
void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() { void SparseSoftmaxCrossEntropyOp<Context>::RunWithType() {
auto* Pdata = prob->template data<Tx, Context>(); auto* Pdata = prob->template data<Tx, Context>();
auto* Tdata = Input(1).template data<Ty, Context>(); auto* Tdata = Input(1).template data<Ty, Context>();
auto* Idata = !ignores.count() ? nullptr : auto* Idata = !ignores.count() ? nullptr :
ignores.template data<int, Context>(); ignores.template data<int, Context>();
auto* Ldata = losses.template mutable_data<Tx, Context>(); auto* Ldata = losses.template mutable_data<float, Context>();
auto* Fdata = flags.template mutable_data<Tx, Context>(); auto* Fdata = flags.template mutable_data<float, Context>();
kernel::SparseSoftmaxCrossEntropy<Tx, Ty, Context>( kernel::SparseSoftmaxCrossEntropy<Tx, Ty, Context>(
outer_dim, Input(0).dim(axis), inner_dim, outer_dim, Input(0).dim(axis), inner_dim,
Pdata, Tdata, Idata, ignores.count(), Pdata, Tdata, Idata, ignores.count(),
Ldata, Fdata, &ctx()); Ldata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
Output(0)->ReshapeLike(losses); Output(0)->ReshapeLike(losses);
Output(0)->template CopyFrom<Context>(losses); Output(0)->template CopyFrom<Context>(losses, ctx());
return; return;
} }
Tx normalizer = 1; float normalizer = 1;
if (normalization == "VALID") { if (normalization == "VALID") {
normalizer = std::max( normalizer = std::max(
math::ASum<Tx, Context>( math::ASum<float, Context>(
flags.count(), Fdata), (Tx)1.f); flags.count(), Fdata), 1.f);
} else if (normalization == "BATCH_SIZE") { } else if (normalization == "BATCH_SIZE") {
normalizer = Input(0).dim(0); normalizer = Input(0).dim(0);
} else if (normalization == "FULL") { } else if (normalization == "FULL") {
normalizer = outer_dim * inner_dim; normalizer = outer_dim * inner_dim;
} }
Tx loss = math::ASum<Tx, Context>(losses.count(), Ldata); float loss = math::ASum<float, Context>(losses.count(), Ldata);
Output(0)->Reshape({ 1 }); Output(0)->Reshape({ 1 });
auto* Ydata = Output(0)->template mutable_data<Tx, Context>(); auto* Ydata = Output(0)->template mutable_data<float, Context>();
math::Set<Tx, Context>(1, loss / normalizer, Ydata); math::Set<float, Context>(1, loss / normalizer, Ydata, ctx());
} }
template <class Context> template <class Context>
void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() { void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
CHECK_EQ(outer_dim * inner_dim, Input(1).count()) CHECK_EQ(outer_dim * inner_dim, Input(1).count())
<< "\nNumber of predictions must match the number of labels."; << "\nNumber of predictions must match the number of labels.";
losses.Reshape({ outer_dim * inner_dim }); losses.Reshape({ outer_dim * inner_dim });
flags.Reshape({ outer_dim * inner_dim }); flags.Reshape({ outer_dim * inner_dim });
prob = ws()->CreateTensor("/mnt/" + anchor() + "/softmax/prob"); prob = ws()->CreateTensor("/mnt/" + anchor() + "/softmax/prob");
SoftmaxRun();
if (XIsType(Input(0), float) || if (XIsType(Input(0), float)) {
XIsType(Input(0), float16)) {
if (XIsType(Input(0), float16)) SoftmaxRunFP16();
else SoftmaxRun();
if (XIsType(Input(1), float)) RunWithType<float, float>(); if (XIsType(Input(1), float)) RunWithType<float, float>();
else if (XIsType(Input(1), int64_t)) RunWithType<float, int64_t>(); else if (XIsType(Input(1), int64_t)) RunWithType<float, int64_t>();
else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" }); else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" });
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32" }); } else if (XIsType(Input(0), float16)) {
if (XIsType(Input(1), float)) RunWithType<float16, float>();
else if (XIsType(Input(1), int64_t)) RunWithType<float16, int64_t>();
else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" });
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(SparseSoftmaxCrossEntropy); DEPLOY_CPU(SparseSoftmaxCrossEntropy);
...@@ -113,62 +96,66 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -113,62 +96,66 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
auto* Idata = !ignores.count() ? nullptr : auto* Idata = !ignores.count() ? nullptr :
ignores.template data<int, Context>(); ignores.template data<int, Context>();
auto* dXdata = Output(0)->template mutable_data<Tx, Context>(); auto* dXdata = Output(0)->template mutable_data<Tx, Context>();
auto* Fdata = flags.template mutable_data<Tx, Context>(); auto* Fdata = flags.template mutable_data<float, Context>();
ctx().template Copy<Tx, Context, Context>( ctx()->template Copy<Tx, Context, Context>(
prob->count(), dXdata, Pdata); prob->count(), dXdata, Pdata);
kernel::SparseSoftmaxCrossEntropyGrad<Tx, Ty, Context>( kernel::SparseSoftmaxCrossEntropyGrad<Tx, Ty, Context>(
outer_dim, Output(0)->dim(axis), inner_dim, outer_dim, Output(0)->dim(axis), inner_dim,
Pdata, Tdata, Idata, ignores.count(), Pdata, Tdata, Idata, ignores.count(),
dXdata, Fdata, &ctx()); dXdata, Fdata, ctx());
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = Input(-1).template data<Tx, Context>(); auto* dYdata = Input(-1).template data<float, Context>();
kernel::SumGrad<Tx, Context>( auto* WSdata = ws()->template caches<float, Context>(
{ Input(0).count() })[0];
kernel::SumGrad<float, Context>(
Input(0).count() / Input(0).dim(axis), Input(0).count() / Input(0).dim(axis),
Input(0).dim(axis), inner_dim, Input(0).dim(axis), inner_dim,
1.0, dYdata, Pdata); 1.0, dYdata, WSdata, ctx());
math::Mul<Tx, Context>( kernel::TypeA2B<float, Tx, Context>(
Output(0)->count(), Pdata, dXdata, dXdata); Input(0).count(), WSdata, Pdata, ctx());
math::Mul<Tx, Context>(Output(0)->count(),
Pdata, dXdata, dXdata, ctx());
return; return;
} }
Tx normalizer = 1; float normalizer = 1;
if (normalization == "VALID") { if (normalization == "VALID") {
normalizer = std::max( normalizer = std::max(
math::ASum<Tx, Context>( math::ASum<float, Context>(
flags.count(), Fdata), (Tx)1.f); flags.count(), Fdata), 1.f);
} else if (normalization == "BATCH_SIZE") { } else if (normalization == "BATCH_SIZE") {
normalizer = Input(0).dim(0); normalizer = Input(0).dim(0);
} else if (normalization == "FULL") { } else if (normalization == "FULL") {
normalizer = outer_dim * inner_dim; normalizer = outer_dim * inner_dim;
} }
auto* dYdata = Input(-1).template data<Tx, Context>(); auto* dYdata = Input(-1).template data<float, Context>();
Tx dYdata_host; ctx().template Copy<Tx, CPUContext, Context>( float dYdata_host; ctx()->template Copy<float, CPUContext, Context>(
1, &dYdata_host, dYdata); 1, &dYdata_host, dYdata);
math::Scal<Tx, Context>(Output(0)->count(), math::Scal<Tx, Context>(Output(0)->count(),
dYdata_host / normalizer, dXdata, &ctx()); dYdata_host / normalizer, dXdata, ctx());
} }
template <class Context> template <class Context>
void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() { void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob"); prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax/prob");
outer_dim = prob->count(0, axis); outer_dim = prob->count(0, axis);
inner_dim = prob->count(axis + 1); inner_dim = prob->count(axis + 1);
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
flags.Reshape({ outer_dim * inner_dim }); flags.Reshape({ outer_dim * inner_dim });
if (XIsType(Input(0), float) || XIsType(Input(0), float16)) { if (XIsType(Input(0), float)) {
if (XIsType(Input(1), float)) RunWithType<float, float>(); if (XIsType(Input(1), float)) RunWithType<float, float>();
else if (XIsType(Input(1), int64_t)) RunWithType<float, int64_t>(); else if (XIsType(Input(1), int64_t)) RunWithType<float, int64_t>();
else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" }); else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" });
if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
auto* dXdataF32 = Output(0)->template data<float, Context>(); if (XIsType(Input(1), float)) RunWithType<float16, float>();
auto* dXdataF16 = prob->template mutable_data<float16, Context>(); else if (XIsType(Input(1), int64_t)) RunWithType<float16, int64_t>();
kernel::TypeA2B<float, float16, Context>(Output(0)->count(), dXdataF32, dXdataF16); else LOG(FATAL) << DTypeHelper(Input(1), { "float32", "int64" });
Output(0)->template CopyFrom<Context>(*prob);
}
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); } else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
......
...@@ -9,23 +9,27 @@ namespace dragon { ...@@ -9,23 +9,27 @@ namespace dragon {
template <class Context> template <typename Tx, typename Ty> template <class Context> template <typename Tx, typename Ty>
void AccuracyOp<Context>::RunWithType() { void AccuracyOp<Context>::RunWithType() {
static CPUContext cctx;
float* Y1data, *Y2data = nullptr;
Y1data = Output(0)->template mutable_data<float, CPUContext>();
if (OutputSize() > 1) { if (OutputSize() > 1) {
math::Set<float, CPUContext>(num_classes, 0, Y2data = Output(1)->template mutable_data<float, CPUContext>();
Output(1)->template mutable_data<float, CPUContext>()); math::Set<float, CPUContext>(num_classes, 0, Y2data, &cctx);
} }
Map<int, TIndex> num_per_class;
Map<int, TIndex> num_per_class;
TIndex acc = 0, count = 0; TIndex acc = 0, count = 0;
const Tx* Xdata; const Tx* Xdata;
if (XIsType(Input(0), float16)) { if (XIsType(Input(0), float16)) {
Tensor* XF32 = ws()->CreateTensor("/mnt/" + anchor() + "/accuracy/xf32"); Tensor* X32T = ws()->CreateTensor(
XF32->ReshapeLike(Input(0)); "/mnt/" + anchor() + "/accuracy/f32");
auto* XdataF16 = Input(0).template data<float16, CPUContext>(); X32T->ReshapeLike(Input(0));
auto* XdataF32 = XF32->template mutable_data<float, CPUContext>(); auto* X16 = Input(0).template data<float16, CPUContext>();
auto* X32 = X32T->template mutable_data<float, CPUContext>();
kernel::TypeA2B<float16, float, CPUContext>( kernel::TypeA2B<float16, float, CPUContext>(
Input(0).count(), XdataF16, XdataF32); Input(0).count(), X16, X32, &cctx);
Xdata = XdataF32; Xdata = X32;
} else Xdata = Input(0).template data<Tx, CPUContext>(); } else Xdata = Input(0).template data<Tx, CPUContext>();
auto* labels = Input(1).template data<Ty, CPUContext>(); auto* labels = Input(1).template data<Ty, CPUContext>();
...@@ -41,15 +45,13 @@ void AccuracyOp<Context>::RunWithType() { ...@@ -41,15 +45,13 @@ void AccuracyOp<Context>::RunWithType() {
vector<pair<Tx, int> > vec; vector<pair<Tx, int> > vec;
for (int k = 0; k < num_classes; k++) for (int k = 0; k < num_classes; k++)
vec.push_back( vec.push_back(
std::make_pair(Xdata[i * dim + k * inner_dim + j], k) std::make_pair(Xdata[i * dim + k * inner_dim + j], k));
);
std::partial_sort( std::partial_sort(
vec.begin(), vec.begin() + top_k, vec.end(), vec.begin(), vec.begin() + top_k, vec.end(),
std::greater<pair<Tx, int> >()); std::greater<pair<Tx, int> >());
for (int k = 0; k < top_k; k++) { for (int k = 0; k < top_k; k++) {
if (vec[k].second == label) { if (vec[k].second == label) {
if (OutputSize() > 1) if (OutputSize() > 1) Y2data[label]++;
Output(1)->template mutable_data<float, CPUContext>()[label]++;
acc++; acc++;
break; break;
} }
...@@ -58,12 +60,11 @@ void AccuracyOp<Context>::RunWithType() { ...@@ -58,12 +60,11 @@ void AccuracyOp<Context>::RunWithType() {
} // end inner_dim } // end inner_dim
} // end outer_dim } // end outer_dim
Output(0)->template mutable_data<float, CPUContext>()[0] = (float)acc / count; Y1data[0] = (float)acc / count;
if (OutputSize() > 1) { if (Y2data) {
auto* acc_per_class = Output(1)->template mutable_data<float, CPUContext>();
for (int i = 0; i < num_classes; i++) for (int i = 0; i < num_classes; i++)
acc_per_class[i] = num_per_class[i] == 0 ? Y2data[i] = num_per_class[i] == 0 ?
0 : acc_per_class[i] / num_per_class[i]; 0 : Y2data[i] / num_per_class[i];
} }
} }
......
...@@ -14,14 +14,14 @@ namespace dragon { ...@@ -14,14 +14,14 @@ namespace dragon {
Output(0)->ReshapeLike(Input(0)); \ Output(0)->ReshapeLike(Input(0)); \
auto* Xdata = Input(0).template data<type_a, Context>(); \ auto* Xdata = Input(0).template data<type_a, Context>(); \
auto* Ydata = Output(0)->template mutable_data<type_b, Context>(); \ auto* Ydata = Output(0)->template mutable_data<type_b, Context>(); \
kernel::TypeA2B<type_a, type_b, Context>(Input(0).count(), Xdata, Ydata); \ kernel::TypeA2B<type_a, type_b, Context>(Input(0).count(), Xdata, Ydata, ctx()); \
} else { \ } else { \
TIndex count = Output(0)->count(); \ TIndex count = Output(0)->count(); \
auto* Xdata = Output(0)->template data<type_a, Context>(); \ auto* Xdata = Output(0)->template data<type_a, Context>(); \
auto* Cdata = ws()->template caches<type_b, Context>({ count })[0]; \ auto* Cdata = ws()->template caches<type_b, Context>({ count })[0]; \
kernel::TypeA2B<type_a, type_b, Context>(count, Xdata, Cdata); \ kernel::TypeA2B<type_a, type_b, Context>(count, Xdata, Cdata, ctx()); \
auto* Ydata = Output(0)->template mutable_data<type_b, Context>(); \ auto* Ydata = Output(0)->template mutable_data<type_b, Context>(); \
ctx().template Copy<type_b, Context, Context>(count, Ydata, Cdata); \ ctx()->template Copy<type_b, Context, Context>(count, Ydata, Cdata); \
} \ } \
return; \ return; \
} }
......
...@@ -11,7 +11,7 @@ void GradientGenerateOp<Context>::RunWithType() { ...@@ -11,7 +11,7 @@ void GradientGenerateOp<Context>::RunWithType() {
Output(i)->ReshapeLike(Input(i)); Output(i)->ReshapeLike(Input(i));
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), math::Set<T, Context>(Output(0)->count(),
dragon_cast<T, float>(defaults[i]), dXdata); dragon_cast<T, float>(defaults[i]), dXdata, ctx());
} }
} }
...@@ -37,12 +37,13 @@ void GradientGatherOp<Context>::RunWithType() { ...@@ -37,12 +37,13 @@ void GradientGatherOp<Context>::RunWithType() {
CHECK(Output(0)->dims() == Input(indices[i]).dims()); CHECK(Output(0)->dims() == Input(indices[i]).dims());
auto* dYdata = Input(indices[i]).template data<T, Context>(); auto* dYdata = Input(indices[i]).template data<T, Context>();
if (i == 0) { if (i == 0) {
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
count, dXdata, dYdata); count, dXdata, dYdata);
} else { } else {
math::Add<T, Context>( math::Add<T, Context>(
count, dXdata, dYdata, dXdata); count, dXdata, dYdata, dXdata, ctx());
} }
ctx()->FinishDeviceCompution();
Input(indices[i]).Reset(); Input(indices[i]).Reset();
} }
} }
...@@ -68,7 +69,7 @@ template <class Context> ...@@ -68,7 +69,7 @@ template <class Context>
void StopGradientOp<Context>::RunOnDevice() { void StopGradientOp<Context>::RunOnDevice() {
if (Output(0)->name() != Input(0).name()) { if (Output(0)->name() != Input(0).name()) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template CopyFrom<Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0), ctx());
} }
} }
......
...@@ -14,7 +14,7 @@ void ImageDataOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void ImageDataOp<Context>::RunWithType() {
kernel::ImageData<Tx, Ty, Context>( kernel::ImageData<Tx, Ty, Context>(
Output(0)->count(), n, c, h, w, Mdata, Sdata, Output(0)->count(), n, c, h, w, Mdata, Sdata,
data_format, Xdata, Ydata); data_format, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -7,7 +7,7 @@ template <class Context> template <typename T> ...@@ -7,7 +7,7 @@ template <class Context> template <typename T>
void InitializeOp<Context>::RunWithType() { void InitializeOp<Context>::RunWithType() {
unique_ptr< Filler<T, Context> > f; unique_ptr< Filler<T, Context> > f;
f.reset(CreateFiller<T, Context>(filler)); f.reset(CreateFiller<T, Context>(filler));
f->Fill(Output(0), &ctx()); f->Fill(Output(0), ctx());
} }
template <class Context> template <class Context>
......
...@@ -14,7 +14,7 @@ void MPIBroadcastOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
auto* Xdata = Input(0).template mutable_data<T, CPUContext>(); auto* Xdata = Input(0).template mutable_data<T, CPUContext>();
#endif #endif
MPI_Bcast(Xdata, Input(0).count(), mpi_dtype(), comm_root, comm); MPI_Bcast(Xdata, Input(0).count(), mpi_dtype(), comm_root, comm);
Output(0)->template CopyFrom<Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0), ctx());
} else { } else {
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
...@@ -62,12 +62,13 @@ void MPIBroadcastGradientOp<Context>::RunWithType() { ...@@ -62,12 +62,13 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
auto* dYdata = Input(-1).template mutable_data<T, Context>(); auto* dYdata = Input(-1).template mutable_data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), dXdata, dYdata); Output(0)->count(), dXdata, dYdata);
#else #else
auto* dYdata = Input(-1).template mutable_data<T, CPUContext>(); auto* dYdata = Input(-1).template mutable_data<T, CPUContext>();
auto* dXdata = Output(0)->template mutable_data<T, CPUContext>(); auto* dXdata = Output(0)->template mutable_data<T, CPUContext>();
CPUContext::template Copy<T, CPUContext, CPUContext>( static CPUContext cctx;
cctx.template Copy<T, CPUContext, CPUContext>(
Output(0)->count(), dXdata, dYdata); Output(0)->count(), dXdata, dYdata);
#endif #endif
for (int i = 0; i < comm_size; i++) { for (int i = 0; i < comm_size; i++) {
...@@ -76,10 +77,10 @@ void MPIBroadcastGradientOp<Context>::RunWithType() { ...@@ -76,10 +77,10 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
i, 0, comm, MPI_STATUS_IGNORE); i, 0, comm, MPI_STATUS_IGNORE);
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
math::Add<T, Context>(Output(0)->count(), math::Add<T, Context>(Output(0)->count(),
dYdata, dXdata, dXdata); dYdata, dXdata, dXdata, ctx());
#else #else
math::Add<T, CPUContext>(Output(0)->count(), math::Add<T, CPUContext>(Output(0)->count(),
dYdata, dXdata, dXdata); dYdata, dXdata, dXdata, &cctx);
#endif #endif
} }
} }
......
...@@ -8,7 +8,7 @@ namespace dragon { ...@@ -8,7 +8,7 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void MPIGatherOp<Context>::RunWithType() { void MPIGatherOp<Context>::RunWithType() {
if (comm_rank == comm_root) { if (comm_rank == comm_root) {
Output(comm_rank)->template CopyFrom<Context>(Input(0)); Output(comm_rank)->template CopyFrom<Context>(Input(0), ctx());
for (int i = 0; i < comm_size; i++) { for (int i = 0; i < comm_size; i++) {
if (i == comm_root) continue; if (i == comm_root) continue;
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
...@@ -76,7 +76,8 @@ OPERATOR_SCHEMA(MPIGather).NumInputs(1).NumOutputs(1, INT_MAX); ...@@ -76,7 +76,8 @@ OPERATOR_SCHEMA(MPIGather).NumInputs(1).NumOutputs(1, INT_MAX);
template <class Context> template <typename T> template <class Context> template <typename T>
void MPIGatherGradientOp<Context>::RunWithType() { void MPIGatherGradientOp<Context>::RunWithType() {
if (comm_rank == comm_root) { if (comm_rank == comm_root) {
Output(0)->template CopyFrom<Context>(Input(this->comm_rank + 1)); Output(0)->template CopyFrom<Context>(
Input(this->comm_rank + 1), ctx());
for (int i = 0; i < comm_size; i++) { for (int i = 0; i < comm_size; i++) {
if (i == comm_root) continue; if (i == comm_root) continue;
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
......
...@@ -11,7 +11,7 @@ void ArangeOp<Context>::RunWithType() { ...@@ -11,7 +11,7 @@ void ArangeOp<Context>::RunWithType() {
count = (stop_ - start_ - 1) / step_ + 1; count = (stop_ - start_ - 1) / step_ + 1;
Output(0)->Reshape({ count }); Output(0)->Reshape({ count });
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Arange<T, Context>(count, start_, step_, Ydata); kernel::Arange<T, Context>(count, start_, step_, Ydata, ctx());
} }
template <class Context> template <class Context>
......
#include "utils/op_kernel.h" #include "utils/op_kernel.h"
#include "utils/math_functions.h"
#include "operators/ndarray/argreduce_op.h" #include "operators/ndarray/argreduce_op.h"
namespace dragon { namespace dragon {
...@@ -12,14 +13,15 @@ void ArgReduceOp<Context>::RunWithType() { ...@@ -12,14 +13,15 @@ void ArgReduceOp<Context>::RunWithType() {
auto* Idata = Output(0)->template mutable_data<int64_t, CPUContext>(); auto* Idata = Output(0)->template mutable_data<int64_t, CPUContext>();
auto* Vdata = OutputSize() == 2 ? auto* Vdata = OutputSize() == 2 ?
Output(1)->template mutable_data<T, CPUContext>() : nullptr; Output(1)->template mutable_data<T, CPUContext>() : nullptr;
static CPUContext cctx;
if (operation == "ARGMAX") { if (operation == "ARGMAX") {
kernel::Argmax<T, CPUContext>( kernel::Argmax<T, CPUContext>(
count, axis_dim, inner_dim, count, axis_dim, inner_dim,
top_k, Xdata, Idata, Vdata); top_k, Xdata, Idata, Vdata, &cctx);
} else if (operation == "ARGMIN") { } else if (operation == "ARGMIN") {
kernel::Argmin<T, CPUContext>( kernel::Argmin<T, CPUContext>(
count, axis_dim, inner_dim, count, axis_dim, inner_dim,
top_k, Xdata, Idata, Vdata); top_k, Xdata, Idata, Vdata, &cctx);
} else LOG(FATAL) << "Unknown operation: [" << operation << "]."; } else LOG(FATAL) << "Unknown operation: [" << operation << "].";
} else { } else {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
...@@ -29,11 +31,11 @@ void ArgReduceOp<Context>::RunWithType() { ...@@ -29,11 +31,11 @@ void ArgReduceOp<Context>::RunWithType() {
if (operation == "ARGMAX") { if (operation == "ARGMAX") {
kernel::Argmax<T, Context>( kernel::Argmax<T, Context>(
count, axis_dim, inner_dim, count, axis_dim, inner_dim,
top_k, Xdata, Idata, Vdata); top_k, Xdata, Idata, Vdata, ctx());
} else if (operation == "ARGMIN") { } else if (operation == "ARGMIN") {
kernel::Argmin<T, Context>( kernel::Argmin<T, Context>(
count, axis_dim, inner_dim, count, axis_dim, inner_dim,
top_k, Xdata, Idata, Vdata); top_k, Xdata, Idata, Vdata, ctx());
} else LOG(FATAL) << "Unknown operation: [" << operation << "]."; } else LOG(FATAL) << "Unknown operation: [" << operation << "].";
} }
} }
......
...@@ -14,7 +14,7 @@ void ConcatOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void ConcatOp<Context>::RunWithType() {
kernel::Concat<T, Context>( kernel::Concat<T, Context>(
count, outer_dim, inner_dim, count, outer_dim, inner_dim,
x_concat_dim, y_concat_dim, x_concat_dim, y_concat_dim,
concat_offset, Xdata, Ydata); concat_offset, Xdata, Ydata, ctx());
concat_offset += x_concat_dim; concat_offset += x_concat_dim;
} }
} }
...@@ -61,7 +61,7 @@ void ConcatGradientOp<Context>::RunWithType() { ...@@ -61,7 +61,7 @@ void ConcatGradientOp<Context>::RunWithType() {
kernel::ConcatGrad<T, Context>( kernel::ConcatGrad<T, Context>(
count, outer_dim, inner_dim, count, outer_dim, inner_dim,
x_concat_dim, y_concat_dim, x_concat_dim, y_concat_dim,
concat_offset, dYdata, dXdata); concat_offset, dYdata, dXdata, ctx());
} }
concat_offset += x_concat_dim; concat_offset += x_concat_dim;
} }
......
...@@ -17,7 +17,7 @@ void CropOp<Context>::RunWithType() { ...@@ -17,7 +17,7 @@ void CropOp<Context>::RunWithType() {
kernel::Crop1D<T, Context>(dest->count(), kernel::Crop1D<T, Context>(dest->count(),
dim, ed[axis] - st[axis], inner_dim, dim, ed[axis] - st[axis], inner_dim,
st[axis], Xdata, Ydata); st[axis], Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -46,7 +46,7 @@ void CropOp<Context>::Setup() { ...@@ -46,7 +46,7 @@ void CropOp<Context>::Setup() {
// make ends // make ends
ed.assign(Input(0).ndim(), 0); ed.assign(Input(0).ndim(), 0);
keep_dims.resize(Input(0).ndim(), 0); keep_dims.assign(Input(0).ndim(), 1);
if (shape.size() + shape_like.size() != 0) { if (shape.size() + shape_like.size() != 0) {
CHECK(shape.size() * shape_like.size() == 0) CHECK(shape.size() * shape_like.size() == 0)
<< "\nCan not set shape and shape_like both."; << "\nCan not set shape and shape_like both.";
...@@ -75,7 +75,6 @@ void CropOp<Context>::Setup() { ...@@ -75,7 +75,6 @@ void CropOp<Context>::Setup() {
// static crop // static crop
int n_given = (int)GET_ARGUMENTS_SIZE(ends); int n_given = (int)GET_ARGUMENTS_SIZE(ends);
for (int i = 0; i < ed.size(); i++) { for (int i = 0; i < ed.size(); i++) {
keep_dims[i] = 1;
if (i < n_given) ed[i] = ends(i); if (i < n_given) ed[i] = ends(i);
if (ed[i] == 0) ed[i] = Input(0).dim(i); if (ed[i] == 0) ed[i] = Input(0).dim(i);
if (ed[i] == -1) { ed[i] = st[i] + 1; keep_dims[i] = 0; } if (ed[i] == -1) { ed[i] = st[i] + 1; keep_dims[i] = 0; }
...@@ -125,7 +124,7 @@ void CropOp<Context>::RunOnDevice() { ...@@ -125,7 +124,7 @@ void CropOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template CopyFrom<Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0), ctx());
// squeeze dimensions // squeeze dimensions
vector<TIndex> squeeze_shape; vector<TIndex> squeeze_shape;
for (int i = 0; i < keep_dims.size(); i++) for (int i = 0; i < keep_dims.size(); i++)
...@@ -149,6 +148,7 @@ void CropOp<Context>::RunOnDevice() { ...@@ -149,6 +148,7 @@ void CropOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), int)) RunWithType<int>(); else if (XIsType(Input(0), int)) RunWithType<int>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "int32" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "int32" });
ctx()->FinishDeviceCompution();
// allow buffer to protect X if the num of tasks >= 2 // allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest); std::swap(source, dest);
if (process_axes.size() % 2 == 1) { if (process_axes.size() % 2 == 1) {
...@@ -209,7 +209,7 @@ void CropGradientOp<Context>::RunWithType() { ...@@ -209,7 +209,7 @@ void CropGradientOp<Context>::RunWithType() {
kernel::Crop1DGrad<T, Context>(dest->count(), kernel::Crop1DGrad<T, Context>(dest->count(),
Input(0).dim(axis), dim, inner_dim, Input(0).dim(axis), dim, inner_dim,
st[axis], ed[axis], dYdata, dXdata); st[axis], ed[axis], dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -229,7 +229,7 @@ void CropGradientOp<Context>::RunOnDevice() { ...@@ -229,7 +229,7 @@ void CropGradientOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
Output(0)->template CopyFrom<Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1), ctx());
return; return;
} }
...@@ -248,6 +248,7 @@ void CropGradientOp<Context>::RunOnDevice() { ...@@ -248,6 +248,7 @@ void CropGradientOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), int)) RunWithType<int>(); else if (XIsType(Input(0), int)) RunWithType<int>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "int32" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "int32" });
ctx()->FinishDeviceCompution();
// allow buffer to protect X if the num of tasks >= 2 // allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest); std::swap(source, dest);
if (process_axes.size() % 2 == 1) { if (process_axes.size() % 2 == 1) {
......
...@@ -12,11 +12,11 @@ void GatherOp<Context>::RunWithType() { ...@@ -12,11 +12,11 @@ void GatherOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::CanonicalAxis<int, Context>( kernel::CanonicalAxis<int, Context>(
Input(1).count(), x_slice_dim, indices); Input(1).count(), x_slice_dim, indices, ctx());
kernel::Gather<T, Context>( kernel::Gather<T, Context>(Output(0)->count(),
Output(0)->count(), outer_dim, inner_dim, outer_dim, inner_dim, x_slice_dim, y_slice_dim,
x_slice_dim, y_slice_dim, indices, Xdata, Ydata); indices, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -46,13 +46,18 @@ template <class Context> template <typename T> ...@@ -46,13 +46,18 @@ template <class Context> template <typename T>
void GatherGradientOp<Context>::RunWithType() { void GatherGradientOp<Context>::RunWithType() {
auto* indices = Input(1).template data<int, Context>(); auto* indices = Input(1).template data<int, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>();
if (!acc_grad) math::Set<T, Context>(Output(0)->count(), 0, dXdata); T* dXdata = nullptr;
if (!acc_grad) {
dXdata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), 0, dXdata, ctx());
} else {
dXdata = Output(0)->template mutable_data<T, Context>(ctx());
}
kernel::GatherGrad<T, Context>( kernel::GatherGrad<T, Context>(Input(-1).count(),
Input(-1).count(), outer_dim, inner_dim, outer_dim, inner_dim, x_slice_dim, y_slice_dim,
x_slice_dim, y_slice_dim, indices, dYdata, dXdata); indices, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -10,10 +10,10 @@ void OneHotOp<Context>::RunWithType() { ...@@ -10,10 +10,10 @@ void OneHotOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), math::Set<T, Context>(Output(0)->count(),
dragon_cast<T, float>(float(off_value)), Ydata); dragon_cast<T, float>(float(off_value)), Ydata, ctx());
kernel::OneHot<T, Context>(Input(0).count(), kernel::OneHot<T, Context>(Input(0).count(),
depth, on_value, Xdata, Ydata); depth, on_value, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -17,7 +17,7 @@ void PadOp<Context>::ConstRunWithType() { ...@@ -17,7 +17,7 @@ void PadOp<Context>::ConstRunWithType() {
kernel::ConstPad1D<T, Context>(dest->count(), kernel::ConstPad1D<T, Context>(dest->count(),
dim, dim + pad_l[axis] + pad_r[axis], inner_dim, dim, dim + pad_l[axis] + pad_r[axis], inner_dim,
pad_l[axis], value, Xdata, Ydata); pad_l[axis], value, Xdata, Ydata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -32,7 +32,7 @@ void PadOp<Context>::ReflectRunWithType() { ...@@ -32,7 +32,7 @@ void PadOp<Context>::ReflectRunWithType() {
kernel::ReflectPad1D<T, Context>(dest->count(), kernel::ReflectPad1D<T, Context>(dest->count(),
dim, dim + pad_l[axis] + pad_r[axis], inner_dim, dim, dim + pad_l[axis] + pad_r[axis], inner_dim,
pad_l[axis], Xdata, Ydata); pad_l[axis], Xdata, Ydata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -47,7 +47,7 @@ void PadOp<Context>::EdgeRunWithType() { ...@@ -47,7 +47,7 @@ void PadOp<Context>::EdgeRunWithType() {
kernel::EdgePad1D<T, Context>(dest->count(), kernel::EdgePad1D<T, Context>(dest->count(),
dim, dim + pad_l[axis] + pad_r[axis], inner_dim, dim, dim + pad_l[axis] + pad_r[axis], inner_dim,
pad_l[axis], Xdata, Ydata); pad_l[axis], Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -61,7 +61,7 @@ void PadOp<Context>::RunOnDevice() { ...@@ -61,7 +61,7 @@ void PadOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template CopyFrom<Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0), ctx());
return; return;
} }
...@@ -99,6 +99,7 @@ void PadOp<Context>::RunOnDevice() { ...@@ -99,6 +99,7 @@ void PadOp<Context>::RunOnDevice() {
} else { } else {
LOG(FATAL) << "Unsupported padding mode: " << mode << "."; LOG(FATAL) << "Unsupported padding mode: " << mode << ".";
} }
ctx()->FinishDeviceCompution();
// allow buffer to protect X if the num of tasks >= 2 // allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest); std::swap(source, dest);
if (process_axes.size() % 2 == 1) { if (process_axes.size() % 2 == 1) {
...@@ -127,7 +128,7 @@ void PadGradientOp<Context>::ConstRunWithType() { ...@@ -127,7 +128,7 @@ void PadGradientOp<Context>::ConstRunWithType() {
kernel::ConstPad1DGrad<T, Context>(dest->count(), kernel::ConstPad1DGrad<T, Context>(dest->count(),
dim - pad_l[axis] - pad_r[axis], dim, inner_dim, dim - pad_l[axis] - pad_r[axis], dim, inner_dim,
pad_l[axis], dYdata, dXdata); pad_l[axis], dYdata, dXdata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -140,11 +141,11 @@ void PadGradientOp<Context>::ReflectRunWithType() { ...@@ -140,11 +141,11 @@ void PadGradientOp<Context>::ReflectRunWithType() {
dXdata = ws()->template caches<T, Context>({ dest->count() })[0]; dXdata = ws()->template caches<T, Context>({ dest->count() })[0];
} else { dXdata = dest->template mutable_data<T, Context>(); } } else { dXdata = dest->template mutable_data<T, Context>(); }
math::Set<T, Context>(dest->count(), 0, dXdata); math::Set<T, Context>(dest->count(), 0, dXdata, ctx());
kernel::ReflectPad1DGrad<T, Context>(source->count(), kernel::ReflectPad1DGrad<T, Context>(source->count(),
dim - pad_l[axis] - pad_r[axis], dim, inner_dim, dim - pad_l[axis] - pad_r[axis], dim, inner_dim,
pad_l[axis], dYdata, dXdata); pad_l[axis], dYdata, dXdata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -157,11 +158,11 @@ void PadGradientOp<Context>::EdgeRunWithType() { ...@@ -157,11 +158,11 @@ void PadGradientOp<Context>::EdgeRunWithType() {
dXdata = ws()->template caches<T, Context>({ dest->count() })[0]; dXdata = ws()->template caches<T, Context>({ dest->count() })[0];
} else { dXdata = dest->template mutable_data<T, Context>(); } } else { dXdata = dest->template mutable_data<T, Context>(); }
math::Set<T, Context>(dest->count(), 0, dXdata); math::Set<T, Context>(dest->count(), 0, dXdata, ctx());
kernel::EdgePad1DGrad<T, Context>(source->count(), kernel::EdgePad1DGrad<T, Context>(source->count(),
dim - pad_l[axis] - pad_r[axis], dim, inner_dim, dim - pad_l[axis] - pad_r[axis], dim, inner_dim,
pad_l[axis], dYdata, dXdata); pad_l[axis], dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -175,7 +176,7 @@ void PadGradientOp<Context>::RunOnDevice() { ...@@ -175,7 +176,7 @@ void PadGradientOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
Output(0)->template CopyFrom<Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1), ctx());
return; return;
} }
...@@ -213,6 +214,7 @@ void PadGradientOp<Context>::RunOnDevice() { ...@@ -213,6 +214,7 @@ void PadGradientOp<Context>::RunOnDevice() {
} else { } else {
LOG(FATAL) << "Unsupported padding mode: " << mode << "."; LOG(FATAL) << "Unsupported padding mode: " << mode << ".";
} }
ctx()->FinishDeviceCompution();
// allow buffer to protect X if the num of tasks >= 2 // allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest); std::swap(source, dest);
if (process_axes.size() % 2 == 1) { if (process_axes.size() % 2 == 1) {
......
...@@ -9,15 +9,15 @@ template <class Context> template <typename T> ...@@ -9,15 +9,15 @@ template <class Context> template <typename T>
void RandomPickOp<Context>::RunWithType() { void RandomPickOp<Context>::RunWithType() {
auto* indices = pick_indices->template mutable_data<int, CPUContext>(); auto* indices = pick_indices->template mutable_data<int, CPUContext>();
for (int i = 0; i < pick_indices->count(); i++) for (int i = 0; i < pick_indices->count(); i++)
indices[i] = int((*ctx().rand_generator())() % x_slice_dim); indices[i] = int((*ctx()->rand_generator())() % x_slice_dim);
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
indices = pick_indices->template mutable_data<int, Context>(); indices = pick_indices->template mutable_data<int, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Gather<T, Context>( kernel::Gather<T, Context>(Output(0)->count(),
Output(0)->count(), outer_dim, inner_dim, outer_dim, inner_dim, x_slice_dim, y_slice_dim,
x_slice_dim, y_slice_dim, indices, Xdata, Ydata); indices, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -39,7 +39,7 @@ void RandomPickOp<Context>::RunOnDevice() { ...@@ -39,7 +39,7 @@ void RandomPickOp<Context>::RunOnDevice() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
Output(1)->ReshapeLike(*pick_indices); Output(1)->ReshapeLike(*pick_indices);
Output(1)->template CopyFrom<Context>(*pick_indices); Output(1)->template CopyFrom<Context>(*pick_indices, ctx());
} }
} }
...@@ -55,11 +55,11 @@ void RandomPickGradientOp<Context>::RunWithType() { ...@@ -55,11 +55,11 @@ void RandomPickGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), 0, dXdata); math::Set<T, Context>(Output(0)->count(), 0, dXdata, ctx());
kernel::GatherGrad<T, Context>( kernel::GatherGrad<T, Context>(Input(-1).count(),
Input(-1).count(), outer_dim, inner_dim, outer_dim, inner_dim, x_slice_dim, y_slice_dim,
x_slice_dim, y_slice_dim, indices, dYdata, dXdata); indices, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -8,14 +8,17 @@ namespace dragon { ...@@ -8,14 +8,17 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void ReduceOp<Context>::SumRunWithType() { void ReduceOp<Context>::SumRunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>();
if (axis == -1) { if (axis == -1) {
DECLARE_MULTIPLIER(multiplier, Input(0).count()); DECLARE_MULTIPLIER(multiplier, Input(0).count());
auto* Ydata = Output(0)->template mutable_data<T, CPUContext>(); T result_host;
Ydata[0] = math::Dot<T, Context>( math::Dot<T, Context>(Input(0).count(),
Input(0).count(), multiplier, Xdata, &ctx()); multiplier, Xdata, &result_host, ctx());
ctx()->template Copy<T, Context, CPUContext>(
1, Ydata, &result_host);
} else { } else {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); kernel::Sum<T, Context>(count,
kernel::Sum<T, Context>(count, axis_dim, inner_dim, Xdata, Ydata); axis_dim, inner_dim, Xdata, Ydata, ctx());
} }
} }
...@@ -24,7 +27,7 @@ void ReduceOp<Context>::MeanRunWithType() { ...@@ -24,7 +27,7 @@ void ReduceOp<Context>::MeanRunWithType() {
SumRunWithType<T>(); SumRunWithType<T>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
T coeff = axis != -1 ? 1.0 / axis_dim : 1.0 / Input(0).count(); T coeff = axis != -1 ? 1.0 / axis_dim : 1.0 / Input(0).count();
math::Scal<T, Context>(Output(0)->count(), coeff, Ydata, &ctx()); math::Scal<T, Context>(Output(0)->count(), coeff, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -62,11 +65,12 @@ void ReduceGradientOp<Context>::SumRunWithType() { ...@@ -62,11 +65,12 @@ void ReduceGradientOp<Context>::SumRunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
if (axis == -1) { if (axis == -1) {
auto* dYdata = Input(-1).template data<T, CPUContext>(); auto* dYdata = Input(-1).template data<T, CPUContext>();
math::Set<T, Context>(Output(0)->count(), dYdata[0], dXdata); math::Set<T, Context>(Output(0)->count(),
dYdata[0], dXdata, ctx());
} else { } else {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(count, kernel::SumGrad<T, Context>(count,
axis_dim, inner_dim, 1.0, dYdata, dXdata); axis_dim, inner_dim, 1.0, dYdata, dXdata, ctx());
} }
} }
...@@ -76,11 +80,12 @@ void ReduceGradientOp<Context>::MeanRunWithType() { ...@@ -76,11 +80,12 @@ void ReduceGradientOp<Context>::MeanRunWithType() {
if (axis == -1) { if (axis == -1) {
auto* dYdata = Input(-1).template data<T, CPUContext>(); auto* dYdata = Input(-1).template data<T, CPUContext>();
math::Set<T, Context>(Output(0)->count(), math::Set<T, Context>(Output(0)->count(),
dYdata[0] / Input(0).count(), dXdata); dYdata[0] / Input(0).count(), dXdata, ctx());
} else { } else {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(count, kernel::SumGrad<T, Context>(count,
axis_dim, inner_dim, 1.0 / axis_dim, dYdata, dXdata); axis_dim, inner_dim, 1.0 / axis_dim,
dYdata, dXdata, ctx());
} }
} }
......
...@@ -10,7 +10,7 @@ void RepeatOp<Context>::RunWithType() { ...@@ -10,7 +10,7 @@ void RepeatOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::Repeat<T, Context>( kernel::Repeat<T, Context>(
Output(0)->count(), outer_dim, dim, Output(0)->count(), outer_dim, dim,
inner_dim, repeats(), Xdata, Ydata); inner_dim, repeats(), Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -44,7 +44,7 @@ void RepeatGradientOp<Context>::RunWithType() { ...@@ -44,7 +44,7 @@ void RepeatGradientOp<Context>::RunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
kernel::RepeatGrad<T, Context>( kernel::RepeatGrad<T, Context>(
Output(0)->count(), outer_dim, dim, inner_dim, Output(0)->count(), outer_dim, dim, inner_dim,
repeats(), dYdata, dXdata, &ctx()); repeats(), dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -10,8 +10,9 @@ void SliceOp<Context>::RunWithType() { ...@@ -10,8 +10,9 @@ void SliceOp<Context>::RunWithType() {
for (int i = 0; i < nout; i++) { for (int i = 0; i < nout; i++) {
auto* Ydata = Output(i)->template mutable_data<T, Context>(); auto* Ydata = Output(i)->template mutable_data<T, Context>();
TIndex count = Output(i)->count(); TIndex count = Output(i)->count();
kernel::Slice<T, Context>(count, outer_dim, inner_dim, kernel::Slice<T, Context>(count,
x_slice_dim, y_slice_dim, slice_offset, Xdata, Ydata); outer_dim, inner_dim, x_slice_dim, y_slice_dim,
slice_offset, Xdata, Ydata, ctx());
slice_offset += y_slice_dim; slice_offset += y_slice_dim;
} }
} }
...@@ -46,8 +47,9 @@ void SliceGradientOp<Context>::RunWithType() { ...@@ -46,8 +47,9 @@ void SliceGradientOp<Context>::RunWithType() {
if (Input(i + 1).name() == "ignore") continue; if (Input(i + 1).name() == "ignore") continue;
auto* dYdata = Input(i + 1).template data<T, Context>(); auto* dYdata = Input(i + 1).template data<T, Context>();
TIndex count = Input(i + 1).count(); TIndex count = Input(i + 1).count();
kernel::SliceGrad<T, Context>(count, outer_dim, inner_dim, kernel::SliceGrad<T, Context>(count,
x_slice_dim, y_slice_dim, slice_offset, dYdata, dXdata); outer_dim, inner_dim, x_slice_dim, y_slice_dim,
slice_offset, dYdata, dXdata, ctx());
slice_offset += y_slice_dim; slice_offset += y_slice_dim;
} }
} }
......
...@@ -14,7 +14,7 @@ void StackOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void StackOp<Context>::RunWithType() {
kernel::Concat<T, Context>( kernel::Concat<T, Context>(
count, outer_dim, inner_dim, count, outer_dim, inner_dim,
x_concat_dim, y_concat_dim, x_concat_dim, y_concat_dim,
concat_offset, Xdata, Ydata); concat_offset, Xdata, Ydata, ctx());
concat_offset += x_concat_dim; concat_offset += x_concat_dim;
} }
} }
...@@ -59,7 +59,7 @@ void StackGradientOp<Context>::RunWithType() { ...@@ -59,7 +59,7 @@ void StackGradientOp<Context>::RunWithType() {
kernel::ConcatGrad<T, Context>( kernel::ConcatGrad<T, Context>(
count, outer_dim, inner_dim, count, outer_dim, inner_dim,
x_concat_dim, y_concat_dim, x_concat_dim, y_concat_dim,
concat_offset, dYdata, dXdata); concat_offset, dYdata, dXdata, ctx());
} }
concat_offset += x_concat_dim; concat_offset += x_concat_dim;
} }
......
...@@ -22,7 +22,7 @@ void TileOp<Context>::TileRunWithType() { ...@@ -22,7 +22,7 @@ void TileOp<Context>::TileRunWithType() {
kernel::Tile<T, Context>(dest->count(), kernel::Tile<T, Context>(dest->count(),
outer_dim, ex_inner_dim, outer_dim, ex_inner_dim,
multiple, Xdata, Ydata); multiple, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -35,7 +35,7 @@ void TileOp<Context>::RunOnDevice() { ...@@ -35,7 +35,7 @@ void TileOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
Output(0)->template CopyFrom<Context>(Input(0)); Output(0)->template CopyFrom<Context>(Input(0), ctx());
return; return;
} }
...@@ -48,6 +48,7 @@ void TileOp<Context>::RunOnDevice() { ...@@ -48,6 +48,7 @@ void TileOp<Context>::RunOnDevice() {
axis = task.second; multiple = task.first; axis = task.second; multiple = task.first;
if (XIsType(Input(0), float)) TileRunWithType<float>(); if (XIsType(Input(0), float)) TileRunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
ctx()->FinishDeviceCompution();
// allow buffer to protect X if the num of tasks >= 2 // allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest); std::swap(source, dest);
if (process_axes.size() % 2 == 1) { if (process_axes.size() % 2 == 1) {
...@@ -82,7 +83,7 @@ void TileGradientOp<Context>::TileRunWithType() { ...@@ -82,7 +83,7 @@ void TileGradientOp<Context>::TileRunWithType() {
kernel::TileGrad<T, Context>( kernel::TileGrad<T, Context>(
dest->count(), outer_dim, ex_inner_dim, dest->count(), outer_dim, ex_inner_dim,
multiple, dYdata, dXdata, &ctx()); multiple, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -96,7 +97,7 @@ void TileGradientOp<Context>::RunOnDevice() { ...@@ -96,7 +97,7 @@ void TileGradientOp<Context>::RunOnDevice() {
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
Output(0)->ReshapeLike(Input(-1)); Output(0)->ReshapeLike(Input(-1));
Output(0)->template CopyFrom<Context>(Input(-1)); Output(0)->template CopyFrom<Context>(Input(-1), ctx());
return; return;
} }
...@@ -109,6 +110,7 @@ void TileGradientOp<Context>::RunOnDevice() { ...@@ -109,6 +110,7 @@ void TileGradientOp<Context>::RunOnDevice() {
axis = task.second; multiple = task.first; axis = task.second; multiple = task.first;
if (XIsType(Input(0), float)) TileRunWithType<float>(); if (XIsType(Input(0), float)) TileRunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
ctx()->FinishDeviceCompution();
// allow buffer to protect X if the num of tasks >= 2 // allow buffer to protect X if the num of tasks >= 2
std::swap(source, dest); std::swap(source, dest);
if (process_axes.size() % 2 == 1) { if (process_axes.size() % 2 == 1) {
......
...@@ -14,7 +14,7 @@ void TransposeOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void TransposeOp<Context>::RunWithType() {
kernel::Transpose<T, Context>( kernel::Transpose<T, Context>(
Output(0)->count(), (int)Output(0)->ndim(), Output(0)->count(), (int)Output(0)->ndim(),
ORdata, OSdata, NSdata, Xdata, Ydata); ORdata, OSdata, NSdata, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -75,7 +75,7 @@ void TransposeGradientOp<Context>::RunWithType() { ...@@ -75,7 +75,7 @@ void TransposeGradientOp<Context>::RunWithType() {
kernel::TransposeGrad<T, Context>( kernel::TransposeGrad<T, Context>(
Input(-1).count(), order->count(), Input(-1).count(), order->count(),
ORdata, OSdata, NSdata, dYdata, dXdata); ORdata, OSdata, NSdata, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -20,23 +20,23 @@ void BatchNormOp<Context>::TrainingRunWithType() { ...@@ -20,23 +20,23 @@ void BatchNormOp<Context>::TrainingRunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata);
// compute mean // compute mean
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / NS, Xdata, MXmult, 1.0 / NS, Xdata, MXmult,
0, NCdata, &ctx()); 0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0, Tmean, &ctx()); 0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0 / NS, Xdata, MXmult, 1.0 / NS, Xdata, MXmult,
0, Tmean, &ctx()); 0, Tmean, ctx());
} }
// subtract mean // subtract mean
...@@ -45,37 +45,37 @@ void BatchNormOp<Context>::TrainingRunWithType() { ...@@ -45,37 +45,37 @@ void BatchNormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
-1.0, NCdata, MXmult, -1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
-1.0, MXmult, Tmean, -1.0, MXmult, Tmean,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
// compute variance // compute variance
// note that we use VAR(X) = E((X - EX) ^ 2) // note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(Output(0)->count(), Ydata, WSdata); math::Square<T, Context>(Output(0)->count(), Ydata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / NS, WSdata, MXmult, 1.0 / NS, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0 / NS, WSdata, MXmult, 1.0 / NS, WSdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} }
// compute moving average // compute moving average
...@@ -92,21 +92,21 @@ void BatchNormOp<Context>::TrainingRunWithType() { ...@@ -92,21 +92,21 @@ void BatchNormOp<Context>::TrainingRunWithType() {
float coeff = m > 1 ? float(m) / (m - 1) : 1; float coeff = m > 1 ? float(m) / (m - 1) : 1;
// History(X) = Cur(X) + momentum * History(X) // History(X) = Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean.count(), math::Axpby<T, Context>(mean.count(),
1.0, Tmean, momentum, Hmean, &ctx()); 1.0, Tmean, momentum, Hmean, ctx());
math::Axpby<T, Context>(var->count(), math::Axpby<T, Context>(var->count(),
coeff, Tvar, momentum, Hvar, &ctx()); coeff, Tvar, momentum, Hvar, ctx());
} else { } else {
// History(X) = (1 - momentum) * Cur(X) + momentum * History(X) // History(X) = (1 - momentum) * Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean.count(), math::Axpby<T, Context>(mean.count(),
1.0 - momentum, Tmean, momentum, Hmean, &ctx()); 1.0 - momentum, Tmean, momentum, Hmean, ctx());
math::Axpby<T, Context>(var->count(), math::Axpby<T, Context>(var->count(),
1.0 - momentum, Tvar, momentum, Hvar, &ctx()); 1.0 - momentum, Tvar, momentum, Hvar, ctx());
} }
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -114,20 +114,21 @@ void BatchNormOp<Context>::TrainingRunWithType() { ...@@ -114,20 +114,21 @@ void BatchNormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -145,7 +146,7 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -145,7 +146,7 @@ void BatchNormOp<Context>::InferenceRunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata);
// scale the mean and variance if necessary // scale the mean and variance if necessary
if (mode == "CAFFE") { if (mode == "CAFFE") {
...@@ -156,12 +157,12 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -156,12 +157,12 @@ void BatchNormOp<Context>::InferenceRunWithType() {
const float factor = dragon_cast<float, T>(hFact_data[0]); const float factor = dragon_cast<float, T>(hFact_data[0]);
const float scale = factor == 0 ? 0 : 1.0 / factor; const float scale = factor == 0 ? 0 : 1.0 / factor;
math::Scale<T, Context>(mean.count(), math::Scale<T, Context>(mean.count(),
scale, Hmean, Tmean, &ctx()); scale, Hmean, Tmean, ctx());
math::Scale<T, Context>(var->count(), math::Scale<T, Context>(var->count(),
scale, Hvar, Tvar, &ctx()); scale, Hvar, Tvar, ctx());
} else { } else {
ctx().template Copy<T, Context, Context>(mean.count(), Tmean, Hmean); ctx()->template Copy<T, Context, Context>(mean.count(), Tmean, Hmean);
ctx().template Copy<T, Context, Context>(var->count(), Tvar, Hvar); ctx()->template Copy<T, Context, Context>(var->count(), Tvar, Hvar);
} }
// subtract mean // subtract mean
...@@ -170,23 +171,23 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -170,23 +171,23 @@ void BatchNormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
-1.0, NCdata, MXmult, -1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
-1.0, MXmult, Tmean, -1.0, MXmult, Tmean,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -194,20 +195,21 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -194,20 +195,21 @@ void BatchNormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -246,10 +248,7 @@ void BatchNormOp<Context>::RunOnDevice() { ...@@ -246,10 +248,7 @@ void BatchNormOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
if (use_global_stats) InferenceRunWithType<float>(); if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>(); else TrainingRunWithType<float>();
} else if (XIsType(Input(0), float16)) { } else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(BatchNorm); DEPLOY_CPU(BatchNorm);
...@@ -273,97 +272,100 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -273,97 +272,100 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
auto* Ydata = Input(1).template data<T, Context>(); auto* Ydata = Input(1).template data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), Ydata, dYdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
Ydata, dYdata, dXdata, ctx());
// sum(dE/dY \cdot Y) // sum(dE/dY \cdot Y)
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} }
// sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY \cdot Y) \cdot Y
math::Mul<T, Context>(Output(0)->count(), Ydata, dXdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
Ydata, dXdata, dXdata, ctx());
// sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} }
// dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y // dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y
// = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y) // = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y)
math::Axpby<T, Context>(Output(0)->count(), math::Axpby<T, Context>(Output(0)->count(),
1.0, dYdata, -1.0 / NS, dXdata, &ctx()); 1.0, dYdata, -1.0 / NS, dXdata, ctx());
// divide by stddev // divide by stddev
math::Div<T, Context>(Output(0)->count(), dXdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dXdata, WSdata, dXdata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -381,21 +383,22 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -381,21 +383,22 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), dYdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dYdata, WSdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -430,10 +433,7 @@ void BatchNormGradientOp<Context>::RunOnDevice() { ...@@ -430,10 +433,7 @@ void BatchNormGradientOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
if (use_global_stats) InferenceRunWithType<float>(); if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>(); else TrainingRunWithType<float>();
} else if (XIsType(Input(0), float16)) { } else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(BatchNormGradient); DEPLOY_CPU(BatchNormGradient);
......
...@@ -20,7 +20,7 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -20,7 +20,7 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata);
auto* Td = d.template mutable_data<T, Context>(); auto* Td = d.template mutable_data<T, Context>();
auto* Tr = r->template mutable_data<T, Context>(); auto* Tr = r->template mutable_data<T, Context>();
...@@ -35,11 +35,11 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -35,11 +35,11 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
auto* hFact_data = Input(3).template mutable_data<T, CPUContext>(); auto* hFact_data = Input(3).template mutable_data<T, CPUContext>();
const float factor = dragon_cast<float, T>(hFact_data[0]); const float factor = dragon_cast<float, T>(hFact_data[0]);
const float scale = factor == 0 ? 0 : 1.0 / factor; const float scale = factor == 0 ? 0 : 1.0 / factor;
math::Scale<T, Context>(mean.count(), scale, Hmean, THmean, &ctx()); math::Scale<T, Context>(mean.count(), scale, Hmean, THmean, ctx());
math::Scale<T, Context>(mean.count(), scale, Hvar, THvar, &ctx()); math::Scale<T, Context>(mean.count(), scale, Hvar, THvar, ctx());
} else { } else {
ctx().template Copy<T, Context, Context>(mean.count(), THmean, Hmean); ctx()->template Copy<T, Context, Context>(mean.count(), THmean, Hmean);
ctx().template Copy<T, Context, Context>(var->count(), THvar, Hvar); ctx()->template Copy<T, Context, Context>(var->count(), THvar, Hvar);
} }
// compute mean // compute mean
...@@ -47,16 +47,16 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -47,16 +47,16 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / NS, Xdata, MXmult, 1.0 / NS, Xdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0 / NS, Xdata, MXmult, 1.0 / NS, Xdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} }
// subtract mean // subtract mean
...@@ -65,37 +65,37 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -65,37 +65,37 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
-1.0, NCdata, MXmult, -1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
-1.0, MXmult, Tmean, -1.0, MXmult, Tmean,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
// compute variance // compute variance
// note that we use VAR(X) = E((X - EX) ^ 2) // note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(Output(0)->count(), Ydata, WSdata); math::Square<T, Context>(Output(0)->count(), Ydata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / NS, WSdata, MXmult, 1.0 / NS, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0 / NS, WSdata, MXmult, 1.0 / NS, WSdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} }
// compute moving average // compute moving average
...@@ -112,21 +112,21 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -112,21 +112,21 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
float coeff = m > 1 ? float(m) / (m - 1) : 1; float coeff = m > 1 ? float(m) / (m - 1) : 1;
// History(X) = Cur(X) + momentum * History(X) // History(X) = Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean.count(), math::Axpby<T, Context>(mean.count(),
1.0, Tmean, momentum, Hmean, &ctx()); 1.0, Tmean, momentum, Hmean, ctx());
math::Axpby<T, Context>(var->count(), math::Axpby<T, Context>(var->count(),
coeff, Tvar, momentum, Hvar, &ctx()); coeff, Tvar, momentum, Hvar, ctx());
} else { } else {
// History(X) = (1 - momentum) * Cur(X) + momentum * History(X) // History(X) = (1 - momentum) * Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean.count(), math::Axpby<T, Context>(mean.count(),
1.0 - momentum, Tmean, momentum, Hmean, &ctx()); 1.0 - momentum, Tmean, momentum, Hmean, ctx());
math::Axpby<T, Context>(var->count(), math::Axpby<T, Context>(var->count(),
1.0 - momentum, Tvar, momentum, Hvar, &ctx()); 1.0 - momentum, Tvar, momentum, Hvar, ctx());
} }
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -134,35 +134,36 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -134,35 +134,36 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// compute renorm // compute renorm
if (!is_recomputing) { if (!is_recomputing) {
// compute history stddev // compute history stddev
math::AddScalar<T, Context>(var->count(), eps, THvar); math::AddScalar<T, Context>(var->count(), eps, THvar, ctx());
math::Sqrt<T, Context>(var->count(), THvar, THvar); math::Sqrt<T, Context>(var->count(), THvar, THvar, ctx());
// compute r // compute r
math::Div<T, Context>(var->count(), Tvar, THvar, Tr); math::Div<T, Context>(var->count(), Tvar, THvar, Tr, ctx());
math::Clip<T, Context>(var->count(), 1.0 / t_r_max, t_r_max, Tr); math::Clip<T, Context>(var->count(), 1.0 / t_r_max, t_r_max, Tr, ctx());
// compute d // compute d
math::Sub<T, Context>(mean.count(), Tmean, THmean, Td); math::Sub<T, Context>(mean.count(), Tmean, THmean, Td, ctx());
math::Div<T, Context>(mean.count(), Td, THvar, Td); math::Div<T, Context>(mean.count(), Td, THvar, Td, ctx());
math::Clip<T, Context>(mean.count(), -t_d_max, t_d_max, Td); math::Clip<T, Context>(mean.count(), -t_d_max, t_d_max, Td, ctx());
// update the bound of r & d // update the bound of r & d
t_r_max = r_max / (1.0 + (r_max - 1.0) * exp(-t_val)); t_r_max = r_max / (1.0 + (r_max - 1.0) * exp(-t_val));
...@@ -173,7 +174,7 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -173,7 +174,7 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
// apply renorm // apply renorm
// store x_norm for backward // store x_norm for backward
auto* XNorm_data = x_norm->template mutable_data<T, Context>(); auto* XNorm_data = x_norm->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), XNorm_data, Ydata); Output(0)->count(), XNorm_data, Ydata);
// correction: mul by r // correction: mul by r
...@@ -182,20 +183,21 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -182,20 +183,21 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tr, 1.0, MXmult, Tr,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tr, 1.0, MXmult, Tr,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Mul<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// correction: add by d // correction: add by d
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -203,18 +205,18 @@ void BatchRenormOp<Context>::TrainingRunWithType() { ...@@ -203,18 +205,18 @@ void BatchRenormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Td, 1.0, MXmult, Td,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Td, 1.0, MXmult, Td,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
} }
...@@ -233,7 +235,7 @@ void BatchRenormOp<Context>::InferenceRunWithType() { ...@@ -233,7 +235,7 @@ void BatchRenormOp<Context>::InferenceRunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata);
// scale the mean and variance if necessary // scale the mean and variance if necessary
if (mode == "CAFFE") { if (mode == "CAFFE") {
...@@ -243,11 +245,11 @@ void BatchRenormOp<Context>::InferenceRunWithType() { ...@@ -243,11 +245,11 @@ void BatchRenormOp<Context>::InferenceRunWithType() {
auto* hFact_data = Input(3).template mutable_data<T, CPUContext>(); auto* hFact_data = Input(3).template mutable_data<T, CPUContext>();
const float factor = dragon_cast<float, T>(hFact_data[0]); const float factor = dragon_cast<float, T>(hFact_data[0]);
const float scale = factor == 0 ? 0 : 1.0 / factor; const float scale = factor == 0 ? 0 : 1.0 / factor;
math::Scale<T, Context>(mean.count(), scale, Hmean, Tmean, &ctx()); math::Scale<T, Context>(mean.count(), scale, Hmean, Tmean, ctx());
math::Scale<T, Context>(var->count(), scale, Hvar, Tvar, &ctx()); math::Scale<T, Context>(var->count(), scale, Hvar, Tvar, ctx());
} else { } else {
ctx().template Copy<T, Context, Context>(mean.count(), Tmean, Hmean); ctx()->template Copy<T, Context, Context>(mean.count(), Tmean, Hmean);
ctx().template Copy<T, Context, Context>(var->count(), Tvar, Hvar); ctx()->template Copy<T, Context, Context>(var->count(), Tvar, Hvar);
} }
// subtract mean // subtract mean
...@@ -256,22 +258,22 @@ void BatchRenormOp<Context>::InferenceRunWithType() { ...@@ -256,22 +258,22 @@ void BatchRenormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, NC, S, 1, CblasNoTrans, CblasNoTrans, NC, S, 1,
-1.0, NCdata, MXmult, -1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
-1.0, MXmult, Tmean, -1.0, MXmult, Tmean,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -279,20 +281,21 @@ void BatchRenormOp<Context>::InferenceRunWithType() { ...@@ -279,20 +281,21 @@ void BatchRenormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -366,93 +369,96 @@ void BatchRenormGradientOp<Context>::TrainingRunWithType() { ...@@ -366,93 +369,96 @@ void BatchRenormGradientOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tr, 1.0, MXmult, Tr,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NWHC") { } else if (data_format == "NWHC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tr, 1.0, MXmult, Tr,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), dYdata, WSdata, WSdata); math::Mul<T, Context>(Output(0)->count(),
dYdata, WSdata, WSdata, ctx());
// sum(dE/dY \cdot Y) // sum(dE/dY \cdot Y)
math::Mul<T, Context>(Output(0)->count(), XNorm_data, WSdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
XNorm_data, WSdata, dXdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} }
// sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY \cdot Y) \cdot Y
math::Mul<T, Context>(Output(0)->count(), XNorm_data, dXdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
XNorm_data, dXdata, dXdata, ctx());
// sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} }
// dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y // dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y
// = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y) // = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y)
math::Axpby<T, Context>(Output(0)->count(), math::Axpby<T, Context>(Output(0)->count(),
1.0, WSdata, -1.0 / NS, dXdata, &ctx()); 1.0, WSdata, -1.0 / NS, dXdata, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -460,21 +466,24 @@ void BatchRenormGradientOp<Context>::TrainingRunWithType() { ...@@ -460,21 +466,24 @@ void BatchRenormGradientOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), dXdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dXdata, WSdata, dXdata, ctx());
ctx()->FinishDeviceCompution();
x_norm->Reset(); x_norm->Reset();
} }
...@@ -493,21 +502,22 @@ void BatchRenormGradientOp<Context>::InferenceRunWithType() { ...@@ -493,21 +502,22 @@ void BatchRenormGradientOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), dYdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dYdata, WSdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -10,6 +10,8 @@ namespace dragon { ...@@ -10,6 +10,8 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNBatchNormOp<Context>::RunWithType() { void CuDNNBatchNormOp<Context>::RunWithType() {
typedef typename CUDNNType<T>::BNParamType BNParamType;
// determine the bn desc // determine the bn desc
if (Input(0).ndim() == 2) { if (Input(0).ndim() == 2) {
bn_mode = CUDNN_BATCHNORM_PER_ACTIVATION; bn_mode = CUDNN_BATCHNORM_PER_ACTIVATION;
...@@ -54,32 +56,32 @@ void CuDNNBatchNormOp<Context>::RunWithType() { ...@@ -54,32 +56,32 @@ void CuDNNBatchNormOp<Context>::RunWithType() {
// derive the bn desc // derive the bn desc
CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(bn_desc, input_desc, bn_mode)); CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(bn_desc, input_desc, bn_mode));
TENSOR_FILL(Input(1), vector<TIndex>(1, C)); // history_mean TENSOR_FILL_WITH_TYPE(Input(1), vector<TIndex>(1, C), BNParamType); // history_mean
TENSOR_FILL(Input(2), vector<TIndex>(1, C)); // history_var TENSOR_FILL_WITH_TYPE(Input(2), vector<TIndex>(1, C), BNParamType); // history_var
TENSOR_FILL(Input(3), vector<TIndex>(1, C)); // scale TENSOR_FILL_WITH_TYPE(Input(3), vector<TIndex>(1, C), BNParamType); // scale
TENSOR_FILL(Input(4), vector<TIndex>(1, C)); // bias TENSOR_FILL_WITH_TYPE(Input(4), vector<TIndex>(1, C), BNParamType); // bias
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* Hmean = Input(1).template mutable_data<T, Context>(); auto* Hmean = Input(1).template mutable_data<BNParamType, Context>();
auto* Hvar = Input(2).template mutable_data<T, Context>(); auto* Hvar = Input(2).template mutable_data<BNParamType, Context>();
auto* Sdata = Input(3).template data<T, Context>(); auto* Sdata = Input(3).template data<BNParamType, Context>();
auto* Bdata = Input(4).template data<T, Context>(); auto* Bdata = Input(4).template data<BNParamType, Context>();
if (this->use_global_stats) { if (this->use_global_stats) {
CUDNN_CHECK(cudnnBatchNormalizationForwardInference( CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
ctx().cudnn_handle(), bn_mode, ctx()->cudnn_handle(), bn_mode,
CUDNNType<T>::one, CUDNNType<T>::zero, CUDNNType<T>::one, CUDNNType<T>::zero,
input_desc, Xdata, output_desc, Ydata, input_desc, Xdata, output_desc, Ydata,
bn_desc, Sdata, Bdata, bn_desc, Sdata, Bdata,
Hmean, Hvar, eps64)); Hmean, Hvar, eps64));
} else { } else {
auto* Tmean = mean->template mutable_data<T, Context>(); auto* Tmean = mean->template mutable_data<BNParamType, Context>();
auto* Tvar = var->template mutable_data<T, Context>(); auto* Tvar = var->template mutable_data<BNParamType, Context>();
auto mt = this->is_recomputing ? 0.0 : 1.0 - this->momentum; auto mt = this->is_recomputing ? 0.0 : 1.0 - this->momentum;
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
ctx().cudnn_handle(), bn_mode, ctx()->cudnn_handle(), bn_mode,
CUDNNType<T>::one, CUDNNType<T>::zero, CUDNNType<T>::one, CUDNNType<T>::zero,
input_desc, Xdata, output_desc, Ydata, input_desc, Xdata, output_desc, Ydata,
bn_desc, Sdata, Bdata, bn_desc, Sdata, Bdata,
...@@ -131,7 +133,10 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() { ...@@ -131,7 +133,10 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() {
#endif #endif
} }
REGISTER_CUDNN_OPERATOR(FusedBatchNorm, CuDNNBatchNormOp<CUDAContext>); REGISTER_CUDNN_OPERATOR(
FusedBatchNorm,
CuDNNBatchNormOp<CUDAContext>
);
INSTANTIATE_CUDNN_OPERATOR(BatchNorm); INSTANTIATE_CUDNN_OPERATOR(BatchNorm);
template <class Context> template <class Context>
...@@ -169,6 +174,8 @@ void CuDNNBatchNormGradientOp<Context>::Setup() { ...@@ -169,6 +174,8 @@ void CuDNNBatchNormGradientOp<Context>::Setup() {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() { void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() {
typedef typename CUDNNType<T>::BNParamType BNParamType;
// determine the bn desc // determine the bn desc
if (Input(0).ndim() == 2) { if (Input(0).ndim() == 2) {
bn_mode = CUDNN_BATCHNORM_PER_ACTIVATION; bn_mode = CUDNN_BATCHNORM_PER_ACTIVATION;
...@@ -218,14 +225,14 @@ void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -218,14 +225,14 @@ void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Sdata = Input(3).template data<T, Context>(); auto* Sdata = Input(3).template data<BNParamType, Context>();
auto* dSdata = Output(1)->template mutable_data<T, Context>(); auto* dSdata = Output(1)->template mutable_data<BNParamType, Context>();
auto* dBdata = Output(2)->template mutable_data<T, Context>(); auto* dBdata = Output(2)->template mutable_data<BNParamType, Context>();
auto* Tmean = mean->template data<T, Context>(); auto* Tmean = mean->template data<BNParamType, Context>();
auto* Tvar = var->template data<T, Context>(); auto* Tvar = var->template data<BNParamType, Context>();
CUDNN_CHECK(cudnnBatchNormalizationBackward( CUDNN_CHECK(cudnnBatchNormalizationBackward(
ctx().cudnn_handle(), bn_mode, ctx()->cudnn_handle(), bn_mode,
CUDNNType<T>::one, CUDNNType<T>::zero, CUDNNType<T>::one, CUDNNType<T>::zero,
CUDNNType<T>::one, CUDNNType<T>::one, CUDNNType<T>::one, CUDNNType<T>::one,
output_desc, Xdata, input_desc, dYdata, output_desc, Xdata, input_desc, dYdata,
...@@ -256,16 +263,16 @@ void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -256,16 +263,16 @@ void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} }
} }
...@@ -275,12 +282,12 @@ void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -275,12 +282,12 @@ void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() {
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
// compute stddev // compute stddev
ctx().template Copy<T, Context, Context>(var->count(), Tvar, Hvar); ctx()->template Copy<T, Context, Context>(var->count(), Tvar, Hvar);
math::AddScalar<T, Context>(var->count(), this->eps, Tvar); math::AddScalar<T, Context>(var->count(), this->eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide scale by stddev // divide scale by stddev
math::Div<T, Context>(var->count(), Sdata, Tvar, Tvar); math::Div<T, Context>(var->count(), Sdata, Tvar, Tvar, ctx());
// compute dE/dY \cot (scale / std(X)) // compute dE/dY \cot (scale / std(X))
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -288,20 +295,21 @@ void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -288,20 +295,21 @@ void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), dYdata, WSdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
dYdata, WSdata, dXdata, ctx());
} }
} }
...@@ -314,8 +322,10 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() { ...@@ -314,8 +322,10 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() {
if (this->use_global_stats) InferenceRunWithType<float>(); if (this->use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>(); else TrainingRunWithType<float>();
} else if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
if (this->use_global_stats) InferenceRunWithType<float16>(); if (this->use_global_stats) {
else TrainingRunWithType<float16>(); // fp16 is disabled during inference
LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
} else TrainingRunWithType<float16>();
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); } else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
#else #else
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
...@@ -325,7 +335,10 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() { ...@@ -325,7 +335,10 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() {
#endif #endif
} }
REGISTER_CUDNN_OPERATOR(FusedBatchNormGradient, CuDNNBatchNormGradientOp<CUDAContext>); REGISTER_CUDNN_OPERATOR(
FusedBatchNormGradient,
CuDNNBatchNormGradientOp<CUDAContext>
);
INSTANTIATE_CUDNN_OPERATOR(BatchNormGradient); INSTANTIATE_CUDNN_OPERATOR(BatchNormGradient);
} // namespace dragon } // namespace dragon
......
...@@ -24,23 +24,23 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -24,23 +24,23 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata);
// compute mean // compute mean
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / NS, Xdata, MXmult, 1.0 / NS, Xdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0 / NS, Xdata, MXmult, 1.0 / NS, Xdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} }
// subtract mean // subtract mean
...@@ -49,51 +49,51 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -49,51 +49,51 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
-1.0, NCdata, MXmult, -1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
-1.0, MXmult, Tmean, -1.0, MXmult, Tmean,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
// compute variance // compute variance
// note that we use VAR(X) = E((X - EX) ^ 2) // note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(Output(0)->count(), Ydata, WSdata); math::Square<T, Context>(Output(0)->count(), Ydata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / NS, WSdata, MXmult, 1.0 / NS, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0 / NS, WSdata, MXmult, 1.0 / NS, WSdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} }
// compute moving average // compute moving average
if (!is_recomputing) { if (!is_recomputing) {
// History(X) = (1 - momentum) * Cur(X) + momentum * History(X) // History(X) = (1 - momentum) * Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean->count(), math::Axpby<T, Context>(mean->count(),
1.0 - momentum, Tmean, momentum, Hmean, &ctx()); 1.0 - momentum, Tmean, momentum, Hmean, ctx());
math::Axpby<T, Context>(var->count(), math::Axpby<T, Context>(var->count(),
1.0 - momentum, Tvar, momentum, Hvar, &ctx()); 1.0 - momentum, Tvar, momentum, Hvar, ctx());
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -101,24 +101,25 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -101,24 +101,25 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// store x_norm for backward // store x_norm for backward
auto* XNorm_data = x_norm->template mutable_data<T, Context>(); auto* XNorm_data = x_norm->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), XNorm_data, Ydata); Output(0)->count(), XNorm_data, Ydata);
// scale // scale
...@@ -127,20 +128,21 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -127,20 +128,21 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Mul<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// shift // shift
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -148,18 +150,18 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -148,18 +150,18 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Bdata, 1.0, MXmult, Bdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Bdata, 1.0, MXmult, Bdata,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
} }
...@@ -182,9 +184,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -182,9 +184,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Input(0).count(), Ydata, Xdata);
ctx().template Copy<T, Context, Context>(mean->count(), Tmean, Hmean); ctx()->template Copy<T, Context, Context>(mean->count(), Tmean, Hmean);
ctx().template Copy<T, Context, Context>(var->count(), Tvar, Hvar); ctx()->template Copy<T, Context, Context>(var->count(), Tvar, Hvar);
// subtract mean // subtract mean
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -192,23 +194,23 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -192,23 +194,23 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
-1.0, NCdata, MXmult, -1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
-1.0, MXmult, Tmean, -1.0, MXmult, Tmean,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -216,20 +218,21 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -216,20 +218,21 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// scale // scale
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -237,20 +240,21 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -237,20 +240,21 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Mul<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// shift // shift
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -258,18 +262,18 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -258,18 +262,18 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Bdata, 1.0, MXmult, Bdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Bdata, 1.0, MXmult, Bdata,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
} }
...@@ -312,10 +316,7 @@ void FusedBatchNormOp<Context>::RunOnDevice() { ...@@ -312,10 +316,7 @@ void FusedBatchNormOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
if (use_global_stats) InferenceRunWithType<float>(); if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>(); else TrainingRunWithType<float>();
} else if (XIsType(Input(0), float16)) { } else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
...@@ -341,21 +342,22 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -341,21 +342,22 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
// gradient w.r.t. scale // gradient w.r.t. scale
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dSdata = Output(1)->template mutable_data<T, Context>(); auto* dSdata = Output(1)->template mutable_data<T, Context>();
math::Mul<T, Context>(x_norm->count(), XNorm_data, dYdata, WSdata); math::Mul<T, Context>(x_norm->count(),
XNorm_data, dYdata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dSdata, &ctx()); 1.0, dSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
1.0, dSdata, &ctx()); 1.0, dSdata, ctx());
} }
} }
...@@ -366,16 +368,16 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -366,16 +368,16 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} }
} }
...@@ -387,37 +389,39 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -387,37 +389,39 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(x_norm->count(), WSdata, dYdata, WSdata); math::Mul<T, Context>(x_norm->count(),
WSdata, dYdata, WSdata, ctx());
// sum of x_hat * (dl / dx_hat) // sum of x_hat * (dl / dx_hat)
math::Mul<T, Context>(x_norm->count(), XNorm_data, WSdata, dXdata); math::Mul<T, Context>(x_norm->count(),
XNorm_data, WSdata, dXdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} }
// x_hat times the sum // x_hat times the sum
...@@ -426,54 +430,55 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -426,54 +430,55 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} }
math::Mul<T, Context>(x_norm->count(), XNorm_data, dXdata, dXdata); math::Mul<T, Context>(x_norm->count(),
XNorm_data, dXdata, dXdata, ctx());
// subtract the average of x_hat times the sum // subtract the average of x_hat times the sum
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tmean, 1.0, MXmult, Tmean,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} }
math::Axpby<T, Context>(x_norm->count(), math::Axpby<T, Context>(x_norm->count(),
1.0, WSdata, -1.0 / NS, dXdata, &ctx()); 1.0, WSdata, -1.0 / NS, dXdata, ctx());
// multiply with the inverse std // multiply with the inverse std
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -481,21 +486,22 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -481,21 +486,22 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
// divide by stddev // divide by stddev
math::Div<T, Context>(x_norm->count(), dXdata, WSdata, dXdata); math::Div<T, Context>(x_norm->count(),
dXdata, WSdata, dXdata, ctx());
} }
} }
...@@ -519,16 +525,16 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -519,16 +525,16 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} }
} }
...@@ -538,7 +544,7 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -538,7 +544,7 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
// divide scale by stddev // divide scale by stddev
math::Div<T, Context>(var->count(), Sdata, Tvar, Tvar); math::Div<T, Context>(var->count(), Sdata, Tvar, Tvar, ctx());
// compute dE/dY \cot (scale / std(X)) // compute dE/dY \cot (scale / std(X))
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -546,20 +552,21 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -546,20 +552,21 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Tvar, 1.0, MXmult, Tvar,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), dYdata, WSdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
dYdata, WSdata, dXdata, ctx());
} }
} }
...@@ -599,10 +606,7 @@ void FusedBatchNormGradientOp<Context>::RunOnDevice() { ...@@ -599,10 +606,7 @@ void FusedBatchNormGradientOp<Context>::RunOnDevice() {
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
if (use_global_stats) InferenceRunWithType<float>(); if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>(); else TrainingRunWithType<float>();
} else if (XIsType(Input(0), float16)) { } else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(FusedBatchNormGradient); DEPLOY_CPU(FusedBatchNormGradient);
......
...@@ -21,14 +21,14 @@ void FusedGroupNormOp<Context>::RunWithType() { ...@@ -21,14 +21,14 @@ void FusedGroupNormOp<Context>::RunWithType() {
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata);
// compute mean // compute mean
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0 / CGS, Xdata, MXmult, 1.0 / CGS, Xdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -39,26 +39,26 @@ void FusedGroupNormOp<Context>::RunWithType() { ...@@ -39,26 +39,26 @@ void FusedGroupNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
-1.0, Tmean, MXmult, -1.0, Tmean, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
// compute variance // compute variance
// note that we use VAR(X) = E((X - EX) ^ 2) // note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(Output(0)->count(), Ydata, WSdata); math::Square<T, Context>(Output(0)->count(), Ydata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0 / CGS, WSdata, MXmult, 1.0 / CGS, WSdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -66,15 +66,16 @@ void FusedGroupNormOp<Context>::RunWithType() { ...@@ -66,15 +66,16 @@ void FusedGroupNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tvar, MXmult, 1.0, Tvar, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// store x_norm for backward // store x_norm for backward
auto* XNorm_data = x_norm->template mutable_data<T, Context>(); auto* XNorm_data = x_norm->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>( ctx()->template Copy<T, Context, Context>(
Output(0)->count(), XNorm_data, Ydata); Output(0)->count(), XNorm_data, Ydata);
// scale // scale
...@@ -83,20 +84,21 @@ void FusedGroupNormOp<Context>::RunWithType() { ...@@ -83,20 +84,21 @@ void FusedGroupNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Mul<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
// shift // shift
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -104,18 +106,18 @@ void FusedGroupNormOp<Context>::RunWithType() { ...@@ -104,18 +106,18 @@ void FusedGroupNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Bdata, 1.0, MXmult, Bdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Bdata, 1.0, MXmult, Bdata,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} }
} }
...@@ -157,8 +159,7 @@ void FusedGroupNormOp<Context>::RunOnDevice() { ...@@ -157,8 +159,7 @@ void FusedGroupNormOp<Context>::RunOnDevice() {
Setup(); Setup();
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
...@@ -184,21 +185,22 @@ void FusedGroupNormGradientOp<Context>::RunWithType() { ...@@ -184,21 +185,22 @@ void FusedGroupNormGradientOp<Context>::RunWithType() {
// gradient w.r.t. scale // gradient w.r.t. scale
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* dSdata = Output(1)->template mutable_data<T, Context>(); auto* dSdata = Output(1)->template mutable_data<T, Context>();
math::Mul<T, Context>(x_norm->count(), XNorm_data, dYdata, WSdata); math::Mul<T, Context>(x_norm->count(),
XNorm_data, dYdata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dSdata, &ctx()); 1.0, dSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
1.0, dSdata, &ctx()); 1.0, dSdata, ctx());
} }
} }
...@@ -209,16 +211,16 @@ void FusedGroupNormGradientOp<Context>::RunWithType() { ...@@ -209,16 +211,16 @@ void FusedGroupNormGradientOp<Context>::RunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, N, C, CblasTrans, N, C,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, NS, C, CblasTrans, NS, C,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
1.0, dBdata, &ctx()); 1.0, dBdata, ctx());
} }
} }
...@@ -230,28 +232,30 @@ void FusedGroupNormGradientOp<Context>::RunWithType() { ...@@ -230,28 +232,30 @@ void FusedGroupNormGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
N, C, 1, N, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, NCdata, &ctx()); 0.0, NCdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, NCdata, MXmult, 1.0, NCdata, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NS, C, 1, NS, C, 1,
1.0, MXmult, Sdata, 1.0, MXmult, Sdata,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} }
math::Mul<T, Context>(x_norm->count(), WSdata, dYdata, WSdata); math::Mul<T, Context>(x_norm->count(),
WSdata, dYdata, WSdata, ctx());
// sum of x_hat * (dl / dx_hat) // sum of x_hat * (dl / dx_hat)
math::Mul<T, Context>(x_norm->count(), XNorm_data, WSdata, dXdata); math::Mul<T, Context>(x_norm->count(),
XNorm_data, WSdata, dXdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -262,28 +266,29 @@ void FusedGroupNormGradientOp<Context>::RunWithType() { ...@@ -262,28 +266,29 @@ void FusedGroupNormGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tmean, MXmult, 1.0, Tmean, MXmult,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
math::Mul<T, Context>(x_norm->count(), XNorm_data, dXdata, dXdata); math::Mul<T, Context>(x_norm->count(),
XNorm_data, dXdata, dXdata, ctx());
// subtract the average of x_hat times the sum // subtract the average of x_hat times the sum
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0, WSdata, MXmult, 1.0, WSdata, MXmult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tmean, MXmult, 1.0, Tmean, MXmult,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
math::Axpby<T, Context>(x_norm->count(), math::Axpby<T, Context>(x_norm->count(),
1.0, WSdata, -1.0 / CGS, dXdata, &ctx()); 1.0, WSdata, -1.0 / CGS, dXdata, ctx());
// multiply with the inverse std // multiply with the inverse std
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -291,12 +296,13 @@ void FusedGroupNormGradientOp<Context>::RunWithType() { ...@@ -291,12 +296,13 @@ void FusedGroupNormGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tvar, MXmult, 1.0, Tvar, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
// divide by stddev // divide by stddev
math::Div<T, Context>(Output(0)->count(), dXdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dXdata, WSdata, dXdata, ctx());
} }
} }
...@@ -337,8 +343,7 @@ void FusedGroupNormGradientOp<Context>::RunOnDevice() { ...@@ -337,8 +343,7 @@ void FusedGroupNormGradientOp<Context>::RunOnDevice() {
Setup(); Setup();
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(FusedGroupNormGradient); DEPLOY_CPU(FusedGroupNormGradient);
......
...@@ -15,14 +15,14 @@ void GroupNormOp<Context>::RunWithType() { ...@@ -15,14 +15,14 @@ void GroupNormOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* NCdata = nc.template mutable_data<T, Context>(); auto* NCdata = nc.template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata);
// compute mean // compute mean
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0 / CGS, Xdata, MXmult, 1.0 / CGS, Xdata, MXmult,
0, Tmean, &ctx()); 0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -33,26 +33,26 @@ void GroupNormOp<Context>::RunWithType() { ...@@ -33,26 +33,26 @@ void GroupNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
-1.0, Tmean, MXmult, -1.0, Tmean, MXmult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
// compute variance // compute variance
// note that we use VAR(X) = E((X - EX) ^ 2) // note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(Output(0)->count(), Ydata, WSdata); math::Square<T, Context>(Output(0)->count(), Ydata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0 / CGS, WSdata, MXmult, 1.0 / CGS, WSdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -60,11 +60,12 @@ void GroupNormOp<Context>::RunWithType() { ...@@ -60,11 +60,12 @@ void GroupNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tvar, MXmult, 1.0, Tvar, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -102,8 +103,7 @@ void GroupNormOp<Context>::RunOnDevice() { ...@@ -102,8 +103,7 @@ void GroupNormOp<Context>::RunOnDevice() {
Setup(); Setup();
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(GroupNorm); DEPLOY_CPU(GroupNorm);
...@@ -127,43 +127,45 @@ void GroupNormGradientOp<Context>::RunWithType() { ...@@ -127,43 +127,45 @@ void GroupNormGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tvar, MXmult, 1.0, Tvar, MXmult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
auto* Ydata = Input(1).template data<T, Context>(); auto* Ydata = Input(1).template data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), Ydata, dYdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
Ydata, dYdata, dXdata, ctx());
// sum(dE/dY \cdot Y) // sum(dE/dY \cdot Y)
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0, dXdata, MXmult, 1.0, dXdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tvar, MXmult, 1.0, Tvar, MXmult,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
// sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY \cdot Y) \cdot Y
math::Mul<T, Context>(Output(0)->count(), Ydata, dXdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
Ydata, dXdata, dXdata, ctx());
// sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NG, CGS, CblasNoTrans, NG, CGS,
1.0, dYdata, MXmult, 1.0, dYdata, MXmult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NG, CGS, 1, NG, CGS, 1,
1.0, Tvar, MXmult, 1.0, Tvar, MXmult,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -171,10 +173,11 @@ void GroupNormGradientOp<Context>::RunWithType() { ...@@ -171,10 +173,11 @@ void GroupNormGradientOp<Context>::RunWithType() {
// dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y // dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y
// = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y) // = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y)
math::Axpby<T, Context>(Output(0)->count(), math::Axpby<T, Context>(Output(0)->count(),
1.0, dYdata, -1.0 / CGS, dXdata, &ctx()); 1.0, dYdata, -1.0 / CGS, dXdata, ctx());
// divide by stddev // divide by stddev
math::Div<T, Context>(Output(0)->count(), dXdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dXdata, WSdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -210,8 +213,7 @@ void GroupNormGradientOp<Context>::RunOnDevice() { ...@@ -210,8 +213,7 @@ void GroupNormGradientOp<Context>::RunOnDevice() {
Setup(); Setup();
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(GroupNormGradient); DEPLOY_CPU(GroupNormGradient);
......
...@@ -14,14 +14,14 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -14,14 +14,14 @@ void InstanceNormOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0]; auto* WSdata = ws()->template caches<T, Context>({ Input(0).count() })[0];
ctx().template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata); ctx()->template Copy<T, Context, Context>(Output(0)->count(), Ydata, Xdata);
// compute mean // compute mean
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / S, Xdata, Smult, 1.0 / S, Xdata, Smult,
0.0, Tmean, &ctx()); 0.0, Tmean, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
auto* x = Xdata; auto* x = Xdata;
auto* tm = Tmean; auto* tm = Tmean;
...@@ -29,7 +29,7 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -29,7 +29,7 @@ void InstanceNormOp<Context>::RunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, S, C, CblasTrans, S, C,
1.0 / S, x, Smult, 1.0 / S, x, Smult,
0.0, tm, &ctx()); 0.0, tm, ctx());
x += CS; x += CS;
tm += C; tm += C;
} }
...@@ -41,7 +41,7 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -41,7 +41,7 @@ void InstanceNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
-1.0, Tmean, Smult, -1.0, Tmean, Smult,
1.0, Ydata, &ctx()); 1.0, Ydata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
auto* y = Ydata; auto* y = Ydata;
auto* tm = Tmean; auto* tm = Tmean;
...@@ -50,7 +50,7 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -50,7 +50,7 @@ void InstanceNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
S, C, 1, S, C, 1,
-1.0, Smult, tm, -1.0, Smult, tm,
1.0, y, &ctx()); 1.0, y, ctx());
y += CS; y += CS;
tm += C; tm += C;
} }
...@@ -58,12 +58,12 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -58,12 +58,12 @@ void InstanceNormOp<Context>::RunWithType() {
// compute variance // compute variance
// note that we use VAR(X) = E((X - EX) ^ 2) // note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(Output(0)->count(), Ydata, WSdata); math::Square<T, Context>(Output(0)->count(), Ydata, WSdata, ctx());
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0 / S, WSdata, Smult, 1.0 / S, WSdata, Smult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
auto* x2 = WSdata; auto* x2 = WSdata;
auto* tv = Tvar; auto* tv = Tvar;
...@@ -71,15 +71,15 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -71,15 +71,15 @@ void InstanceNormOp<Context>::RunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, S, C, CblasTrans, S, C,
1.0 / S, x2, Smult, 1.0 / S, x2, Smult,
0.0, tv, &ctx()); 0.0, tv, ctx());
x2 += CS; x2 += CS;
tv += C; tv += C;
} }
} }
// compute stddev // compute stddev
math::AddScalar<T, Context>(var->count(), eps, Tvar); math::AddScalar<T, Context>(var->count(), eps, Tvar, ctx());
math::Sqrt<T, Context>(var->count(), Tvar, Tvar); math::Sqrt<T, Context>(var->count(), Tvar, Tvar, ctx());
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -87,7 +87,7 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -87,7 +87,7 @@ void InstanceNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, Tvar, Smult, 1.0, Tvar, Smult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
auto* std = WSdata; auto* std = WSdata;
auto* tv = Tvar; auto* tv = Tvar;
...@@ -96,12 +96,13 @@ void InstanceNormOp<Context>::RunWithType() { ...@@ -96,12 +96,13 @@ void InstanceNormOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
S, C, 1, S, C, 1,
1.0, Smult, tv, 1.0, Smult, tv,
0.0, std, &ctx()); 0.0, std, ctx());
std += CS; std += CS;
tv += C; tv += C;
} }
} }
math::Div<T, Context>(Output(0)->count(), Ydata, WSdata, Ydata); math::Div<T, Context>(Output(0)->count(),
Ydata, WSdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -133,8 +134,7 @@ void InstanceNormOp<Context>::RunOnDevice() { ...@@ -133,8 +134,7 @@ void InstanceNormOp<Context>::RunOnDevice() {
Setup(); Setup();
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(InstanceNorm); DEPLOY_CPU(InstanceNorm);
...@@ -157,7 +157,7 @@ void InstanceNormGradientOp<Context>::RunWithType() { ...@@ -157,7 +157,7 @@ void InstanceNormGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, Tvar, Smult, 1.0, Tvar, Smult,
0.0, WSdata, &ctx()); 0.0, WSdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
auto* std = WSdata; auto* std = WSdata;
auto* tv = Tvar; auto* tv = Tvar;
...@@ -166,26 +166,27 @@ void InstanceNormGradientOp<Context>::RunWithType() { ...@@ -166,26 +166,27 @@ void InstanceNormGradientOp<Context>::RunWithType() {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
S, C, 1, S, C, 1,
1.0, Smult, tv, 1.0, Smult, tv,
0.0, std, &ctx()); 0.0, std, ctx());
std += CS; std += CS;
tv += C; tv += C;
} }
} }
auto* Ydata = Input(-2).template data<T, Context>(); auto* Ydata = Input(-2).template data<T, Context>();
math::Mul<T, Context>(Output(0)->count(), Ydata, dYdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
Ydata, dYdata, dXdata, ctx());
// sum(dE/dY \cdot Y) // sum(dE/dY \cdot Y)
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dXdata, Smult, 1.0, dXdata, Smult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, Tvar, Smult, 1.0, Tvar, Smult,
0.0, dXdata, &ctx()); 0.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
auto* dx = dXdata; auto* dx = dXdata;
...@@ -194,12 +195,12 @@ void InstanceNormGradientOp<Context>::RunWithType() { ...@@ -194,12 +195,12 @@ void InstanceNormGradientOp<Context>::RunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, S, C, CblasTrans, S, C,
1.0, dx, Smult, 1.0, dx, Smult,
0, tv, &ctx()); 0, tv, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
S, C, 1, S, C, 1,
1.0, Smult, tv, 1.0, Smult, tv,
0.0, dx, &ctx()); 0.0, dx, ctx());
dx += CS; dx += CS;
tv += C; tv += C;
} }
...@@ -207,19 +208,20 @@ void InstanceNormGradientOp<Context>::RunWithType() { ...@@ -207,19 +208,20 @@ void InstanceNormGradientOp<Context>::RunWithType() {
} }
// sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY \cdot Y) \cdot Y
math::Mul<T, Context>(Output(0)->count(), Ydata, dXdata, dXdata); math::Mul<T, Context>(Output(0)->count(),
Ydata, dXdata, dXdata, ctx());
// sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y // sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, NC, S, CblasNoTrans, NC, S,
1.0, dYdata, Smult, 1.0, dYdata, Smult,
0.0, Tvar, &ctx()); 0.0, Tvar, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
NC, S, 1, NC, S, 1,
1.0, Tvar, Smult, 1.0, Tvar, Smult,
1.0, dXdata, &ctx()); 1.0, dXdata, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
auto* dy = dYdata; auto* dy = dYdata;
...@@ -229,12 +231,12 @@ void InstanceNormGradientOp<Context>::RunWithType() { ...@@ -229,12 +231,12 @@ void InstanceNormGradientOp<Context>::RunWithType() {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, S, C, CblasTrans, S, C,
1.0, dy, Smult, 1.0, dy, Smult,
0, tv, &ctx()); 0, tv, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
S, C, 1, S, C, 1,
1.0, Smult, tv, 1.0, Smult, tv,
1.0, dx, &ctx()); 1.0, dx, ctx());
dy += CS; dy += CS;
dx += CS; dx += CS;
tv += C; tv += C;
...@@ -245,10 +247,11 @@ void InstanceNormGradientOp<Context>::RunWithType() { ...@@ -245,10 +247,11 @@ void InstanceNormGradientOp<Context>::RunWithType() {
// dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y // dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y
// = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y) // = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y)
math::Axpby<T, Context>(Output(0)->count(), math::Axpby<T, Context>(Output(0)->count(),
1.0, dYdata, -1.0 / S, dXdata, &ctx()); 1.0, dYdata, -1.0 / S, dXdata, ctx());
// divide by stddev // divide by stddev
math::Div<T, Context>(Output(0)->count(), dXdata, WSdata, dXdata); math::Div<T, Context>(Output(0)->count(),
dXdata, WSdata, dXdata, ctx());
} }
template <class Context> template <class Context>
...@@ -279,8 +282,7 @@ void InstanceNormGradientOp<Context>::RunOnDevice() { ...@@ -279,8 +282,7 @@ void InstanceNormGradientOp<Context>::RunOnDevice() {
Setup(); Setup();
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
DEPLOY_CPU(InstanceNormGradient); DEPLOY_CPU(InstanceNormGradient);
......
...@@ -24,35 +24,28 @@ void L2NormOp<Context>::RunWithType() { ...@@ -24,35 +24,28 @@ void L2NormOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
auto* Bdata = ws()->template caches<T, Context>({ buffer.count() })[0]; auto* Bdata = ws()->template caches<T, Context>({ buffer.count() })[0];
auto* Ndata = norm->template mutable_data<T, Context>(); auto* Ndata = norm->template mutable_data<T, Context>();
math::Set<T, Context>(norm->count(), dragon_cast<T, float>(eps), Ndata); math::Set<T, Context>(norm->count(),
dragon_cast<T, float>(eps), Ndata, ctx());
for (int n = 0; n < outer_dim; n++) { for (int n = 0; n < outer_dim; n++) {
if (across_inner) { math::Square<T, Context>(buffer.count(),
auto* Ndata_ = norm->template mutable_data<float, CPUContext>(); Xdata, Bdata, ctx());
float sum_of_sqr = math::Dot<T, Context>(
buffer.count(), Xdata, Xdata, &ctx());
if (mode == "MEAN") sum_of_sqr = sum_of_sqr / dim;
Ndata_[n] = pow(sum_of_sqr + eps, 0.5);
math::Scale<T, Context>(buffer.count(),
1.0 / Ndata_[n], Xdata, Ydata, &ctx());
} else {
math::Square<T, Context>(buffer.count(), Xdata, Bdata);
// compute T1 = \sum_{i} x_{i,j}^{2} // compute T1 = \sum_{i} x_{i,j}^{2}
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, dim, inner_dim, CblasTrans, dim, inner_dim,
mode == "MEAN" ? 1.0 / dim : 1.0, Bdata, Dmult, mode == "MEAN" ? 1.0 / dim : 1.0, Bdata, Dmult,
1.0, Ndata, &ctx()); 1.0, Ndata, ctx());
// compute T2 = \sqrt{T1} // compute T2 = \sqrt{T1}
math::Sqrt<T, Context>(inner_dim, Ndata, Ndata); math::Sqrt<T, Context>(inner_dim, Ndata, Ndata, ctx());
// compute T3 = x / [(T2)]_{dim} // compute T3 = x / [(T2)]_{dim}
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
dim, inner_dim, 1, dim, inner_dim, 1,
1.0, Dmult, Ndata, 1.0, Dmult, Ndata,
0.0, Bdata, &ctx()); 0.0, Bdata, ctx());
math::Div<T, Context>(buffer.count(), Xdata, Bdata, Ydata); math::Div<T, Context>(buffer.count(),
Xdata, Bdata, Ydata, ctx());
Ndata += inner_dim; Ndata += inner_dim;
}
Xdata += buffer.count(); Xdata += buffer.count();
Ydata += buffer.count(); Ydata += buffer.count();
} }
...@@ -70,8 +63,6 @@ void L2NormOp<Context>::RunOnDevice() { ...@@ -70,8 +63,6 @@ void L2NormOp<Context>::RunOnDevice() {
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
dim = Input(0).count(axis, axis + num_axes); dim = Input(0).count(axis, axis + num_axes);
inner_dim = Input(0).count(axis + num_axes); inner_dim = Input(0).count(axis + num_axes);
if (inner_dim == 1) across_inner = true;
else across_inner = false;
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
...@@ -96,8 +87,8 @@ void L2NormGradientOp<Context>::RunWithType() { ...@@ -96,8 +87,8 @@ void L2NormGradientOp<Context>::RunWithType() {
for (int i = 0; i < axis; i++) dims[i] = 1; for (int i = 0; i < axis; i++) dims[i] = 1;
buffer.Reshape(dims); buffer.Reshape(dims);
buffer_inner.Reshape({ inner_dim }); buffer_inner.Reshape({ inner_dim });
vector<T*> BSdata = ws()->template caches<T, Context>({ vector<T*> BSdata = ws()->template caches<T, Context>(
buffer.count(), buffer_inner.count() }); { buffer.count(), buffer_inner.count() });
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
...@@ -106,48 +97,42 @@ void L2NormGradientOp<Context>::RunWithType() { ...@@ -106,48 +97,42 @@ void L2NormGradientOp<Context>::RunWithType() {
auto* Bdata = BSdata[0], *BInnerdata = BSdata[1]; auto* Bdata = BSdata[0], *BInnerdata = BSdata[1];
for (int n = 0; n < outer_dim; n++) { for (int n = 0; n < outer_dim; n++) {
if (across_inner) {
Ndata = norm->template data<T, CPUContext>();
T sum_of_x_mul_dy = math::Dot<T, Context>(
buffer.count(), Xdata, dYdata, &ctx());
if (mode == "MEAN") sum_of_x_mul_dy = sum_of_x_mul_dy / dim;
math::Scale<T, Context>(buffer.count(),
sum_of_x_mul_dy / Ndata[n] / Ndata[n], Xdata, dXdata, &ctx());
math::Sub<T, Context>(buffer.count(), dYdata, dXdata, dXdata);
math::Scal<T, Context>(buffer.count(),
T(1.0 / Ndata[n]), dXdata, &ctx());
} else {
// compute \sum_{i} x_{i, j}dy_{i, j} // compute \sum_{i} x_{i, j}dy_{i, j}
math::Mul<T, Context>(buffer.count(), Xdata, dYdata, Bdata); math::Mul<T, Context>(buffer.count(),
Xdata, dYdata, Bdata, ctx());
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, dim, inner_dim, CblasTrans, dim, inner_dim,
mode == "MEAN" ? 1.0 / dim : 1.0, Bdata, Dmult, mode == "MEAN" ? 1.0 / dim : 1.0, Bdata, Dmult,
0.0, BInnerdata, &ctx()); 0.0, BInnerdata, ctx());
// compute T1 = x[(\sum_{i} x_{i, j}dy_{i, j})]_{dim} // compute T1 = x[(\sum_{i} x_{i, j}dy_{i, j})]_{dim}
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
dim, inner_dim, 1, dim, inner_dim, 1,
1.0, Dmult, BInnerdata, 1.0, Dmult, BInnerdata,
0.0, Bdata, &ctx()); 0.0, Bdata, ctx());
math::Mul<T, Context>(buffer.count(), Xdata, Bdata, dXdata); math::Mul<T, Context>(buffer.count(),
Xdata, Bdata, dXdata, ctx());
// compute T2 = T1 / Normalizer^{2} // compute T2 = T1 / Normalizer^{2}
math::Pow<T, Context>(inner_dim, 2.0, Ndata, BInnerdata); math::Pow<T, Context>(inner_dim,
2.0, Ndata, BInnerdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
dim, inner_dim, 1, dim, inner_dim, 1,
1.0, Dmult, BInnerdata, 1.0, Dmult, BInnerdata,
0.0, Bdata, &ctx()); 0.0, Bdata, ctx());
math::Div<T, Context>(buffer.count(), dXdata, Bdata, dXdata); math::Div<T, Context>(buffer.count(),
dXdata, Bdata, dXdata, ctx());
// compute T3 = (dy - T2) / Normalizer // compute T3 = (dy - T2) / Normalizer
math::Sub<T, Context>(buffer.count(), dYdata, dXdata, dXdata); math::Sub<T, Context>(buffer.count(),
dYdata, dXdata, dXdata, ctx());
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
dim, inner_dim, 1, dim, inner_dim, 1,
1.0, Dmult, Ndata, 1.0, Dmult, Ndata,
0.0, Bdata, &ctx()); 0.0, Bdata, ctx());
math::Div<T, Context>(buffer.count(), dXdata, Bdata, dXdata); math::Div<T, Context>(buffer.count(),
dXdata, Bdata, dXdata, ctx());
Ndata += inner_dim; Ndata += inner_dim;
}
Xdata += buffer.count(); Xdata += buffer.count();
dYdata += buffer.count(); dYdata += buffer.count();
dXdata += buffer.count(); dXdata += buffer.count();
...@@ -166,8 +151,6 @@ void L2NormGradientOp<Context>::RunOnDevice() { ...@@ -166,8 +151,6 @@ void L2NormGradientOp<Context>::RunOnDevice() {
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
dim = Input(0).count(axis, axis + num_axes); dim = Input(0).count(axis, axis + num_axes);
inner_dim = Input(0).count(axis + num_axes); inner_dim = Input(0).count(axis + num_axes);
if (inner_dim == 1) across_inner = true;
else across_inner = false;
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
......
...@@ -23,20 +23,20 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -23,20 +23,20 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
if (!states_initialized) { if (!states_initialized) {
states_initialized = true; states_initialized = true;
CUDNN_CHECK(cudnnDropoutGetStatesSize( CUDNN_CHECK(cudnnDropoutGetStatesSize(
ctx().cudnn_handle(), &states_size)); ctx()->cudnn_handle(), &states_size));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
Tensor* states = ws()->CreateTensor("/share/cudnn/dropout:" + Tensor* states = ws()->CreateTensor("/share/cudnn/dropout:" +
dragon_cast<string, unsigned long long>(random_seed) + "/states"); dragon_cast<string, unsigned long long>(random_seed) + "/states");
if (states->count() > 0) { if (states->count() > 0) {
auto* Sdata = states->template mutable_data<uint8_t, Context>(); auto* Sdata = states->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnRestoreDropoutDescriptor( CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc, ctx().cudnn_handle(), dropout_ratio, dropout_desc, ctx()->cudnn_handle(), dropout_ratio,
Sdata, states_size, random_seed)); Sdata, states_size, random_seed));
} else { } else {
states->Reshape({ (TIndex)states_size }); states->Reshape({ (TIndex)states_size });
auto* Sdata = states->template mutable_data<uint8_t, Context>(); auto* Sdata = states->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnSetDropoutDescriptor( CUDNN_CHECK(cudnnSetDropoutDescriptor(
dropout_desc, ctx().cudnn_handle(), dropout_ratio, dropout_desc, ctx()->cudnn_handle(), dropout_ratio,
Sdata, states_size, random_seed)); Sdata, states_size, random_seed));
} }
} }
...@@ -48,7 +48,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -48,7 +48,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
// setup rnn // setup rnn
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetRNNDescriptor( CUDNN_CHECK(cudnnSetRNNDescriptor(
ctx().cudnn_handle(), rnn_desc, ctx()->cudnn_handle(), rnn_desc,
hidden_size, num_layers, hidden_size, num_layers,
dropout_desc, dropout_desc,
rnn_input_mode, rnn_direction, rnn_mode, rnn_input_mode, rnn_direction, rnn_mode,
...@@ -68,7 +68,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -68,7 +68,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
xs_desc->Set<T>({ batch_size, input_dim, 1 }, { input_dim, 1, 1 }); xs_desc->Set<T>({ batch_size, input_dim, 1 }, { input_dim, 1, 1 });
ys_desc.reset(new cudnnTensorDescriptors(seq_length)); ys_desc.reset(new cudnnTensorDescriptors(seq_length));
ys_desc->Set<T>({ batch_size, output_dim, 1 }, { output_dim, 1, 1 }); ys_desc->Set<T>({ batch_size, output_dim, 1 }, { output_dim, 1, 1 });
CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx().cudnn_handle(), CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx()->cudnn_handle(),
rnn_desc, seq_length, xs_desc->descs(), &workspace_size)); rnn_desc, seq_length, xs_desc->descs(), &workspace_size));
output_dims = { seq_length, batch_size, output_dim }; output_dims = { seq_length, batch_size, output_dim };
...@@ -82,7 +82,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -82,7 +82,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
// setup packed weights // setup packed weights
size_t weights_size; TIndex weights_count; size_t weights_size; TIndex weights_count;
CUDNN_CHECK(cudnnGetRNNParamsSize( CUDNN_CHECK(cudnnGetRNNParamsSize(
ctx().cudnn_handle(), rnn_desc, xs_desc->descs()[0], ctx()->cudnn_handle(), rnn_desc, xs_desc->descs()[0],
&weights_size, CUDNNType<T>::type)); &weights_size, CUDNNType<T>::type));
weights_count = (TIndex)weights_size / sizeof(T); weights_count = (TIndex)weights_size / sizeof(T);
CHECK_EQ(weights_count, Input(1).count()) CHECK_EQ(weights_count, Input(1).count())
...@@ -96,7 +96,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -96,7 +96,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
// setup rnn workspace // setup rnn workspace
CUDNN_CHECK(cudnnGetRNNWorkspaceSize( CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
ctx().cudnn_handle(), rnn_desc, seq_length, ctx()->cudnn_handle(), rnn_desc, seq_length,
xs_desc->descs(), &workspace_size)); xs_desc->descs(), &workspace_size));
} }
...@@ -122,7 +122,7 @@ void CuDNNRecurrentOp<Context>::RunWithType() { ...@@ -122,7 +122,7 @@ void CuDNNRecurrentOp<Context>::RunWithType() {
auto* WSdata = ws()->template caches<Context>({ workspace_size })[0]; auto* WSdata = ws()->template caches<Context>({ workspace_size })[0];
auto handle = ctx().cudnn_handle(); auto handle = ctx()->cudnn_handle();
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(handle, CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(handle,
...@@ -157,8 +157,12 @@ void CuDNNRecurrentOp<Context>::RunWithType() { ...@@ -157,8 +157,12 @@ void CuDNNRecurrentOp<Context>::RunWithType() {
template <class Context> template <class Context>
void CuDNNRecurrentOp<Context>::RunOnDevice() { void CuDNNRecurrentOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
#ifdef WITH_CUDA_FP16
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else if (XIsType(Input(0), float16)) RunWithType<float16>();
#endif
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
...@@ -182,7 +186,7 @@ void CuDNNRecurrentGradientOp<Context>::RunWithType() { ...@@ -182,7 +186,7 @@ void CuDNNRecurrentGradientOp<Context>::RunWithType() {
auto* WSdata = ws()->template caches<Context>({ workspace_size })[0]; auto* WSdata = ws()->template caches<Context>({ workspace_size })[0];
// check the reserve space // check the reserve space
CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx().cudnn_handle(), CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx()->cudnn_handle(),
rnn_desc, seq_length, xs_desc->descs(), &reserve_size)); rnn_desc, seq_length, xs_desc->descs(), &reserve_size));
auto* reserveT = ws()->GetTensor("/mnt/" + anchor() + "/rnn/reserve"); auto* reserveT = ws()->GetTensor("/mnt/" + anchor() + "/rnn/reserve");
CHECK_EQ(reserve_size, reserveT->nbytes()); CHECK_EQ(reserve_size, reserveT->nbytes());
...@@ -192,7 +196,7 @@ void CuDNNRecurrentGradientOp<Context>::RunWithType() { ...@@ -192,7 +196,7 @@ void CuDNNRecurrentGradientOp<Context>::RunWithType() {
auto* RSdata = reserveT->template data<uint8_t, Context>(); auto* RSdata = reserveT->template data<uint8_t, Context>();
#endif #endif
auto handle = ctx().cudnn_handle(); auto handle = ctx()->cudnn_handle();
if (Output(0)->name() != "ignore" || if (Output(0)->name() != "ignore" ||
Output(1)->name() != "ignore" || Output(1)->name() != "ignore" ||
...@@ -228,13 +232,17 @@ void CuDNNRecurrentGradientOp<Context>::RunWithType() { ...@@ -228,13 +232,17 @@ void CuDNNRecurrentGradientOp<Context>::RunWithType() {
template <class Context> template <class Context>
void CuDNNRecurrentGradientOp<Context>::RunOnDevice() { void CuDNNRecurrentGradientOp<Context>::RunOnDevice() {
ctx()->set_stream_id(0); // enforce default stream
Output(0)->ReshapeLike(Input(0)); // dX Output(0)->ReshapeLike(Input(0)); // dX
Output(1)->ReshapeLike(Input(1)); // dW Output(1)->ReshapeLike(Input(1)); // dW
Output(2)->ReshapeLike(Input(2)); // dHx Output(2)->ReshapeLike(Input(2)); // dHx
Output(3)->ReshapeLike(Input(3)); // dCx Output(3)->ReshapeLike(Input(3)); // dCx
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
#ifdef WITH_CUDA_FP16
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else if (XIsType(Input(0), float16)) RunWithType<float16>();
#endif
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
......
...@@ -14,7 +14,7 @@ void LSTMCellOp<Context>::RunWithType() { ...@@ -14,7 +14,7 @@ void LSTMCellOp<Context>::RunWithType() {
kernel::LSTMCell<T, Context>(Input(1).count(), Input(1).dim(0), kernel::LSTMCell<T, Context>(Input(1).count(), Input(1).dim(0),
Input(1).ndim() == 2 ? Input(1).dim(1) : Input(1).dim(2), Input(1).ndim() == 2 ? Input(1).dim(1) : Input(1).dim(2),
CXdata, XAdata, Cdata, Hdata); CXdata, XAdata, Cdata, Hdata, ctx());
} }
template <class Context> template <class Context>
...@@ -44,7 +44,7 @@ void LSTMCellGradientOp<Context>::RunWithType() { ...@@ -44,7 +44,7 @@ void LSTMCellGradientOp<Context>::RunWithType() {
kernel::LSTMCellGrad<T, Context>(Input(1).count(), Input(1).dim(0), kernel::LSTMCellGrad<T, Context>(Input(1).count(), Input(1).dim(0),
Input(1).ndim() == 2 ? Input(1).dim(1) : Input(1).dim(2), Input(1).ndim() == 2 ? Input(1).dim(1) : Input(1).dim(2),
CXdata, XAdata, Cdata, dCdata, dHdata, dCXdata, dXdata); CXdata, XAdata, Cdata, dCdata, dHdata, dCXdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -30,7 +30,7 @@ void RNNParamSetOp<Context>::RunWithType() { ...@@ -30,7 +30,7 @@ void RNNParamSetOp<Context>::RunWithType() {
<< "\nExcepted the size of param is " << size << "\nExcepted the size of param is " << size
<< ", but got " << Input(0).count(); << ", but got " << Input(0).count();
offset += param_type == "bias" ? matrix_count : 0; offset += param_type == "bias" ? matrix_count : 0;
ctx().template Copy<T, Context, Context>(size, Wdata + offset, Pdata); ctx()->template Copy<T, Context, Context>(size, Wdata + offset, Pdata);
} }
template <class Context> template <class Context>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void AdamUpdateOp<Context>::ComputeRunWithFloat() { void AdamUpdateOp<Context>::ComputeRunWithFloat32() {
Tensor* m = ws()->CreateTensor("/mnt/" + Slot() + "/adam/m"); Tensor* m = ws()->CreateTensor("/mnt/" + Slot() + "/adam/m");
Tensor* v = ws()->CreateTensor("/mnt/" + Slot() + "/adam/v"); Tensor* v = ws()->CreateTensor("/mnt/" + Slot() + "/adam/v");
m->ReshapeLike(Input(0)); m->ReshapeLike(Input(0));
...@@ -16,12 +16,11 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() { ...@@ -16,12 +16,11 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat() {
float coeff = sqrt(1. - pow(beta2, t)) / (1. - pow(beta1, t)); float coeff = sqrt(1. - pow(beta2, t)) / (1. - pow(beta1, t));
lr = Param("base_lr") * coeff * this->lr_mult; lr = Param("base_lr") * coeff * this->lr_mult;
auto* dXdata = Input(0).template mutable_data<float, Context>(); auto* dXdata = Input(0).template mutable_data<float, Context>();
auto* Mdata = m->mutable_data<float, Context>(); auto* Mdata = m->mutable_data<float, Context>(ctx());
auto* Vdata = v->mutable_data<float, Context>(); auto* Vdata = v->mutable_data<float, Context>(ctx());
kernel::AdamUpdate<float, Context>( kernel::AdamUpdate<float, Context>(Input(0).count(),
Input(0).count(), lr, beta1, beta2, eps, lr, beta1, beta2, eps, dXdata, Mdata, Vdata, ctx());
dXdata, Mdata, Vdata);
} }
template <class Context> template <class Context>
...@@ -35,13 +34,19 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat16() { ...@@ -35,13 +34,19 @@ void AdamUpdateOp<Context>::ComputeRunWithFloat16() {
beta1 = Param("beta1"), beta2 = Param("beta2"), eps = Param("eps"); beta1 = Param("beta1"), beta2 = Param("beta2"), eps = Param("eps");
float coeff = sqrt(1. - pow(beta2, t)) / (1. - pow(beta1, t)); float coeff = sqrt(1. - pow(beta2, t)) / (1. - pow(beta1, t));
lr = Param("base_lr") * coeff * this->lr_mult; lr = Param("base_lr") * coeff * this->lr_mult;
auto* dXdata = Input(0).template mutable_data<float16, Context>();
auto* Mdata = m->mutable_data<float16, Context>();
auto* Vdata = v->mutable_data<float16, Context>();
kernel::AdamUpdate<float16, Context>( auto* dX32T = ws()->CreateTensor(Input(0).name() + "/f32");
Input(0).count(), lr, beta1, beta2, eps, dX32T->ReshapeLike(Input(0));
dXdata, Mdata, Vdata);
auto* dX32 = dX32T->template mutable_data<float, Context>();
auto* dX16 = Input(0).template mutable_data<float16, Context>();
auto* M32 = m->mutable_data<float, Context>(ctx());
auto* V32 = v->mutable_data<float, Context>(ctx());
kernel::TypeA2B<float16, float, Context>(
Input(0).count(), dX16, dX32, ctx());
kernel::AdamUpdate<float, Context>(Input(0).count(),
lr, beta1, beta2, eps, dX32, M32, V32, ctx());
} }
DEPLOY_CPU(AdamUpdate); DEPLOY_CPU(AdamUpdate);
......
...@@ -32,16 +32,17 @@ void CollectiveUpdateOp<Context>::InitNCCL() { ...@@ -32,16 +32,17 @@ void CollectiveUpdateOp<Context>::InitNCCL() {
if (comm_rank == comm_root) NCCL_CHECK(ncclGetUniqueId(&id)); if (comm_rank == comm_root) NCCL_CHECK(ncclGetUniqueId(&id));
MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, comm_root, comm); MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, comm_root, comm);
NCCL_CHECK(ncclCommInitRank(&nccl_comm, comm_size, id, comm_rank)); NCCL_CHECK(ncclCommInitRank(&nccl_comm, comm_size, id, comm_rank));
closure = CUDAClosure<Context>(&ctx()); closure = CUDAClosure<Context>(ctx());
#else #else
LOG(FATAL) << "NCCL was not compiled."; LOG(FATAL) << "NCCL was not compiled.";
#endif #endif
} }
template <class Context> template <class Context> template <typename T>
void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() { void CollectiveUpdateOp<Context>::MPIAllReduce(
for (int j = 0; j < InputSize(); j++) { Tensor* tensor,
TIndex count = Input(j).count(); MPI_Datatype dtype) {
TIndex count = tensor->count();
MPI_Request recv_req; MPI_Request recv_req;
TIndex segment_size = count / comm_size; TIndex segment_size = count / comm_size;
TIndex residual = count % comm_size; TIndex residual = count % comm_size;
...@@ -52,11 +53,11 @@ void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() { ...@@ -52,11 +53,11 @@ void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() {
for (int i = 1; i < segment_ends.size(); i++) for (int i = 1; i < segment_ends.size(); i++)
segment_ends[i] = segment_sizes[i] + segment_ends[i - 1]; segment_ends[i] = segment_sizes[i] + segment_ends[i - 1];
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
auto* WSdata = ws()->template caches<float, Context>({ segment_sizes[0] })[0]; auto* WSdata = ws()->template caches<T, Context>({ segment_sizes[0] })[0];
auto* dXdata = Input(j).template mutable_data<float, Context>(); auto* dXdata = tensor->template mutable_data<T, Context>();
#else #else
auto* WSdata = ws()->template caches<float, CPUContext>({ segment_sizes[0] })[0]; auto* WSdata = ws()->template caches<T, CPUContext>({ segment_sizes[0] })[0];
auto* dXdata = Input(j).template mutable_data<float, CPUContext>(); auto* dXdata = tensor->template mutable_data<T, CPUContext>();
#endif // WITH_MPI_CUDA #endif // WITH_MPI_CUDA
int recv_from = (comm_rank - 1 + comm_size) % comm_size; int recv_from = (comm_rank - 1 + comm_size) % comm_size;
int send_to = (comm_rank + 1) % comm_size; int send_to = (comm_rank + 1) % comm_size;
...@@ -66,23 +67,21 @@ void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() { ...@@ -66,23 +67,21 @@ void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() {
int recv_chunk = (comm_rank - i - 1 + comm_size) % comm_size; int recv_chunk = (comm_rank - i - 1 + comm_size) % comm_size;
int send_chunk = (comm_rank - i + comm_size) % comm_size; int send_chunk = (comm_rank - i + comm_size) % comm_size;
auto* segment_send = &(dXdata[ auto* segment_send = &(dXdata[
segment_ends[send_chunk] - segment_sizes[send_chunk] segment_ends[send_chunk] - segment_sizes[send_chunk]]);
]);
MPI_Irecv(WSdata, segment_sizes[recv_chunk], MPI_Irecv(WSdata, segment_sizes[recv_chunk],
MPI_FLOAT, recv_from, 0, comm, &recv_req); dtype, recv_from, 0, comm, &recv_req);
MPI_Send(segment_send, segment_sizes[send_chunk], MPI_Send(segment_send, segment_sizes[send_chunk],
MPI_FLOAT, send_to, 0, comm); dtype, send_to, 0, comm);
auto* segment_update = &(dXdata[ auto* segment_update = &(dXdata[
segment_ends[recv_chunk] - segment_sizes[recv_chunk] segment_ends[recv_chunk] - segment_sizes[recv_chunk]]);
]);
MPI_Wait(&recv_req, MPI_STATUS_IGNORE); MPI_Wait(&recv_req, MPI_STATUS_IGNORE);
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
math::Axpy<float, Context>(segment_sizes[recv_chunk], math::Axpy<T, Context>(segment_sizes[recv_chunk],
1.0, WSdata, segment_update, &ctx()); 1.0, WSdata, segment_update, ctx());
ctx().FinishDeviceCompution(); ctx()->FinishDeviceCompution();
#else #else
math::Axpy<float, CPUContext>(segment_sizes[recv_chunk], math::Axpy<T, CPUContext>(segment_sizes[recv_chunk],
1.0, WSdata, segment_update, &ctx()); 1.0, WSdata, segment_update, ctx());
#endif // WITH_MPI_CUDA #endif // WITH_MPI_CUDA
} }
...@@ -91,90 +90,117 @@ void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() { ...@@ -91,90 +90,117 @@ void CollectiveUpdateOp<Context>::MPIAllReduceWithFloat() {
int send_chunk = (comm_rank - i + 1 + comm_size) % comm_size; int send_chunk = (comm_rank - i + 1 + comm_size) % comm_size;
int recv_chunk = (comm_rank - i + comm_size) % comm_size; int recv_chunk = (comm_rank - i + comm_size) % comm_size;
auto* segment_send = &(dXdata[ auto* segment_send = &(dXdata[
segment_ends[send_chunk] - segment_sizes[send_chunk] segment_ends[send_chunk] - segment_sizes[send_chunk]]);
]);
auto* segment_recv = &(dXdata[ auto* segment_recv = &(dXdata[
segment_ends[recv_chunk] - segment_sizes[recv_chunk] segment_ends[recv_chunk] - segment_sizes[recv_chunk]]);
]);
MPI_Sendrecv(segment_send, segment_sizes[send_chunk], MPI_Sendrecv(segment_send, segment_sizes[send_chunk],
MPI_FLOAT, send_to, 0, dtype, send_to, 0, segment_recv, segment_sizes[recv_chunk],
segment_recv, segment_sizes[recv_chunk], dtype, recv_from, 0, comm, MPI_STATUS_IGNORE);
MPI_FLOAT, recv_from, 0,
comm, MPI_STATUS_IGNORE);
} }
// normalization // normalization
if (comm_size > 1) { if (comm_size > 1) {
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
math::Scal<float, Context>(count, math::Scal<T, Context>(count, 1.f / comm_size, dXdata, ctx());
1.f / comm_size, dXdata, &ctx());
#else #else
math::Scal<float, CPUContext>(count, math::Scal<T, CPUContext>(count, 1.f / comm_size, dXdata, ctx());
1.f / comm_size, dXdata, &ctx());
#endif // WITH_MPI_CUDA #endif // WITH_MPI_CUDA
} }
}
} }
template <class Context> template <class Context> template <typename T>
void CollectiveUpdateOp<Context>::NCCLAllReduceWithFloat() { void CollectiveUpdateOp<Context>::MPIBcast(
Tensor* tensor,
MPI_Datatype dtype) {
TIndex count = tensor->count();
#ifdef WITH_MPI_CUDA
auto* dXdata = tensor->template mutable_data<float, Context>();
#else
auto* dXdata = tensor->template mutable_data<float, CPUContext>();
#endif
MPI_Bcast(dXdata, count, dtype, comm_root, comm);
}
#ifdef WITH_MPI_NCCL #ifdef WITH_MPI_NCCL
auto stream = closure.cuda_stream(0);
for (int i = 0; i < InputSize(); i++) { template <class Context> template <typename T>
TIndex count = Input(i).count(); void CollectiveUpdateOp<Context>::NCCLAllReduce(
auto* dXdata = Input(i).template mutable_data<float, Context>(); Tensor* tensor,
ncclDataType_t dtype,
cudaStream_t& stream) {
TIndex count = tensor->count();
auto* dXdata = tensor->template mutable_data<T, Context>();
NCCL_CHECK(ncclAllReduce((const void*)dXdata, (void*)dXdata, NCCL_CHECK(ncclAllReduce((const void*)dXdata, (void*)dXdata,
count, ncclFloat, ncclSum, nccl_comm, stream)); count, dtype, ncclSum, nccl_comm, stream));
}
closure.Sync();
for (int i = 0; i < InputSize(); i++) {
TIndex count = Input(i).count();
auto* dXdata = Input(i).template mutable_data<float, Context>();
math::Scal<float, Context>(count, 1.f / comm_size, dXdata, &ctx());
}
#endif
} }
template <class Context> template <class Context> template <typename T>
void CollectiveUpdateOp<Context>::MPIBcastWithFloat() { void CollectiveUpdateOp<Context>::NCCLBcast(
for (int i = 0; i < InputSize(); i++) { Tensor* tensor,
TIndex count = Input(i).count(); ncclDataType_t dtype,
#ifdef WITH_MPI_CUDA cudaStream_t& stream) {
auto* dXdata = Input(i).template mutable_data<float, Context>(); TIndex count = tensor->count();
#else auto* dXdata = tensor->template mutable_data<T, Context>();
auto* dXdata = Input(i).template mutable_data<float, CPUContext>(); NCCL_CHECK(ncclBcast((void*)dXdata,
#endif count, dtype, comm_root, nccl_comm, stream));
MPI_Bcast(dXdata, count, MPI_FLOAT, comm_root, comm);
}
} }
#endif
template <class Context> template <class Context>
void CollectiveUpdateOp<Context>::NCCLBcastWithFloat() { void CollectiveUpdateOp<Context>::RunOnDevice() {
if (mode == "MPI_ALLREDUCE") {
for (int i = 0; i < InputSize(); i++) {
if (XIsType(Input(i), float))
MPIAllReduce<float>(&Input(i), MPI_FLOAT);
else if (XIsType(Input(i), float16))
MPIAllReduce<float16>(&Input(i), MPI_UNSIGNED_SHORT);
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
} else if (mode == "MPI_BCAST") {
for (int i = 0; i < InputSize(); i++) {
if (XIsType(Input(i), float))
MPIBcast<float>(&Input(i), MPI_FLOAT);
else if (XIsType(Input(i), float16))
MPIBcast<float16>(&Input(i), MPI_UNSIGNED_SHORT);
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
}
#ifdef WITH_MPI_NCCL #ifdef WITH_MPI_NCCL
auto stream = closure.cuda_stream(0); else if (mode == "NCCL_ALLREDUCE") {
auto stream = closure.cuda_stream(1);
for (int i = 0; i < InputSize(); i++) {
if (XIsType(Input(i), float))
NCCLAllReduce<float>(&Input(i), ncclFloat, stream);
else if (XIsType(Input(i), float16))
NCCLAllReduce<float16>(&Input(i), ncclHalf, stream);
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
closure.Sync();
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
TIndex count = Input(i).count(); TIndex count = Input(i).count();
if (XIsType(Input(i), float)) {
auto* dXdata = Input(i).template mutable_data<float, Context>(); auto* dXdata = Input(i).template mutable_data<float, Context>();
NCCL_CHECK(ncclBcast((void*)dXdata, count, math::Scal<float, Context>(count, 1.f / comm_size, dXdata, ctx());
ncclFloat, comm_root, nccl_comm, stream)); }
else if (XIsType(Input(i), float16)) {
auto* dXdata = Input(i).template mutable_data<float16, Context>();
math::Scal<float16, Context>(count, 1.f / comm_size, dXdata, ctx());
}
}
} else if (mode == "NCCL_BCAST") {
auto stream = closure.cuda_stream(1);
for (int i = 0; i < InputSize(); i++) {
if (XIsType(Input(i), float))
NCCLBcast<float>(&Input(i), ncclFloat, stream);
else if (XIsType(Input(i), float16))
NCCLBcast<float16>(&Input(i), ncclHalf, stream);
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
closure.Sync(); closure.Sync();
}
#endif #endif
} else LOG(FATAL) << "Unsupported collective mode: " << mode;
template <class Context>
void CollectiveUpdateOp<Context>::RunOnDevice() {
if(XIsType(Input(0), float)) {
if (mode == "MPI_ALLREDUCE") {
MPIAllReduceWithFloat();
} else if (mode == "NCCL_ALLREDUCE") {
NCCLAllReduceWithFloat();
} else if (mode == "MPI_BCAST") {
MPIBcastWithFloat();
} else if (mode == "NCCL_BCAST") {
NCCLBcastWithFloat();
} else LOG(FATAL) << "Unsupported collective mode: " << mode;
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
} }
DEPLOY_CPU(CollectiveUpdate); DEPLOY_CPU(CollectiveUpdate);
......
...@@ -8,7 +8,7 @@ void MovingAverageOp<Context>::RunWithType() { ...@@ -8,7 +8,7 @@ void MovingAverageOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
math::Axpby<T, Context>(Input(0).count(), math::Axpby<T, Context>(Input(0).count(),
1.f - decay, Xdata, decay, Ydata, &ctx()); 1.f - decay, Xdata, decay, Ydata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -6,16 +6,16 @@ ...@@ -6,16 +6,16 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void NesterovUpdateOp<Context>::ComputeRunWithFloat() { void NesterovUpdateOp<Context>::ComputeRunWithFloat32() {
Tensor* h = ws()->CreateTensor("/mnt/" + Slot() + "/nesterov/h"); Tensor* h = ws()->CreateTensor("/mnt/" + Slot() + "/nesterov/h");
h->ReshapeLike(Input(0)); h->ReshapeLike(Input(0));
lr = Param("base_lr") * this->lr_mult, momentum = Param("momentum"); lr = Param("base_lr") * this->lr_mult, momentum = Param("momentum");
auto* dXdata = Input(0).template mutable_data<float, Context>(); auto* dXdata = Input(0).template mutable_data<float, Context>();
auto* Hdata = h->template mutable_data<float, Context>(); auto* Hdata = h->template mutable_data<float, Context>(ctx());
kernel::NesterovUpdate<float, Context>( kernel::NesterovUpdate<float, Context>(
Input(0).count(), lr, momentum, dXdata, Hdata); Input(0).count(), lr, momentum, dXdata, Hdata, ctx());
} }
template <class Context> template <class Context>
...@@ -24,11 +24,18 @@ void NesterovUpdateOp<Context>::ComputeRunWithFloat16() { ...@@ -24,11 +24,18 @@ void NesterovUpdateOp<Context>::ComputeRunWithFloat16() {
h->ReshapeLike(Input(0)); h->ReshapeLike(Input(0));
lr = Param("base_lr") * this->lr_mult, momentum = Param("momentum"); lr = Param("base_lr") * this->lr_mult, momentum = Param("momentum");
auto* dXdata = Input(0).template mutable_data<float16, Context>();
auto* Hdata = h->template mutable_data<float16, Context>();
kernel::NesterovUpdate<float16, Context>( auto* dX32T = ws()->CreateTensor(Input(0).name() + "/f32");
Input(0).count(), lr, momentum, dXdata, Hdata); dX32T->ReshapeLike(Input(0));
auto* dX32 = dX32T->template mutable_data<float, Context>();
auto* dX16 = Input(0).template mutable_data<float16, Context>();
auto* H32 = h->template mutable_data<float, Context>(ctx());
kernel::TypeA2B<float16, float, Context>(
Input(0).count(), dX16, dX32, ctx());
kernel::NesterovUpdate<float, Context>(
Input(0).count(), lr, momentum, dX32, H32, ctx());
} }
DEPLOY_CPU(NesterovUpdate); DEPLOY_CPU(NesterovUpdate);
......
...@@ -5,17 +5,17 @@ ...@@ -5,17 +5,17 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void RMSPropUpdateOp<Context>::ComputeRunWithFloat() { void RMSPropUpdateOp<Context>::ComputeRunWithFloat32() {
Tensor* h = ws()->CreateTensor("/mnt/" + Slot() + "/rmsprop/h"); Tensor* h = ws()->CreateTensor("/mnt/" + Slot() + "/rmsprop/h");
h->ReshapeLike(Input(0)); h->ReshapeLike(Input(0));
lr = Param("base_lr") * this->lr_mult; lr = Param("base_lr") * this->lr_mult;
decay = Param("decay"), eps = Param("eps"); decay = Param("decay"), eps = Param("eps");
auto* dXdata = Input(0).template mutable_data<float, Context>(); auto* dXdata = Input(0).template mutable_data<float, Context>();
auto* Hdata = h->template mutable_data<float, Context>(); auto* Hdata = h->template mutable_data<float, Context>(ctx());
kernel::RMSPropUpdate<float, Context>( kernel::RMSPropUpdate<float, Context>(
Input(0).count(), lr, decay, eps, dXdata, Hdata); Input(0).count(), lr, decay, eps, dXdata, Hdata, ctx());
} }
template <class Context> template <class Context>
...@@ -25,11 +25,18 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat16() { ...@@ -25,11 +25,18 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat16() {
lr = Param("base_lr") * this->lr_mult; lr = Param("base_lr") * this->lr_mult;
decay = Param("decay"), eps = Param("eps"); decay = Param("decay"), eps = Param("eps");
auto* dXdata = Input(0).template mutable_data<float16, Context>();
auto* Hdata = h->template mutable_data<float16, Context>();
kernel::RMSPropUpdate<float16, Context>( auto* dX32T = ws()->CreateTensor(Input(0).name() + "/f32");
Input(0).count(), lr, decay, eps, dXdata, Hdata); dX32T->ReshapeLike(Input(0));
auto* dX32 = dX32T->template mutable_data<float, Context>();
auto* dX16 = Input(0).template mutable_data<float16, Context>();
auto* H32 = h->template mutable_data<float, Context>(ctx());
kernel::TypeA2B<float16, float, Context>(
Input(0).count(), dX16, dX32, ctx());
kernel::RMSPropUpdate<float, Context>(
Input(0).count(), lr, decay, eps, dX32, H32, ctx());
} }
DEPLOY_CPU(RMSPropUpdate); DEPLOY_CPU(RMSPropUpdate);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void SGDUpdateOp<Context>::ComputeRunWithFloat() { void SGDUpdateOp<Context>::ComputeRunWithFloat32() {
Tensor* h = ws()->CreateTensor("/mnt/" + Slot() + "/sgd/h"); Tensor* h = ws()->CreateTensor("/mnt/" + Slot() + "/sgd/h");
h->ReshapeLike(Input(0)); h->ReshapeLike(Input(0));
...@@ -14,10 +14,10 @@ void SGDUpdateOp<Context>::ComputeRunWithFloat() { ...@@ -14,10 +14,10 @@ void SGDUpdateOp<Context>::ComputeRunWithFloat() {
// momentum correction, see arXiv:1706.02677 // momentum correction, see arXiv:1706.02677
if (old_lr > 0) { correction = lr / old_lr; } old_lr = lr; if (old_lr > 0) { correction = lr / old_lr; } old_lr = lr;
auto* dXdata = Input(0).template mutable_data<float, Context>(); auto* dXdata = Input(0).template mutable_data<float, Context>();
auto* Hdata = h->template mutable_data<float, Context>(); auto* Hdata = h->template mutable_data<float, Context>(ctx());
kernel::SGDUpdate<float, Context>(Input(0).count(), kernel::SGDUpdate<float, Context>(Input(0).count(),
lr, momentum * correction, dXdata, Hdata); lr, momentum * correction, dXdata, Hdata, ctx());
} }
template <class Context> template <class Context>
...@@ -27,11 +27,18 @@ void SGDUpdateOp<Context>::ComputeRunWithFloat16() { ...@@ -27,11 +27,18 @@ void SGDUpdateOp<Context>::ComputeRunWithFloat16() {
lr = Param("base_lr") * this->lr_mult, momentum = Param("momentum"); lr = Param("base_lr") * this->lr_mult, momentum = Param("momentum");
if (old_lr > 0) { correction = lr / old_lr; } old_lr = lr; if (old_lr > 0) { correction = lr / old_lr; } old_lr = lr;
auto* dXdata = Input(0).template mutable_data<float16, Context>();
auto* Hdata = h->template mutable_data<float16, Context>();
kernel::SGDUpdate<float16, Context>(Input(0).count(), auto* dX32T = ws()->CreateTensor(Input(0).name() + "/f32");
lr, momentum * correction, dXdata, Hdata); dX32T->ReshapeLike(Input(0));
auto* dX32 = dX32T->template mutable_data<float, Context>();
auto* dX16 = Input(0).template mutable_data<float16, Context>();
auto* H32 = h->template mutable_data<float, Context>(ctx());
kernel::TypeA2B<float16, float, Context>(
Input(0).count(), dX16, dX32, ctx());
kernel::SGDUpdate<float, Context>(Input(0).count(),
lr, momentum * correction, dX32, H32, ctx());
} }
DEPLOY_CPU(SGDUpdate); DEPLOY_CPU(SGDUpdate);
......
#include "core/workspace.h" #include "core/workspace.h"
#include "utils/cast.h" #include "utils/cast.h"
#include "utils/math_functions.h" #include "utils/math_functions.h"
#include "utils/op_kernel.h"
#include "operators/update/update_op_base.h" #include "operators/update/update_op_base.h"
namespace dragon { namespace dragon {
...@@ -20,22 +21,24 @@ template <class Context> template <typename T> ...@@ -20,22 +21,24 @@ template <class Context> template <typename T>
void UpdateOpBase<Context>::PreprocessRunWithType() { void UpdateOpBase<Context>::PreprocessRunWithType() {
// scale // scale
scale_factor = Param("scale_gradient"); scale_factor = Param("scale_gradient");
if (scale_factor != 1) { if (scale_factor != 1.f) {
auto* dXdata = Input(0).template mutable_data<T, Context>(); auto* dXdata = Input(0).template mutable_data<T, Context>();
math::Scal<T, Context>(Input(0).count(), math::Scal<T, Context>(Input(0).count(),
scale_factor, dXdata, &ctx()); scale_factor, dXdata, ctx());
} }
// clip // clip
clip_thresh = Param("clip_gradient"); clip_thresh = Param("clip_gradient");
if (clip_thresh > 0) { if (clip_thresh > 0) {
auto* dXdata = Input(0).template mutable_data<T, Context>(); auto* dXdata = Input(0).template mutable_data<T, Context>();
float sumsq_grad = math::Dot<T, Context>( T sumsq_grad;
Input(0).count(), dXdata, dXdata, &ctx()); math::Dot<T, Context>(Input(0).count(),
const float l2norm = sqrt(sumsq_grad); dXdata, dXdata, &sumsq_grad, ctx());
const float l2norm = sqrt(
dragon_cast<float, T>(sumsq_grad));
if (l2norm > clip_thresh) { if (l2norm > clip_thresh) {
float norm_factor = clip_thresh / l2norm; float norm_factor = clip_thresh / l2norm;
math::Scal<T, Context>(Input(0).count(), math::Scal<T, Context>(Input(0).count(),
norm_factor, dXdata, &ctx()); norm_factor, dXdata, ctx());
} }
} }
// decay // decay
...@@ -44,34 +47,76 @@ void UpdateOpBase<Context>::PreprocessRunWithType() { ...@@ -44,34 +47,76 @@ void UpdateOpBase<Context>::PreprocessRunWithType() {
auto* dXdata = Input(0).template mutable_data<T, Context>(); auto* dXdata = Input(0).template mutable_data<T, Context>();
auto* Xdata = Output(0)->template data<T, Context>(); auto* Xdata = Output(0)->template data<T, Context>();
math::Axpy<T, Context>(Input(0).count(), math::Axpy<T, Context>(Input(0).count(),
l2_decay, Xdata, dXdata, &ctx()); l2_decay, Xdata, dXdata, ctx());
} }
} }
template <class Context> template <typename T> template <class Context>
void UpdateOpBase<Context>::UpdateRunWithType() { void UpdateOpBase<Context>::UpdateRunWithFloat32() {
auto* dXdata = Input(0).template mutable_data<T, Context>(); auto* dXdata = Input(0).template mutable_data<float, Context>();
auto* Xdata = Output(0)->template mutable_data<T, Context>(); auto* Xdata = Output(0)->template mutable_data<float, Context>();
math::Axpy<T, Context>(Output(0)->count(), -1, dXdata, Xdata, &ctx()); // weights update & zero grads
T zeroT = dragon_cast<T, float>(0.f); math::Axpy<float, Context>(Output(0)->count(),
if (zero_grad) math::Set<T, Context>(Input(0).count(), zeroT, dXdata); -1, dXdata, Xdata, ctx());
if (zero_grad) math::Set<float, Context>(
Input(0).count(), 0.f, dXdata, ctx());
}
template <class Context>
void UpdateOpBase<Context>::UpdateRunWithFloat16() {
/* ------------------------------------------------
*
* Mixed Precision Training
*
* http://arxiv.org/abs/1710.03740
*
* ------------------------------------------------ */
// the "master" weights
auto* X32T = ws()->CreateTensor(Output(0)->name() + "/f32");
X32T->ReshapeLike(Input(0));
// the "master" updates
auto* dX32T = ws()->GetTensor(Input(0).name() + "/f32");
auto* dX32 = dX32T->template data<float, Context>();
auto* X16 = Output(0)->template mutable_data<float16, Context>();
auto* X32 = X32T->template mutable_data<float, Context>();
// X16 -> X32
kernel::TypeA2B<float16, float, Context>(
Input(0).count(), X16, X32, ctx());
// weights update & zero grads
math::Axpy<float, Context>(
Input(0).count(), -1, dX32, X32, ctx());
if (zero_grad) {
float16 zero = dragon_cast<float16, float>(0.f);
auto* dX16 = Input(0).template mutable_data<float16, Context>();
math::Set<float16, Context>(Input(0).count(), zero, dX16, ctx());
}
// X32 -> X16
kernel::TypeA2B<float, float16, Context>(
Input(0).count(), X32, X16, ctx());
} }
template <class Context> template <class Context>
void UpdateOpBase<Context>::RunOnDevice() { void UpdateOpBase<Context>::RunOnDevice() {
// skip empty param or grad // skip empty param or grads
if (Input(0).count() == 0 || Output(0)->count() == 0) return; if (Input(0).count() == 0 || Output(0)->count() == 0) return;
CHECK(Input(0).dims() == Output(0)->dims()) CHECK(Input(0).dims() == Output(0)->dims())
<< "\nTensor and its gradients should have same dims.\nGot " << "\nTensor and its gradients should have same dims.\nGot "
<< Output(0)->DimString() << " and " << Input(0).DimString(); << Output(0)->DimString() << " and " << Input(0).DimString();
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
PreprocessRunWithType<float>(); PreprocessRunWithType<float>();
ComputeRunWithFloat(); ComputeRunWithFloat32();
UpdateRunWithType<float>(); UpdateRunWithFloat32();
} else if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
PreprocessRunWithType<float16>(); PreprocessRunWithType<float16>();
ComputeRunWithFloat16(); ComputeRunWithFloat16();
UpdateRunWithType<float16>(); UpdateRunWithFloat16();
} else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); } else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
} }
......
...@@ -15,7 +15,7 @@ void BiasAddOp<Context>::RunWithType() { ...@@ -15,7 +15,7 @@ void BiasAddOp<Context>::RunWithType() {
kernel::BiasAdd<T, Context>( kernel::BiasAdd<T, Context>(
Output(0)->count(), outer_dim, dim, inner_dim, Output(0)->count(), outer_dim, dim, inner_dim,
data_format, Bdata, multiplier, Ydata, &ctx()); data_format, Bdata, multiplier, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -45,19 +45,19 @@ void BiasAddGradientOp<Context>::RunWithType() { ...@@ -45,19 +45,19 @@ void BiasAddGradientOp<Context>::RunWithType() {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
DECLARE_MULTIPLIER(multiplier, inner_dim); DECLARE_MULTIPLIER(multiplier, inner_dim);
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dBias = Output(1)->template mutable_data<T, Context>(); auto* dBias = Output(1)->template mutable_data<T, Context>(ctx());
const int y_offset = dim * inner_dim; const int y_offset = dim * inner_dim;
for (int n = 0; n < outer_dim; n++) { for (int n = 0; n < outer_dim; n++) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, dim, inner_dim, CblasNoTrans, dim, inner_dim,
1.0, dYdata, multiplier, 1.0, dYdata, multiplier,
1.0, dBias, &ctx()); 1.0, dBias, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, inner_dim, dim, CblasTrans, inner_dim, dim,
1.0, dYdata, multiplier, 1.0, dYdata, multiplier,
1.0, dBias, &ctx()); 1.0, dBias, ctx());
} }
dYdata += y_offset; dYdata += y_offset;
} }
......
...@@ -26,7 +26,7 @@ void BilinearResizeOp<Context>::RunWithType() { ...@@ -26,7 +26,7 @@ void BilinearResizeOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::BilinearResize<T, Context>(Output(0)->count(), kernel::BilinearResize<T, Context>(Output(0)->count(),
n, c, h, w, out_h, out_w, data_format, Xdata, Ydata); n, c, h, w, out_h, out_w, data_format, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -77,8 +77,10 @@ void BilinearResizeGradientOp<Context>::RunWithType() { ...@@ -77,8 +77,10 @@ void BilinearResizeGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), 0, dXdata, ctx());
kernel::BilinearResizeGrad<T, Context>(Input(-1).count(), kernel::BilinearResizeGrad<T, Context>(Input(-1).count(),
n, c, h, w, out_h, out_w, data_format, dYdata, dXdata); n, c, h, w, out_h, out_w, data_format, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -41,7 +41,7 @@ void Conv2dGradientOp<Context>::RunWithType() { ...@@ -41,7 +41,7 @@ void Conv2dGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
if (HasBias()) { if (HasBias()) {
T* dBdata = Output(2)->template mutable_data<T, Context>(); T* dBdata = Output(2)->template mutable_data<T, Context>(ctx());
for (int n = 0; n < Input(2).dim(0); n++) for (int n = 0; n < Input(2).dim(0); n++)
Db(dYdata + n * y_offset, dBdata); Db(dYdata + n * y_offset, dBdata);
} }
...@@ -49,7 +49,7 @@ void Conv2dGradientOp<Context>::RunWithType() { ...@@ -49,7 +49,7 @@ void Conv2dGradientOp<Context>::RunWithType() {
for (int n = 0; n < Input(2).dim(0); n++) { for (int n = 0; n < Input(2).dim(0); n++) {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dWdata = Output(1)->template mutable_data<T, Context>(); auto* dWdata = Output(1)->template mutable_data<T, Context>(ctx());
Dw(dYdata + n * y_offset, Xdata + n * x_offset, dWdata); Dw(dYdata + n * y_offset, Xdata + n * x_offset, dWdata);
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
......
...@@ -44,7 +44,7 @@ void Conv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -44,7 +44,7 @@ void Conv2dTransposeGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
if (Output(2)->name() != "ignore") { if (Output(2)->name() != "ignore") {
auto* dBdata = Output(2)->template mutable_data<T, Context>(); auto* dBdata = Output(2)->template mutable_data<T, Context>(ctx());
for (int n = 0; n < Input(2).dim(0); n++) for (int n = 0; n < Input(2).dim(0); n++)
Db(dYdata + n * y_offset, dBdata); Db(dYdata + n * y_offset, dBdata);
} }
...@@ -52,7 +52,7 @@ void Conv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -52,7 +52,7 @@ void Conv2dTransposeGradientOp<Context>::RunWithType() {
for (int n = 0; n < Input(2).dim(0); n++) { for (int n = 0; n < Input(2).dim(0); n++) {
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dWdata = Output(1)->template mutable_data<T, Context>(); auto* dWdata = Output(1)->template mutable_data<T, Context>(ctx());
Dw(Xdata + n * x_offset, dYdata + n * y_offset, dWdata); Dw(Xdata + n * x_offset, dYdata + n * y_offset, dWdata);
} }
if (Output(0)->name() != "ignore") { if (Output(0)->name() != "ignore") {
......
...@@ -77,7 +77,7 @@ void ConvOpBase<Context>::Wx( ...@@ -77,7 +77,7 @@ void ConvOpBase<Context>::Wx(
kernel_dim, kernel_dim,
1.0, weights + weight_offset * g, 1.0, weights + weight_offset * g,
col_buffer + col_offset * g, col_buffer + col_offset * g,
0.0, y + output_offset * g, &ctx()); 0.0, y + output_offset * g, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
...@@ -86,7 +86,7 @@ void ConvOpBase<Context>::Wx( ...@@ -86,7 +86,7 @@ void ConvOpBase<Context>::Wx(
kernel_dim, kernel_dim,
1.0, col_buffer + col_offset * g, 1.0, col_buffer + col_offset * g,
weights + weight_offset * g, weights + weight_offset * g,
0.0, y + output_offset * g, &ctx()); 0.0, y + output_offset * g, ctx());
} }
} }
} }
...@@ -99,13 +99,13 @@ void ConvOpBase<Context>::Pb(const T* bias, T* y) { ...@@ -99,13 +99,13 @@ void ConvOpBase<Context>::Pb(const T* bias, T* y) {
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
num_output, out_spatial_dim, 1, num_output, out_spatial_dim, 1,
1.0, bias, multiplier, 1.0, bias, multiplier,
1.0, y, &ctx()); 1.0, y, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
out_spatial_dim, num_output, 1, out_spatial_dim, num_output, 1,
1.0, multiplier, bias, 1.0, multiplier, bias,
1.0, y, &ctx()); 1.0, y, ctx());
} }
} }
...@@ -122,7 +122,7 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* weights, T* dx) { ...@@ -122,7 +122,7 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* weights, T* dx) {
conv_out_channels / group, conv_out_channels / group,
1.0, weights + weight_offset * g, 1.0, weights + weight_offset * g,
dy + output_offset * g, dy + output_offset * g,
0.0, col_buffer + col_offset * g, &ctx()); 0.0, col_buffer + col_offset * g, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasNoTrans, CblasTrans, CblasNoTrans, CblasTrans,
...@@ -131,7 +131,7 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* weights, T* dx) { ...@@ -131,7 +131,7 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* weights, T* dx) {
conv_out_channels / group, conv_out_channels / group,
1.0, dy + output_offset * g, 1.0, dy + output_offset * g,
weights + weight_offset * g, weights + weight_offset * g,
0.0, col_buffer + col_offset * g, &ctx()); 0.0, col_buffer + col_offset * g, ctx());
} }
} }
if (!is_1x1) Col2Im(col_buffer, dx); if (!is_1x1) Col2Im(col_buffer, dx);
...@@ -154,7 +154,7 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T *dw) { ...@@ -154,7 +154,7 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T *dw) {
conv_out_spatial_dim, conv_out_spatial_dim,
1.0, dy + output_offset * g, 1.0, dy + output_offset * g,
col_buffer + col_offset * g, col_buffer + col_offset * g,
1.0, dw + weight_offset * g, &ctx()); 1.0, dw + weight_offset * g, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>( math::Gemm<T, Context>(
CblasTrans, CblasNoTrans, CblasTrans, CblasNoTrans,
...@@ -163,7 +163,7 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T *dw) { ...@@ -163,7 +163,7 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T *dw) {
conv_out_spatial_dim, conv_out_spatial_dim,
1.0, col_buffer + col_offset * g, 1.0, col_buffer + col_offset * g,
dy + output_offset * g, dy + output_offset * g,
1.0, dw + weight_offset * g, &ctx()); 1.0, dw + weight_offset * g, ctx());
} }
} }
} }
...@@ -175,12 +175,12 @@ void ConvOpBase<Context>::Db(const T* dy, T* db) { ...@@ -175,12 +175,12 @@ void ConvOpBase<Context>::Db(const T* dy, T* db) {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasNoTrans, num_output, out_spatial_dim, CblasNoTrans, num_output, out_spatial_dim,
1.0, dy, multiplier, 1.0, dy, multiplier,
1.0, db, &ctx()); 1.0, db, ctx());
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemv<T, Context>( math::Gemv<T, Context>(
CblasTrans, out_spatial_dim, num_output, CblasTrans, out_spatial_dim, num_output,
1.0, dy, multiplier, 1.0, dy, multiplier,
1.0, db, &ctx()); 1.0, db, ctx());
} }
} }
......
...@@ -54,13 +54,13 @@ void CuDNNConv2dOp<Context>::ResetDesc() { ...@@ -54,13 +54,13 @@ void CuDNNConv2dOp<Context>::ResetDesc() {
} }
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx().cudnn_handle(), input_desc, ctx()->cudnn_handle(), input_desc,
filter_desc, conv_desc, output_desc, filter_desc, conv_desc, output_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES, &fwd_algo)); WORKSPACE_LIMIT_BYTES, &fwd_algo));
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx().cudnn_handle(), input_desc, ctx()->cudnn_handle(), input_desc,
filter_desc, conv_desc, output_desc, filter_desc, conv_desc, output_desc,
fwd_algo, &fwd_data_size)); fwd_algo, &fwd_data_size));
} }
...@@ -78,7 +78,7 @@ void CuDNNConv2dOp<Context>::RunWithType() { ...@@ -78,7 +78,7 @@ void CuDNNConv2dOp<Context>::RunWithType() {
auto* WSdata = (uint8_t*)ws()->template auto* WSdata = (uint8_t*)ws()->template
caches<Context>({ fwd_data_size })[0]; caches<Context>({ fwd_data_size })[0];
auto cudnn_handle = ctx().cudnn_handle(); auto cudnn_handle = ctx()->cudnn_handle();
for (int g = 0; g < cudnn_group; g++) { for (int g = 0; g < cudnn_group; g++) {
CUDNN_CHECK(cudnnConvolutionForward(cudnn_handle, CUDNN_CHECK(cudnnConvolutionForward(cudnn_handle,
...@@ -104,6 +104,8 @@ void CuDNNConv2dOp<Context>::RunOnDevice() { ...@@ -104,6 +104,8 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
#endif #endif
Conv2dOp<Context>::Reshape(); Conv2dOp<Context>::Reshape();
ctx()->set_stream_id(0); // enforce default stream
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
...@@ -199,24 +201,24 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -199,24 +201,24 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
} }
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx().cudnn_handle(), output_desc, ctx()->cudnn_handle(), output_desc,
input_desc, conv_desc, filter_desc, input_desc, conv_desc, filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES, &bwd_filter_algo)); WORKSPACE_LIMIT_BYTES, &bwd_filter_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx().cudnn_handle(), output_desc, ctx()->cudnn_handle(), output_desc,
input_desc, conv_desc, filter_desc, input_desc, conv_desc, filter_desc,
bwd_filter_algo, &bwd_filter_size)); bwd_filter_algo, &bwd_filter_size));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx().cudnn_handle(), filter_desc, ctx()->cudnn_handle(), filter_desc,
input_desc, conv_desc, output_desc, input_desc, conv_desc, output_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES, &bwd_data_algo)); WORKSPACE_LIMIT_BYTES, &bwd_data_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx().cudnn_handle(), filter_desc, ctx()->cudnn_handle(), filter_desc,
input_desc, conv_desc, output_desc, input_desc, conv_desc, output_desc,
bwd_data_algo, &bwd_data_size)); bwd_data_algo, &bwd_data_size));
} }
...@@ -230,18 +232,18 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() { ...@@ -230,18 +232,18 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() {
auto* WSdata = ws()->template caches<Context>({ auto* WSdata = ws()->template caches<Context>({
std::max(bwd_data_size, bwd_filter_size)})[0]; std::max(bwd_data_size, bwd_filter_size)})[0];
auto cudnn_handle = ctx().cudnn_handle(); auto cudnn_handle = ctx()->cudnn_handle();
for (int g = 0; g < cudnn_group; g++) { for (int g = 0; g < cudnn_group; g++) {
if (Output(2)->name() != "ignore") { if (Output(2)->name() != "ignore") {
T* dBdata = Output(2)->template mutable_data<T, Context>(); T* dBdata = Output(2)->template mutable_data<T, Context>(ctx());
CUDNN_CHECK(cudnnConvolutionBackwardBias(cudnn_handle, CUDNN_CHECK(cudnnConvolutionBackwardBias(cudnn_handle,
CUDNNType<T>::one, input_desc, dYdata + y_offset * g, CUDNNType<T>::one, input_desc, dYdata + y_offset * g,
CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g)); CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g));
} }
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dWdata = Output(1)->template mutable_data<T, Context>(); auto* dWdata = Output(1)->template mutable_data<T, Context>(ctx());
CUDNN_CHECK(cudnnConvolutionBackwardFilter(cudnn_handle, CUDNN_CHECK(cudnnConvolutionBackwardFilter(cudnn_handle,
CUDNNType<T>::one, output_desc, Xdata + x_offset * g, CUDNNType<T>::one, output_desc, Xdata + x_offset * g,
input_desc, dYdata + y_offset * g, input_desc, dYdata + y_offset * g,
...@@ -269,6 +271,8 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() { ...@@ -269,6 +271,8 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
#endif #endif
Conv2dGradientOp<Context>::GradientReshape(); Conv2dGradientOp<Context>::GradientReshape();
ctx()->set_stream_id(0); // enforce default stream
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
......
...@@ -54,13 +54,13 @@ void CuDNNConv2dTransposeOp<Context>::ResetDesc() { ...@@ -54,13 +54,13 @@ void CuDNNConv2dTransposeOp<Context>::ResetDesc() {
} }
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx().cudnn_handle(), filter_desc, ctx()->cudnn_handle(), filter_desc,
input_desc, conv_desc, output_desc, input_desc, conv_desc, output_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES, &fwd_algo)); WORKSPACE_LIMIT_BYTES, &fwd_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx().cudnn_handle(), filter_desc, ctx()->cudnn_handle(), filter_desc,
input_desc, conv_desc, output_desc, input_desc, conv_desc, output_desc,
fwd_algo, &fwd_data_size)); fwd_algo, &fwd_data_size));
} }
...@@ -78,7 +78,7 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() { ...@@ -78,7 +78,7 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() {
auto* WSdata = (uint8_t*)ws()->template auto* WSdata = (uint8_t*)ws()->template
caches<Context>({ fwd_data_size })[0]; caches<Context>({ fwd_data_size })[0];
auto cudnn_handle = ctx().cudnn_handle(); auto cudnn_handle = ctx()->cudnn_handle();
for (int g = 0; g < cudnn_group; g++) { for (int g = 0; g < cudnn_group; g++) {
CUDNN_CHECK(cudnnConvolutionBackwardData(cudnn_handle, CUDNN_CHECK(cudnnConvolutionBackwardData(cudnn_handle,
...@@ -104,6 +104,8 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() { ...@@ -104,6 +104,8 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
#endif #endif
Conv2dTransposeOp<Context>::Reshape(); Conv2dTransposeOp<Context>::Reshape();
ctx()->set_stream_id(0); // enforce default stream
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
...@@ -199,24 +201,24 @@ void CuDNNConv2dTransposeGradientOp<Context>::ResetDesc() { ...@@ -199,24 +201,24 @@ void CuDNNConv2dTransposeGradientOp<Context>::ResetDesc() {
} }
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx().cudnn_handle(), input_desc, ctx()->cudnn_handle(), input_desc,
output_desc, conv_desc, filter_desc, output_desc, conv_desc, filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES, &bwd_filter_algo)); WORKSPACE_LIMIT_BYTES, &bwd_filter_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx().cudnn_handle(), input_desc, ctx()->cudnn_handle(), input_desc,
output_desc, conv_desc, filter_desc, output_desc, conv_desc, filter_desc,
bwd_filter_algo, &bwd_filter_size)); bwd_filter_algo, &bwd_filter_size));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx().cudnn_handle(), input_desc, ctx()->cudnn_handle(), input_desc,
filter_desc, conv_desc, output_desc, filter_desc, conv_desc, output_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES, &bwd_data_algo)); WORKSPACE_LIMIT_BYTES, &bwd_data_algo));
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx().cudnn_handle(), input_desc, ctx()->cudnn_handle(), input_desc,
filter_desc, conv_desc, output_desc, filter_desc, conv_desc, output_desc,
bwd_data_algo, &bwd_data_size)); bwd_data_algo, &bwd_data_size));
} }
...@@ -230,18 +232,18 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -230,18 +232,18 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() {
auto* WSdata = ws()->template caches<Context>({ auto* WSdata = ws()->template caches<Context>({
std::max(bwd_data_size, bwd_filter_size) })[0]; std::max(bwd_data_size, bwd_filter_size) })[0];
auto cudnn_handle = ctx().cudnn_handle(); auto cudnn_handle = ctx()->cudnn_handle();
for (int g = 0; g < cudnn_group; g++) { for (int g = 0; g < cudnn_group; g++) {
if (Output(2)->name() != "ignore") { if (Output(2)->name() != "ignore") {
T* dBdata = Output(2)->template mutable_data<T, Context>(); T* dBdata = Output(2)->template mutable_data<T, Context>(ctx());
CUDNN_CHECK(cudnnConvolutionBackwardBias(cudnn_handle, CUDNN_CHECK(cudnnConvolutionBackwardBias(cudnn_handle,
CUDNNType<T>::one, input_desc, dYdata + y_offset * g, CUDNNType<T>::one, input_desc, dYdata + y_offset * g,
CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g)); CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g));
} }
if (Output(1)->name() != "ignore") { if (Output(1)->name() != "ignore") {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* dWdata = Output(1)->template mutable_data<T, Context>(); auto* dWdata = Output(1)->template mutable_data<T, Context>(ctx());
CUDNN_CHECK(cudnnConvolutionBackwardFilter(cudnn_handle, CUDNN_CHECK(cudnnConvolutionBackwardFilter(cudnn_handle,
CUDNNType<T>::one, input_desc, dYdata + y_offset * g, CUDNNType<T>::one, input_desc, dYdata + y_offset * g,
output_desc, Xdata + x_offset * g, output_desc, Xdata + x_offset * g,
...@@ -269,6 +271,8 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() { ...@@ -269,6 +271,8 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
#endif #endif
Conv2dTransposeGradientOp<Context>::GradientReshape(); Conv2dTransposeGradientOp<Context>::GradientReshape();
ctx()->set_stream_id(0); // enforce default stream
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
......
...@@ -13,7 +13,7 @@ void CuDNNLRNOp<Context>::RunWithType() { ...@@ -13,7 +13,7 @@ void CuDNNLRNOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnLRNCrossChannelForward( CUDNN_CHECK(cudnnLRNCrossChannelForward(
ctx().cudnn_handle(), norm_desc, ctx()->cudnn_handle(), norm_desc,
CUDNN_LRN_CROSS_CHANNEL_DIM1, CUDNN_LRN_CROSS_CHANNEL_DIM1,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
...@@ -55,7 +55,7 @@ void CuDNNLRNGradientOp<Context>::RunWithType() { ...@@ -55,7 +55,7 @@ void CuDNNLRNGradientOp<Context>::RunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnLRNCrossChannelBackward( CUDNN_CHECK(cudnnLRNCrossChannelBackward(
ctx().cudnn_handle(), norm_desc, ctx()->cudnn_handle(), norm_desc,
CUDNN_LRN_CROSS_CHANNEL_DIM1, CUDNN_LRN_CROSS_CHANNEL_DIM1,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Xdata, input_desc, dYdata, output_desc, Xdata,
......
...@@ -25,7 +25,7 @@ void CuDNNPooling2dOp<Context>::RunWithType() { ...@@ -25,7 +25,7 @@ void CuDNNPooling2dOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnPoolingForward( CUDNN_CHECK(cudnnPoolingForward(
ctx().cudnn_handle(), pool_desc, ctx()->cudnn_handle(), pool_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
} }
...@@ -69,7 +69,7 @@ void CuDNNPooling2dGradientOp<Context>::RunWithType() { ...@@ -69,7 +69,7 @@ void CuDNNPooling2dGradientOp<Context>::RunWithType() {
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnPoolingBackward( CUDNN_CHECK(cudnnPoolingBackward(
ctx().cudnn_handle(), pool_desc, ctx()->cudnn_handle(), pool_desc,
CUDNNType<T>::one, input_desc, Ydata, CUDNNType<T>::one, input_desc, Ydata,
input_desc, dYdata, output_desc, Xdata, input_desc, dYdata, output_desc, Xdata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
......
...@@ -28,7 +28,7 @@ void DenseConcatGradientOp<Context>::RestoreX1() { ...@@ -28,7 +28,7 @@ void DenseConcatGradientOp<Context>::RestoreX1() {
kernel::ConcatGrad<T, Context>( kernel::ConcatGrad<T, Context>(
count, this->outer_dim, this->inner_dim, count, this->outer_dim, this->inner_dim,
this->x_concat_dim, this->y_concat_dim, this->x_concat_dim, this->y_concat_dim,
0, Ydata, Xdata); 0, Ydata, Xdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -17,11 +17,11 @@ template <class Context> template <typename T> ...@@ -17,11 +17,11 @@ template <class Context> template <typename T>
void LRNOp<Context>::SplitRunWithType() { void LRNOp<Context>::SplitRunWithType() {
sqr_in = ws()->CreateTensor("/mnt/" + anchor() + "/sqr/in"); sqr_in = ws()->CreateTensor("/mnt/" + anchor() + "/sqr/in");
sqr_in->ReshapeLike(Input(0)); sqr_in->ReshapeLike(Input(0));
sqr_in->template CopyFrom<Context>(Input(0)); sqr_in->template CopyFrom<Context>(Input(0), ctx());
prod_in = ws()->CreateTensor("/mnt/" + anchor() + "/prod/in"); prod_in = ws()->CreateTensor("/mnt/" + anchor() + "/prod/in");
prod_in->ReshapeLike(Input(0)); prod_in->ReshapeLike(Input(0));
prod_in->template CopyFrom<Context>(Input(0)); prod_in->template CopyFrom<Context>(Input(0), ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -229,7 +229,7 @@ void LRNGradientOp<Context>::SplitRunWithType() { ...@@ -229,7 +229,7 @@ void LRNGradientOp<Context>::SplitRunWithType() {
auto* data0 = g_sqr_in->template data<T, Context>(); auto* data0 = g_sqr_in->template data<T, Context>();
auto* data1 = g_prod_in->template data<T, Context>(); auto* data1 = g_prod_in->template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Add<T, Context>(Output(0)->count(), data0, data1, dXdata); math::Add<T, Context>(Output(0)->count(), data0, data1, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -26,7 +26,7 @@ void NNResizeOp<Context>::RunWithType() { ...@@ -26,7 +26,7 @@ void NNResizeOp<Context>::RunWithType() {
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::NNResize<T, Context>(Output(0)->count(), kernel::NNResize<T, Context>(Output(0)->count(),
n, c, h, w, out_h, out_w, data_format, Xdata, Ydata); n, c, h, w, out_h, out_w, data_format, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -77,8 +77,10 @@ void NNResizeGradientOp<Context>::RunWithType() { ...@@ -77,8 +77,10 @@ void NNResizeGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), 0, dXdata, ctx());
kernel::NNResizeGrad<T, Context>(Input(-1).count(), kernel::NNResizeGrad<T, Context>(Input(-1).count(),
n, c, h, w, out_h, out_w, data_format, dYdata, dXdata); n, c, h, w, out_h, out_w, data_format, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -17,7 +17,7 @@ void Pooling2dOp<Context>::MAXRunWithType() { ...@@ -17,7 +17,7 @@ void Pooling2dOp<Context>::MAXRunWithType() {
kernel::MAXPooling2d<T, Context>(Output(0)->count(), kernel::MAXPooling2d<T, Context>(Output(0)->count(),
n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1], n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1],
stride[0], stride[1], pad[0], pad[1], stride[0], stride[1], pad[0], pad[1],
data_format, Xdata, Mdata, Ydata); data_format, Xdata, Mdata, Ydata, ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -28,7 +28,7 @@ void Pooling2dOp<Context>::AVGRunWithType() { ...@@ -28,7 +28,7 @@ void Pooling2dOp<Context>::AVGRunWithType() {
kernel::AVGPooling2d<T, Context>(Output(0)->count(), kernel::AVGPooling2d<T, Context>(Output(0)->count(),
n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1], n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1],
stride[0], stride[1], pad[0], pad[1], stride[0], stride[1], pad[0], pad[1],
data_format, Xdata, Ydata); data_format, Xdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -127,8 +127,9 @@ void Pooling2dGradientOp<Context>::MAXRunWithType() { ...@@ -127,8 +127,9 @@ void Pooling2dGradientOp<Context>::MAXRunWithType() {
kernel::MAXPooling2dGrad<T, Context>(Output(0)->count(), kernel::MAXPooling2dGrad<T, Context>(Output(0)->count(),
n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1], n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1],
stride[0], stride[1], pad[0], pad[1], stride[0], stride[1], pad[0], pad[1],
data_format, dYdata, Mdata, dXdata); data_format, dYdata, Mdata, dXdata, ctx());
ctx()->FinishDeviceCompution();
mask->Reset(); mask->Reset();
} }
...@@ -140,7 +141,7 @@ void Pooling2dGradientOp<Context>::AVGRunWithType() { ...@@ -140,7 +141,7 @@ void Pooling2dGradientOp<Context>::AVGRunWithType() {
kernel::AVGPooling2dGrad<T, Context>(Output(0)->count(), kernel::AVGPooling2dGrad<T, Context>(Output(0)->count(),
n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1], n, c, h, w, pool_h, pool_w, kernel_size[0], kernel_size[1],
stride[0], stride[1], pad[0], pad[1], stride[0], stride[1], pad[0], pad[1],
data_format, dYdata, dXdata); data_format, dYdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -14,7 +14,8 @@ void ROIAlignOp<Context>::RunWithType() { ...@@ -14,7 +14,8 @@ void ROIAlignOp<Context>::RunWithType() {
kernel::ROIAlign<T, Context>( kernel::ROIAlign<T, Context>(
Output(0)->count(), Input(0).dim(0), Input(0).dim(1), Output(0)->count(), Input(0).dim(0), Input(0).dim(1),
Input(0).dim(2), Input(0).dim(3), pool_h, pool_w, Input(0).dim(2), Input(0).dim(3), pool_h, pool_w,
Input(1).dim(0), spatial_scale, sampling_ratio, Xdata, Rdata, Ydata); Input(1).dim(0), spatial_scale, sampling_ratio,
Xdata, Rdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -38,12 +39,13 @@ void ROIAlignGradientOp<Context>::RunWithType() { ...@@ -38,12 +39,13 @@ void ROIAlignGradientOp<Context>::RunWithType() {
auto* Rdata = Input(1).template data<T, CUDAContext>(); auto* Rdata = Input(1).template data<T, CUDAContext>();
auto* dXdata = Output(0)->template mutable_data<T, Context>(); auto* dXdata = Output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(Output(0)->count(), 0, dXdata); math::Set<T, Context>(Output(0)->count(), 0, dXdata, ctx());
kernel::ROIAlignGrad<T, Context>( kernel::ROIAlignGrad<T, Context>(
Input(-1).count(), Output(0)->dim(0), Output(0)->dim(1), Input(-1).count(), Output(0)->dim(0), Output(0)->dim(1),
Output(0)->dim(2), Output(0)->dim(3), pool_h, pool_w, Output(0)->dim(2), Output(0)->dim(3), pool_h, pool_w,
Input(1).dim(0), spatial_scale, sampling_ratio, dYdata, Rdata, dXdata); Input(1).dim(0), spatial_scale, sampling_ratio,
dYdata, Rdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -19,7 +19,8 @@ void ROIPoolingOp<Context>::RunWithType() { ...@@ -19,7 +19,8 @@ void ROIPoolingOp<Context>::RunWithType() {
kernel::ROIPooling<T, Context>( kernel::ROIPooling<T, Context>(
Output(0)->count(), Input(0).dim(0), Input(0).dim(1), Output(0)->count(), Input(0).dim(0), Input(0).dim(1),
Input(0).dim(2), Input(0).dim(3), pool_h, pool_w, Input(0).dim(2), Input(0).dim(3), pool_h, pool_w,
Input(1).dim(0), spatial_scale, Xdata, Rdata, Mdata, Ydata); Input(1).dim(0), spatial_scale,
Xdata, Rdata, Mdata, Ydata, ctx());
} }
template <class Context> template <class Context>
...@@ -50,7 +51,8 @@ void ROIPoolingGradientOp<Context>::RunWithType() { ...@@ -50,7 +51,8 @@ void ROIPoolingGradientOp<Context>::RunWithType() {
kernel::ROIPoolingGrad<T, Context>( kernel::ROIPoolingGrad<T, Context>(
Output(0)->count(), Output(0)->dim(0), Output(0)->dim(1), Output(0)->count(), Output(0)->dim(0), Output(0)->dim(1),
Output(0)->dim(2), Output(0)->dim(3), pool_h, pool_w, Output(0)->dim(2), Output(0)->dim(3), pool_h, pool_w,
Input(1).dim(0), spatial_scale, dYdata, Rdata, Mdata, dXdata); Input(1).dim(0), spatial_scale,
dYdata, Rdata, Mdata, dXdata, ctx());
} }
template <class Context> template <class Context>
......
...@@ -14,7 +14,8 @@ namespace math { ...@@ -14,7 +14,8 @@ namespace math {
template <> void Set<float, CPUContext>( template <> void Set<float, CPUContext>(
const int n, const int n,
const float alpha, const float alpha,
float* x) { float* x,
CPUContext* ctx) {
if (alpha == 0) { if (alpha == 0) {
memset(x, 0, sizeof(float) * n); memset(x, 0, sizeof(float) * n);
return; return;
...@@ -32,7 +33,8 @@ template <> void Set<float, CPUContext>( ...@@ -32,7 +33,8 @@ template <> void Set<float, CPUContext>(
template <> void Set<int, CPUContext>( template <> void Set<int, CPUContext>(
const int n, const int n,
const int alpha, const int alpha,
int* x) { int* x,
CPUContext* ctx) {
if (alpha == 0) { if (alpha == 0) {
memset(x, 0, sizeof(int) * n); memset(x, 0, sizeof(int) * n);
return; return;
...@@ -50,7 +52,8 @@ template <> void Set<int, CPUContext>( ...@@ -50,7 +52,8 @@ template <> void Set<int, CPUContext>(
template <> void Set<float16, CPUContext>( template <> void Set<float16, CPUContext>(
const int n, const int n,
const float16 alpha, const float16 alpha,
float16* x) { float16* x,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -164,7 +167,8 @@ template <> void Add<float, CPUContext>( ...@@ -164,7 +167,8 @@ template <> void Add<float, CPUContext>(
const int n, const int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_SSE #ifdef WITH_SSE
sse::Add<float>(n, a, b, y); sse::Add<float>(n, a, b, y);
#else #else
...@@ -179,7 +183,8 @@ template <> void Add<int, CPUContext>( ...@@ -179,7 +183,8 @@ template <> void Add<int, CPUContext>(
const int n, const int n,
const int* a, const int* a,
const int* b, const int* b,
int* y) { int* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -190,7 +195,8 @@ template <> void Add<float16, CPUContext>( ...@@ -190,7 +195,8 @@ template <> void Add<float16, CPUContext>(
const int n, const int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -198,7 +204,8 @@ template <> void Sub<float, CPUContext>( ...@@ -198,7 +204,8 @@ template <> void Sub<float, CPUContext>(
const int n, const int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_SSE #ifdef WITH_SSE
sse::Sub<float>(n, a, b, y); sse::Sub<float>(n, a, b, y);
#else #else
...@@ -213,7 +220,8 @@ template <> void Sub<float16, CPUContext>( ...@@ -213,7 +220,8 @@ template <> void Sub<float16, CPUContext>(
const int n, const int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -221,7 +229,8 @@ template <> void Mul<float, CPUContext>( ...@@ -221,7 +229,8 @@ template <> void Mul<float, CPUContext>(
const int n, const int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_SSE #ifdef WITH_SSE
sse::Mul<float>(n, a, b, y); sse::Mul<float>(n, a, b, y);
#else #else
...@@ -236,7 +245,8 @@ template <> void Mul<float16, CPUContext>( ...@@ -236,7 +245,8 @@ template <> void Mul<float16, CPUContext>(
const int n, const int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -244,7 +254,8 @@ template <> void Div<float, CPUContext>( ...@@ -244,7 +254,8 @@ template <> void Div<float, CPUContext>(
const int n, const int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_SSE #ifdef WITH_SSE
sse::Div<float>(n, a, b, y); sse::Div<float>(n, a, b, y);
#else #else
...@@ -259,7 +270,8 @@ template <> void Div<float16, CPUContext>( ...@@ -259,7 +270,8 @@ template <> void Div<float16, CPUContext>(
const int n, const int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -267,7 +279,8 @@ template <> void Clip<float, CPUContext>( ...@@ -267,7 +279,8 @@ template <> void Clip<float, CPUContext>(
const int n, const int n,
const float low, const float low,
const float high, const float high,
float* x) { float* x,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -279,7 +292,8 @@ template <> void Clip<float, CPUContext>( ...@@ -279,7 +292,8 @@ template <> void Clip<float, CPUContext>(
template <> void Exp<float, CPUContext>( template <> void Exp<float, CPUContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -289,7 +303,8 @@ template <> void Exp<float, CPUContext>( ...@@ -289,7 +303,8 @@ template <> void Exp<float, CPUContext>(
template <> void Log<float, CPUContext>( template <> void Log<float, CPUContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -299,7 +314,8 @@ template <> void Log<float, CPUContext>( ...@@ -299,7 +314,8 @@ template <> void Log<float, CPUContext>(
template <> void Square<float, CPUContext>( template <> void Square<float, CPUContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -309,14 +325,16 @@ template <> void Square<float, CPUContext>( ...@@ -309,14 +325,16 @@ template <> void Square<float, CPUContext>(
template <> void Square<float16, CPUContext>( template <> void Square<float16, CPUContext>(
int n, int n,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
template <> void Sqrt<float, CPUContext>( template <> void Sqrt<float, CPUContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -326,7 +344,8 @@ template <> void Sqrt<float, CPUContext>( ...@@ -326,7 +344,8 @@ template <> void Sqrt<float, CPUContext>(
template <> void Sqrt<float16, CPUContext>( template <> void Sqrt<float16, CPUContext>(
int n, int n,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -334,7 +353,8 @@ template <> void Pow<float, CPUContext>( ...@@ -334,7 +353,8 @@ template <> void Pow<float, CPUContext>(
int n, int n,
const float alpha, const float alpha,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -345,7 +365,8 @@ template <> void Pow<float16, CPUContext>( ...@@ -345,7 +365,8 @@ template <> void Pow<float16, CPUContext>(
int n, int n,
const float alpha, const float alpha,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -353,7 +374,8 @@ template <> void Inv<float, CPUContext>( ...@@ -353,7 +374,8 @@ template <> void Inv<float, CPUContext>(
const int n, const int n,
const float numerator, const float numerator,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) #pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif #endif
...@@ -364,7 +386,8 @@ template <> void Inv<float16, CPUContext>( ...@@ -364,7 +386,8 @@ template <> void Inv<float16, CPUContext>(
const int n, const int n,
const float numerator, const float numerator,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -423,51 +446,51 @@ template <> void Scale<float, CPUContext>( ...@@ -423,51 +446,51 @@ template <> void Scale<float, CPUContext>(
#endif // WITH_BLAS #endif // WITH_BLAS
} }
template <> float StridedDot<float, CPUContext>( template <> void StridedDot<float, CPUContext>(
const int n, const int n,
const float* a, const float* a,
const int incx, const int incx,
const float* b, const float* b,
const int incy, const int incy,
float* y,
CPUContext* ctx) { CPUContext* ctx) {
#ifdef WITH_BLAS #ifdef WITH_BLAS
return cblas_sdot(n, a, incx, b, incy); float result = cblas_sdot(n, a, incx, b, incy);
#else #else
float ret = 0.f; float result = 0.f;
#ifdef WITH_OMP int cx = 0, cy = 0;
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) for (int i = 0; i < n; ++i) {
#endif result += a[cx] * b[cy];
for (int i = 0; i < n; ++i) ret += a[i] * b[i]; cx += incx; cy += incy;
return ret; }
#endif // WITH_BLAS #endif // WITH_BLAS
*y = result;
} }
template <> float Dot<float, CPUContext>( template <> void Dot<float, CPUContext>(
int n, int n,
const float* a, const float* a,
const float* b, const float* b,
float* y,
CPUContext* ctx) { CPUContext* ctx) {
#ifdef WITH_BLAS #ifdef WITH_BLAS
return StridedDot<float, CPUContext>(n, a, 1, b, 1, ctx); StridedDot<float, CPUContext>(n, a, 1, b, 1, y, ctx);
#elif WITH_SSE #elif WITH_SSE
return sse::Dot<float>(n, a, b); *y = sse::Dot<float>(n, a, b);
#else #else
float ret = 0.f; float result = 0.f;
#ifdef WITH_OMP for (int i = 0; i < n; ++i) result += a[i] * b[i];
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) *y = result;
#endif
for (int i = 0; i < n; ++i) ret += a[i] * b[i];
return ret;
#endif // WITH_BLAS #endif // WITH_BLAS
} }
template <> float Dot<float16, CPUContext>( template <> void Dot<float16, CPUContext>(
int n, int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y,
CPUContext* ctx) { CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
return 0;
} }
template <> float ASum<float, CPUContext>( template <> float ASum<float, CPUContext>(
...@@ -475,22 +498,19 @@ template <> float ASum<float, CPUContext>( ...@@ -475,22 +498,19 @@ template <> float ASum<float, CPUContext>(
const float* x) { const float* x) {
#ifdef WITH_BLAS #ifdef WITH_BLAS
return cblas_sasum(n, x, 1); return cblas_sasum(n, x, 1);
#elif WITH_SSE
return sse::ASum<float>(n, x);
#else #else
float ret = 0.f; float result = 0.f;
#ifdef WITH_OMP for (int i = 0; i < n; ++i)
#pragma omp parallel for num_threads(GET_OMP_THREADS(n)) result += std::abs(x[i]);
#endif return result;
for (int i = 0; i < n; ++i) ret += x[i];
return ret;
#endif // WITH_BLAS #endif // WITH_BLAS
} }
template <> void AddScalar<float, CPUContext>( template <> void AddScalar<float, CPUContext>(
const int n, const int n,
const float alpha, const float alpha,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_SSE #ifdef WITH_SSE
sse::AddScalar<float>(n, alpha, y); sse::AddScalar<float>(n, alpha, y);
#else #else
...@@ -504,14 +524,16 @@ template <> void AddScalar<float, CPUContext>( ...@@ -504,14 +524,16 @@ template <> void AddScalar<float, CPUContext>(
template <> void AddScalar<float16, CPUContext>( template <> void AddScalar<float16, CPUContext>(
const int n, const int n,
const float alpha, const float alpha,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
template <> void MulScalar<float, CPUContext>( template <> void MulScalar<float, CPUContext>(
const int n, const int n,
const float alpha, const float alpha,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_SSE #ifdef WITH_SSE
sse::MulScalar<float>(n, alpha, y); sse::MulScalar<float>(n, alpha, y);
#else #else
...@@ -525,7 +547,8 @@ template <> void MulScalar<float, CPUContext>( ...@@ -525,7 +547,8 @@ template <> void MulScalar<float, CPUContext>(
template <> void MulScalar<float16, CPUContext>( template <> void MulScalar<float16, CPUContext>(
const int n, const int n,
const float alpha, const float alpha,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
......
...@@ -18,7 +18,7 @@ __global__ void _Set( ...@@ -18,7 +18,7 @@ __global__ void _Set(
const int n, const int n,
const T alpha, const T alpha,
T* x) { T* x) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
x[idx] = alpha; x[idx] = alpha;
} }
} }
...@@ -26,27 +26,31 @@ __global__ void _Set( ...@@ -26,27 +26,31 @@ __global__ void _Set(
template <> void Set<float, CUDAContext>( template <> void Set<float, CUDAContext>(
const int n, const int n,
const float alpha, const float alpha,
float* x) { float* x,
if (alpha == 0) { CUDAContext* ctx) {
CUDA_CHECK(cudaMemset(x, 0, sizeof(float) * n)); if (alpha == 0.f) {
return; CUDA_CHECK(cudaMemsetAsync(x, 0,
} sizeof(float) * n, ctx->cuda_stream()));
} else {
_Set<float> _Set<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, alpha, x); 0, ctx->cuda_stream() >> >(n, alpha, x);
}
} }
template <> void Set<int, CUDAContext>( template <> void Set<int, CUDAContext>(
const int n, const int n,
const int alpha, const int alpha,
int* x) { int* x,
CUDAContext* ctx) {
if (alpha == 0) { if (alpha == 0) {
CUDA_CHECK(cudaMemset(x, 0, sizeof(int) * n)); CUDA_CHECK(cudaMemsetAsync(x, 0,
return; sizeof(int) * n, ctx->cuda_stream()));
} } else {
_Set<int> _Set<int>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, alpha, x); 0, ctx->cuda_stream() >> >(n, alpha, x);
}
} }
template <> void RandomUniform<uint32_t, CUDAContext>( template <> void RandomUniform<uint32_t, CUDAContext>(
...@@ -89,7 +93,7 @@ __global__ void _Add( ...@@ -89,7 +93,7 @@ __global__ void _Add(
const T* a, const T* a,
const T* b, const T* b,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = a[idx] + b[idx]; y[idx] = a[idx] + b[idx];
} }
} }
...@@ -98,10 +102,11 @@ template <> void Add<float, CUDAContext>( ...@@ -98,10 +102,11 @@ template <> void Add<float, CUDAContext>(
int n, int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CUDAContext* ctx) {
_Add<float> _Add<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, a, b, y); 0, ctx->cuda_stream() >> >(n, a, b, y);
} }
template <typename T> template <typename T>
...@@ -110,7 +115,7 @@ __global__ void _Sub( ...@@ -110,7 +115,7 @@ __global__ void _Sub(
const T* a, const T* a,
const T* b, const T* b,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = a[idx] - b[idx]; y[idx] = a[idx] - b[idx];
} }
} }
...@@ -119,10 +124,11 @@ template <> void Sub<float, CUDAContext>( ...@@ -119,10 +124,11 @@ template <> void Sub<float, CUDAContext>(
int n, int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CUDAContext* ctx) {
_Sub<float> _Sub<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, a, b, y); 0, ctx->cuda_stream() >> >(n, a, b, y);
} }
template <typename T> template <typename T>
...@@ -131,7 +137,7 @@ __global__ void _Mul( ...@@ -131,7 +137,7 @@ __global__ void _Mul(
const T* a, const T* a,
const T* b, const T* b,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = a[idx] * b[idx]; y[idx] = a[idx] * b[idx];
} }
} }
...@@ -140,10 +146,11 @@ template <> void Mul<float, CUDAContext>( ...@@ -140,10 +146,11 @@ template <> void Mul<float, CUDAContext>(
int n, int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CUDAContext* ctx) {
_Mul<float> _Mul<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, a, b, y); 0, ctx->cuda_stream() >> >(n, a, b, y);
} }
template <typename T> template <typename T>
...@@ -152,7 +159,7 @@ __global__ void _Div( ...@@ -152,7 +159,7 @@ __global__ void _Div(
const T* a, const T* a,
const T* b, const T* b,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = a[idx] / b[idx]; y[idx] = a[idx] / b[idx];
} }
} }
...@@ -161,10 +168,11 @@ template <> void Div<float, CUDAContext>( ...@@ -161,10 +168,11 @@ template <> void Div<float, CUDAContext>(
int n, int n,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CUDAContext* ctx) {
_Div<float> _Div<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, a, b, y); 0, ctx->cuda_stream() >> >(n, a, b, y);
} }
template <typename T> template <typename T>
...@@ -173,7 +181,7 @@ __global__ void _Clip( ...@@ -173,7 +181,7 @@ __global__ void _Clip(
const T low, const T low,
const T high, const T high,
T* x) { T* x) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
x[idx] = x[idx] > high ? high : x[idx]; x[idx] = x[idx] > high ? high : x[idx];
x[idx] = x[idx] < low ? low : x[idx]; x[idx] = x[idx] < low ? low : x[idx];
} }
...@@ -183,10 +191,11 @@ template <> void Clip<float, CUDAContext>( ...@@ -183,10 +191,11 @@ template <> void Clip<float, CUDAContext>(
const int n, const int n,
const float low, const float low,
const float high, const float high,
float* x) { float* x,
CUDAContext* ctx) {
_Clip<float> _Clip<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, low, high, x); 0, ctx->cuda_stream() >> >(n, low, high, x);
} }
template <typename T> template <typename T>
...@@ -194,7 +203,7 @@ __global__ void _Exp( ...@@ -194,7 +203,7 @@ __global__ void _Exp(
const int n, const int n,
const T* a, const T* a,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = exp(a[idx]); y[idx] = exp(a[idx]);
} }
} }
...@@ -202,10 +211,11 @@ __global__ void _Exp( ...@@ -202,10 +211,11 @@ __global__ void _Exp(
template <> void Exp<float, CUDAContext>( template <> void Exp<float, CUDAContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CUDAContext* ctx) {
_Exp<float> _Exp<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, x, y); 0, ctx->cuda_stream() >> >(n, x, y);
} }
template <typename T> template <typename T>
...@@ -213,7 +223,7 @@ __global__ void _Log( ...@@ -213,7 +223,7 @@ __global__ void _Log(
const int n, const int n,
const T* a, const T* a,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = log(a[idx]); y[idx] = log(a[idx]);
} }
} }
...@@ -221,10 +231,11 @@ __global__ void _Log( ...@@ -221,10 +231,11 @@ __global__ void _Log(
template <> void Log<float, CUDAContext>( template <> void Log<float, CUDAContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CUDAContext* ctx) {
_Log<float> _Log<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, x, y); 0, ctx->cuda_stream() >> >(n, x, y);
} }
template <typename T> template <typename T>
...@@ -232,7 +243,7 @@ __global__ void _Square( ...@@ -232,7 +243,7 @@ __global__ void _Square(
const int n, const int n,
const T* x, const T* x,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = x[idx] * x[idx]; y[idx] = x[idx] * x[idx];
} }
} }
...@@ -240,10 +251,11 @@ __global__ void _Square( ...@@ -240,10 +251,11 @@ __global__ void _Square(
template <> void Square<float, CUDAContext>( template <> void Square<float, CUDAContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CUDAContext* ctx) {
_Square<float> _Square<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, x, y); 0, ctx->cuda_stream() >> >(n, x, y);
} }
template <typename T> template <typename T>
...@@ -251,7 +263,7 @@ __global__ void _Sqrt( ...@@ -251,7 +263,7 @@ __global__ void _Sqrt(
const int n, const int n,
const T* x, const T* x,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = sqrt(x[idx]); y[idx] = sqrt(x[idx]);
} }
} }
...@@ -259,10 +271,11 @@ __global__ void _Sqrt( ...@@ -259,10 +271,11 @@ __global__ void _Sqrt(
template <> void Sqrt<float, CUDAContext>( template <> void Sqrt<float, CUDAContext>(
int n, int n,
const float* x, const float* x,
float* y) { float* y,
CUDAContext* ctx) {
_Sqrt<float> _Sqrt<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, x, y); 0, ctx->cuda_stream() >> >(n, x, y);
} }
template <typename T> template <typename T>
...@@ -271,7 +284,7 @@ __global__ void _Pow( ...@@ -271,7 +284,7 @@ __global__ void _Pow(
const T alpha, const T alpha,
const T* a, const T* a,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = pow(a[idx], alpha); y[idx] = pow(a[idx], alpha);
} }
} }
...@@ -280,10 +293,11 @@ template <> void Pow<float, CUDAContext>( ...@@ -280,10 +293,11 @@ template <> void Pow<float, CUDAContext>(
int n, int n,
const float alpha, const float alpha,
const float* x, const float* x,
float* y) { float* y,
CUDAContext* ctx) {
_Pow<float> _Pow<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, alpha, x, y); 0, ctx->cuda_stream() >> >(n, alpha, x, y);
} }
template <typename T> template <typename T>
...@@ -292,7 +306,7 @@ __global__ void _Inv( ...@@ -292,7 +306,7 @@ __global__ void _Inv(
const float numerator, const float numerator,
const T* x, const T* x,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] = numerator / x[idx]; y[idx] = numerator / x[idx];
} }
} }
...@@ -301,10 +315,11 @@ template <> void Inv<float, CUDAContext>( ...@@ -301,10 +315,11 @@ template <> void Inv<float, CUDAContext>(
const int n, const int n,
const float numerator, const float numerator,
const float* x, const float* x,
float* y) { float* y,
CUDAContext* ctx) {
_Inv<float> _Inv<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, numerator, x, y); 0, ctx->cuda_stream() >> >(n, numerator, x, y);
} }
/******************** Level-2 ********************/ /******************** Level-2 ********************/
...@@ -330,26 +345,27 @@ template <> void Scale<float, CUDAContext>( ...@@ -330,26 +345,27 @@ template <> void Scale<float, CUDAContext>(
ctx->cublas_handle(), n, &alpha, y, 1)); ctx->cublas_handle(), n, &alpha, y, 1));
} }
template <> float StridedDot<float, CUDAContext>( template <> void StridedDot<float, CUDAContext>(
const int n, const int n,
const float* a, const float* a,
const int incx, const int incx,
const float* b, const float* b,
const int incy, const int incy,
float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
float result;
CUBLAS_CHECK(cublasSdot_v2(ctx->cublas_handle(), CUBLAS_CHECK(cublasSdot_v2(ctx->cublas_handle(),
n, a, incx, b, incy, &result)); n, a, incx, b, incy, y));
return result;
} }
template <> float Dot<float, CUDAContext>( template <> void Dot<float, CUDAContext>(
int n, int n,
const float* a, const float* a,
const float* b, const float* b,
float* y,
CUDAContext* ctx) { CUDAContext* ctx) {
return StridedDot<float, CUDAContext>( StridedDot<float, CUDAContext>(
n, a, 1, b, 1, ctx); n, a, 1, b, 1, y, ctx);
ctx->FinishDeviceCompution();
} }
template <> float ASum<float, CUDAContext>( template <> float ASum<float, CUDAContext>(
...@@ -363,7 +379,7 @@ __global__ void _AddScalar( ...@@ -363,7 +379,7 @@ __global__ void _AddScalar(
const int n, const int n,
T alpha, T alpha,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] += alpha; y[idx] += alpha;
} }
} }
...@@ -371,10 +387,11 @@ __global__ void _AddScalar( ...@@ -371,10 +387,11 @@ __global__ void _AddScalar(
template <> void AddScalar<float, CUDAContext>( template <> void AddScalar<float, CUDAContext>(
const int n, const int n,
const float alpha, const float alpha,
float* y) { float* y,
CUDAContext* ctx) {
_AddScalar<float> _AddScalar<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, alpha, y); 0, ctx->cuda_stream() >> >(n, alpha, y);
} }
template <typename T> template <typename T>
...@@ -382,7 +399,7 @@ __global__ void _MulScalar( ...@@ -382,7 +399,7 @@ __global__ void _MulScalar(
const int n, const int n,
T alpha, T alpha,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
y[idx] *= alpha; y[idx] *= alpha;
} }
} }
...@@ -390,10 +407,11 @@ __global__ void _MulScalar( ...@@ -390,10 +407,11 @@ __global__ void _MulScalar(
template <> void MulScalar<float, CUDAContext>( template <> void MulScalar<float, CUDAContext>(
const int n, const int n,
const float alpha, const float alpha,
float* y) { float* y,
CUDAContext* ctx) {
_MulScalar<float> _MulScalar<float>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, alpha, y); 0, ctx->cuda_stream() >> >(n, alpha, y);
} }
template <> void Axpy<float, CUDAContext>( template <> void Axpy<float, CUDAContext>(
...@@ -427,7 +445,7 @@ template <> void RandomUniform<float, CUDAContext>( ...@@ -427,7 +445,7 @@ template <> void RandomUniform<float, CUDAContext>(
ctx->curand_generator(), x, n)); ctx->curand_generator(), x, n));
float range = high - low; float range = high - low;
if (range != 1.f) Scal<float, CUDAContext>(n, range, x, ctx); if (range != 1.f) Scal<float, CUDAContext>(n, range, x, ctx);
if (low != 0.f) AddScalar<float, CUDAContext>(n, low, x); if (low != 0.f) AddScalar<float, CUDAContext>(n, low, x, ctx);
} }
/******************** Level-3 ********************/ /******************** Level-3 ********************/
......
...@@ -18,7 +18,7 @@ __global__ void _SetHalf( ...@@ -18,7 +18,7 @@ __global__ void _SetHalf(
const int n, const int n,
const T alpha, const T alpha,
T* x) { T* x) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
x[idx] = alpha; x[idx] = alpha;
} }
} }
...@@ -26,16 +26,19 @@ __global__ void _SetHalf( ...@@ -26,16 +26,19 @@ __global__ void _SetHalf(
template <> void Set<float16, CUDAContext>( template <> void Set<float16, CUDAContext>(
const int n, const int n,
const float16 alpha, const float16 alpha,
float16* x) { float16* x,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_SetHalf<half2> _SetHalf<half2>
<< <CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
dragon_cast<half2, float16>(alpha), dragon_cast<half2, float16>(alpha),
reinterpret_cast<half2*>(x)); reinterpret_cast<half2*>(x));
} else { } else {
_SetHalf<float16> _SetHalf<float16>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >(n, alpha, x); << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n, alpha, x);
} }
#else #else
CUDA_FP16_NOT_COMPILED; CUDA_FP16_NOT_COMPILED;
...@@ -47,7 +50,7 @@ __global__ void _TypeFloat2Half( ...@@ -47,7 +50,7 @@ __global__ void _TypeFloat2Half(
const int n, const int n,
const float* a, const float* a,
half* b) { half* b) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
b[idx] = __float2half(a[idx]); b[idx] = __float2half(a[idx]);
} }
} }
...@@ -64,8 +67,9 @@ template <> void RandomNormal<float16, CUDAContext>( ...@@ -64,8 +67,9 @@ template <> void RandomNormal<float16, CUDAContext>(
CURAND_CHECK(curandGenerateNormal( CURAND_CHECK(curandGenerateNormal(
ctx->curand_generator(), xf32, n, mu, sigma)); ctx->curand_generator(), xf32, n, mu, sigma));
_TypeFloat2Half _TypeFloat2Half
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, xf32, reinterpret_cast<half*>(x)); 0, ctx->cuda_stream() >> >(n,
xf32, reinterpret_cast<half*>(x));
CUDAContext::Delete(xf32); CUDAContext::Delete(xf32);
#else #else
CUDA_FP16_NOT_COMPILED; CUDA_FP16_NOT_COMPILED;
...@@ -81,7 +85,7 @@ __global__ void _AddHalf( ...@@ -81,7 +85,7 @@ __global__ void _AddHalf(
const half* a, const half* a,
const half* b, const half* b,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hadd(a[idx], b[idx]); y[idx] = __hadd(a[idx], b[idx]);
#endif #endif
...@@ -94,7 +98,7 @@ __global__ void _AddHalf2( ...@@ -94,7 +98,7 @@ __global__ void _AddHalf2(
const half2* a, const half2* a,
const half2* b, const half2* b,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hadd2(a[idx], b[idx]); y[idx] = __hadd2(a[idx], b[idx]);
#endif #endif
...@@ -106,17 +110,20 @@ template <> void Add<float16, CUDAContext>( ...@@ -106,17 +110,20 @@ template <> void Add<float16, CUDAContext>(
int n, int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_AddHalf2<half2> _AddHalf2<half2>
<< <CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
reinterpret_cast<const half2*>(a), reinterpret_cast<const half2*>(a),
reinterpret_cast<const half2*>(b), reinterpret_cast<const half2*>(b),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_AddHalf<half> _AddHalf<half>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(b), reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -133,7 +140,7 @@ __global__ void _SubHalf( ...@@ -133,7 +140,7 @@ __global__ void _SubHalf(
const half* a, const half* a,
const half* b, const half* b,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hsub(a[idx], b[idx]); y[idx] = __hsub(a[idx], b[idx]);
#endif #endif
...@@ -146,7 +153,7 @@ __global__ void _SubHalf2( ...@@ -146,7 +153,7 @@ __global__ void _SubHalf2(
const half2* a, const half2* a,
const half2* b, const half2* b,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hsub2(a[idx], b[idx]); y[idx] = __hsub2(a[idx], b[idx]);
#endif #endif
...@@ -158,17 +165,20 @@ template <> void Sub<float16, CUDAContext>( ...@@ -158,17 +165,20 @@ template <> void Sub<float16, CUDAContext>(
int n, int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_SubHalf2<half2> _SubHalf2<half2>
<< <CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
reinterpret_cast<const half2*>(a), reinterpret_cast<const half2*>(a),
reinterpret_cast<const half2*>(b), reinterpret_cast<const half2*>(b),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_SubHalf<half> _SubHalf<half>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(b), reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -185,7 +195,7 @@ __global__ void _MulHalf( ...@@ -185,7 +195,7 @@ __global__ void _MulHalf(
const half* a, const half* a,
const half* b, const half* b,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul(a[idx], b[idx]); y[idx] = __hmul(a[idx], b[idx]);
#endif #endif
...@@ -198,7 +208,7 @@ __global__ void _MulHalf2( ...@@ -198,7 +208,7 @@ __global__ void _MulHalf2(
const half2* a, const half2* a,
const half2* b, const half2* b,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul2(a[idx], b[idx]); y[idx] = __hmul2(a[idx], b[idx]);
#endif #endif
...@@ -210,17 +220,20 @@ template <> void Mul<float16, CUDAContext>( ...@@ -210,17 +220,20 @@ template <> void Mul<float16, CUDAContext>(
int n, int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_MulHalf2<half2> _MulHalf2<half2>
<< <CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
reinterpret_cast<const half2*>(a), reinterpret_cast<const half2*>(a),
reinterpret_cast<const half2*>(b), reinterpret_cast<const half2*>(b),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_MulHalf<half> _MulHalf<half>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> > (n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(b), reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -237,7 +250,7 @@ __global__ void _DivHalf( ...@@ -237,7 +250,7 @@ __global__ void _DivHalf(
const half* a, const half* a,
const half* b, const half* b,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hdiv(a[idx], b[idx]); y[idx] = __hdiv(a[idx], b[idx]);
#endif #endif
...@@ -249,10 +262,12 @@ template <> void Div<float16, CUDAContext>( ...@@ -249,10 +262,12 @@ template <> void Div<float16, CUDAContext>(
int n, int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_DivHalf<half> _DivHalf<half>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(a),
reinterpret_cast<const half*>(b), reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -267,7 +282,7 @@ __global__ void _SquareHalf( ...@@ -267,7 +282,7 @@ __global__ void _SquareHalf(
const int n, const int n,
const half* x, const half* x,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul(x[idx], x[idx]); y[idx] = __hmul(x[idx], x[idx]);
#endif #endif
...@@ -279,7 +294,7 @@ __global__ void _SquareHalf2( ...@@ -279,7 +294,7 @@ __global__ void _SquareHalf2(
const int n, const int n,
const half2* x, const half2* x,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul2(x[idx], x[idx]); y[idx] = __hmul2(x[idx], x[idx]);
#endif #endif
...@@ -290,16 +305,19 @@ __global__ void _SquareHalf2( ...@@ -290,16 +305,19 @@ __global__ void _SquareHalf2(
template <> void Square<float16, CUDAContext>( template <> void Square<float16, CUDAContext>(
int n, int n,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_SquareHalf2<half2> _SquareHalf2<half2>
<< < CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
reinterpret_cast<const half2*>(x), reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_SquareHalf<half> _SquareHalf<half>
<< < CUDA_BLOCKS(n), CUDA_THREADS >> > (n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
} }
...@@ -314,7 +332,7 @@ __global__ void _SqrtHalf( ...@@ -314,7 +332,7 @@ __global__ void _SqrtHalf(
int n, int n,
const half* x, const half* x,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = hsqrt(x[idx]); y[idx] = hsqrt(x[idx]);
#endif #endif
...@@ -326,7 +344,7 @@ __global__ void _SqrtHalf2( ...@@ -326,7 +344,7 @@ __global__ void _SqrtHalf2(
const int n, const int n,
const half2* x, const half2* x,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = h2sqrt(x[idx]); y[idx] = h2sqrt(x[idx]);
#endif #endif
...@@ -337,16 +355,19 @@ __global__ void _SqrtHalf2( ...@@ -337,16 +355,19 @@ __global__ void _SqrtHalf2(
template <> void Sqrt<float16, CUDAContext>( template <> void Sqrt<float16, CUDAContext>(
int n, int n,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_SqrtHalf2<half2> _SqrtHalf2<half2>
<< < CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
reinterpret_cast<const half2*>(x), reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_SqrtHalf<half> _SqrtHalf<half>
<< < CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
} }
...@@ -362,7 +383,7 @@ __global__ void _PowHalf( ...@@ -362,7 +383,7 @@ __global__ void _PowHalf(
const float alpha, const float alpha,
const half* a, const half* a,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul(a[idx], a[idx]); y[idx] = __hmul(a[idx], a[idx]);
#endif #endif
...@@ -375,7 +396,7 @@ __global__ void _PowHalf2( ...@@ -375,7 +396,7 @@ __global__ void _PowHalf2(
const float alpha, const float alpha,
const half2* a, const half2* a,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul2(a[idx], a[idx]); y[idx] = __hmul2(a[idx], a[idx]);
#endif #endif
...@@ -387,17 +408,20 @@ template <> void Pow<float16, CUDAContext>( ...@@ -387,17 +408,20 @@ template <> void Pow<float16, CUDAContext>(
int n, int n,
const float alpha, const float alpha,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
CHECK(alpha == float(2)) << "fp16 only support the power of 2"; CHECK(alpha == float(2)) << "fp16 only support the power of 2";
if (n % 2 == 0) { if ((n & 1) == 0) {
_PowHalf2<half2> _PowHalf2<half2>
<< < CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
alpha, reinterpret_cast<const half2*>(x), alpha, reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_PowHalf<half> _PowHalf<half>
<< < CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
alpha, reinterpret_cast<const half*>(x), alpha, reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
} }
...@@ -413,7 +437,7 @@ __global__ void _InvHalf( ...@@ -413,7 +437,7 @@ __global__ void _InvHalf(
const half numerator, const half numerator,
const half* x, const half* x,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul(hrcp(x[idx]), numerator); y[idx] = __hmul(hrcp(x[idx]), numerator);
#endif #endif
...@@ -426,7 +450,7 @@ __global__ void _InvHalf2( ...@@ -426,7 +450,7 @@ __global__ void _InvHalf2(
const half2 numerator, const half2 numerator,
const half2* x, const half2* x,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul2(h2rcp(x[idx]), numerator); y[idx] = __hmul2(h2rcp(x[idx]), numerator);
#endif #endif
...@@ -438,17 +462,20 @@ template <> void Inv<float16, CUDAContext>( ...@@ -438,17 +462,20 @@ template <> void Inv<float16, CUDAContext>(
const int n, const int n,
const float numerator, const float numerator,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_InvHalf2<half2> _InvHalf2<half2>
<< < CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
dragon_cast<half2, float>(numerator), dragon_cast<half2, float>(numerator),
reinterpret_cast<const half2*>(x), reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_InvHalf<half> _InvHalf<half>
<< < CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
dragon_cast<half, float>(numerator), dragon_cast<half, float>(numerator),
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -482,27 +509,26 @@ template <> void Scale<float16, CUDAContext>( ...@@ -482,27 +509,26 @@ template <> void Scale<float16, CUDAContext>(
const float16* x, const float16* x,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
CUDAContext::Copy<float16, CUDAContext, CUDAContext>(n, y, x); ctx->Copy<float16, CUDAContext, CUDAContext>(n, y, x);
Scal<float16, CUDAContext>(n, alpha, y, ctx); Scal<float16, CUDAContext>(n, alpha, y, ctx);
} }
template <> float Dot<float16, CUDAContext>( template <> void Dot<float16, CUDAContext>(
int n, int n,
const float16* a, const float16* a,
const float16* b, const float16* b,
float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
float16 result;
CUBLAS_CHECK(cublasDotEx( CUBLAS_CHECK(cublasDotEx(
ctx->cublas_handle(), n, ctx->cublas_handle(), n,
a, CUDA_R_16F, 1, a, CUDA_R_16F, 1,
b, CUDA_R_16F, 1, b, CUDA_R_16F, 1,
&result, CUDA_R_16F, y, CUDA_R_16F,
CUDA_R_32F)); CUDA_R_32F));
return dragon_cast<float, float16>(result); ctx->FinishDeviceCompution();
#else #else
CUDA_FP16_NOT_COMPILED; CUDA_FP16_NOT_COMPILED;
return 0.;
#endif #endif
} }
...@@ -512,7 +538,7 @@ __global__ void _AddScalarHalf( ...@@ -512,7 +538,7 @@ __global__ void _AddScalarHalf(
const int n, const int n,
half alpha, half alpha,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hadd(y[idx], alpha); y[idx] = __hadd(y[idx], alpha);
#endif #endif
...@@ -524,7 +550,7 @@ __global__ void _AddScalarHalf2( ...@@ -524,7 +550,7 @@ __global__ void _AddScalarHalf2(
const int n, const int n,
half2 alpha, half2 alpha,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hadd2(y[idx], alpha); y[idx] = __hadd2(y[idx], alpha);
#endif #endif
...@@ -535,16 +561,19 @@ __global__ void _AddScalarHalf2( ...@@ -535,16 +561,19 @@ __global__ void _AddScalarHalf2(
template <> void AddScalar<float16, CUDAContext>( template <> void AddScalar<float16, CUDAContext>(
const int n, const int n,
const float alpha, const float alpha,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_AddScalarHalf2<half2> _AddScalarHalf2<half2>
<< <CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
dragon_cast<half2, float>(alpha), dragon_cast<half2, float>(alpha),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_AddScalarHalf<half> _AddScalarHalf<half>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
dragon_cast<half, float>(alpha), dragon_cast<half, float>(alpha),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
} }
...@@ -559,7 +588,7 @@ __global__ void _MulScalarHalf( ...@@ -559,7 +588,7 @@ __global__ void _MulScalarHalf(
const int n, const int n,
half alpha, half alpha,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul(y[idx], alpha); y[idx] = __hmul(y[idx], alpha);
#endif #endif
...@@ -571,7 +600,7 @@ __global__ void _MulScalarHalf2( ...@@ -571,7 +600,7 @@ __global__ void _MulScalarHalf2(
const int n, const int n,
half2 alpha, half2 alpha,
half2* y) { half2* y) {
CUDA_KERNEL_LOOP(idx, n) { CUDA_1D_KERNEL_LOOP(idx, n) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hmul2(y[idx], alpha); y[idx] = __hmul2(y[idx], alpha);
#endif #endif
...@@ -582,16 +611,19 @@ __global__ void _MulScalarHalf2( ...@@ -582,16 +611,19 @@ __global__ void _MulScalarHalf2(
template <> void MulScalar<float16, CUDAContext>( template <> void MulScalar<float16, CUDAContext>(
const int n, const int n,
const float alpha, const float alpha,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (n % 2 == 0) { if ((n & 1) == 0) {
_MulScalarHalf2<half2> _MulScalarHalf2<half2>
<< <CUDA_BLOCKS(n / 2), CUDA_THREADS >> >(n / 2, << < CUDA_BLOCKS(n >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n >> 1,
dragon_cast<half2, float>(alpha), dragon_cast<half2, float>(alpha),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_MulScalarHalf<half> _MulScalarHalf<half>
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >(n, << < CUDA_BLOCKS(n), CUDA_THREADS,
0, ctx->cuda_stream() >> >(n,
dragon_cast<half, float>(alpha), dragon_cast<half, float>(alpha),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
} }
...@@ -640,11 +672,12 @@ template <> void RandomUniform<float16, CUDAContext>( ...@@ -640,11 +672,12 @@ template <> void RandomUniform<float16, CUDAContext>(
CURAND_CHECK(curandGenerateUniform( CURAND_CHECK(curandGenerateUniform(
ctx->curand_generator(), xf32, n)); ctx->curand_generator(), xf32, n));
_TypeFloat2Half _TypeFloat2Half
<< <CUDA_BLOCKS(n), CUDA_THREADS >> >( << < CUDA_BLOCKS(n), CUDA_THREADS,
n, xf32, reinterpret_cast<half*>(x)); 0, ctx->cuda_stream() >> >(n,
xf32, reinterpret_cast<half*>(x));
float range = high - low; float range = high - low;
if (range != float(1)) Scal<float16, CUDAContext>(n, range, x, ctx); if (range != 1.f) Scal<float16, CUDAContext>(n, range, x, ctx);
if (low != float(0)) AddScalar<float16, CUDAContext>(n, low, x); if (low != 0.f) AddScalar<float16, CUDAContext>(n, low, x, ctx);
ctx->Delete(xf32); ctx->Delete(xf32);
#else #else
CUDA_FP16_NOT_COMPILED; CUDA_FP16_NOT_COMPILED;
......
...@@ -53,7 +53,8 @@ template<> void Elu<float, CPUContext>( ...@@ -53,7 +53,8 @@ template<> void Elu<float, CPUContext>(
const int count, const int count,
const float alpha, const float alpha,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -68,7 +69,8 @@ template<> void EluGrad<float, CPUContext>( ...@@ -68,7 +69,8 @@ template<> void EluGrad<float, CPUContext>(
const float alpha, const float alpha,
const float* dy, const float* dy,
const float* y, const float* y,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -89,7 +91,8 @@ template<> void PRelu<float, CPUContext>( ...@@ -89,7 +91,8 @@ template<> void PRelu<float, CPUContext>(
const string& data_format, const string& data_format,
const float* x, const float* x,
const float* w, const float* w,
float* y) { float* y,
CPUContext* ctx) {
if (channel_shared) { if (channel_shared) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
...@@ -130,7 +133,8 @@ template<> void PReluGrad<float, CPUContext>( ...@@ -130,7 +133,8 @@ template<> void PReluGrad<float, CPUContext>(
const float* dy, const float* dy,
const float* x, const float* x,
const float* w, const float* w,
float* dx) { float* dx,
CPUContext* ctx) {
if (channel_shared) { if (channel_shared) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
...@@ -184,9 +188,10 @@ template<> void PReluWGrad<float, CPUContext>( ...@@ -184,9 +188,10 @@ template<> void PReluWGrad<float, CPUContext>(
} }
} }
if (channel_shared) { if (channel_shared) {
float w_sum = math::Dot<float, CPUContext>( float w_sum;
channels * dim, bcast_dw, multiplier, ctx); math::Dot<float, CPUContext>(channels * dim,
math::AddScalar<float, CPUContext>(1, w_sum, dw); bcast_dw, multiplier, &w_sum, ctx);
math::AddScalar<float, CPUContext>(1, w_sum, dw, ctx);
} else { } else {
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemv<float, CPUContext>( math::Gemv<float, CPUContext>(
...@@ -208,7 +213,8 @@ template<> void Relu<float, CPUContext>( ...@@ -208,7 +213,8 @@ template<> void Relu<float, CPUContext>(
const int count, const int count,
const float slope, const float slope,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -221,7 +227,8 @@ template<> void Relu<float16, CPUContext>( ...@@ -221,7 +227,8 @@ template<> void Relu<float16, CPUContext>(
const int count, const int count,
const float slope, const float slope,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -230,7 +237,8 @@ template<> void ReluGrad<float, CPUContext>( ...@@ -230,7 +237,8 @@ template<> void ReluGrad<float, CPUContext>(
const float slope, const float slope,
const float* dy, const float* dy,
const float* y, const float* y,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -244,7 +252,8 @@ template<> void ReluGrad<float, CPUContext>( ...@@ -244,7 +252,8 @@ template<> void ReluGrad<float, CPUContext>(
template<> void SElu<float, CPUContext>( template<> void SElu<float, CPUContext>(
const int count, const int count,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -258,7 +267,8 @@ template<> void SEluGrad<float, CPUContext>( ...@@ -258,7 +267,8 @@ template<> void SEluGrad<float, CPUContext>(
const int count, const int count,
const float* dy, const float* dy,
const float* y, const float* y,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -276,7 +286,8 @@ T _sigmoid(T x) { return T(1) / (T(1) + exp(-x)); } ...@@ -276,7 +286,8 @@ T _sigmoid(T x) { return T(1) / (T(1) + exp(-x)); }
template<> void Sigmoid<float, CPUContext>( template<> void Sigmoid<float, CPUContext>(
const int count, const int count,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -287,7 +298,8 @@ template<> void SigmoidGrad<float, CPUContext>( ...@@ -287,7 +298,8 @@ template<> void SigmoidGrad<float, CPUContext>(
const int count, const int count,
const float* dy, const float* dy,
const float* y, const float* y,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -310,7 +322,7 @@ template<> void Softmax<float, CPUContext>( ...@@ -310,7 +322,7 @@ template<> void Softmax<float, CPUContext>(
CPUContext* ctx) { CPUContext* ctx) {
const int dim = count / outer_dim; const int dim = count / outer_dim;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, scale, x + i*dim); inner_dim, scale, x + i*dim);
for (int j = 0; j < classes; ++j) { for (int j = 0; j < classes; ++j) {
for (int k = 0; k < inner_dim; k++) for (int k = 0; k < inner_dim; k++)
...@@ -322,13 +334,13 @@ template<> void Softmax<float, CPUContext>( ...@@ -322,13 +334,13 @@ template<> void Softmax<float, CPUContext>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
classes, inner_dim, 1, classes, inner_dim, 1,
-1.0, sum_multiplier, scale, 1.0, y, ctx); -1.0, sum_multiplier, scale, 1.0, y, ctx);
math::Exp<float, CPUContext>(dim, y, y); math::Exp<float, CPUContext>(dim, y, y, ctx);
math::Gemv<float, CPUContext>( math::Gemv<float, CPUContext>(
CblasTrans, classes, inner_dim, CblasTrans, classes, inner_dim,
1.0, y, sum_multiplier, 1.0, y, sum_multiplier,
0.0, scale, ctx); 0.0, scale, ctx);
for (int j = 0; j < classes; ++j) { for (int j = 0; j < classes; ++j) {
math::Div<float, CPUContext>(inner_dim, y, scale, y); math::Div<float, CPUContext>(inner_dim, y, scale, y, ctx);
y += inner_dim; y += inner_dim;
} }
} }
...@@ -348,17 +360,16 @@ template<> void SoftmaxGrad<float, CPUContext>( ...@@ -348,17 +360,16 @@ template<> void SoftmaxGrad<float, CPUContext>(
const int dim = count / outer_dim; const int dim = count / outer_dim;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int k = 0; k < inner_dim; ++k) for (int k = 0; k < inner_dim; ++k)
scale[k] = math::StridedDot<float, CPUContext>( math::StridedDot<float, CPUContext>(classes,
classes,
dx + i * dim + k, inner_dim, dx + i * dim + k, inner_dim,
y + i*dim + k, inner_dim, ctx); y + i * dim + k, inner_dim, scale + k, ctx);
math::Gemm<float, CPUContext>( math::Gemm<float, CPUContext>(
CblasNoTrans, CblasNoTrans, CblasNoTrans, CblasNoTrans,
classes, inner_dim, 1, classes, inner_dim, 1,
-1.0, sum_multiplier, scale, -1.0, sum_multiplier, scale,
1.0, dx + i * dim, ctx); 1.0, dx + i * dim, ctx);
} }
math::Mul<float, CPUContext>(count, dx, y, dx); math::Mul<float, CPUContext>(count, dx, y, dx, ctx);
} }
/******************** activation.tanh ********************/ /******************** activation.tanh ********************/
...@@ -366,7 +377,8 @@ template<> void SoftmaxGrad<float, CPUContext>( ...@@ -366,7 +377,8 @@ template<> void SoftmaxGrad<float, CPUContext>(
template<> void Tanh<float, CPUContext>( template<> void Tanh<float, CPUContext>(
const int count, const int count,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -379,7 +391,8 @@ template<> void TanhGrad<float, CPUContext>( ...@@ -379,7 +391,8 @@ template<> void TanhGrad<float, CPUContext>(
const int count, const int count,
const float* dy, const float* dy,
const float* y, const float* y,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -467,7 +480,8 @@ template <> void Clip<float, CPUContext>( ...@@ -467,7 +480,8 @@ template <> void Clip<float, CPUContext>(
const float high, const float high,
const float* x, const float* x,
float* mask, float* mask,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -484,7 +498,8 @@ template <> void Equal<float, CPUContext>( ...@@ -484,7 +498,8 @@ template <> void Equal<float, CPUContext>(
const int count, const int count,
const float* a, const float* a,
const float* b, const float* b,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -497,7 +512,8 @@ template <> void Equal<float, CPUContext>( ...@@ -497,7 +512,8 @@ template <> void Equal<float, CPUContext>(
template<> void AbsGrad<float, CPUContext>( template<> void AbsGrad<float, CPUContext>(
const int count, const int count,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -651,7 +667,8 @@ template<> void SmoothL1<float, CPUContext>( ...@@ -651,7 +667,8 @@ template<> void SmoothL1<float, CPUContext>(
const int count, const int count,
const float beta, const float beta,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -667,7 +684,8 @@ template<> void SmoothL1Grad<float, CPUContext>( ...@@ -667,7 +684,8 @@ template<> void SmoothL1Grad<float, CPUContext>(
const int count, const int count,
const float beta, const float beta,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -686,7 +704,8 @@ template <> void SoftmaxCrossEntropy<float, CPUContext>( ...@@ -686,7 +704,8 @@ template <> void SoftmaxCrossEntropy<float, CPUContext>(
const int count, const int count,
const float* prob, const float* prob,
const float* target, const float* target,
float* loss) { float* loss,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -834,6 +853,20 @@ template <> void SparseSoftmaxCrossEntropy<float, float, CPUContext>( ...@@ -834,6 +853,20 @@ template <> void SparseSoftmaxCrossEntropy<float, float, CPUContext>(
losses, flags); losses, flags);
} }
template <> void SparseSoftmaxCrossEntropy<float16, float, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const float* labels,
const int* ignores,
const int num_ignores,
float* losses,
float* flags,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <> void SparseSoftmaxCrossEntropy<float, int64_t, CPUContext>( template <> void SparseSoftmaxCrossEntropy<float, int64_t, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim, const int axis_dim,
...@@ -851,6 +884,20 @@ template <> void SparseSoftmaxCrossEntropy<float, int64_t, CPUContext>( ...@@ -851,6 +884,20 @@ template <> void SparseSoftmaxCrossEntropy<float, int64_t, CPUContext>(
losses, flags); losses, flags);
} }
template <> void SparseSoftmaxCrossEntropy<float16, int64_t, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const int64_t* labels,
const int* ignores,
const int num_ignores,
float* losses,
float* flags,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename Tx, typename Ty> template <typename Tx, typename Ty>
void _SparseSoftmaxCrossEntropyGrad( void _SparseSoftmaxCrossEntropyGrad(
const int outer_dim, const int outer_dim,
...@@ -897,6 +944,20 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, float, CPUContext>( ...@@ -897,6 +944,20 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, float, CPUContext>(
num_ignores, dx, flags); num_ignores, dx, flags);
} }
template<> void SparseSoftmaxCrossEntropyGrad<float16, float, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const float* labels,
const int* ignores,
const int num_ignores,
float16* dx,
float* flags,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CPUContext>( template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CPUContext>(
const int outer_dim, const int outer_dim,
const int axis_dim, const int axis_dim,
...@@ -914,6 +975,20 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CPUContext>( ...@@ -914,6 +975,20 @@ template<> void SparseSoftmaxCrossEntropyGrad<float, int64_t, CPUContext>(
num_ignores, dx, flags); num_ignores, dx, flags);
} }
template<> void SparseSoftmaxCrossEntropyGrad<float16, int64_t, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const int64_t* labels,
const int* ignores,
const int num_ignores,
float16* dx,
float* flags,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
/******************** misc.astype ********************/ /******************** misc.astype ********************/
template <typename Ta, typename Tb> template <typename Ta, typename Tb>
...@@ -936,7 +1011,8 @@ void _TypeA2B_v2(const int count, const Ta* a, Tb* b) { ...@@ -936,7 +1011,8 @@ void _TypeA2B_v2(const int count, const Ta* a, Tb* b) {
template <> void TypeA2B<type_a, type_b, CPUContext>( \ template <> void TypeA2B<type_a, type_b, CPUContext>( \
const int count, \ const int count, \
const type_a* a, \ const type_a* a, \
type_b* b) { \ type_b* b, \
CPUContext* ctx) { \
_TypeA2B<type_a, type_b>(count, a, b); \ _TypeA2B<type_a, type_b>(count, a, b); \
} }
...@@ -944,7 +1020,8 @@ void _TypeA2B_v2(const int count, const Ta* a, Tb* b) { ...@@ -944,7 +1020,8 @@ void _TypeA2B_v2(const int count, const Ta* a, Tb* b) {
template <> void TypeA2B<type_a, type_b, CPUContext>( \ template <> void TypeA2B<type_a, type_b, CPUContext>( \
const int count, \ const int count, \
const type_a* a, \ const type_a* a, \
type_b* b) { \ type_b* b, \
CPUContext* ctx) { \
_TypeA2B_v2<type_a, type_b>(count, a, b); \ _TypeA2B_v2<type_a, type_b>(count, a, b); \
} }
...@@ -952,13 +1029,15 @@ void _TypeA2B_v2(const int count, const Ta* a, Tb* b) { ...@@ -952,13 +1029,15 @@ void _TypeA2B_v2(const int count, const Ta* a, Tb* b) {
template <> void TypeA2B<float16, type, CPUContext>( \ template <> void TypeA2B<float16, type, CPUContext>( \
const int count, \ const int count, \
const float16* a, \ const float16* a, \
type* b) { \ type* b, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \ CPU_FP16_NOT_SUPPORTED; \
} \ } \
template <> void TypeA2B<type, float16, CPUContext>( \ template <> void TypeA2B<type, float16, CPUContext>( \
const int count, \ const int count, \
const type* a, \ const type* a, \
float16* b) { \ float16* b, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \ CPU_FP16_NOT_SUPPORTED; \
} }
...@@ -1039,7 +1118,8 @@ template <> void ImageData<float, float, CPUContext>( ...@@ -1039,7 +1118,8 @@ template <> void ImageData<float, float, CPUContext>(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageData_NCHW<float, float>( _ImageData_NCHW<float, float>(
N, C, H, W, mean_values, std_values, x, y); N, C, H, W, mean_values, std_values, x, y);
...@@ -1059,7 +1139,8 @@ template <> void ImageData<uint8_t, float, CPUContext>( ...@@ -1059,7 +1139,8 @@ template <> void ImageData<uint8_t, float, CPUContext>(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const uint8_t* x, const uint8_t* x,
float* y) { float* y,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageData_NCHW<uint8_t, float>( _ImageData_NCHW<uint8_t, float>(
N, C, H, W, mean_values, std_values, x, y); N, C, H, W, mean_values, std_values, x, y);
...@@ -1079,7 +1160,8 @@ template <> void ImageData<float, float16, CPUContext>( ...@@ -1079,7 +1160,8 @@ template <> void ImageData<float, float16, CPUContext>(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const float* x, const float* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -1093,7 +1175,8 @@ template <> void ImageData<uint8_t, float16, CPUContext>( ...@@ -1093,7 +1175,8 @@ template <> void ImageData<uint8_t, float16, CPUContext>(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const uint8_t* x, const uint8_t* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -1103,7 +1186,8 @@ template<> void Arange<float, CPUContext>( ...@@ -1103,7 +1186,8 @@ template<> void Arange<float, CPUContext>(
const int count, const int count,
const int start, const int start,
const int step, const int step,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1114,7 +1198,8 @@ template<> void Arange<int, CPUContext>( ...@@ -1114,7 +1198,8 @@ template<> void Arange<int, CPUContext>(
const int count, const int count,
const int start, const int start,
const int step, const int step,
int* y) { int* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1130,7 +1215,8 @@ template<> void Argmax<float, CPUContext>( ...@@ -1130,7 +1215,8 @@ template<> void Argmax<float, CPUContext>(
const int top_k, const int top_k,
const float* x, const float* x,
int64_t* indices, int64_t* indices,
float* values) { float* values,
CPUContext* ctx) {
vector<pair<float, int> > vec(axis_dim); vector<pair<float, int> > vec(axis_dim);
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
...@@ -1158,7 +1244,8 @@ template<> void Argmin<float, CPUContext>( ...@@ -1158,7 +1244,8 @@ template<> void Argmin<float, CPUContext>(
const int top_k, const int top_k,
const float* x, const float* x,
int64_t* indices, int64_t* indices,
float* values) { float* values,
CPUContext* ctx) {
vector<pair<float, int> > vec(axis_dim); vector<pair<float, int> > vec(axis_dim);
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
...@@ -1182,7 +1269,8 @@ template<> void Argmin<float, CPUContext>( ...@@ -1182,7 +1269,8 @@ template<> void Argmin<float, CPUContext>(
template <> void CanonicalAxis<int, CPUContext>( template <> void CanonicalAxis<int, CPUContext>(
const int count, const int count,
const int dim, const int dim,
int* y) { int* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1198,7 +1286,8 @@ void _Gather( ...@@ -1198,7 +1286,8 @@ void _Gather(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const T* x, const T* x,
T* y) { T* y,
CPUContext* ctx) {
TIndex x_offset, y_offset, x_idx_offset, y_idx_offset; TIndex x_offset, y_offset, x_idx_offset, y_idx_offset;
for (int i = 0; i < y_slice_dim; ++i) { for (int i = 0; i < y_slice_dim; ++i) {
y_idx_offset = i; y_idx_offset = i;
...@@ -1206,7 +1295,7 @@ void _Gather( ...@@ -1206,7 +1295,7 @@ void _Gather(
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim; x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim; y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim;
CPUContext::Copy<T, CPUContext, CPUContext>( ctx->Copy<T, CPUContext, CPUContext>(
inner_dim, y + y_offset, x + x_offset); inner_dim, y + y_offset, x + x_offset);
} }
} }
...@@ -1220,9 +1309,10 @@ template <> void Gather<float, CPUContext>( ...@@ -1220,9 +1309,10 @@ template <> void Gather<float, CPUContext>(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
_Gather<float>(count, outer_dim, inner_dim, _Gather<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, x, y); x_slice_dim, y_slice_dim, indices, x, y, ctx);
} }
template <> void Gather<int, CPUContext>( template <> void Gather<int, CPUContext>(
...@@ -1233,9 +1323,10 @@ template <> void Gather<int, CPUContext>( ...@@ -1233,9 +1323,10 @@ template <> void Gather<int, CPUContext>(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const int* x, const int* x,
int* y) { int* y,
CPUContext* ctx) {
_Gather<int>(count, outer_dim, inner_dim, _Gather<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, x, y); x_slice_dim, y_slice_dim, indices, x, y, ctx);
} }
template <typename T> template <typename T>
...@@ -1247,7 +1338,8 @@ void _GatherGrad( ...@@ -1247,7 +1338,8 @@ void _GatherGrad(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const T* dy, const T* dy,
T* dx) { T* dx,
CPUContext* ctx) {
TIndex x_offset, y_offset, x_idx_offset, y_idx_offset; TIndex x_offset, y_offset, x_idx_offset, y_idx_offset;
for (int i = 0; i < y_slice_dim; ++i) { for (int i = 0; i < y_slice_dim; ++i) {
y_idx_offset = i; y_idx_offset = i;
...@@ -1256,7 +1348,7 @@ void _GatherGrad( ...@@ -1256,7 +1348,7 @@ void _GatherGrad(
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim; x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim; y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim;
math::Add<T, CPUContext>(inner_dim, math::Add<T, CPUContext>(inner_dim,
dy + y_offset, dx + x_offset, dx + x_offset); dy + y_offset, dx + x_offset, dx + x_offset, ctx);
} }
} }
} }
...@@ -1269,9 +1361,10 @@ template <> void GatherGrad<float, CPUContext>( ...@@ -1269,9 +1361,10 @@ template <> void GatherGrad<float, CPUContext>(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
_GatherGrad<float>(count, outer_dim, inner_dim, _GatherGrad<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, dy, dx); x_slice_dim, y_slice_dim, indices, dy, dx, ctx);
} }
template <> void GatherGrad<int, CPUContext>( template <> void GatherGrad<int, CPUContext>(
...@@ -1282,9 +1375,10 @@ template <> void GatherGrad<int, CPUContext>( ...@@ -1282,9 +1375,10 @@ template <> void GatherGrad<int, CPUContext>(
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int* indices,
const int* dy, const int* dy,
int* dx) { int* dx,
CPUContext* ctx) {
_GatherGrad<int>(count, outer_dim, inner_dim, _GatherGrad<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, dy, dx); x_slice_dim, y_slice_dim, indices, dy, dx, ctx);
} }
/******************** ndarray.concat ********************/ /******************** ndarray.concat ********************/
...@@ -1297,12 +1391,13 @@ template <> void Concat<float, CPUContext>( ...@@ -1297,12 +1391,13 @@ template <> void Concat<float, CPUContext>(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim; x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim; y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
x_concat_dim * inner_dim, y + y_offset, x + x_offset); x_concat_dim * inner_dim, y + y_offset, x + x_offset);
} }
} }
...@@ -1315,12 +1410,13 @@ template <> void Concat<float16, CPUContext>( ...@@ -1315,12 +1410,13 @@ template <> void Concat<float16, CPUContext>(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim; x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim; y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
CPUContext::Copy<float16, CPUContext, CPUContext>( ctx->Copy<float16, CPUContext, CPUContext>(
x_concat_dim * inner_dim, y + y_offset, x + x_offset); x_concat_dim * inner_dim, y + y_offset, x + x_offset);
} }
} }
...@@ -1333,12 +1429,13 @@ template <> void ConcatGrad<float, CPUContext>( ...@@ -1333,12 +1429,13 @@ template <> void ConcatGrad<float, CPUContext>(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim; x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim; y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
x_concat_dim * inner_dim, dx + x_offset, dy + y_offset); x_concat_dim * inner_dim, dx + x_offset, dy + y_offset);
} }
} }
...@@ -1351,12 +1448,13 @@ template <> void ConcatGrad<float16, CPUContext>( ...@@ -1351,12 +1448,13 @@ template <> void ConcatGrad<float16, CPUContext>(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const float16* dy, const float16* dy,
float16* dx) { float16* dx,
CPUContext* ctx) {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim; x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim; y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
CPUContext::Copy<float16, CPUContext, CPUContext>( ctx->Copy<float16, CPUContext, CPUContext>(
x_concat_dim * inner_dim, dx + x_offset, dy + y_offset); x_concat_dim * inner_dim, dx + x_offset, dy + y_offset);
} }
} }
...@@ -1371,7 +1469,8 @@ void _Crop1D( ...@@ -1371,7 +1469,8 @@ void _Crop1D(
const int inner_dim, const int inner_dim,
const int start, const int start,
const T* x, const T* x,
T* y) { T* y,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
...@@ -1381,7 +1480,7 @@ void _Crop1D( ...@@ -1381,7 +1480,7 @@ void _Crop1D(
const int o = idx / ex_dim; const int o = idx / ex_dim;
const T* x_ptr = x + (o * dim + ex_d + start) * inner_dim; const T* x_ptr = x + (o * dim + ex_d + start) * inner_dim;
T* y_ptr = y + (o * ex_dim + ex_d) * inner_dim; T* y_ptr = y + (o * ex_dim + ex_d) * inner_dim;
CPUContext::Copy<T, CPUContext, CPUContext>( ctx->Copy<T, CPUContext, CPUContext>(
inner_dim, y_ptr, x_ptr); inner_dim, y_ptr, x_ptr);
} }
} }
...@@ -1393,8 +1492,10 @@ template<> void Crop1D<int, CPUContext>( ...@@ -1393,8 +1492,10 @@ template<> void Crop1D<int, CPUContext>(
const int inner_dim, const int inner_dim,
const int start, const int start,
const int* x, const int* x,
int* y) { int* y,
_Crop1D<int>(count, dim, ex_dim, inner_dim, start, x, y); CPUContext* ctx) {
_Crop1D<int>(count, dim, ex_dim,
inner_dim, start, x, y, ctx);
} }
template<> void Crop1D<float, CPUContext>( template<> void Crop1D<float, CPUContext>(
...@@ -1404,8 +1505,10 @@ template<> void Crop1D<float, CPUContext>( ...@@ -1404,8 +1505,10 @@ template<> void Crop1D<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int start, const int start,
const float* x, const float* x,
float* y) { float* y,
_Crop1D<float>(count, dim, ex_dim, inner_dim, start, x, y); CPUContext* ctx) {
_Crop1D<float>(count, dim, ex_dim,
inner_dim, start, x, y, ctx);
} }
template <typename T> template <typename T>
...@@ -1417,7 +1520,8 @@ void _Crop1DGrad( ...@@ -1417,7 +1520,8 @@ void _Crop1DGrad(
const int start, const int start,
const int end, const int end,
const T* dy, const T* dy,
T* dx) { T* dx,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
...@@ -1430,7 +1534,7 @@ void _Crop1DGrad( ...@@ -1430,7 +1534,7 @@ void _Crop1DGrad(
for (int i = 0; i < inner_dim; ++i) dx_ptr[i] = 0; for (int i = 0; i < inner_dim; ++i) dx_ptr[i] = 0;
} else { } else {
const T* dy_ptr = dy + (o * ex_dim + d - start) * inner_dim; const T* dy_ptr = dy + (o * ex_dim + d - start) * inner_dim;
CPUContext::Copy<T, CPUContext, CPUContext>( ctx->Copy<T, CPUContext, CPUContext>(
inner_dim, dx_ptr, dy_ptr); inner_dim, dx_ptr, dy_ptr);
} }
} }
...@@ -1444,10 +1548,11 @@ template<> void Crop1DGrad<int, CPUContext>( ...@@ -1444,10 +1548,11 @@ template<> void Crop1DGrad<int, CPUContext>(
const int start, const int start,
const int end, const int end,
const int* dy, const int* dy,
int* dx) { int* dx,
CPUContext* ctx) {
_Crop1DGrad<int>( _Crop1DGrad<int>(
count, dim, ex_dim, inner_dim, count, dim, ex_dim, inner_dim,
start, end, dy, dx); start, end, dy, dx, ctx);
} }
template<> void Crop1DGrad<float, CPUContext>( template<> void Crop1DGrad<float, CPUContext>(
...@@ -1458,10 +1563,11 @@ template<> void Crop1DGrad<float, CPUContext>( ...@@ -1458,10 +1563,11 @@ template<> void Crop1DGrad<float, CPUContext>(
const int start, const int start,
const int end, const int end,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
_Crop1DGrad<float>( _Crop1DGrad<float>(
count, dim, ex_dim, inner_dim, count, dim, ex_dim, inner_dim,
start, end, dy, dx); start, end, dy, dx, ctx);
} }
/******************** ndarray.pad ********************/ /******************** ndarray.pad ********************/
...@@ -1474,7 +1580,8 @@ template <> void ConstPad1D<float, CPUContext>( ...@@ -1474,7 +1580,8 @@ template <> void ConstPad1D<float, CPUContext>(
const int pad_l, const int pad_l,
const float value, const float value,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
...@@ -1488,7 +1595,7 @@ template <> void ConstPad1D<float, CPUContext>( ...@@ -1488,7 +1595,7 @@ template <> void ConstPad1D<float, CPUContext>(
for (int i = 0; i < inner_dim; ++i) y_ptr[i] = value; for (int i = 0; i < inner_dim; ++i) y_ptr[i] = value;
} else { } else {
const float* x_ptr = x + (o * dim + d) * inner_dim; const float* x_ptr = x + (o * dim + d) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, y_ptr, x_ptr); inner_dim, y_ptr, x_ptr);
} }
} }
...@@ -1501,7 +1608,8 @@ template <> void ReflectPad1D<float, CPUContext>( ...@@ -1501,7 +1608,8 @@ template <> void ReflectPad1D<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
...@@ -1518,7 +1626,7 @@ template <> void ReflectPad1D<float, CPUContext>( ...@@ -1518,7 +1626,7 @@ template <> void ReflectPad1D<float, CPUContext>(
y_ptr[i] = x[(o * dim + d) * inner_dim + i]; y_ptr[i] = x[(o * dim + d) * inner_dim + i];
} else { } else {
const float* x_ptr = x + (o * dim + d) * inner_dim; const float* x_ptr = x + (o * dim + d) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, y_ptr, x_ptr); inner_dim, y_ptr, x_ptr);
} }
} }
...@@ -1531,7 +1639,8 @@ template <> void EdgePad1D<float, CPUContext>( ...@@ -1531,7 +1639,8 @@ template <> void EdgePad1D<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
...@@ -1546,7 +1655,7 @@ template <> void EdgePad1D<float, CPUContext>( ...@@ -1546,7 +1655,7 @@ template <> void EdgePad1D<float, CPUContext>(
y_ptr[i] = x[(o * dim + d) * inner_dim + i]; y_ptr[i] = x[(o * dim + d) * inner_dim + i];
} else { } else {
const float* x_ptr = x + (o * dim + d) * inner_dim; const float* x_ptr = x + (o * dim + d) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, y_ptr, x_ptr); inner_dim, y_ptr, x_ptr);
} }
} }
...@@ -1559,7 +1668,8 @@ template <> void ConstPad1DGrad<float, CPUContext>( ...@@ -1559,7 +1668,8 @@ template <> void ConstPad1DGrad<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
...@@ -1570,7 +1680,7 @@ template <> void ConstPad1DGrad<float, CPUContext>( ...@@ -1570,7 +1680,7 @@ template <> void ConstPad1DGrad<float, CPUContext>(
const int ex_d = d + pad_l; const int ex_d = d + pad_l;
const float* dy_ptr = dy + (o * ex_dim + ex_d) * inner_dim; const float* dy_ptr = dy + (o * ex_dim + ex_d) * inner_dim;
float* dx_ptr = dx + (o * dim + d) * inner_dim; float* dx_ptr = dx + (o * dim + d) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, dx_ptr, dy_ptr); inner_dim, dx_ptr, dy_ptr);
} }
} }
...@@ -1582,7 +1692,8 @@ template <> void ReflectPad1DGrad<float, CPUContext>( ...@@ -1582,7 +1692,8 @@ template <> void ReflectPad1DGrad<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
for (int idx = 0; idx < count; ++idx) { for (int idx = 0; idx < count; ++idx) {
const int i = idx % inner_dim; const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim; const int ex_d = (idx / inner_dim) % ex_dim;
...@@ -1601,7 +1712,8 @@ template <> void EdgePad1DGrad<float, CPUContext>( ...@@ -1601,7 +1712,8 @@ template <> void EdgePad1DGrad<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int pad_l, const int pad_l,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
const int count_v2 = count / inner_dim; const int count_v2 = count / inner_dim;
for (int idx = 0; idx < count_v2; ++idx) { for (int idx = 0; idx < count_v2; ++idx) {
const int ex_d = idx % ex_dim; const int ex_d = idx % ex_dim;
...@@ -1613,7 +1725,7 @@ template <> void EdgePad1DGrad<float, CPUContext>( ...@@ -1613,7 +1725,7 @@ template <> void EdgePad1DGrad<float, CPUContext>(
dx[(o * dim + d) * inner_dim + i] += dy_ptr[i]; dx[(o * dim + d) * inner_dim + i] += dy_ptr[i];
} else { } else {
float* dx_ptr = dx + (o * dim + d) * inner_dim; float* dx_ptr = dx + (o * dim + d) * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, dx_ptr, dy_ptr); inner_dim, dx_ptr, dy_ptr);
} }
} }
...@@ -1626,7 +1738,8 @@ template <> void OneHot<float, CPUContext>( ...@@ -1626,7 +1738,8 @@ template <> void OneHot<float, CPUContext>(
const int depth, const int depth,
const int on_value, const int on_value,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1643,7 +1756,8 @@ template<> void Sum<float, CPUContext>( ...@@ -1643,7 +1756,8 @@ template<> void Sum<float, CPUContext>(
const int axis_dim, const int axis_dim,
const int inner_dim, const int inner_dim,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1662,7 +1776,8 @@ template<> void SumGrad<float, CPUContext>( ...@@ -1662,7 +1776,8 @@ template<> void SumGrad<float, CPUContext>(
const int inner_dim, const int inner_dim,
const float coeff, const float coeff,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1682,14 +1797,15 @@ template <> void Repeat<float, CPUContext>( ...@@ -1682,14 +1797,15 @@ template <> void Repeat<float, CPUContext>(
const int inner_dim, const int inner_dim,
const int repeats, const int repeats,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < dim; ++j) { for (int j = 0; j < dim; ++j) {
for (int k = 0; k < repeats; ++k) { for (int k = 0; k < repeats; ++k) {
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, y, x); inner_dim, y, x);
y += inner_dim; y += inner_dim;
} }
...@@ -1709,7 +1825,7 @@ template <> void RepeatGrad<float, CPUContext>( ...@@ -1709,7 +1825,7 @@ template <> void RepeatGrad<float, CPUContext>(
CPUContext* ctx) { CPUContext* ctx) {
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < dim; ++j) { for (int j = 0; j < dim; ++j) {
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
inner_dim, dx, dy); inner_dim, dx, dy);
dy += inner_dim; dy += inner_dim;
for (int k = 1; k < repeats; ++k) { for (int k = 1; k < repeats; ++k) {
...@@ -1732,12 +1848,13 @@ template <> void Slice<float, CPUContext>( ...@@ -1732,12 +1848,13 @@ template <> void Slice<float, CPUContext>(
const int y_slice_dim, const int y_slice_dim,
const int slice_offset, const int slice_offset,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + slice_offset) * inner_dim; x_offset = (n * x_slice_dim + slice_offset) * inner_dim;
y_offset = n * y_slice_dim * inner_dim; y_offset = n * y_slice_dim * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
y_slice_dim * inner_dim, y + y_offset, x + x_offset); y_slice_dim * inner_dim, y + y_offset, x + x_offset);
} }
} }
...@@ -1750,12 +1867,13 @@ template <> void SliceGrad<float, CPUContext>( ...@@ -1750,12 +1867,13 @@ template <> void SliceGrad<float, CPUContext>(
const int y_slice_dim, const int y_slice_dim,
const int slice_offset, const int slice_offset,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
TIndex x_offset, y_offset; TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + slice_offset) * inner_dim; x_offset = (n * x_slice_dim + slice_offset) * inner_dim;
y_offset = n * y_slice_dim * inner_dim; y_offset = n * y_slice_dim * inner_dim;
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
y_slice_dim * inner_dim, dx + x_offset, dy + y_offset); y_slice_dim * inner_dim, dx + x_offset, dy + y_offset);
} }
} }
...@@ -1768,10 +1886,11 @@ template <> void Tile<float, CPUContext>( ...@@ -1768,10 +1886,11 @@ template <> void Tile<float, CPUContext>(
const int ex_inner_dim, const int ex_inner_dim,
const int multiple, const int multiple,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int t = 0; t < multiple; ++t) { for (int t = 0; t < multiple; ++t) {
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
ex_inner_dim, y, x); ex_inner_dim, y, x);
y += ex_inner_dim; y += ex_inner_dim;
} }
...@@ -1788,7 +1907,7 @@ template <> void TileGrad<float, CPUContext>( ...@@ -1788,7 +1907,7 @@ template <> void TileGrad<float, CPUContext>(
float* dx, float* dx,
CPUContext* ctx) { CPUContext* ctx) {
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
CPUContext::Copy<float, CPUContext, CPUContext>( ctx->Copy<float, CPUContext, CPUContext>(
ex_inner_dim, dx, dy); ex_inner_dim, dx, dy);
dy += ex_inner_dim; dy += ex_inner_dim;
for (int t = 1; t < multiple; ++t) { for (int t = 1; t < multiple; ++t) {
...@@ -1809,7 +1928,8 @@ template <> void Transpose<float, CPUContext>( ...@@ -1809,7 +1928,8 @@ template <> void Transpose<float, CPUContext>(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1831,7 +1951,8 @@ template <> void Transpose<float16, CPUContext>( ...@@ -1831,7 +1951,8 @@ template <> void Transpose<float16, CPUContext>(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const float16* x, const float16* x,
float16* y) { float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -1842,7 +1963,8 @@ template <> void TransposeGrad<float, CPUContext>( ...@@ -1842,7 +1963,8 @@ template <> void TransposeGrad<float, CPUContext>(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
...@@ -1864,7 +1986,8 @@ template <> void TransposeGrad<float16, CPUContext>( ...@@ -1864,7 +1986,8 @@ template <> void TransposeGrad<float16, CPUContext>(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const float16* dy, const float16* dy,
float16* dx) { float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -1877,7 +2000,8 @@ template <> void LSTMCell<float, CPUContext>( ...@@ -1877,7 +2000,8 @@ template <> void LSTMCell<float, CPUContext>(
const float* cx, const float* cx,
float* xact, float* xact,
float* c, float* c,
float* h) { float* h,
CPUContext* ctx) {
float i, f, o, c_; float i, f, o, c_;
int f_offset = C, o_offset = 2 * C, c_offset = 3 * C, x_offset = 4 * C; int f_offset = C, o_offset = 2 * C, c_offset = 3 * C, x_offset = 4 * C;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
...@@ -1903,7 +2027,8 @@ template <> void LSTMCellGrad<float, CPUContext>( ...@@ -1903,7 +2027,8 @@ template <> void LSTMCellGrad<float, CPUContext>(
const float* dc, const float* dc,
const float* dh, const float* dh,
float* dcx, float* dcx,
float* dx) { float* dx,
CPUContext* ctx) {
float i, f, o, g, tanh_c, dcx_sum_term; float i, f, o, g, tanh_c, dcx_sum_term;
int f_offset = C, int f_offset = C,
o_offset = 2 * C, o_offset = 2 * C,
...@@ -1964,7 +2089,8 @@ template <> void AdamUpdate<float, CPUContext>( ...@@ -1964,7 +2089,8 @@ template <> void AdamUpdate<float, CPUContext>(
const float eps, const float eps,
float* g, float* g,
float* m, float* m,
float* v) { float* v,
CPUContext* ctx) {
_AdamUpdate<float>(count, lr, beta1, beta2, eps, g, m, v); _AdamUpdate<float>(count, lr, beta1, beta2, eps, g, m, v);
} }
...@@ -1976,7 +2102,8 @@ template <> void AdamUpdate<float16, CPUContext>( ...@@ -1976,7 +2102,8 @@ template <> void AdamUpdate<float16, CPUContext>(
const float eps, const float eps,
float16* g, float16* g,
float16* m, float16* m,
float16* v) { float16* v,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -2004,7 +2131,8 @@ template <> void NesterovUpdate<float, CPUContext>( ...@@ -2004,7 +2131,8 @@ template <> void NesterovUpdate<float, CPUContext>(
const float lr, const float lr,
const float momentum, const float momentum,
float* g, float* g,
float* h) { float* h,
CPUContext* ctx) {
_NesterovUpdate<float>(count, lr, momentum, g, h); _NesterovUpdate<float>(count, lr, momentum, g, h);
} }
...@@ -2013,7 +2141,8 @@ template <> void NesterovUpdate<float16, CPUContext>( ...@@ -2013,7 +2141,8 @@ template <> void NesterovUpdate<float16, CPUContext>(
const float lr, const float lr,
const float momentum, const float momentum,
float16* g, float16* g,
float16* h) { float16* h,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -2043,7 +2172,8 @@ template <> void RMSPropUpdate<float, CPUContext>( ...@@ -2043,7 +2172,8 @@ template <> void RMSPropUpdate<float, CPUContext>(
const float decay, const float decay,
const float eps, const float eps,
float* g, float* g,
float* h) { float* h,
CPUContext* ctx) {
_RMSPropUpdate<float>(count, lr, decay, eps, g, h); _RMSPropUpdate<float>(count, lr, decay, eps, g, h);
} }
...@@ -2053,7 +2183,8 @@ template <> void RMSPropUpdate<float16, CPUContext>( ...@@ -2053,7 +2183,8 @@ template <> void RMSPropUpdate<float16, CPUContext>(
const float decay, const float decay,
const float eps, const float eps,
float16* g, float16* g,
float16* h) { float16* h,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -2080,7 +2211,8 @@ template <> void SGDUpdate<float, CPUContext>( ...@@ -2080,7 +2211,8 @@ template <> void SGDUpdate<float, CPUContext>(
const float lr, const float lr,
const float momentum, const float momentum,
float* g, float* g,
float* h) { float* h,
CPUContext* ctx) {
_SGDUpdate<float>(count, lr, momentum, g, h); _SGDUpdate<float>(count, lr, momentum, g, h);
} }
...@@ -2089,7 +2221,8 @@ template <> void SGDUpdate<float16, CPUContext>( ...@@ -2089,7 +2221,8 @@ template <> void SGDUpdate<float16, CPUContext>(
const float lr, const float lr,
const float momentum, const float momentum,
float16* g, float16* g,
float16* h) { float16* h,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -2217,7 +2350,8 @@ template <> void BilinearResize<float, CPUContext>( ...@@ -2217,7 +2350,8 @@ template <> void BilinearResize<float, CPUContext>(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
const float scale_h = (float)H / out_h; const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w; const float scale_w = (float)W / out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -2326,10 +2460,10 @@ template <> void BilinearResizeGrad<float, CPUContext>( ...@@ -2326,10 +2460,10 @@ template <> void BilinearResizeGrad<float, CPUContext>(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
const float scale_h = (float)H / out_h; const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w; const float scale_w = (float)W / out_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx);
if (data_format == "NCHW") { if (data_format == "NCHW") {
_BilinearResizeGrad_NCHW<float>( _BilinearResizeGrad_NCHW<float>(
N, C, H, W, out_h, out_w, N, C, H, W, out_h, out_w,
...@@ -2439,7 +2573,8 @@ template <> void Im2Col2d<float, CPUContext>( ...@@ -2439,7 +2573,8 @@ template <> void Im2Col2d<float, CPUContext>(
const int dilation_w, const int dilation_w,
const string& data_format, const string& data_format,
const float* im, const float* im,
float* col) { float* col,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
const int count = (C * col_h * col_w); const int count = (C * col_h * col_w);
_Im2Col2d_NCHW<float>( _Im2Col2d_NCHW<float>(
...@@ -2471,8 +2606,9 @@ void _Col2Im2d_NCHW( ...@@ -2471,8 +2606,9 @@ void _Col2Im2d_NCHW(
const int dilation_h, const int dilation_h,
const int dilation_w, const int dilation_w,
const T* col, const T* col,
T* im) { T* im,
math::Set<float, CPUContext>(C * H * W, 0, im); CPUContext* ctx) {
math::Set<float, CPUContext>(C * H * W, 0, im, ctx);
const int im_offset = H * W; const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_offset) { for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
...@@ -2512,8 +2648,9 @@ void _Col2Im2d_NHWC( ...@@ -2512,8 +2648,9 @@ void _Col2Im2d_NHWC(
const int dilation_h, const int dilation_h,
const int dilation_w, const int dilation_w,
const T* col, const T* col,
T* im) { T* im,
math::Set<float, CPUContext>(C * H * W, 0, im); CPUContext* ctx) {
math::Set<float, CPUContext>(C * H * W, 0, im, ctx);
for (int output_h = 0; output_h < col_h; ++output_h) { for (int output_h = 0; output_h < col_h; ++output_h) {
const int base_h = -pad_h + stride_h * output_h; const int base_h = -pad_h + stride_h * output_h;
for (int output_w = 0; output_w < col_w; ++output_w) { for (int output_w = 0; output_w < col_w; ++output_w) {
...@@ -2552,19 +2689,20 @@ template<> void Col2Im2d<float, CPUContext>( ...@@ -2552,19 +2689,20 @@ template<> void Col2Im2d<float, CPUContext>(
const int dilation_w, const int dilation_w,
const string& data_format, const string& data_format,
const float* col, const float* col,
float* im) { float* im,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
const int count = (C * H * W); const int count = (C * H * W);
_Col2Im2d_NCHW<float>( _Col2Im2d_NCHW<float>(
C, H, W, col_h, col_w, kernel_h, kernel_w, C, H, W, col_h, col_w, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, col, im); dilation_h, dilation_w, col, im, ctx);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
const int count = (H * W * C); const int count = (H * W * C);
_Col2Im2d_NHWC<float>( _Col2Im2d_NHWC<float>(
C, H, W, col_h, col_w, kernel_h, kernel_w, C, H, W, col_h, col_w, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, col, im); dilation_h, dilation_w, col, im, ctx);
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
} }
...@@ -2632,7 +2770,8 @@ template <> void NNResize<float, CPUContext>( ...@@ -2632,7 +2770,8 @@ template <> void NNResize<float, CPUContext>(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
const float scale_h = (float)H / out_h; const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w; const float scale_w = (float)W / out_w;
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -2708,10 +2847,10 @@ template <> void NNResizeGrad<float, CPUContext>( ...@@ -2708,10 +2847,10 @@ template <> void NNResizeGrad<float, CPUContext>(
const int out_w, const int out_w,
const string& data_format, const string& data_format,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
const float scale_h = (float)H / out_h; const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w; const float scale_w = (float)W / out_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx);
if (data_format == "NCHW") { if (data_format == "NCHW") {
_NNResizeGrad_NCHW<float>( _NNResizeGrad_NCHW<float>(
N, C, H, W, out_h, out_w, N, C, H, W, out_h, out_w,
...@@ -2847,7 +2986,8 @@ template<> void MAXPooling2d<float, CPUContext>( ...@@ -2847,7 +2986,8 @@ template<> void MAXPooling2d<float, CPUContext>(
const string& data_format, const string& data_format,
const float* x, const float* x,
int* mask, int* mask,
float* y) { float* y,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_MAXPooling2d_NCHW<float>( _MAXPooling2d_NCHW<float>(
N, C, H, W, pool_h, pool_w, kernel_h, kernel_w, N, C, H, W, pool_h, pool_w, kernel_h, kernel_w,
...@@ -2966,7 +3106,8 @@ template<> void AVGPooling2d<float, CPUContext>( ...@@ -2966,7 +3106,8 @@ template<> void AVGPooling2d<float, CPUContext>(
const int pad_w, const int pad_w,
const string& data_format, const string& data_format,
const float* x, const float* x,
float* y) { float* y,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_AVGPooling2d_NCHW<float>( _AVGPooling2d_NCHW<float>(
N, C, H, W, pool_h, pool_w, kernel_h, kernel_w, N, C, H, W, pool_h, pool_w, kernel_h, kernel_w,
...@@ -2994,10 +3135,11 @@ void _MAXPooling2dGrad_NCHW( ...@@ -2994,10 +3135,11 @@ void _MAXPooling2dGrad_NCHW(
const int pad_w, const int pad_w,
const float* dy, const float* dy,
const int* mask, const int* mask,
float* dx) { float* dx,
CPUContext* ctx) {
int x_offset = H * W; int x_offset = H * W;
int y_offset = pool_h * pool_w; int y_offset = pool_h * pool_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx); math::Set<float, CPUContext>(N * C * H * W, 0, dx, ctx);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) { for (int c = 0; c < C; ++c) {
for (int ph = 0; ph < pool_h; ++ph) { for (int ph = 0; ph < pool_h; ++ph) {
...@@ -3030,10 +3172,11 @@ void _MAXPooling2dGrad_NHWC( ...@@ -3030,10 +3172,11 @@ void _MAXPooling2dGrad_NHWC(
const int pad_w, const int pad_w,
const float* dy, const float* dy,
const int* mask, const int* mask,
float* dx) { float* dx,
CPUContext* ctx) {
int x_offset = H * W * C; int x_offset = H * W * C;
int y_offset = pool_h * pool_w * C; int y_offset = pool_h * pool_w * C;
math::Set<float, CPUContext>(N * H * W * C, 0, dx); math::Set<float, CPUContext>(N * H * W * C, 0, dx, ctx);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int ph = 0; ph < pool_h; ph++) { for (int ph = 0; ph < pool_h; ph++) {
for (int pw = 0; pw < pool_w; ++pw) { for (int pw = 0; pw < pool_w; ++pw) {
...@@ -3067,15 +3210,16 @@ template<> void MAXPooling2dGrad<float, CPUContext>( ...@@ -3067,15 +3210,16 @@ template<> void MAXPooling2dGrad<float, CPUContext>(
const string& data_format, const string& data_format,
const float* dy, const float* dy,
const int* mask, const int* mask,
float* dx) { float* dx,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_MAXPooling2dGrad_NCHW<float>( _MAXPooling2dGrad_NCHW<float>(
N, C, H, W, pool_h, pool_w, kernel_h, kernel_w, N, C, H, W, pool_h, pool_w, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dy, mask, dx); stride_h, stride_w, pad_h, pad_w, dy, mask, dx, ctx);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_MAXPooling2dGrad_NHWC<float>( _MAXPooling2dGrad_NHWC<float>(
N, C, H, W, pool_h, pool_w, kernel_h, kernel_w, N, C, H, W, pool_h, pool_w, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dy, mask, dx); stride_h, stride_w, pad_h, pad_w, dy, mask, dx, ctx);
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
} }
...@@ -3094,10 +3238,11 @@ void _AVGPooling2dGrad_NCHW( ...@@ -3094,10 +3238,11 @@ void _AVGPooling2dGrad_NCHW(
const int pad_h, const int pad_h,
const int pad_w, const int pad_w,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
int x_offset = H * W; int x_offset = H * W;
int y_offset = pool_h * pool_w; int y_offset = pool_h * pool_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx); math::Set<float, CPUContext>(N * C * H * W, 0, dx,ctx);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) { for (int c = 0; c < C; ++c) {
for (int ph = 0; ph < pool_h; ++ph) { for (int ph = 0; ph < pool_h; ++ph) {
...@@ -3141,10 +3286,11 @@ void _AVGPooling2dGrad_NHWC( ...@@ -3141,10 +3286,11 @@ void _AVGPooling2dGrad_NHWC(
const int pad_h, const int pad_h,
const int pad_w, const int pad_w,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
int x_offset = H * W * C; int x_offset = H * W * C;
int y_offset = pool_h * pool_w * C; int y_offset = pool_h * pool_w * C;
math::Set<float, CPUContext>(N * H * W * C, 0, dx); math::Set<float, CPUContext>(N * H * W * C, 0, dx, ctx);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int ph = 0; ph < pool_h; ph++) { for (int ph = 0; ph < pool_h; ph++) {
for (int pw = 0; pw < pool_w; ++pw) { for (int pw = 0; pw < pool_w; ++pw) {
...@@ -3187,15 +3333,16 @@ template<> void AVGPooling2dGrad<float, CPUContext>( ...@@ -3187,15 +3333,16 @@ template<> void AVGPooling2dGrad<float, CPUContext>(
const int pad_w, const int pad_w,
const string& data_format, const string& data_format,
const float* dy, const float* dy,
float* dx) { float* dx,
CPUContext* ctx) {
if (data_format == "NCHW") { if (data_format == "NCHW") {
_AVGPooling2dGrad_NCHW<float>( _AVGPooling2dGrad_NCHW<float>(
N, C, H, W, pool_h, pool_w, kernel_h, kernel_w, N, C, H, W, pool_h, pool_w, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dy, dx); stride_h, stride_w, pad_h, pad_w, dy, dx, ctx);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_AVGPooling2dGrad_NHWC<float>( _AVGPooling2dGrad_NHWC<float>(
N, C, H, W, pool_h, pool_w, kernel_h, kernel_w, N, C, H, W, pool_h, pool_w, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dy, dx); stride_h, stride_w, pad_h, pad_w, dy, dx, ctx);
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
} }
...@@ -3214,12 +3361,11 @@ template<> void ROIPooling<float, CPUContext>( ...@@ -3214,12 +3361,11 @@ template<> void ROIPooling<float, CPUContext>(
const float* x, const float* x,
const float* rois, const float* rois,
int* mask, int* mask,
float* y) { float* y,
CPUContext* ctx) {
const TIndex x_offset = H * W, const TIndex x_offset = H * W,
y_offset = pool_h * pool_w, y_offset = pool_h * pool_w,
im_offset = C * H * W; im_offset = C * H * W;
math::Set<float, CPUContext>(count, -FLT_MAX, y);
math::Set<int, CPUContext>(count, -1, mask);
for (int n = 0; n < num_rois; ++n) { for (int n = 0; n < num_rois; ++n) {
int im_idx = rois[0]; int im_idx = rois[0];
int x1 = round(rois[1] * spatial_scale); int x1 = round(rois[1] * spatial_scale);
...@@ -3248,10 +3394,10 @@ template<> void ROIPooling<float, CPUContext>( ...@@ -3248,10 +3394,10 @@ template<> void ROIPooling<float, CPUContext>(
end_w = std::min(end_w, W); end_w = std::min(end_w, W);
bool is_empty = (end_h == start_h) || (end_w == start_w); bool is_empty = (end_h == start_h) || (end_w == start_w);
const int pool_idx = ph * pool_w + pw; const int pool_idx = ph * pool_w + pw;
if (is_empty) { if (is_empty || im_idx < 0) y[pool_idx] = 0;
y[pool_idx] = 0; else y[pool_idx] = -FLT_MAX;
mask[pool_idx] = -1; mask[pool_idx] = -1;
} if (im_idx < 0) continue;
for (int h = start_h; h < end_h; ++h) { for (int h = start_h; h < end_h; ++h) {
for (int w = start_w; w < end_w; ++w) { for (int w = start_w; w < end_w; ++w) {
const int idx = h * W + w; const int idx = h * W + w;
...@@ -3286,7 +3432,8 @@ template<> void ROIPoolingGrad<float, CPUContext>( ...@@ -3286,7 +3432,8 @@ template<> void ROIPoolingGrad<float, CPUContext>(
const float* dy, const float* dy,
const float* rois, const float* rois,
const int* mask, const int* mask,
float* dx) { float* dx,
CPUContext* ctx) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -3305,7 +3452,8 @@ template<> void ROIAlign<float, CPUContext>( ...@@ -3305,7 +3452,8 @@ template<> void ROIAlign<float, CPUContext>(
const int sampling_ratio, const int sampling_ratio,
const float* x, const float* x,
const float* rois, const float* rois,
float* y) { float* y,
CPUContext* ctx) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -3322,7 +3470,8 @@ template<> void ROIAlignGrad<float, CPUContext>( ...@@ -3322,7 +3470,8 @@ template<> void ROIAlignGrad<float, CPUContext>(
const int sampling_ratio, const int sampling_ratio,
const float* dy, const float* dy,
const float* rois, const float* rois,
float* dx) { float* dx,
CPUContext* ctx) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
......
This diff could not be displayed because it is too large.
...@@ -23,7 +23,7 @@ __global__ void _ReluHalf( ...@@ -23,7 +23,7 @@ __global__ void _ReluHalf(
const half* x, const half* x,
half* y) { half* y) {
const half kZero = __float2half(0.f); const half kZero = __float2half(0.f);
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hgt(x[idx], kZero) ? y[idx] = __hgt(x[idx], kZero) ?
x[idx] : __hmul(x[idx], slope); x[idx] : __hmul(x[idx], slope);
...@@ -38,7 +38,7 @@ __global__ void _ReluHalf2( ...@@ -38,7 +38,7 @@ __global__ void _ReluHalf2(
const half2* x, const half2* x,
half2* y) { half2* y) {
const half2 kZero = __float2half2_rn(0.f); const half2 kZero = __float2half2_rn(0.f);
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
y[idx] = __hbgt2(x[idx], kZero) ? y[idx] = __hbgt2(x[idx], kZero) ?
x[idx] : __hmul2(x[idx], slope); x[idx] : __hmul2(x[idx], slope);
...@@ -51,17 +51,20 @@ template<> void Relu<float16, CUDAContext>( ...@@ -51,17 +51,20 @@ template<> void Relu<float16, CUDAContext>(
const int count, const int count,
const float slope, const float slope,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (count % 2 == 0) { if ((count & 1) == 0 == 0) {
_ReluHalf2<half2> _ReluHalf2<half2>
<< < CUDA_BLOCKS(count), CUDA_THREADS >> > (count / 2, << < CUDA_BLOCKS(count >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> > (count >> 1,
dragon_cast<half2, float>(slope), dragon_cast<half2, float>(slope),
reinterpret_cast<const half2*>(x), reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_ReluHalf<half> _ReluHalf<half>
<< < CUDA_BLOCKS(count), CUDA_THREADS >> >(count, << < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >(count,
dragon_cast<half, float>(slope), dragon_cast<half, float>(slope),
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -82,7 +85,7 @@ __global__ void _AffineWithOBiasHalf( ...@@ -82,7 +85,7 @@ __global__ void _AffineWithOBiasHalf(
const half* x, const half* x,
const half* alpha, const half* alpha,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
const int scale_idx = (idx / inner_dim) % scale_dim; const int scale_idx = (idx / inner_dim) % scale_dim;
y[idx] = __hmul(alpha[scale_idx], x[idx]); y[idx] = __hmul(alpha[scale_idx], x[idx]);
...@@ -99,7 +102,7 @@ __global__ void _AffineWithBiasHalf( ...@@ -99,7 +102,7 @@ __global__ void _AffineWithBiasHalf(
const half* alpha, const half* alpha,
const half* beta, const half* beta,
half* y) { half* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
const int scale_idx = (idx / inner_dim) % scale_dim; const int scale_idx = (idx / inner_dim) % scale_dim;
y[idx] = __hadd( y[idx] = __hadd(
...@@ -125,16 +128,18 @@ template<> void Affine<float16, CUDAContext>( ...@@ -125,16 +128,18 @@ template<> void Affine<float16, CUDAContext>(
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (beta != nullptr) { if (beta != nullptr) {
_AffineWithBiasHalf<float> _AffineWithBiasHalf<float>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, scale_dim, inner_dim, 0, ctx->cuda_stream() >> >(count,
scale_dim, inner_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(alpha), reinterpret_cast<const half*>(alpha),
reinterpret_cast<const half*>(beta), reinterpret_cast<const half*>(beta),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
} else { } else {
_AffineWithOBiasHalf<float> _AffineWithOBiasHalf<float>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, scale_dim, inner_dim, 0, ctx->cuda_stream() >> >(count,
scale_dim, inner_dim,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(alpha), reinterpret_cast<const half*>(alpha),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -144,6 +149,163 @@ template<> void Affine<float16, CUDAContext>( ...@@ -144,6 +149,163 @@ template<> void Affine<float16, CUDAContext>(
#endif #endif
} }
/******************** loss.sparse_softmax_cross_entropy ********************/
template <typename Ty>
__global__ void _SparseSoftmaxCrossEntropyHalf(
const int count,
const int axis_dim,
const int inner_dim,
const half* prob,
const Ty* labels,
const int* ignores,
const int num_ignores,
float* losses,
float* flags) {
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
const int oix = idx / inner_dim;
const int iix = idx % inner_dim;
const int label = labels[oix * inner_dim + iix];
int k;
for (k = 0; k < num_ignores; k++) {
if (label == ignores[k]) {
losses[idx] = flags[idx] = 0;
break;
}
}
if (k == num_ignores) {
const half kMIN = __float2half(HFLT_MIN);
half loss = __hneg(
hlog(
__hgt(prob[(oix * axis_dim + label)
* inner_dim + iix], kMIN) ?
prob[(oix * axis_dim + label)
* inner_dim + iix] : kMIN
)
);
losses[idx] = __half2float(loss);
flags[idx] = 1;
}
#endif
}
}
template <> void SparseSoftmaxCrossEntropy<float16, float, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const float* labels,
const int* ignores,
const int num_ignores,
float* losses,
float* flags,
CUDAContext* ctx) {
const int num_preds = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyHalf<float>
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >(
num_preds, axis_dim, inner_dim,
reinterpret_cast<const half*>(prob), labels,
ignores, num_ignores, losses, flags);
}
template <> void SparseSoftmaxCrossEntropy<float16, int64_t, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const int64_t* labels,
const int* ignores,
const int num_ignores,
float* losses,
float* flags,
CUDAContext* ctx) {
const int num_preds = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyHalf<int64_t>
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >(
num_preds, axis_dim, inner_dim,
reinterpret_cast<const half*>(prob), labels,
ignores, num_ignores, losses, flags);
}
template <typename Ty>
__global__ void _SparseSoftmaxCrossEntropyGradHalf(
const int count,
const int axis_dim,
const int inner_dim,
const half* prob,
const Ty* labels,
const int* ignores,
const int num_ignores,
half* dx,
float* flags) {
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
const int oix = idx / inner_dim;
const int iix = idx % inner_dim;
const int label = labels[oix * inner_dim + iix];
int k;
for (k = 0; k < num_ignores; k++)
if (label == ignores[k]) break;
if (k != num_ignores) {
for (int c = 0; c < axis_dim; c++)
dx[(oix * axis_dim + c) * inner_dim + iix]
= __float2half(0.f);
flags[idx] = 0;
} else {
const int x_idx = (oix * axis_dim + label) * inner_dim + iix;
dx[x_idx] = __hsub(dx[x_idx], __float2half(1.f));
flags[idx] = 1;
}
#endif
}
}
template<> void SparseSoftmaxCrossEntropyGrad<float16, float, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const float* labels,
const int* ignores,
const int num_ignores,
float16* dx,
float* flags,
CUDAContext* ctx) {
const int num_preds = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyGradHalf<float>
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >(
num_preds, axis_dim, inner_dim,
reinterpret_cast<const half*>(prob), labels,
ignores, num_ignores,
reinterpret_cast<half*>(dx), flags);
}
template<> void SparseSoftmaxCrossEntropyGrad<float16, int64_t, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const float16* prob,
const int64_t* labels,
const int* ignores,
const int num_ignores,
float16* dx,
float* flags,
CUDAContext* ctx) {
const int num_preds = outer_dim * inner_dim;
_SparseSoftmaxCrossEntropyGradHalf<int64_t>
<< < CUDA_BLOCKS(num_preds), CUDA_THREADS,
0, ctx->cuda_stream() >> >(
num_preds, axis_dim, inner_dim,
reinterpret_cast<const half*>(prob), labels,
ignores, num_ignores,
reinterpret_cast<half*>(dx), flags);
}
/******************** misc.astype ********************/ /******************** misc.astype ********************/
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -151,7 +313,7 @@ __global__ void _TypeHalf2Float( ...@@ -151,7 +313,7 @@ __global__ void _TypeHalf2Float(
const int count, const int count,
const half* a, const half* a,
float* b) { float* b) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
b[idx] = __half2float(a[idx]); b[idx] = __half2float(a[idx]);
} }
} }
...@@ -159,7 +321,7 @@ __global__ void _TypeFloat2Half( ...@@ -159,7 +321,7 @@ __global__ void _TypeFloat2Half(
const int count, const int count,
const float* a, const float* a,
half* b) { half* b) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
b[idx] = __float2half(a[idx]); b[idx] = __float2half(a[idx]);
} }
} }
...@@ -168,7 +330,7 @@ __global__ void _TypeHalf2Half( ...@@ -168,7 +330,7 @@ __global__ void _TypeHalf2Half(
const int count, const int count,
const half* a, const half* a,
half* b) { half* b) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
b[idx] = a[idx]; b[idx] = a[idx];
} }
} }
...@@ -178,14 +340,16 @@ __global__ void _TypeHalf2Half( ...@@ -178,14 +340,16 @@ __global__ void _TypeHalf2Half(
template <> void TypeA2B<float16, type, CUDAContext>( \ template <> void TypeA2B<float16, type, CUDAContext>( \
const int count, \ const int count, \
const float16* a, \ const float16* a, \
type* b) { \ type* b, \
CUDAContext* ctx) { \
LOG(FATAL) << "CUDAContext has not implemented: float16 -> " \ LOG(FATAL) << "CUDAContext has not implemented: float16 -> " \
<< TypeMetaToString(TypeMeta::Make<type>()); \ << TypeMetaToString(TypeMeta::Make<type>()); \
} \ } \
template <> void TypeA2B<type, float16, CUDAContext>( \ template <> void TypeA2B<type, float16, CUDAContext>( \
const int count, \ const int count, \
const type* a, \ const type* a, \
float16* b) { \ float16* b, \
CUDAContext* ctx) { \
LOG(FATAL) << "CUDAContext has not implemented: " \ LOG(FATAL) << "CUDAContext has not implemented: " \
<< TypeMetaToString(TypeMeta::Make<type>()) << " -> float16"; \ << TypeMetaToString(TypeMeta::Make<type>()) << " -> float16"; \
} }
...@@ -194,27 +358,33 @@ __global__ void _TypeHalf2Half( ...@@ -194,27 +358,33 @@ __global__ void _TypeHalf2Half(
template <> void TypeA2B<float16, float, CUDAContext>( \ template <> void TypeA2B<float16, float, CUDAContext>( \
const int count, \ const int count, \
const float16* a, \ const float16* a, \
float* b) { \ float* b, \
CUDAContext* ctx) { \
_TypeHalf2Float \ _TypeHalf2Float \
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( \ << < CUDA_BLOCKS(count), CUDA_THREADS, \
count, reinterpret_cast<const half*>(a), b); \ 0, ctx->cuda_stream() >> >(count, \
reinterpret_cast<const half*>(a), b); \
} \ } \
template <> void TypeA2B<float, float16, CUDAContext>( \ template <> void TypeA2B<float, float16, CUDAContext>( \
const int count, \ const int count, \
const float* a, \ const float* a, \
float16* b) { \ float16* b, \
CUDAContext* ctx) { \
_TypeFloat2Half \ _TypeFloat2Half \
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( \ << < CUDA_BLOCKS(count), CUDA_THREADS, \
count, a, reinterpret_cast<half*>(b)); \ 0, ctx->cuda_stream() >> >(count, \
a, reinterpret_cast<half*>(b)); \
} }
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
template <> void TypeA2B<float16, float16, CUDAContext>( template <> void TypeA2B<float16, float16, CUDAContext>(
const int count, const int count,
const float16* a, const float16* a,
float16* b) { float16* b,
CUDAContext* ctx) {
_TypeHalf2Half _TypeHalf2Half
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >(count, << < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >(count,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(a),
reinterpret_cast<half*>(b)); reinterpret_cast<half*>(b));
} }
...@@ -227,7 +397,8 @@ DEFINE_TYPE_DISABLE_FP16(uint8_t); ...@@ -227,7 +397,8 @@ DEFINE_TYPE_DISABLE_FP16(uint8_t);
template <> void TypeA2B<float16, float16, CUDAContext>( template <> void TypeA2B<float16, float16, CUDAContext>(
const int count, const int count,
const float16* a, const float16* a,
float16* b) { float16* b,
CUDAContext* ctx) {
LOG(FATAL) << "CUDAContext has not implemented: float16 -> float16"; LOG(FATAL) << "CUDAContext has not implemented: float16 -> float16";
} }
DEFINE_TYPE_DISABLE_FP16(float); DEFINE_TYPE_DISABLE_FP16(float);
...@@ -251,7 +422,7 @@ __global__ void _ImageDataHalf_NCHW( ...@@ -251,7 +422,7 @@ __global__ void _ImageDataHalf_NCHW(
const float* std_values, const float* std_values,
const Tx* x, const Tx* x,
Ty* y) { Ty* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
const int w = idx % W; const int w = idx % W;
const int h = (idx / W) % H; const int h = (idx / W) % H;
const int c = (idx / W / H) % C; const int c = (idx / W / H) % C;
...@@ -274,7 +445,7 @@ __global__ void _ImageDataHalf_NHWC( ...@@ -274,7 +445,7 @@ __global__ void _ImageDataHalf_NHWC(
const float* std_values, const float* std_values,
const Tx* x, const Tx* x,
Ty* y) { Ty* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
const int c = idx % C; const int c = idx % C;
float raw_value = x[idx]; float raw_value = x[idx];
if (mean_values) raw_value -= mean_values[c]; if (mean_values) raw_value -= mean_values[c];
...@@ -294,17 +465,20 @@ template <> void ImageData<float, float16, CUDAContext>( ...@@ -294,17 +465,20 @@ template <> void ImageData<float, float16, CUDAContext>(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const float* x, const float* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageDataHalf_NCHW<float, half> _ImageDataHalf_NCHW<float, half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, N, C, H, W, mean_values, std_values, 0, ctx->cuda_stream() >> >(count,
N, C, H, W, mean_values, std_values,
x, reinterpret_cast<half*>(y)); x, reinterpret_cast<half*>(y));
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_ImageDataHalf_NHWC<float, half> _ImageDataHalf_NHWC<float, half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, N, C, H, W, mean_values, std_values, 0, ctx->cuda_stream() >> >(count,
N, C, H, W, mean_values, std_values,
x, reinterpret_cast<half*>(y)); x, reinterpret_cast<half*>(y));
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
#else #else
...@@ -322,17 +496,20 @@ template <> void ImageData<uint8_t, float16, CUDAContext>( ...@@ -322,17 +496,20 @@ template <> void ImageData<uint8_t, float16, CUDAContext>(
const float* std_values, const float* std_values,
const string& data_format, const string& data_format,
const uint8_t* x, const uint8_t* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (data_format == "NCHW") { if (data_format == "NCHW") {
_ImageDataHalf_NCHW<uint8_t, half> _ImageDataHalf_NCHW<uint8_t, half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, N, C, H, W, mean_values, std_values, 0, ctx->cuda_stream() >> >(count,
N, C, H, W, mean_values, std_values,
x, reinterpret_cast<half*>(y)); x, reinterpret_cast<half*>(y));
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
_ImageDataHalf_NHWC<uint8_t, half> _ImageDataHalf_NHWC<uint8_t, half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, N, C, H, W, mean_values, std_values, 0, ctx->cuda_stream() >> >(count,
N, C, H, W, mean_values, std_values,
x, reinterpret_cast<half*>(y)); x, reinterpret_cast<half*>(y));
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
#else #else
...@@ -352,7 +529,7 @@ __global__ void _ConcatHalf( ...@@ -352,7 +529,7 @@ __global__ void _ConcatHalf(
const int concat_offset, const int concat_offset,
const T* x, const T* x,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
const int tmp = x_concat_dim * inner_dim; const int tmp = x_concat_dim * inner_dim;
const int outer_idx = idx / tmp; const int outer_idx = idx / tmp;
const int concat_idx = idx % tmp; const int concat_idx = idx % tmp;
...@@ -370,11 +547,13 @@ template <> void Concat<float16, CUDAContext>( ...@@ -370,11 +547,13 @@ template <> void Concat<float16, CUDAContext>(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_ConcatHalf<half> _ConcatHalf<half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, outer_dim, inner_dim, 0, ctx->cuda_stream() >> >(count,
outer_dim, inner_dim,
x_concat_dim, y_concat_dim, concat_offset, x_concat_dim, y_concat_dim, concat_offset,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
...@@ -393,7 +572,7 @@ __global__ void _ConcatGradHalf( ...@@ -393,7 +572,7 @@ __global__ void _ConcatGradHalf(
const int concat_offset, const int concat_offset,
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
const int tmp = x_concat_dim * inner_dim; const int tmp = x_concat_dim * inner_dim;
const int outer_idx = idx / tmp; const int outer_idx = idx / tmp;
const int concat_idx = idx % tmp; const int concat_idx = idx % tmp;
...@@ -411,11 +590,13 @@ template <> void ConcatGrad<float16, CUDAContext>( ...@@ -411,11 +590,13 @@ template <> void ConcatGrad<float16, CUDAContext>(
const int y_concat_dim, const int y_concat_dim,
const int concat_offset, const int concat_offset,
const float16* dy, const float16* dy,
float16* dx) { float16* dx,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_ConcatGradHalf<half> _ConcatGradHalf<half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, outer_dim, inner_dim, 0, ctx->cuda_stream() >> >(count,
outer_dim, inner_dim,
x_concat_dim, y_concat_dim, concat_offset, x_concat_dim, y_concat_dim, concat_offset,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
...@@ -435,7 +616,7 @@ __global__ void _TransposeHalf( ...@@ -435,7 +616,7 @@ __global__ void _TransposeHalf(
const int* new_steps, const int* new_steps,
const T* x, const T* x,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
int x_idx = 0, y_idx = idx; int x_idx = 0, y_idx = idx;
for (int j = 0; j < ndim; ++j) { for (int j = 0; j < ndim; ++j) {
int k = order[j]; int k = order[j];
...@@ -453,11 +634,13 @@ template <> void Transpose<float16, CUDAContext>( ...@@ -453,11 +634,13 @@ template <> void Transpose<float16, CUDAContext>(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const float16* x, const float16* x,
float16* y) { float16* y,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_TransposeHalf<half> _TransposeHalf<half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, ndim, order, old_steps, new_steps, 0, ctx->cuda_stream() >> >(count,
ndim, order, old_steps, new_steps,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y)); reinterpret_cast<half*>(y));
#else #else
...@@ -474,7 +657,7 @@ __global__ void _TransposeGradHalf( ...@@ -474,7 +657,7 @@ __global__ void _TransposeGradHalf(
const int* new_steps, const int* new_steps,
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(idx, count) {
int x_idx = 0, y_idx = idx; int x_idx = 0, y_idx = idx;
for (int j = 0; j < ndim; ++j) { for (int j = 0; j < ndim; ++j) {
int k = order[j]; int k = order[j];
...@@ -492,11 +675,13 @@ template <> void TransposeGrad<float16, CUDAContext>( ...@@ -492,11 +675,13 @@ template <> void TransposeGrad<float16, CUDAContext>(
const int* old_steps, const int* old_steps,
const int* new_steps, const int* new_steps,
const float16* dy, const float16* dy,
float16* dx) { float16* dx,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_TransposeGradHalf<half> _TransposeGradHalf<half>
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, ndim, order, old_steps, new_steps, 0, ctx->cuda_stream() >> >(count,
ndim, order, old_steps, new_steps,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
#else #else
...@@ -516,7 +701,7 @@ __global__ void _AdamUpdateHalf( ...@@ -516,7 +701,7 @@ __global__ void _AdamUpdateHalf(
half* g, half* g,
half* m, half* m,
half* v) { half* v) {
CUDA_KERNEL_LOOP(i, count) { CUDA_1D_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
half gi = g[i]; half gi = g[i];
half kOne = __float2half(1.f); half kOne = __float2half(1.f);
...@@ -545,11 +730,13 @@ template <> void AdamUpdate<float16, CUDAContext>( ...@@ -545,11 +730,13 @@ template <> void AdamUpdate<float16, CUDAContext>(
const float eps, const float eps,
float16* g, float16* g,
float16* m, float16* m,
float16* v) { float16* v,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_AdamUpdateHalf _AdamUpdateHalf
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, dragon_cast<half, float>(lr), 0, ctx->cuda_stream() >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(beta1), dragon_cast<half, float>(beta1),
dragon_cast<half, float>(beta2), dragon_cast<half, float>(beta2),
dragon_cast<half, float>(eps), dragon_cast<half, float>(eps),
...@@ -570,7 +757,7 @@ __global__ void _NesterovUpdateHalf( ...@@ -570,7 +757,7 @@ __global__ void _NesterovUpdateHalf(
const half momentum, const half momentum,
half* g, half* g,
half* h) { half* h) {
CUDA_KERNEL_LOOP(i, count) { CUDA_1D_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
half hi = h[i]; half hi = h[i];
half hi_new = h[i] = __hadd( half hi_new = h[i] = __hadd(
...@@ -592,7 +779,7 @@ __global__ void _NesterovUpdateHalf2( ...@@ -592,7 +779,7 @@ __global__ void _NesterovUpdateHalf2(
const half2 momentum, const half2 momentum,
half2* g, half2* g,
half2* h) { half2* h) {
CUDA_KERNEL_LOOP(i, count) { CUDA_1D_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
half2 hi = h[i]; half2 hi = h[i];
half2 hi_new = h[i] = __hadd2( half2 hi_new = h[i] = __hadd2(
...@@ -614,19 +801,22 @@ template <> void NesterovUpdate<float16, CUDAContext>( ...@@ -614,19 +801,22 @@ template <> void NesterovUpdate<float16, CUDAContext>(
const float lr, const float lr,
const float momentum, const float momentum,
float16* g, float16* g,
float16* h) { float16* h,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (count % 2 == 0) { if ((count & 1) == 0 == 0) {
_NesterovUpdateHalf2 _NesterovUpdateHalf2
<< <CUDA_BLOCKS(count / 2), CUDA_THREADS >> >( << < CUDA_BLOCKS(count >> 1), CUDA_THREADS,
count / 2, dragon_cast<half2, float>(lr), 0, ctx->cuda_stream() >> >(count >> 1,
dragon_cast<half2, float>(lr),
dragon_cast<half2, float>(momentum), dragon_cast<half2, float>(momentum),
reinterpret_cast<half2*>(g), reinterpret_cast<half2*>(g),
reinterpret_cast<half2*>(h)); reinterpret_cast<half2*>(h));
} else { } else {
_NesterovUpdateHalf _NesterovUpdateHalf
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, dragon_cast<half, float>(lr), 0, ctx->cuda_stream() >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(momentum), dragon_cast<half, float>(momentum),
reinterpret_cast<half*>(g), reinterpret_cast<half*>(g),
reinterpret_cast<half*>(h)); reinterpret_cast<half*>(h));
...@@ -646,7 +836,7 @@ __global__ void _RMSPropUpdateHalf( ...@@ -646,7 +836,7 @@ __global__ void _RMSPropUpdateHalf(
const half eps, const half eps,
half* g, half* g,
half* h) { half* h) {
CUDA_KERNEL_LOOP(i, count) { CUDA_1D_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
half gi = g[i]; half gi = g[i];
half kOne = __float2half(1.f); half kOne = __float2half(1.f);
...@@ -669,11 +859,13 @@ template <> void RMSPropUpdate<float16, CUDAContext>( ...@@ -669,11 +859,13 @@ template <> void RMSPropUpdate<float16, CUDAContext>(
const float decay, const float decay,
const float eps, const float eps,
float16* g, float16* g,
float16* h) { float16* h,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
_RMSPropUpdateHalf _RMSPropUpdateHalf
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, dragon_cast<half, float>(lr), 0, ctx->cuda_stream() >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(decay), dragon_cast<half, float>(decay),
dragon_cast<half, float>(eps), dragon_cast<half, float>(eps),
reinterpret_cast<half*>(g), reinterpret_cast<half*>(g),
...@@ -692,7 +884,7 @@ __global__ void _SGDUpdateHalf( ...@@ -692,7 +884,7 @@ __global__ void _SGDUpdateHalf(
const half momentum, const half momentum,
half* g, half* g,
half* h) { half* h) {
CUDA_KERNEL_LOOP(i, count) { CUDA_1D_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
half hi = h[i]; half hi = h[i];
g[i] = h[i] = __hadd( g[i] = h[i] = __hadd(
...@@ -709,7 +901,7 @@ __global__ void _SGDUpdateHalf2( ...@@ -709,7 +901,7 @@ __global__ void _SGDUpdateHalf2(
const half2 momentum, const half2 momentum,
half2* g, half2* g,
half2* h) { half2* h) {
CUDA_KERNEL_LOOP(i, count) { CUDA_1D_KERNEL_LOOP(i, count) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
half2 hi = h[i]; half2 hi = h[i];
g[i] = h[i] = __hadd2( g[i] = h[i] = __hadd2(
...@@ -726,19 +918,22 @@ template <> void SGDUpdate<float16, CUDAContext>( ...@@ -726,19 +918,22 @@ template <> void SGDUpdate<float16, CUDAContext>(
const float lr, const float lr,
const float momentum, const float momentum,
float16* g, float16* g,
float16* h) { float16* h,
CUDAContext* ctx) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
if (count % 2 == 0) { if ((count & 1) == 0 == 0) {
_SGDUpdateHalf2 _SGDUpdateHalf2
<< <CUDA_BLOCKS(count / 2), CUDA_THREADS >> >( << < CUDA_BLOCKS(count >> 1), CUDA_THREADS,
count / 2, dragon_cast<half2, float>(lr), 0, ctx->cuda_stream() >> >(count >> 1,
dragon_cast<half2, float>(lr),
dragon_cast<half2, float>(momentum), dragon_cast<half2, float>(momentum),
reinterpret_cast<half2*>(g), reinterpret_cast<half2*>(g),
reinterpret_cast<half2*>(h)); reinterpret_cast<half2*>(h));
} else { } else {
_SGDUpdateHalf _SGDUpdateHalf
<< <CUDA_BLOCKS(count), CUDA_THREADS >> >( << < CUDA_BLOCKS(count), CUDA_THREADS,
count, dragon_cast<half, float>(lr), 0, ctx->cuda_stream() >> >(count,
dragon_cast<half, float>(lr),
dragon_cast<half, float>(momentum), dragon_cast<half, float>(momentum),
reinterpret_cast<half*>(g), reinterpret_cast<half*>(g),
reinterpret_cast<half*>(h)); reinterpret_cast<half*>(h));
......
...@@ -162,7 +162,7 @@ template<> void Axpby( ...@@ -162,7 +162,7 @@ template<> void Axpby(
SSE_LOOP2(i, n) y[i] = alpha * x[i] + beta* y[i]; SSE_LOOP2(i, n) y[i] = alpha * x[i] + beta* y[i];
} }
template<> float ASum( template<> float Sum(
const int n, const int n,
const float* x) { const float* x) {
__m128 x1, sum = SSE_FP32_ZERO; __m128 x1, sum = SSE_FP32_ZERO;
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!