From 57d0d18aeff96a0fdd3c35c887126c483a6a6fed Mon Sep 17 00:00:00 2001 From: Juncheng Date: Fri, 8 Oct 2021 14:05:39 +0800 Subject: [PATCH] Refine acc actor (#6444) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- cmake/cfg.cmake | 2 - oneflow/core/actor/acc_actor.cpp | 47 ++---- oneflow/core/actor/actor.cpp | 4 - oneflow/core/actor/actor.h | 1 - oneflow/core/framework/op_arg_util.h | 7 +- oneflow/core/job/id_manager.cpp | 4 - oneflow/core/job/id_manager.h | 1 - oneflow/core/register/pod.proto | 41 ------ oneflow/core/register/pod_desc.cpp | 207 --------------------------- oneflow/core/register/pod_desc.h | 176 ----------------------- oneflow/core/register/pod_ptr.cpp | 37 ----- oneflow/core/register/pod_ptr.h | 96 ------------- 12 files changed, 16 insertions(+), 607 deletions(-) delete mode 100644 oneflow/core/register/pod.proto delete mode 100644 oneflow/core/register/pod_desc.cpp delete mode 100644 oneflow/core/register/pod_desc.h delete mode 100644 oneflow/core/register/pod_ptr.cpp delete mode 100644 oneflow/core/register/pod_ptr.h diff --git a/cmake/cfg.cmake b/cmake/cfg.cmake index b22ff05ed27..a9e34409f81 100644 --- a/cmake/cfg.cmake +++ b/cmake/cfg.cmake @@ -45,7 +45,6 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) oneflow/core/job/sbp_parallel.proto oneflow/core/graph/boxing/collective_boxing.proto oneflow/core/register/blob_desc.proto - oneflow/core/register/pod.proto oneflow/core/job/scope.proto oneflow/core/job/mirrored_parallel.proto oneflow/core/operator/op_attribute.proto @@ -100,7 +99,6 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) oneflow/core/operator/interface_blob_conf.proto oneflow/core/common/shape.proto oneflow/core/register/blob_desc.proto - oneflow/core/register/pod.proto oneflow/core/operator/op_conf.proto ) diff --git a/oneflow/core/actor/acc_actor.cpp b/oneflow/core/actor/acc_actor.cpp index 676acb99f12..0cacb22943a 100644 --- a/oneflow/core/actor/acc_actor.cpp +++ b/oneflow/core/actor/acc_actor.cpp @@ -23,18 +23,14 @@ class AccActor final : public Actor { AccActor() = default; ~AccActor() override = default; - using Actor::Init; - private: void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void VirtualActorInit(const TaskProto& proto) override; - void Init(const TaskProto&, int32_t max_acc_cnt); - std::function cpy_func_; - int32_t acc_cnt_; - int32_t max_acc_cnt_; + int32_t acc_cnt_{}; + int32_t max_acc_cnt_{}; }; void AccActor::VirtualActorInit(const TaskProto& proto) { @@ -45,44 +41,21 @@ void AccActor::VirtualActorInit(const TaskProto& proto) { ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); CHECK_GE(in_time_shape.elem_cnt(), out_time_shape.elem_cnt()); - Init(proto, in_time_shape.elem_cnt() / out_time_shape.elem_cnt()); -} - -void AccActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt) { - using namespace std::placeholders; - if (GetDeviceType() == DeviceType::kCPU) { - cpy_func_ = std::bind(Memcpy, _1, _2, _3, _4); - } else { -#ifdef WITH_CUDA - cpy_func_ = std::bind(Memcpy, _1, _2, _3, _4); -#else - UNIMPLEMENTED(); -#endif - } - OF_SET_MSG_HANDLER(&AccActor::HandlerNormal); + max_acc_cnt_ = in_time_shape.elem_cnt() / out_time_shape.elem_cnt(); acc_cnt_ = 0; - max_acc_cnt_ = max_acc_cnt; + OF_SET_MSG_HANDLER(&AccActor::HandlerNormal); } void AccActor::Act() { - Regst* out_regst = GetNaiveCurWriteable("out"); - Regst* in_regst = GetNaiveCurReadable("in"); if (acc_cnt_ == 0) { + Regst* out_regst = GetNaiveCurWriteable("out"); + Regst* in_regst = GetNaiveCurReadable("in"); const Blob* in_blob = in_regst->GetMutSoleBlob(); Blob* out_blob = out_regst->GetMutSoleBlob(); - if (GetDeviceType() == DeviceType::kCPU) { - Memcpy(mut_device_ctx().get(), out_blob->ForceMutDptr(), in_blob->dptr(), - out_blob->ByteSizeOfBlobBody()); - } else if (GetDeviceType() == DeviceType::kGPU) { -#ifdef WITH_CUDA - Memcpy(mut_device_ctx().get(), out_blob->ForceMutDptr(), in_blob->dptr(), - out_blob->ByteSizeOfBlobBody()); -#else - UNIMPLEMENTED(); -#endif - } else { - UNIMPLEMENTED(); - } + const size_t size = in_blob->ByteSizeOfBlobBody(); + CHECK_EQ(out_blob->ByteSizeOfBlobBody(), size); + AutoMemcpy(mut_device_ctx().get(), out_blob->ForceMutDptr(), in_blob->dptr(), size, + out_blob->mem_case(), in_blob->mem_case()); } else { AsyncLaunchKernel(); } diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index c11b27eb5c1..6170560dd22 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -297,10 +297,6 @@ void Actor::ForEachProducedRegst(const std::function& Handler) con } } -DeviceType Actor::GetDeviceType() const { - return Global::Get()->GetDeviceTypeFromActorId(actor_id_); -} - int64_t Actor::Name2SoleRegstDescId(const std::string& name) const { auto find_it = name2regst_desc_id_.find(name); if (find_it != name2regst_desc_id_.end()) { diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index db9860c492a..f2f5b6f351a 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -68,7 +68,6 @@ class Actor : public ActorBase { const ParallelContext* parallel_ctx() const { return parallel_ctx_.get(); } bool ReceiveAllEordMsg() const { return remaining_eord_cnt_ == 0; } bool ReceiveEordMsg(int64_t regst_desc_id) const; - DeviceType GetDeviceType() const; virtual void VirtualActorInit(const TaskProto&) {} int64_t Name2SoleRegstDescId(const std::string& name) const; const std::vector& Name2RegstDescIds(const std::string& name) const; diff --git a/oneflow/core/framework/op_arg_util.h b/oneflow/core/framework/op_arg_util.h index 4c9c33918f5..eeaf297ba6d 100644 --- a/oneflow/core/framework/op_arg_util.h +++ b/oneflow/core/framework/op_arg_util.h @@ -25,7 +25,6 @@ limitations under the License. #include "oneflow/core/common/shape.cfg.h" #include "oneflow/core/register/logical_blob_id.cfg.h" #include "oneflow/core/operator/interface_blob_conf.cfg.h" -#include "oneflow/core/register/pod.cfg.h" #include "oneflow/core/register/blob_desc.cfg.h" #include "oneflow/core/operator/op_node_signature.cfg.h" #include "oneflow/core/job/parallel_signature.cfg.h" @@ -42,6 +41,9 @@ class OpArgBlobAttribute { const std::string& logical_blob_name); OpArgBlobAttribute(const OpArgBlobAttribute& op_arg_blob_attr) = default; + OpArgBlobAttribute(OpArgBlobAttribute&& op_arg_blob_attr) = delete; + OpArgBlobAttribute& operator=(const OpArgBlobAttribute&) = delete; + OpArgBlobAttribute& operator=(OpArgBlobAttribute&&) = delete; virtual ~OpArgBlobAttribute() = default; std::shared_ptr blob_desc() const; @@ -78,6 +80,9 @@ class OpArgParallelAttribute { const std::shared_ptr& opt_mirrored_parallel); OpArgParallelAttribute(const OpArgParallelAttribute& op_arg_para_attr) = default; + OpArgParallelAttribute(OpArgParallelAttribute&& op_arg_blob_attr) = delete; + OpArgParallelAttribute& operator=(const OpArgParallelAttribute&) = delete; + OpArgParallelAttribute& operator=(OpArgParallelAttribute&&) = delete; virtual ~OpArgParallelAttribute() = default; std::shared_ptr parallel_desc_symbol() const; diff --git a/oneflow/core/job/id_manager.cpp b/oneflow/core/job/id_manager.cpp index f5b7de0d3d3..faad57f7bea 100644 --- a/oneflow/core/job/id_manager.cpp +++ b/oneflow/core/job/id_manager.cpp @@ -19,10 +19,6 @@ limitations under the License. namespace oneflow { -DeviceType IDMgr::GetDeviceTypeFromActorId(int64_t actor_id) const { - return DeserializeTaskIdFromInt64(actor_id).stream_id().device_id().device_type(); -} - int64_t IDMgr::MachineId4ActorId(int64_t actor_id) const { // TODO: change this inferface semantics, rank does not indicate machine_id in multi-client return DeserializeTaskIdFromInt64(actor_id).stream_id().device_id().rank(); diff --git a/oneflow/core/job/id_manager.h b/oneflow/core/job/id_manager.h index 053f5178118..e996f839437 100644 --- a/oneflow/core/job/id_manager.h +++ b/oneflow/core/job/id_manager.h @@ -35,7 +35,6 @@ class IDMgr final { int64_t NewChunkId() { return chunk_id_count_++; } // Runtime - DeviceType GetDeviceTypeFromActorId(int64_t actor_id) const; int64_t MachineId4ActorId(int64_t actor_id) const; int64_t ThrdId4ActorId(int64_t actor_id) const; diff --git a/oneflow/core/register/pod.proto b/oneflow/core/register/pod.proto deleted file mode 100644 index 928540504c8..00000000000 --- a/oneflow/core/register/pod.proto +++ /dev/null @@ -1,41 +0,0 @@ -syntax = "proto2"; -package oneflow; - -import "oneflow/core/common/shape.proto"; -import "oneflow/core/common/data_type.proto"; -import "oneflow/core/register/logical_blob_id.proto"; - -message TensorPodProto { - required ShapeProto shape = 1; - required DataType data_type = 2; -} - -message StructPodProto { - repeated FieldPodProto field = 1; -} - -enum FieldKey { - kInvalidFieldKey = 0; - kTensorShape = 1; - kFieldKeySize = 2; -} - -message FieldId { - oneof field_id_type { - FieldKey key = 1; - LogicalBlobId lbi = 2; - } -} - -message FieldPodProto { - required FieldId field_id = 1; - required int32 alignment = 2; - required PodProto pod = 3; -} - -message PodProto { - oneof pod_type { - TensorPodProto tensor_pod = 1; - StructPodProto struct_pod = 2; - } -} diff --git a/oneflow/core/register/pod_desc.cpp b/oneflow/core/register/pod_desc.cpp deleted file mode 100644 index bb5f55358d7..00000000000 --- a/oneflow/core/register/pod_desc.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/register/pod_desc.h" - -namespace oneflow { - -namespace { - -std::unique_ptr NewPodDesc(const PodProto& pod) { - if (pod.has_tensor_pod()) { return std::make_unique(pod.tensor_pod()); } - if (pod.has_struct_pod()) { return std::make_unique(pod.struct_pod()); } - // ignore field pod - UNIMPLEMENTED(); - return std::unique_ptr(); -} - -} // namespace - -FieldId NewFieldId(FieldKey key) { - FieldId ret; - ret.set_key(key); - return ret; -} - -FieldId NewFieldId(const LogicalBlobId& lbi) { - FieldId ret; - *ret.mutable_lbi() = lbi; - return ret; -} - -TensorPodDesc::TensorPodDesc() : PodDesc(), shape_(std::make_shared()) {} -TensorPodDesc::TensorPodDesc(const Shape& shape, DataType data_type) - : PodDesc(), shape_(std::make_shared(shape)), data_type_(data_type) {} -TensorPodDesc::TensorPodDesc(const std::shared_ptr& shape, DataType data_type) - : PodDesc(), shape_(shape), data_type_(data_type) {} - -TensorPodDesc::TensorPodDesc(const TensorPodProto& tensor_pod) { - shape_ = std::make_shared(); - InitFromProto(tensor_pod); -} - -TensorPodDesc::TensorPodDesc(const TensorPodDesc& tensor_pod) { - shape_ = std::make_shared(); - PodProto pod_proto; - tensor_pod.ToProto(&pod_proto); - InitFromProto(pod_proto.tensor_pod()); -} - -void TensorPodDesc::InitFromProto(const TensorPodProto& tensor_pod) { - *mut_shape() = Shape(tensor_pod.shape()); - data_type_ = tensor_pod.data_type(); -} - -size_t TensorPodDesc::ByteSize() const { - return shape().elem_cnt() * GetSizeOfDataType(data_type_); -} - -bool TensorPodDesc::operator==(const PodDesc& rhs) const { - const auto* tensor_rhs = dynamic_cast(&rhs); - if (tensor_rhs == nullptr) { return false; } - return shape() == tensor_rhs->shape() && data_type() == tensor_rhs->data_type(); -} - -void TensorPodDesc::ToProto(PodProto* pod_proto) const { ToProto(pod_proto->mutable_tensor_pod()); } - -void TensorPodDesc::ToProto(TensorPodProto* proto) const { - shape().ToProto(proto->mutable_shape()); - proto->set_data_type(data_type_); -} - -FieldPodDesc::FieldPodDesc(const FieldPodProto& field_pod) { - field_id_ = field_pod.field_id(); - pod_ = NewPodDesc(field_pod.pod()); - alignment_ = field_pod.alignment(); -} - -size_t FieldPodDesc::ByteSize() const { return RoundUp(pod_->ByteSize(), alignment_); } - -bool FieldPodDesc::operator==(const PodDesc& rhs) const { - const auto* field_rhs = dynamic_cast(&rhs); - if (field_rhs == nullptr) { return false; } - return field_id() == field_rhs->field_id() && pod() == field_rhs->pod() - && alignment_ == field_rhs->alignment_; -} - -void FieldPodDesc::ToProto(FieldPodProto* field_pod_proto) const { - *field_pod_proto->mutable_field_id() = field_id_; - field_pod_proto->set_alignment(alignment_); - pod_->ToProto(field_pod_proto->mutable_pod()); -} - -StructPodDesc::StructPodDesc(const StructPodProto& struct_pod_proto) { - InitFromProto(struct_pod_proto); -} - -StructPodDesc::StructPodDesc(const StructPodDesc& struct_pod_desc) { *this = struct_pod_desc; } - -void StructPodDesc::InitFromProto(const StructPodProto& struct_pod) { - Clear(); - for (const auto& field : struct_pod.field()) { - std::unique_ptr pod(new FieldPodDesc(field)); - AddField(std::move(pod)); - } -} - -size_t StructPodDesc::ByteSize() const { - size_t size = 0; - for (const auto& field : fields_) { size += field->ByteSize(); } - return size; -} - -bool StructPodDesc::operator==(const PodDesc& rhs) const { - const auto* struct_rhs = dynamic_cast(&rhs); - if (struct_rhs == nullptr) { return false; } - if (field_id2field_idx_ != struct_rhs->field_id2field_idx_) { return false; } - for (int i = 0; i < field_id2field_idx_.size(); ++i) { - if (*fields_.at(i) != *struct_rhs->fields_.at(i)) { return false; } - } - return true; -} - -void StructPodDesc::ToProto(StructPodProto* struct_pod_proto) const { - struct_pod_proto->Clear(); - for (const auto& field : fields_) { field->ToProto(struct_pod_proto->add_field()); } -} - -bool StructPodDesc::HasField(const FieldId& field_id) const { - return field_id2field_idx_.find(field_id) != field_id2field_idx_.end(); -} - -StructPodDesc* StructPodDesc::MutStructField(const FieldId& field_id) { - return MutStructField(field_id, 1); -} - -StructPodDesc* StructPodDesc::MutStructField(const FieldId& field_id, int32_t alignment) { - if (!HasField(field_id)) { AddField(field_id, std::make_unique(), alignment); } - return MutExistedField(field_id)->MutCast(); -} - -PodDesc* StructPodDesc::MutExistedField(const FieldId& field_id) { - return fields_.at(field_id2field_idx_.at(field_id))->mut_pod(); -} - -const PodDesc& StructPodDesc::Field(const FieldId& field_id) const { - return fields_.at(field_id2field_idx_.at(field_id))->pod(); -} - -void StructPodDesc::AddField(FieldKey field_key, const PodDesc& pod_desc) { - return AddField(NewFieldId(field_key), pod_desc); -} - -void StructPodDesc::AddField(const FieldId& field_id, const PodDesc& pod_desc) { - return AddField(field_id, pod_desc, 1); -} - -void StructPodDesc::AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment) { - AddField(field_id, pod_desc.Clone(), alignment); -} - -void StructPodDesc::AddField(const FieldId& field_id, std::unique_ptr&& field, - size_t alignment) { - auto* pod = new FieldPodDesc(field_id, std::move(field), alignment); - AddField(std::unique_ptr(pod)); -} - -void StructPodDesc::AddField(std::unique_ptr&& field) { - CHECK(field_id2field_idx_.emplace(field->field_id(), fields_.size()).second); - fields_.emplace_back(std::move(field)); -} - -size_t StructPodDesc::ByteOffset4Field(const FieldId& field_id) const { - CHECK(HasField(field_id)); - size_t offset = 0; - for (int32_t i = 0; i < field_id2field_idx_.at(field_id); ++i) { - offset += fields_.at(i)->ByteSize(); - } - return offset; -} - -StructPodDesc& StructPodDesc::operator=(const StructPodDesc& struct_pod_desc) { - Clear(); - StructPodProto struct_pod_proto; - struct_pod_desc.ToProto(&struct_pod_proto); - InitFromProto(struct_pod_proto); - return *this; -} - -void StructPodDesc::Clear() { - CHECK_EQ(fields_.size(), field_id2field_idx_.size()); - fields_.clear(); - field_id2field_idx_.clear(); -} - -} // namespace oneflow diff --git a/oneflow/core/register/pod_desc.h b/oneflow/core/register/pod_desc.h deleted file mode 100644 index f2c604ad116..00000000000 --- a/oneflow/core/register/pod_desc.h +++ /dev/null @@ -1,176 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_REGISTER_POD_DESC_H_ -#define ONEFLOW_CORE_REGISTER_POD_DESC_H_ - -#include "oneflow/core/common/util.h" -#include "oneflow/core/common/data_type.h" -#include "oneflow/core/common/protobuf.h" -#include "oneflow/core/common/shape.h" -#include "oneflow/core/register/pod.pb.h" - -namespace std { - -template<> -struct hash { - size_t operator()(const oneflow::FieldId& field_id) const { - if (field_id.has_key()) { return std::hash()(field_id.key()); } - if (field_id.has_lbi()) { return std::hash()(field_id.lbi()); } - UNIMPLEMENTED(); - } -}; - -} // namespace std - -namespace oneflow { - -FieldId NewFieldId(FieldKey key); -FieldId NewFieldId(const LogicalBlobId& lbi); -inline bool operator==(const FieldId& lhs, const FieldId& rhs) { - PbMd message_diff; - return message_diff.Equivalent(lhs, rhs); -} - -class PodDesc { - public: - OF_DISALLOW_COPY_AND_MOVE(PodDesc); - PodDesc() = default; - virtual ~PodDesc() = default; - - template - const T& Cast() const; - template - T* MutCast(); - - virtual size_t ByteSize() const = 0; - virtual void ToProto(PodProto* pod_proto) const = 0; - virtual std::unique_ptr Clone() const = 0; - virtual bool operator==(const PodDesc& rhs) const = 0; - bool operator!=(const PodDesc& rhs) const { return !(*this == rhs); } -}; - -class TensorPodDesc final : public PodDesc { - public: - TensorPodDesc(); - TensorPodDesc(const Shape& shape, DataType data_type); - TensorPodDesc(const std::shared_ptr& shape, DataType data_type); - explicit TensorPodDesc(const TensorPodProto& shape_pod_proto); - explicit TensorPodDesc(const TensorPodDesc& shape_pod); - ~TensorPodDesc() = default; - const Shape& shape() const { return *CHECK_NOTNULL(shape_.get()); } - DataType data_type() const { return data_type_; } - DataType* mut_data_type() { return &data_type_; } - Shape* mut_shape() { return CHECK_NOTNULL(shape_.get()); } - void set_data_type(DataType data_type) { data_type_ = data_type; } - - void InitFromProto(const TensorPodProto& shape_pod); - - size_t ByteSize() const override; - void ToProto(PodProto* pod_proto) const override; - void ToProto(TensorPodProto* pod_proto) const; - std::unique_ptr Clone() const override { return std::make_unique(*this); } - bool operator==(const PodDesc& rhs) const override; - - private: - std::shared_ptr shape_; - DataType data_type_; -}; - -class FieldPodDesc; - -class StructPodDesc final : public PodDesc { - public: - StructPodDesc() = default; - explicit StructPodDesc(const StructPodProto&); - explicit StructPodDesc(const StructPodDesc&); - ~StructPodDesc() = default; - - StructPodDesc* MutStructField(const FieldId& field_id); - StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment); - const PodDesc& Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); } - const PodDesc& Field(const FieldId& field_id) const; - void AddField(FieldKey field_key, const PodDesc& pod_desc); - void AddField(const FieldId& field_id, const PodDesc& pod_desc); - void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment); - bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); } - bool HasField(const FieldId& field_id) const; - PodDesc* MutExistedField(FieldKey field_key) { return MutExistedField(NewFieldId(field_key)); } - - std::unique_ptr Clone() const override { return std::make_unique(*this); } - void InitFromProto(const StructPodProto& struct_pod); - void ToProto(PodProto* pod_proto) const override { ToProto(pod_proto->mutable_struct_pod()); } - void ToProto(StructPodProto* pod_proto) const; - - size_t ByteOffset4Field(const FieldId& field_name) const; - size_t ByteSize() const override; - - StructPodDesc& operator=(const StructPodDesc&); - bool operator==(const PodDesc& rhs) const override; - - private: - PodDesc* MutExistedField(const FieldId& field_id); - void Clear(); - void AddField(std::unique_ptr&& field); - void AddField(const FieldId& field_id, std::unique_ptr&& field); - void AddField(const FieldId& field_id, std::unique_ptr&& field, size_t alignment); - - std::vector> fields_; - HashMap field_id2field_idx_; -}; - -class FieldPodDesc final : public PodDesc { - public: - OF_DISALLOW_COPY_AND_MOVE(FieldPodDesc); - ~FieldPodDesc() = default; - - private: - friend class StructPodDesc; - FieldPodDesc(const FieldId& field_id, std::unique_ptr&& pod, size_t alignment) - : PodDesc(), field_id_(field_id), pod_(std::move(pod)), alignment_(alignment) {} - explicit FieldPodDesc(const FieldPodProto& field_pod_proto); - - size_t ByteSize() const override; - void ToProto(PodProto* pod_proto) const override { UNIMPLEMENTED(); } - std::unique_ptr Clone() const override { UNIMPLEMENTED(); } - void ToProto(FieldPodProto* field_proto) const; - bool operator==(const PodDesc& rhs) const override; - - const PodDesc& pod() const { return *pod_; } - const FieldId& field_id() const { return field_id_; } - PodDesc* mut_pod() { return pod_.get(); } - - FieldId field_id_; - std::unique_ptr pod_; - size_t alignment_; -}; - -template -const T& PodDesc::Cast() const { - static_assert(std::is_same::value || std::is_same::value, - "only TensorPodDesc and StructPodDesc supported"); - return *dynamic_cast(this); -} - -template -T* PodDesc::MutCast() { - static_assert(std::is_same::value || std::is_same::value, - "only TensorPodDesc and StructPodDesc supported"); - return dynamic_cast(this); -} - -} // namespace oneflow - -#endif // ONEFLOW_CORE_REGISTER_POD_DESC_H_ diff --git a/oneflow/core/register/pod_ptr.cpp b/oneflow/core/register/pod_ptr.cpp deleted file mode 100644 index a91665e3e71..00000000000 --- a/oneflow/core/register/pod_ptr.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/register/pod_ptr.h" - -namespace oneflow { - -PodPtr PodPtrField(const PodDesc* pod_desc, const FieldId& field_id, char* ptr) { - const auto* struct_pod = dynamic_cast(pod_desc); - CHECK_NOTNULL(struct_pod); - return PodPtr(struct_pod->Field(field_id), ptr + struct_pod->ByteOffset4Field(field_id)); -} - -bool PodPtr::HasField(const FieldId& field_id) const { - const auto* struct_pod = dynamic_cast(pod_desc_); - return struct_pod && struct_pod->HasField(field_id); -} - -const PodPtr PodPtr::Field(const FieldId& field_id) const { - return PodPtrField(pod_desc_, field_id, ptr_); -} - -PodPtr PodPtr::MutField(const FieldId& field_id) { return PodPtrField(pod_desc_, field_id, ptr_); } - -} // namespace oneflow diff --git a/oneflow/core/register/pod_ptr.h b/oneflow/core/register/pod_ptr.h deleted file mode 100644 index c8cd88ea0a4..00000000000 --- a/oneflow/core/register/pod_ptr.h +++ /dev/null @@ -1,96 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_REGISTER_POD_PTR_H_ -#define ONEFLOW_CORE_REGISTER_POD_PTR_H_ -#include "oneflow/core/register/pod_desc.h" -#include "oneflow/core/common/util.h" -#include "oneflow/core/common/data_type.h" - -namespace oneflow { - -class PodPtr final { - public: - PodPtr(const PodDesc& pod_desc, char* ptr) : pod_desc_(&pod_desc), ptr_(ptr) {} - PodPtr(const PodPtr&) = default; - ~PodPtr() = default; - - template - const T* TensorPtr() const; - template - const T* TensorPtr(FieldKey field_key) const { - return TensorPtr(field_key, nullptr); - } - template - const T* TensorPtr(FieldKey field_key, const T* default_ptr) const; - - template - T* MutTensorPtr(); - template - T* MutTensorPtr(FieldKey field_key) { - return MutTensorPtr(field_key, nullptr); - } - template - T* MutTensorPtr(FieldKey field_key, T* default_ptr); - - const PodDesc& pod_desc() const { return *pod_desc_; } - char* ptr() const { return ptr_; } - bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); } - const PodPtr Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); } - PodPtr MutField(FieldKey field_key) { return MutField(NewFieldId(field_key)); } - - bool HasField(const FieldId& field_id) const; - const PodPtr Field(const FieldId& field_id) const; - PodPtr MutField(const FieldId& field_id); - - private: - template - void CheckDataType() const { - const auto* tensor_pod = dynamic_cast(pod_desc_); - CHECK_NOTNULL(tensor_pod); - CHECK_EQ(tensor_pod->data_type(), GetDataType::value); - } - - const PodDesc* const pod_desc_; - char* const ptr_; -}; - -template -const T* PodPtr::TensorPtr(FieldKey field_key, const T* default_ptr) const { - if (!HasField(field_key)) { return default_ptr; } - return Field(field_key).template TensorPtr(); -} - -template -T* PodPtr::MutTensorPtr(FieldKey field_key, T* default_ptr) { - if (!HasField(field_key)) { return default_ptr; } - return MutField(field_key).template MutTensorPtr(); -} - -template -const T* PodPtr::TensorPtr() const { - CheckDataType(); - return reinterpret_cast(ptr_); -} - -template -T* PodPtr::MutTensorPtr() { - CheckDataType(); - return reinterpret_cast(ptr_); -} - -} // namespace oneflow - -#endif // ONEFLOW_CORE_REGISTER_POD_PTR_H_