Skip to content

Commit

Permalink
Optimize transpose performance (#3487)
Browse files Browse the repository at this point in the history
Former-commit-id: 809793c
  • Loading branch information
liujuncheng authored Aug 16, 2020
1 parent ea1d417 commit 396bd65
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 51 deletions.
89 changes: 71 additions & 18 deletions oneflow/core/kernel/util/cuda_arithemetic_interface.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,18 @@ template<int32_t NDIMS>
__device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, int32_t y_idx) {
int32_t x_idx = 0;
for (int32_t i = NDIMS - 1; i >= 0; --i) {
x_idx += (y_idx % y_shape[i]) * x_strides[i];
y_idx /= y_shape[i];
const int32_t next_y_idx = y_idx / y_shape[i];
x_idx += (y_idx - next_y_idx * y_shape[i]) * x_strides[i];
y_idx = next_y_idx;
}
return x_idx;
}

template<int32_t NDIMS, typename T>
__global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<NDIMS> x_strides,
const int32_t elem_cnt, const T* x, T* y) {
__shared__ int32_t x_strides_shared[NDIMS];
__shared__ int32_t y_dims_shared[NDIMS];
const int32_t tid = threadIdx.x;
if (tid < NDIMS) {
y_dims_shared[tid] = y_shape.val[tid];
x_strides_shared[tid] = x_strides.val[tid];
}
__syncthreads();
CUDA_1D_KERNEL_LOOP(y_idx, elem_cnt) {
const int32_t x_idx = GetXIndex<NDIMS>(y_dims_shared, x_strides_shared, y_idx);
const int32_t x_idx = GetXIndex<NDIMS>(y_shape.val, x_strides.val, y_idx);
#if __CUDA_ARCH__ >= 350
y[y_idx] = __ldg(x + x_idx);
#else
Expand All @@ -62,7 +55,8 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N

template<int32_t NDIMS, typename T>
void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation, const int64_t elem_cnt, const T* x, T* y) {
const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
T* y) {
CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>());
Int32Array<NDIMS> y_shape_struct;
FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); }
Expand Down Expand Up @@ -95,7 +89,7 @@ struct TransposeUtil final {
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y) {
TRANSPOSE_CHECK;
TransposeUtil<float>::SwitchTransposeImpl(SwitchCase(num_axis), ctx, x_shape, y_shape,
Expand All @@ -104,7 +98,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const double* x,
double* y) {
TRANSPOSE_CHECK;
Expand All @@ -114,7 +108,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const float16* x,
float16* y) {
TRANSPOSE_CHECK;
Expand All @@ -125,7 +119,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x,
int8_t* y) {
TRANSPOSE_CHECK;
Expand All @@ -135,7 +129,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x,
int32_t* y) {
TRANSPOSE_CHECK;
Expand All @@ -145,7 +139,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x,
int64_t* y) {
TRANSPOSE_CHECK;
Expand All @@ -155,6 +149,65 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
#undef TRANSPOSE_CHECK
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const double* x,
double* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float16* x,
float16* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x,
int8_t* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x,
int32_t* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x,
int64_t* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::InitializeWithConstConf(
DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob) {
WithHostBlobAndStreamSynchronizeEnv(ctx, blob, [&](Blob* host_blob) {
Expand Down
43 changes: 31 additions & 12 deletions oneflow/core/kernel/util/cuda_arithemetic_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,43 @@ class ConstantInitializerConf;

template<>
struct ArithemeticIf<DeviceType::kGPU> {
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const float16* x, float16* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int64_t* x, int64_t* y);

static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float16* x, float16* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const float16* x, float16* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x, int64_t* y);
int64_t elem_cnt, const int64_t* x, int64_t* y);

static void InitializeWithConstConf(DeviceCtx* ctx,
const ConstantInitializerConf& initializer_conf, Blob* blob);
Expand Down
Loading

0 comments on commit 396bd65

Please sign in to comment.