Skip to content

Commit

Permalink
Add identity op to bn grad (#3118)
Browse files Browse the repository at this point in the history
* Add identity op to bn grad

* rename identity to diff_identity

* SetTensorDescInferFn use TensorDesc

* revert flow.identity api

Former-commit-id: 61a2df9
  • Loading branch information
liujuncheng authored Jul 2, 2020
1 parent e87f38c commit df33356
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
43 changes: 43 additions & 0 deletions oneflow/customized/kernels/identity_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"

namespace oneflow {

namespace {

template<DeviceType device_type>
class IdentityKernel final : public user_op::OpKernel {
public:
IdentityKernel() = default;
~IdentityKernel() override = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const ShapeView& in_shape = in->shape();
CHECK_EQ(out->shape(), in_shape);
const DataType in_data_type = in->data_type();
CHECK_EQ(out->data_type(), in_data_type);
Memcpy<device_type>(ctx->device_ctx(), out->mut_dptr<void>(), in->dptr<void>(),
in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_IDENTITY_KERNEL(device) \
REGISTER_USER_KERNEL("identity") \
.SetCreateFn<IdentityKernel<device>>() \
.SetIsMatchedHob(user_op::HobDeviceType() == device) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); \
return Maybe<void>::Ok(); \
});

REGISTER_IDENTITY_KERNEL(DeviceType::kCPU)
REGISTER_IDENTITY_KERNEL(DeviceType::kGPU)

} // namespace

} // namespace oneflow
51 changes: 51 additions & 0 deletions oneflow/customized/ops/identity_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "oneflow/core/framework/framework.h"

namespace oneflow {

namespace {

REGISTER_USER_OP("identity")
.Input("in")
.Output("out")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const user_op::TensorDesc* in_desc = ctx->TensorDesc4ArgNameAndIndex("in", 0);
user_op::TensorDesc* out_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0);
*out_desc = *in_desc;
return Maybe<void>::Ok();
})
.SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> {
*ctx->BatchAxis4ArgNameAndIndex("out", 0) = *ctx->BatchAxis4ArgNameAndIndex("in", 0);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0);
FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("in", 0), i)
.Split(user_op::OpArg("out", 0), i)
.Build();
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("in", 0))
.PartialSum(user_op::OpArg("out", 0))
.Build();
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("identity")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper identity_op =
builder.Op("identity")
.Input("in", op.GetGradTensorWithOpOutput("out", 0))
.Output("out")
.Build();
op.BindGradTensorWithOpInput(identity_op.output("out", 0), "in", 0);
AddOp(identity_op);
}
});

} // namespace

} // namespace oneflow
20 changes: 18 additions & 2 deletions oneflow/customized/ops/normalization_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,27 @@ REGISTER_USER_OP_GRAD("normalization")
need_norm_grad_op = true;
}
if (op.NeedGenGradTensor4OpInput("gamma", 0)) {
op.BindGradTensorWithOpInput(grad_op.output("gamma_diff", 0), "gamma", 0);
// TODO(liujuncheng): delete identity op when boxing support separated regsts
const auto identity =
user_op::UserOpConfWrapperBuilder(op.op_name() + "_grad_gamma_diff_identity")
.Op("identity")
.Input("in", grad_op.output("gamma_diff", 0))
.Output("out")
.Build();
AddOp(identity);
op.BindGradTensorWithOpInput(identity.output("out", 0), "gamma", 0);
need_norm_grad_op = true;
}
if (op.NeedGenGradTensor4OpInput("beta", 0)) {
op.BindGradTensorWithOpInput(grad_op.output("beta_diff", 0), "beta", 0);
// TODO(liujuncheng): delete identity op when boxing support separated regsts
const auto identity =
user_op::UserOpConfWrapperBuilder(op.op_name() + "_grad_beta_diff_identity")
.Op("identity")
.Input("in", grad_op.output("beta_diff", 0))
.Output("out")
.Build();
AddOp(identity);
op.BindGradTensorWithOpInput(identity.output("out", 0), "beta", 0);
need_norm_grad_op = true;
}
if (need_norm_grad_op) { AddOp(grad_op); }
Expand Down

0 comments on commit df33356

Please sign in to comment.