Commit 6720373b by Ting PAN

Implement RoIAlignV2

Summary:
This commit add an ’aligned’ flag for RoIAlignOp.
RoIAlignV2 is universally used under the continuous coordinate system.
1 parent 390d2035
......@@ -17,8 +17,5 @@ if (EXISTS "${THIRD_PARTY_DIR}/mpi/lib")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
set(MPI_LIBRARIES z)
if (UNIX AND (NOT APPLE))
set(MPI_LIBRARIES ${MPI_LIBRARIES} udev)
endif()
set(MPI_LIBRARIES_SHARED mpi)
set(MPI_LIBRARIES_STATIC mpi open-rte open-pal)
......@@ -53,6 +53,7 @@ void _RoiAlign(
const int num_rois,
const float spatial_scale,
const int sampling_ratio,
const bool aligned,
const T* x,
const float* rois,
T* y) {
......@@ -71,13 +72,16 @@ void _RoiAlign(
continue;
}
const float roi_wstart = roi[1] * spatial_scale;
const float roi_hstart = roi[2] * spatial_scale;
const float roi_wend = roi[3] * spatial_scale;
const float roi_hend = roi[4] * spatial_scale;
const float roi_offset = aligned ? 0.5f : 0.0f;
const float roi_wstart = roi[1] * spatial_scale - roi_offset;
const float roi_hstart = roi[2] * spatial_scale - roi_offset;
const float roi_wend = roi[3] * spatial_scale - roi_offset;
const float roi_hend = roi[4] * spatial_scale - roi_offset;
const float roi_w = std::max(roi_wend - roi_wstart, 1.f);
const float roi_h = std::max(roi_hend - roi_hstart, 1.f);
const float roi_w =
aligned ? roi_wend - roi_wstart : std::max(roi_wend - roi_wstart, 1.f);
const float roi_h =
aligned ? roi_hend - roi_hstart : std::max(roi_hend - roi_hstart, 1.f);
const float bin_h = roi_h / float(out_h);
const float bin_w = roi_w / float(out_w);
......@@ -87,7 +91,7 @@ void _RoiAlign(
const int grid_w = sampling_ratio > 0
? sampling_ratio
: int(std::ceil(roi_w / float(out_w)));
const T num_grids = T(grid_h * grid_w);
const T num_grids = std::max(T(grid_h * grid_w), T(1));
int yi;
T val;
......@@ -131,6 +135,7 @@ void RoiAlign<float16, CPUContext>(
const int num_rois,
const float spatial_scale,
const int sampling_ratio,
const bool aligned,
const float16* x,
const float* rois,
float16* y,
......@@ -149,6 +154,7 @@ void RoiAlign<float16, CPUContext>(
const int num_rois, \
const float spatial_scale, \
const int sampling_ratio, \
const bool aligned, \
const T* x, \
const float* rois, \
T* y, \
......@@ -162,6 +168,7 @@ void RoiAlign<float16, CPUContext>(
num_rois, \
spatial_scale, \
sampling_ratio, \
aligned, \
x, \
rois, \
y); \
......@@ -178,6 +185,7 @@ void RoiAlign<float16, CPUContext>(
const int num_rois, \
const float spatial_scale, \
const int sampling_ratio, \
const bool aligned, \
const T* dy, \
const float* rois, \
float* dx, \
......
......@@ -99,6 +99,7 @@ __global__ void _RoiAlign(
const int out_w,
const float spatial_scale,
const int sampling_ratio,
const bool aligned,
const T* x,
const float* rois,
T* y) {
......@@ -116,13 +117,16 @@ __global__ void _RoiAlign(
continue;
}
const float roi_wstart = roi[1] * spatial_scale;
const float roi_hstart = roi[2] * spatial_scale;
const float roi_wend = roi[3] * spatial_scale;
const float roi_hend = roi[4] * spatial_scale;
const float roi_offset = aligned ? 0.5f : 0.0f;
const float roi_wstart = roi[1] * spatial_scale - roi_offset;
const float roi_hstart = roi[2] * spatial_scale - roi_offset;
const float roi_wend = roi[3] * spatial_scale - roi_offset;
const float roi_hend = roi[4] * spatial_scale - roi_offset;
const float roi_w = max(roi_wend - roi_wstart, 1.f);
const float roi_h = max(roi_hend - roi_hstart, 1.f);
const float roi_w =
aligned ? roi_wend - roi_wstart : max(roi_wend - roi_wstart, 1.f);
const float roi_h =
aligned ? roi_hend - roi_hstart : max(roi_hend - roi_hstart, 1.f);
const float bin_h = roi_h / float(out_h);
const float bin_w = roi_w / float(out_w);
......@@ -143,7 +147,7 @@ __global__ void _RoiAlign(
val += _RoiAlignIntp(H, W, h, w, offset_x);
}
}
y[yi] = convert::To<T>(val / AccT(grid_h * grid_w));
y[yi] = convert::To<T>(val / AccT(max(grid_h * grid_w, 1)));
}
}
......@@ -157,6 +161,7 @@ __global__ void _RoiAlignGrad(
const int out_w,
const float spatial_scale,
const int sampling_ratio,
const bool aligned,
const T* dy,
const float* rois,
AccT* dx) {
......@@ -171,13 +176,16 @@ __global__ void _RoiAlignGrad(
if (batch_ind < 0) continue;
const float roi_wstart = roi[1] * spatial_scale;
const float roi_hstart = roi[2] * spatial_scale;
const float roi_wend = roi[3] * spatial_scale;
const float roi_hend = roi[4] * spatial_scale;
const float roi_offset = aligned ? 0.5f : 0.0f;
const float roi_wstart = roi[1] * spatial_scale - roi_offset;
const float roi_hstart = roi[2] * spatial_scale - roi_offset;
const float roi_wend = roi[3] * spatial_scale - roi_offset;
const float roi_hend = roi[4] * spatial_scale - roi_offset;
const float roi_w = max(roi_wend - roi_wstart, 1.f);
const float roi_h = max(roi_hend - roi_hstart, 1.f);
const float roi_w =
aligned ? roi_wend - roi_wstart : max(roi_wend - roi_wstart, 1.f);
const float roi_h =
aligned ? roi_hend - roi_hstart : max(roi_hend - roi_hstart, 1.f);
const float bin_h = roi_h / float(out_h);
const float bin_w = roi_w / float(out_w);
......@@ -188,8 +196,9 @@ __global__ void _RoiAlignGrad(
sampling_ratio > 0 ? sampling_ratio : int(ceil(roi_h / float(out_h)));
const int grid_w =
sampling_ratio > 0 ? sampling_ratio : int(ceil(roi_w / float(out_w)));
const float dyi = convert::To<float>(dy[yi]) / float(grid_h * grid_w);
float* offset_dx = dx + (batch_ind * C + c) * H * W;
const float grad = convert::To<float>(dy[yi]) / float(grid_h * grid_w);
for (int i = 0; i < grid_h; i++) {
const float h = hstart + (i + .5f) * bin_h / grid_h;
......@@ -199,8 +208,8 @@ __global__ void _RoiAlignGrad(
float v, u;
_RoiAlignIntpParam(H, W, h, w, ti, bi, li, ri, v, u);
if (li >= 0 && ri >= 0 && ti >= 0 && bi >= 0) {
const float db = dyi * v;
const float dt = dyi * (1.f - v);
const float db = grad * v;
const float dt = grad * (1.f - v);
math::utils::AtomicAdd(offset_dx + ti * W + li, (1.f - u) * dt);
math::utils::AtomicAdd(offset_dx + ti * W + ri, u * dt);
math::utils::AtomicAdd(offset_dx + bi * W + li, (1.f - u) * db);
......@@ -226,6 +235,7 @@ __global__ void _RoiAlignGrad(
const int num_rois, \
const float spatial_scale, \
const int sampling_ratio, \
const bool aligned, \
const InputT* x, \
const float* rois, \
OutputT* y, \
......@@ -241,6 +251,7 @@ __global__ void _RoiAlignGrad(
out_w, \
spatial_scale, \
sampling_ratio, \
aligned, \
reinterpret_cast<const math::ScalarType<InputT>::type*>(x), \
rois, \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
......
......@@ -110,7 +110,7 @@ foreach(_file ${MODULE_INCLUDES})
protobuf_remove_constexpr(${_install_dir}/dragon/${_dir}/${_name})
endif()
endforeach()
file(COPY ${THIRD_PARTY_DIR}/cub/cub DESTINATION ${_install_dir})
file(COPY ${THIRD_PARTY_DIR}/cub/cub DESTINATION ${_install_dir}/cub)
file(COPY ${THIRD_PARTY_DIR}/eigen/Eigen DESTINATION ${_install_dir})
file(COPY ${THIRD_PARTY_DIR}/eigen/unsupported/Eigen DESTINATION ${_install_dir}/unsupported)
file(COPY ${THIRD_PARTY_DIR}/pybind11/include/pybind11 DESTINATION ${_install_dir})
......
......@@ -94,7 +94,7 @@ foreach(_file ${MODULE_INCLUDES})
endif()
endforeach()
file(COPY dragon_runtime.h DESTINATION ${_install_dir}/dragon)
file(COPY ${THIRD_PARTY_DIR}/cub/cub DESTINATION ${_install_dir})
file(COPY ${THIRD_PARTY_DIR}/cub/cub DESTINATION ${_install_dir}/cub)
file(COPY ${THIRD_PARTY_DIR}/eigen/Eigen DESTINATION ${_install_dir})
file(COPY ${THIRD_PARTY_DIR}/eigen/unsupported/Eigen DESTINATION ${_install_dir}/unsupported)
file(COPY ${THIRD_PARTY_DIR}/pybind11/include/pybind11 DESTINATION ${_install_dir})
......
......@@ -28,12 +28,14 @@ void BatchNormOp<Context>::RunTraining() {
// Compute moments.
if (sync_stats_ > 0) {
#ifdef USE_MPI
int64_t N = N_;
AllReduce(&N, &N, 1);
// Compute E(X) and E(X^2)
kernels::BatchNormExpectation(
N_,
C_,
S_,
float(N_ * S_ * comm_size_),
float(N * S_),
data_format(),
x,
params,
......@@ -166,8 +168,10 @@ void BatchNormGradientOp<Context>::RunTraining() {
ctx());
}
int64_t N = N_; // Total batch size.
if (sync_stats_ > 0) {
#ifdef USE_MPI
AllReduce(&N, &N, 1);
ctx()->FinishDeviceComputation();
if (enable_nccl_) {
#ifdef USE_NCCL
......@@ -192,11 +196,7 @@ void BatchNormGradientOp<Context>::RunTraining() {
N_,
C_,
S_,
#ifdef USE_MPI
float(N_ * S_ * comm_size_),
#else
float(N_ * S_),
#endif
float(N * S_),
data_format(),
x,
mu,
......
......@@ -34,13 +34,19 @@ void CuDNNConvOp<Context>::SetOpDesc() {
fwd_algo_ = ConvAlgoSearch<FwdAlgo>().get_deterministic();
return;
}
fwd_algo_ = ConvAlgoSearch<FwdAlgo>().get(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
scratch_max_size_);
auto fn = [&]() {
return std::tuple<FwdAlgo, float>(
ConvAlgoSearch<FwdAlgo>().get(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
scratch_max_size_),
0.f);
};
auto result = fwd_algo_cache_.get(X.dims(), W.dims(), compute_type_, fn);
fwd_algo_ = std::get<0>(result);
}
template <class Context>
......@@ -154,20 +160,36 @@ void CuDNNConvGradientOp<Context>::SetOpDesc() {
bwd_filter_algo_ = ConvAlgoSearch<BwdFilterAlgo>().get_deterministic();
return;
}
bwd_data_algo_ = ConvAlgoSearch<BwdDataAlgo>().get(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
scratch_max_size_);
bwd_filter_algo_ = ConvAlgoSearch<BwdFilterAlgo>().get(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
scratch_max_size_);
{
auto fn = [&]() {
return std::tuple<BwdDataAlgo, float>(
ConvAlgoSearch<BwdDataAlgo>().get(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
scratch_max_size_),
0.f);
};
auto result = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, fn);
bwd_data_algo_ = std::get<0>(result);
}
{
auto fn = [&]() {
return std::tuple<BwdFilterAlgo, float>(
ConvAlgoSearch<BwdFilterAlgo>().get(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
scratch_max_size_),
0.f);
};
auto result = filter_algo_cache_.get(X.dims(), W.dims(), compute_type_, fn);
bwd_filter_algo_ = std::get<0>(result);
}
}
template <class Context>
......
......@@ -34,13 +34,19 @@ void CuDNNConvTransposeOp<Context>::SetOpDesc() {
fwd_algo_ = ConvAlgoSearch<FwdAlgo>().get_deterministic();
return;
}
fwd_algo_ = ConvAlgoSearch<FwdAlgo>().get(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
scratch_max_size_);
auto fn = [&]() {
return std::tuple<FwdAlgo, float>(
ConvAlgoSearch<FwdAlgo>().get(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
scratch_max_size_),
0.f);
};
auto result = fwd_algo_cache_.get(X.dims(), W.dims(), compute_type_, fn);
fwd_algo_ = std::get<0>(result);
}
template <class Context>
......@@ -154,20 +160,36 @@ void CuDNNConvTransposeGradientOp<Context>::SetOpDesc() {
bwd_filter_algo_ = ConvAlgoSearch<BwdFilterAlgo>().get_deterministic();
return;
}
bwd_data_algo_ = ConvAlgoSearch<BwdDataAlgo>().get(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
scratch_max_size_);
bwd_filter_algo_ = ConvAlgoSearch<BwdFilterAlgo>().get(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
scratch_max_size_);
{
auto fn = [&]() {
return std::tuple<BwdDataAlgo, float>(
ConvAlgoSearch<BwdDataAlgo>().get(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
scratch_max_size_),
0.f);
};
auto result = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, fn);
bwd_data_algo_ = std::get<0>(result);
}
{
auto fn = [&]() {
return std::tuple<BwdFilterAlgo, float>(
ConvAlgoSearch<BwdFilterAlgo>().get(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
scratch_max_size_),
0.f);
};
auto result = filter_algo_cache_.get(X.dims(), W.dims(), compute_type_, fn);
bwd_filter_algo_ = std::get<0>(result);
}
}
template <class Context>
......
......@@ -21,6 +21,7 @@ void RoiAlignOp<Context>::DoRunWithType() {
RoI.dim(0),
spatial_scale_,
sampling_ratio_,
aligned_ > 0,
X.template data<T, Context>(),
RoI.template data<float, Context>(),
Y->template mutable_data<T, Context>(),
......@@ -28,11 +29,6 @@ void RoiAlignOp<Context>::DoRunWithType() {
}
template <class Context>
void RoiAlignOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void RoiAlignGradientOp<Context>::DoRunWithType() {
auto &RoI = Input(0), &dY = Input(1);
......@@ -60,6 +56,7 @@ void RoiAlignGradientOp<Context>::DoRunWithType() {
RoI.dim(0),
spatial_scale_,
sampling_ratio_,
aligned_ > 0,
dY.template data<T, Context>(),
RoI.template data<float, Context>(),
dx_acc != nullptr ? dx_acc : reinterpret_cast<float*>(dx),
......@@ -71,11 +68,6 @@ void RoiAlignGradientOp<Context>::DoRunWithType() {
}
}
template <class Context>
void RoiAlignGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(1));
}
DEPLOY_CPU_OPERATOR(RoiAlign);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(RoiAlign);
......
......@@ -25,20 +25,23 @@ class RoiAlignOp final : public Operator<Context> {
out_h_(OP_SINGLE_ARG(int64_t, "pooled_h", 0)),
out_w_(OP_SINGLE_ARG(int64_t, "pooled_w", 0)),
spatial_scale_(OP_SINGLE_ARG(float, "spatial_scale", 1.f)),
sampling_ratio_(OP_SINGLE_ARG(int64_t, "sampling_ratio", 2)) {
sampling_ratio_(OP_SINGLE_ARG(int64_t, "sampling_ratio", 0)),
aligned_(OP_SINGLE_ARG(int64_t, "aligned", 0)) {
CHECK_GT(out_h_, 0) << "\npooled_h must > 0";
CHECK_GT(out_w_, 0) << "\npooled_w must > 0";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
float spatial_scale_;
int64_t sampling_ratio_;
int64_t sampling_ratio_, aligned_;
int64_t out_h_, out_w_;
};
......@@ -50,17 +53,20 @@ class RoiAlignGradientOp final : public Operator<Context> {
out_h_(OP_SINGLE_ARG(int64_t, "pooled_h", 0)),
out_w_(OP_SINGLE_ARG(int64_t, "pooled_w", 0)),
spatial_scale_(OP_SINGLE_ARG(float, "spatial_scale", 1.f)),
sampling_ratio_(OP_SINGLE_ARG(int64_t, "sampling_ratio", 2)) {}
sampling_ratio_(OP_SINGLE_ARG(int64_t, "sampling_ratio", 0)),
aligned_(OP_SINGLE_ARG(int64_t, "aligned", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(1));
}
template <typename T>
void DoRunWithType();
protected:
float spatial_scale_;
int64_t sampling_ratio_;
int64_t sampling_ratio_, aligned_;
int64_t out_h_, out_w_;
};
......
......@@ -494,7 +494,8 @@ def roi_align_args(**kwargs):
'pooled_h': kwargs.get('pooled_h', 7),
'pooled_w': kwargs.get('pooled_w', 7),
'spatial_scale': kwargs.get('spatial_scale', 1.0),
'sampling_ratio': kwargs.get('sampling_ratio', 2),
'sampling_ratio': kwargs.get('sampling_ratio', 0),
'aligned': kwargs.get('aligned', False),
}
......
......@@ -1384,7 +1384,8 @@ def roi_align(
pooled_h,
pooled_w,
spatial_scale=1.0,
sampling_ratio=2,
sampling_ratio=-1,
aligned=False,
**kwargs
):
r"""Apply the average roi align.
......@@ -1412,8 +1413,10 @@ def roi_align(
The output width.
spatial_scale : float, optional, default=1.0
The input scale to the size of ``rois``.
sampling_ratio : int, optional, default=2
sampling_ratio : int, optional, default=-1
The number of sampling grids for ``rois``.
aligned : bool, optional, default=False
Whether to shift the input coordinates by ``-0.5``.
Returns
-------
......@@ -1429,12 +1432,14 @@ def roi_align(
pooled_h=pooled_h,
pooled_w=pooled_w,
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio)
sampling_ratio=sampling_ratio,
aligned=aligned)
return OpLib.add('RoiAlign', inputs,
pooled_h=pooled_h,
pooled_w=pooled_w,
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio,
aligned=aligned,
**kwargs)
......
......@@ -51,6 +51,8 @@ def include_paths(cuda=False):
cuda_home_include = _join_cuda_path('include')
if cuda_home_include != '/usr/include':
paths.append(cuda_home_include)
if not _os.path.exists(cuda_home_include + '/cub'):
paths.append(paths[0] + '/cub')
if CUDNN_HOME is not None:
paths.append(_os.path.join(CUDNN_HOME, 'include'))
return paths
......
......@@ -1485,6 +1485,7 @@ void RoiAlign(
const int num_rois,
const float spatial_scale,
const int sampling_ratio,
const bool aligned,
const T* x,
const float* rois,
T* y,
......@@ -1500,6 +1501,7 @@ void RoiAlignGrad(
const int num_rois,
const float spatial_scale,
const int sampling_ratio,
const bool aligned,
const T* dy,
const float* rois,
float* dx,
......
Open MPI: Open Source High Performance Computing
===================================================
# Open MPI: Open Source High Performance Computing
https://www.open-mpi.org/
Note
----
## Introduction
This folder is kept for the specified open-mpi.
Following file structure will be considered by our CMakeLists:
.
├── bin # Binary files
├── mpirun
└── ...
├── include # Include files
├── mpi.h
└── ...
├── lib # Library files
├── libmpi.so
└── ...
├── src # Source files
└── ...
├── build.sh # Build script
└── README.md
```
third_party
|_ mpi
| |_ bin # Binary files
| | |_ mpirun
| | |_ ...
| |_ include # Include files
| | |_ mpi.h
| | |_ ...
| |_ lib # Library files
| | |_ libmpi.so
| | |_ ...
| |_ build.sh # Build script
| |_ README.md
```
......@@ -16,26 +16,28 @@ INSTALL_PATH=$(cd "$(dirname "$0")/..";pwd)
if [ $USE_CUDA_AWARE -eq 1 ];then
echo "Build with cuda...."
read -p "Press any key to continue." var
./configure CFLAGS=-fPIC \
CXXFLAGS=-fPIC \
--with-cuda \
--with-pic=PIC \
--without-verbs \
--without-ucx \
--disable-libudev \
--enable-shared \
--enable-static \
--enable-mpi-thread-multiple \
--prefix=$INSTALL_PATH
else
echo "Build without cuda...."
read -p "Press any key to continue." var
./configure CFLAGS=-fPIC \
CXXFLAGS=-fPIC \
--with-pic=PIC \
--without-verbs \
--without-ucx \
--disable-libudev \
--enable-shared \
--enable-static \
--enable-mpi-thread-multiple \
--prefix=$INSTALL_PATH
fi
make install -j $(getconf _NPROCESSORS_ONLN)
Protocol Buffers - Google's data interchange format
===================================================
# Protocol Buffers - Google's data interchange format
https://developers.google.com/protocol-buffers/
Note
----
## Introduction
This folder is kept for the specified protobuf, or the released libraries of Visual Studio.
Following file structure will be considered by our CMakeLists:
.
├── bin # Binary files
├── protoc
└── protoc.exe
├── include # Include files
└── google
└── protobuf
└── *.h
├── lib # Library files
├── libprotobuf.so
└── protobuf.lib
└── README.md
```
third_party
|_ protobuf
| |_ bin # Binary files
| | |_ protoc
| | |_ protoc.exe
| | |_ ...
| |_ include # Include files
| | |_ google
| | | |_ protobuf
| | | | |_ *.h
| |_ lib # Library files
| | |_ libprotobuf.so
| | |_ protobuf.lib
| | |_ ...
| |_ README.md
```
......@@ -69,9 +69,9 @@ class Affine(Module):
bias : bool, optional, default=True
``True`` to attach a bias.
fix_weight : bool, optional, default=False
``True`` to frozen the ``weight``.
``True`` to freeze the ``weight``.
fix_bias : bool, optional, default=False
``True`` to frozen the ``bias``.
``True`` to freeze the ``bias``.
inplace : bool, optional, default=False
Whether to do the operation in-place.
......
......@@ -16,7 +16,14 @@ from __future__ import print_function
from dragon.vm import torch
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
def roi_align(
input,
boxes,
output_size,
spatial_scale=1.0,
sampling_ratio=-1,
aligned=False,
):
r"""Apply the average roi align to input.
`[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
......@@ -37,6 +44,8 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
The input scale to the size of ``boxes``.
sampling_ratio : int, optional, default=-1
The number of sampling grids for ``boxes``.
aligned : bool, optional, default=False
Whether to shift the input coordinates by ``-0.5``.
Returns
-------
......@@ -47,4 +56,5 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
return torch.autograd.Function.apply(
'RoiAlign', input.device, [input, boxes],
pooled_h=output_size[0], pooled_w=output_size[1],
spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio, aligned=aligned)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!