Skip to content

Commit

Permalink
Fused softmax kernel (#3496)
Browse files Browse the repository at this point in the history
Former-commit-id: 2ec8fc6
  • Loading branch information
liujuncheng authored Aug 20, 2020
1 parent 23cb615 commit b4e377b
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 85 deletions.
32 changes: 26 additions & 6 deletions oneflow/python/test/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def compare_with_tensorflow(device_type, x_shape, data_type, axis):
func_config = flow.FunctionConfig()

if data_type == "float16":
func_config.enable_auto_mixed_precision(True)
dtype = flow.float
else:
dtype = type_name_to_flow_type[data_type]
Expand All @@ -45,10 +44,16 @@ def SoftmaxJob():
"x",
shape=x_shape,
dtype=dtype,
initializer=flow.random_uniform_initializer(minval=-10, maxval=10),
initializer=flow.random_uniform_initializer(minval=-0.1, maxval=0.1),
trainable=True,
)
loss = flow.nn.softmax(x, axis=axis)
if data_type == "float16":
loss = flow.cast(
flow.nn.softmax(flow.cast(x, dtype=flow.float16), axis=axis),
dtype=flow.float,
)
else:
loss = flow.nn.softmax(x, axis=axis)
flow.optimizer.SGD(
flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0
).minimize(loss)
Expand All @@ -71,16 +76,31 @@ def SoftmaxJob():

loss_diff = test_global_storage.Get("loss_diff")
tf_x_diff = tape.gradient(tf_out, x, loss_diff)
assert np.allclose(of_out.numpy(), tf_out.numpy(), rtol=1e-5, atol=1e-5)
if data_type == "float16":
tolerance = 1e-3
else:
tolerance = 1e-5
assert np.allclose(of_out.numpy(), tf_out.numpy(), rtol=tolerance, atol=tolerance)
assert np.allclose(
test_global_storage.Get("x_diff"), tf_x_diff.numpy(), rtol=1e-5, atol=1e-5
test_global_storage.Get("x_diff"),
tf_x_diff.numpy(),
rtol=tolerance,
atol=tolerance,
)


def test_softmax(test_case):
arg_dict = OrderedDict()
arg_dict["device_type"] = ["gpu", "cpu"]
arg_dict["x_shape"] = [(10, 10, 20, 30), (10, 20, 30), (10, 20)]
arg_dict["x_shape"] = [
(10, 10, 20, 30),
(10, 20, 30),
(10, 20),
(10, 960),
(10, 4096),
(10, 8092),
(256, 1001),
]
arg_dict["data_type"] = ["float32", "double", "float16"]
arg_dict["axis"] = [-1, 1, 2, 3]
for arg in GenArgList(arg_dict):
Expand Down
25 changes: 14 additions & 11 deletions oneflow/user/kernels/softmax_cross_entropy_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,27 @@ class SoftmaxCrossEntropyKernel final : public user_op::OpKernel {
const int64_t num_instances = label->shape().Count(0, num_axes - 1);
const int64_t num_classes = label->shape().At(num_axes - 1);
SoftmaxKernelUtil<device_type, T>::ComputeProb(
ctx->device_ctx(), num_instances, num_classes, prediction->dptr<T>(), out->mut_dptr<T>(),
prob->mut_dptr<T>(), tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt() * sizeof(T));
ctx->device_ctx(), num_instances, num_classes, prediction->dptr<T>(), prob->mut_dptr<T>(),
tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt());
CrossEntropyKernelUtil<device_type, T>::ComputeEntropy(ctx->device_ctx(), num_instances,
num_classes, prob->dptr<T>(),
label->dptr<T>(), out->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL(device_type_v, dtype_pair) \
REGISTER_USER_KERNEL("softmax_cross_entropy") \
.SetCreateFn<SoftmaxCrossEntropyKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device_type_v) \
& (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \
& (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const Shape* prediction_shape = ctx->Shape4ArgNameAndIndex("prediction", 0); \
return prediction_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair)); \
#define REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL(device_type_v, dtype_pair) \
REGISTER_USER_KERNEL("softmax_cross_entropy") \
.SetCreateFn<SoftmaxCrossEntropyKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device_type_v) \
& (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \
& (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const Shape* prediction_shape = ctx->Shape4ArgNameAndIndex("prediction", 0); \
const int64_t num_classes = prediction_shape->At(prediction_shape->NumAxes() - 1); \
const int64_t num_instances = prediction_shape->Count(0, prediction_shape->NumAxes() - 1); \
return SoftmaxKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>:: \
GetComputeProbTempStorageSizeInBytes(num_instances, num_classes); \
});

template<DeviceType device_type, typename T>
Expand Down
77 changes: 35 additions & 42 deletions oneflow/user/kernels/softmax_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,31 @@ template<DeviceType device_type, typename T>
class SoftmaxKernel final : public user_op::OpKernel {
public:
SoftmaxKernel() = default;
~SoftmaxKernel() = default;
~SoftmaxKernel() override = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0);
const int64_t num_classes = x->shape().At(x->shape().NumAxes() - 1);
const int64_t num_instances = x->shape().elem_cnt() / num_classes;

const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const int64_t num_classes = in->shape().At(in->shape().NumAxes() - 1);
const int64_t num_instances = in->shape().Count(0, in->shape().NumAxes() - 1);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const size_t temp_storage_bytes = x->shape().elem_cnt() * sizeof(T);
const size_t tmp_bytes = GetCudaAlignedSize(temp_storage_bytes / num_classes);

T* tmp_ptr = tmp_buffer->mut_dptr<T>();
void* temp_storage_ptr = reinterpret_cast<void*>(tmp_ptr + tmp_bytes / sizeof(T));
const size_t temp_storage_bytes = tmp_buffer->shape().elem_cnt();
SoftmaxKernelUtil<device_type, T>::ComputeProb(ctx->device_ctx(), num_instances, num_classes,
x->dptr<T>(), tmp_ptr, y->mut_dptr<T>(),
temp_storage_ptr, temp_storage_bytes);
in->dptr<T>(), out->mut_dptr<T>(),
tmp_buffer->mut_dptr(), temp_storage_bytes);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

template<typename T>
user_op::InferTmpSizeFn GenInferTmpSizeFn(const std::string& bn) {
return [bn](user_op::InferContext* ctx) {
const Shape* x = ctx->Shape4ArgNameAndIndex(bn, 0);
const size_t num_classes = x->dim_vec().back();
size_t temp_storage_bytes = GetCudaAlignedSize(x->elem_cnt() * sizeof(T)); // [i][j]
size_t tmp_or_sum_vec_bytes = GetCudaAlignedSize(temp_storage_bytes / num_classes); //[i]
return tmp_or_sum_vec_bytes + temp_storage_bytes;
template<DeviceType device_type, typename T>
user_op::InferTmpSizeFn GenFwInferTmpSizeFn() {
return [](user_op::InferContext* ctx) {
const Shape* in_shape = ctx->Shape4ArgNameAndIndex("in", 0);
const int64_t num_classes = in_shape->At(in_shape->NumAxes() - 1);
const int64_t num_instances = in_shape->Count(0, in_shape->NumAxes() - 1);
return SoftmaxKernelUtil<device_type, T>::GetComputeProbTempStorageSizeInBytes(num_instances,
num_classes);
};
}

Expand All @@ -63,21 +58,16 @@ user_op::InferTmpSizeFn GenInferTmpSizeFn(const std::string& bn) {
.SetCreateFn<SoftmaxKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("out", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn(GenInferTmpSizeFn<dtype>("in"));
.SetInferTmpSizeFn(GenFwInferTmpSizeFn<device, dtype>());

REGISTER_SOFTMAX_KERNEL(DeviceType::kCPU, float)
REGISTER_SOFTMAX_KERNEL(DeviceType::kCPU, double)
#ifdef WITH_CUDA
REGISTER_SOFTMAX_KERNEL(DeviceType::kGPU, float16)
REGISTER_SOFTMAX_KERNEL(DeviceType::kGPU, float)
REGISTER_SOFTMAX_KERNEL(DeviceType::kGPU, double)
#endif

template<DeviceType device_type, typename T>
class SoftmaxGradKernel final : public user_op::OpKernel {
public:
SoftmaxGradKernel() = default;
~SoftmaxGradKernel() = default;
~SoftmaxGradKernel() override = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
Expand All @@ -89,32 +79,35 @@ class SoftmaxGradKernel final : public user_op::OpKernel {
const int64_t num_instances = y->shape().elem_cnt() / num_classes;

user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const size_t temp_storage_bytes = y->shape().elem_cnt() * sizeof(T);
const size_t sum_vec_bytes = GetCudaAlignedSize(temp_storage_bytes / num_classes);

T* sum_vec_ptr = tmp_buffer->mut_dptr<T>();
void* temp_storage_ptr = reinterpret_cast<void*>(sum_vec_ptr + sum_vec_bytes / sizeof(T));
SoftmaxKernelUtil<device_type, T>::ComputeDiff(
ctx->device_ctx(), num_instances, num_classes, dy->dptr<T>(), y->dptr<T>(), sum_vec_ptr,
dx->mut_dptr<T>(), temp_storage_ptr, temp_storage_bytes);
const size_t temp_storage_bytes = tmp_buffer->shape().elem_cnt();

SoftmaxKernelUtil<device_type, T>::ComputeDiff(ctx->device_ctx(), num_instances, num_classes,
dy->dptr<T>(), y->dptr<T>(), dx->mut_dptr<T>(),
tmp_buffer->mut_dptr(), temp_storage_bytes);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

template<DeviceType device_type, typename T>
user_op::InferTmpSizeFn GenBwInferTmpSizeFn() {
return [](user_op::InferContext* ctx) {
const Shape* dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0);
const int64_t num_classes = dy_shape->At(dy_shape->NumAxes() - 1);
const int64_t num_instances = dy_shape->Count(0, dy_shape->NumAxes() - 1);
return SoftmaxKernelUtil<device_type, T>::GetComputeDiffTempStorageSizeInBytes(num_instances,
num_classes);
};
}

#define REGISTER_SOFTMAX_GRAD_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("softmax_grad") \
.SetCreateFn<SoftmaxGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn(GenInferTmpSizeFn<dtype>("dx"));
.SetInferTmpSizeFn(GenBwInferTmpSizeFn<device, dtype>());

REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kCPU, float)
REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kCPU, double)
#ifdef WITH_CUDA
REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kGPU, float16)
REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kGPU, float)
REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kGPU, double)
#endif

} // namespace

Expand Down
Loading

0 comments on commit b4e377b

Please sign in to comment.