-
Notifications
You must be signed in to change notification settings - Fork 669
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add identity op to bn grad * rename identity to diff_identity * SetTensorDescInferFn use TensorDesc * revert flow.identity api Former-commit-id: 61a2df9
- Loading branch information
1 parent
e87f38c
commit df33356
Showing
3 changed files
with
112 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include "oneflow/core/framework/framework.h" | ||
#include "oneflow/core/kernel/new_kernel_util.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace { | ||
|
||
template<DeviceType device_type> | ||
class IdentityKernel final : public user_op::OpKernel { | ||
public: | ||
IdentityKernel() = default; | ||
~IdentityKernel() override = default; | ||
|
||
private: | ||
void Compute(user_op::KernelComputeContext* ctx) const override { | ||
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); | ||
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); | ||
const ShapeView& in_shape = in->shape(); | ||
CHECK_EQ(out->shape(), in_shape); | ||
const DataType in_data_type = in->data_type(); | ||
CHECK_EQ(out->data_type(), in_data_type); | ||
Memcpy<device_type>(ctx->device_ctx(), out->mut_dptr<void>(), in->dptr<void>(), | ||
in_shape.elem_cnt() * GetSizeOfDataType(in_data_type)); | ||
} | ||
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } | ||
}; | ||
|
||
#define REGISTER_IDENTITY_KERNEL(device) \ | ||
REGISTER_USER_KERNEL("identity") \ | ||
.SetCreateFn<IdentityKernel<device>>() \ | ||
.SetIsMatchedHob(user_op::HobDeviceType() == device) \ | ||
.SetInplaceProposalFn([](const user_op::InferContext&, \ | ||
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \ | ||
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); \ | ||
return Maybe<void>::Ok(); \ | ||
}); | ||
|
||
REGISTER_IDENTITY_KERNEL(DeviceType::kCPU) | ||
REGISTER_IDENTITY_KERNEL(DeviceType::kGPU) | ||
|
||
} // namespace | ||
|
||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#include "oneflow/core/framework/framework.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace { | ||
|
||
REGISTER_USER_OP("identity") | ||
.Input("in") | ||
.Output("out") | ||
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { | ||
const user_op::TensorDesc* in_desc = ctx->TensorDesc4ArgNameAndIndex("in", 0); | ||
user_op::TensorDesc* out_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0); | ||
*out_desc = *in_desc; | ||
return Maybe<void>::Ok(); | ||
}) | ||
.SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> { | ||
*ctx->BatchAxis4ArgNameAndIndex("out", 0) = *ctx->BatchAxis4ArgNameAndIndex("in", 0); | ||
return Maybe<void>::Ok(); | ||
}) | ||
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { | ||
const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); | ||
FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("in", 0), i) | ||
.Split(user_op::OpArg("out", 0), i) | ||
.Build(); | ||
} | ||
ctx->NewBuilder() | ||
.PartialSum(user_op::OpArg("in", 0)) | ||
.PartialSum(user_op::OpArg("out", 0)) | ||
.Build(); | ||
return Maybe<void>::Ok(); | ||
}); | ||
|
||
REGISTER_USER_OP_GRAD("identity") | ||
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { | ||
if (op.NeedGenGradTensor4OpInput("in", 0)) { | ||
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); | ||
user_op::UserOpConfWrapper identity_op = | ||
builder.Op("identity") | ||
.Input("in", op.GetGradTensorWithOpOutput("out", 0)) | ||
.Output("out") | ||
.Build(); | ||
op.BindGradTensorWithOpInput(identity_op.output("out", 0), "in", 0); | ||
AddOp(identity_op); | ||
} | ||
}); | ||
|
||
} // namespace | ||
|
||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters