Skip to content

Commit

Permalink
Dev functional batch_gather (#6233)
Browse files Browse the repository at this point in the history
* add broadcast like docs

* unsorted batch segment sum functional

* add unittest

* add docs

* add batch gather docs rst

* fix doc code

Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
MARD1NO and oneflow-ci-bot authored Sep 13, 2021
1 parent 8f67c6b commit 1868c19
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 30 deletions.
2 changes: 2 additions & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ oneflow
atan2,
atanh,
bernoulli,
broadcast_like,
batch_gather,
cat,
cast,
ceil,
Expand Down
18 changes: 4 additions & 14 deletions oneflow/core/autograd/gradient_funcs/batch_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {
Expand All @@ -34,17 +32,11 @@ class BatchGather : public OpExprGradFunction<BatchGatherCaptureState> {
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
std::shared_ptr<OpExpr> bw_unsorted_batch_segment_sum_op_;
};

Maybe<void> BatchGather::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
const std::string& op_name = fw_op_expr->op_name();
bw_unsorted_batch_segment_sum_op_ =
JUST(op_expr_helper::UnsortedBatchSegmentSumOp(/*num_segments=*/1, GradientOpName(op_name)));
return Maybe<void>::Ok();
}

Expand All @@ -64,10 +56,8 @@ Maybe<void> BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorT
in_grads->resize(2);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& indices = ctx->SavedTensors().at(0);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("num_segments", ctx->num_segments));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*bw_unsorted_batch_segment_sum_op_,
{out_grads.at(0), indices}, attrs));
in_grads->at(0) =
JUST(functional::UnsortedBatchSegmentSum(out_grads.at(0), indices, ctx->num_segments));
return Maybe<void>::Ok();
}

Expand Down
12 changes: 0 additions & 12 deletions oneflow/core/framework/op_expr_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,6 @@ Maybe<one::UserOpExpr> ConcatOp(const int& n, const int64_t& axis, const int64_t
.Build();
}

Maybe<one::UserOpExpr> UnsortedBatchSegmentSumOp(const int& num_segments) {
return UnsortedBatchSegmentSumOp(num_segments, UniqueOpName("unsorted_batch_segment_sum"));
}
Maybe<one::UserOpExpr> UnsortedBatchSegmentSumOp(const int& num_segments, const std::string& name) {
return one::OpBuilder("unsorted_batch_segment_sum", name)
.Input("data")
.Input("segment_ids")
.Output("out")
.Attr<int32_t>("num_segments", num_segments)
.Build();
}

Maybe<one::UserOpExpr> ScalarAddByTensorOp() {
return ScalarAddByTensorOp(UniqueOpName("scalar_add_by_tensor"));
}
Expand Down
3 changes: 0 additions & 3 deletions oneflow/core/framework/op_expr_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ Maybe<one::UserOpExpr> ConcatOp(const int& n, const int64_t& axis, const int64_t
Maybe<one::UserOpExpr> ConcatOp(const int& n, const int64_t& axis, const int64_t& max_dim_size,
const std::string& name);

Maybe<one::UserOpExpr> UnsortedBatchSegmentSumOp(const int& num_segments);
Maybe<one::UserOpExpr> UnsortedBatchSegmentSumOp(const int& num_segments, const std::string& name);

Maybe<one::UserOpExpr> ScalarAddByTensorOp();
Maybe<one::UserOpExpr> ScalarAddByTensorOp(const std::string& name);

Expand Down
15 changes: 15 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1221,3 +1221,18 @@
- name: "recv"
signature: "Tensor (Int64 src, Shape shape=None, DataType dtype=None, Device device=None, *, Tensor out=None) => Recv"
bind_python: True

- name: "batch_gather"
signature:
"Tensor (Tensor in, Tensor indices) => BatchGather"
bind_python: True

- name: "batch_gather"
signature:
"Tensor (Tensor in, Tensor indices) => BatchGather"
bind_python: True

- name: "unsorted_batch_segment_sum"
signature:
"Tensor (Tensor data, Tensor segment_ids, Int64 num_segments) => UnsortedBatchSegmentSum"
bind_python: False
38 changes: 38 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,42 @@ class SplitWithSizeFunctor {
}
};

class BatchGatherFunctor {
public:
BatchGatherFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("batch_gather").Input("in").Input("indices").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,
const std::shared_ptr<one::Tensor>& indices) const {
return OpInterpUtil::Dispatch<Tensor>(*op_, {in, indices});
}

protected:
std::shared_ptr<OpExpr> op_;
};

class UnsortedBatchSegmentSumFunctor {
public:
UnsortedBatchSegmentSumFunctor() {
op_ = CHECK_JUST(one::OpBuilder("unsorted_batch_segment_sum")
.Input("data")
.Input("segment_ids")
.Output("out")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& data,
const std::shared_ptr<one::Tensor>& segment_ids,
const int64_t& num_segments) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("num_segments", num_segments));
return OpInterpUtil::Dispatch<Tensor>(*op_, {data, segment_ids}, attrs);
}

protected:
std::shared_ptr<OpExpr> op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -1708,6 +1744,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SplitFunctor>("Split");
m.add_functor<impl::SplitLikeFunctor>("SplitLike");
m.add_functor<impl::SplitWithSizeFunctor>("SplitWithSize");
m.add_functor<impl::BatchGatherFunctor>("BatchGather");
m.add_functor<impl::UnsortedBatchSegmentSumFunctor>("UnsortedBatchSegmentSum");
};

} // namespace functional
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def atexit_hook(hook):
del oneflow

import oneflow._C
from oneflow._C import tensor
from oneflow._C import tensor, batch_gather

from oneflow.autograd import grad_enable, no_grad, inference_mode, is_grad_enabled
import oneflow.nn.image
Expand Down
46 changes: 46 additions & 0 deletions python/oneflow/framework/docstr/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,49 @@
""",
)

add_docstr(
oneflow.batch_gather,
"""Gather the element in batch dims.
Args:
in (Tensor): the input tensor.
indices (Tensor): the indices tensor, its dtype must be int32/64.
For example:
Example 1:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> x = flow.Tensor(np.array([[1, 2, 3],
... [4, 5, 6]]))
>>> indices = flow.tensor(np.array([1, 0]).astype(np.int64))
>>> out = flow.batch_gather(x, indices)
tensor([[4., 5., 6.],
[1., 2., 3.]], dtype=oneflow.float32)
Example 2:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> x = flow.Tensor(np.array([[[1, 2, 3], [4, 5, 6]],
... [[1, 2, 3], [4, 5, 6]]]))
>>> indices = flow.tensor(np.array([[1, 0],
... [0, 1]]).astype(np.int64))
>>> out = flow.batch_gather(x, indices)
tensor([[[4., 5., 6.],
[1., 2., 3.]],
[[1., 2., 3.],
[4., 5., 6.]]], dtype=oneflow.float32)
""",
)
23 changes: 23 additions & 0 deletions python/oneflow/nn/modules/broadcast_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,27 @@ def forward(self, x, like_tensor):


def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None):
"""This operator broadcast tensor `x` to `like_tensor` according to the broadcast_axes.
Args:
x (Tensor): The input Tensor.
like_tensor (Tensor): The like Tensor.
broadcast_axes (Optional[Sequence], optional): The axes you want to broadcast. Defaults to None.
Returns:
[Tensor]: Broadcasted input Tensor.
For example:
.. code:: python
>>> import oneflow as flow
>>> x = flow.randn(3, 1, 1)
>>> like_tensor = flow.randn(3, 4, 5)
>>> broadcast_tensor = flow.broadcast_like(x, like_tensor, broadcast_axes=[1, 2])
>>> broadcast_tensor.shape
oneflow.Size([3, 4, 5])
"""
return BroadCastLike(broadcast_axes=broadcast_axes)(x, like_tensor)
85 changes: 85 additions & 0 deletions python/oneflow/test/modules/test_batch_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from collections import OrderedDict
import os

import numpy as np
from test_util import GenArgList

import oneflow as flow
import oneflow.unittest


def _test_batch_gather(test_case, shape, device):
# for example: shape = (3, 2, 2)
x = np.random.randn(*shape)
x_tensor = flow.Tensor(x).to(device)
x_tensor.requires_grad = True
batchsize = x.shape[0]
init_index = np.array(
[np.random.randint(batchsize) for i in range(batchsize)]
).astype(np.int64)

batch_gather_index = flow.tensor(init_index).to(device)
batch_gather_out = flow.batch_gather(x_tensor, batch_gather_index)

x_tensor_gather = flow.Tensor(x).to(device)
x_tensor_gather.requires_grad = True
reshaped_shape = [batchsize] # reshaped_shape = [3]
for i in range(len(x.shape) - 1):
reshaped_shape.append(1) # reshaped_shape = [3] -> [3, 1, 1]

gather_index = np.reshape(init_index, reshaped_shape)
gather_index = np.broadcast_to(gather_index, shape).astype(
np.int64
) # [3, 1, 1] -> [3, 2, 2]
gather_index = flow.tensor(gather_index).to(device)
gather_out = flow.gather(x_tensor_gather, gather_index, dim=0)
total_out = batch_gather_out.sum() + gather_out.sum()
total_out.backward()

test_case.assertTrue(
np.allclose(batch_gather_out.numpy(), gather_out.numpy(), atol=1e-4, rtol=1e-4)
)

test_case.assertTrue(
np.allclose(
x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4,
)
)
test_case.assertTrue(
np.allclose(
x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4,
)
)


@flow.unittest.skip_unless_1n1d()
class TestBatchGather(flow.unittest.TestCase):
def test_batch_gather(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_batch_gather]
arg_dict["shape"] = [(3, 2, 2), (3, 2, 4, 2), (3, 3, 4, 2, 2), (4, 2)]
arg_dict["device"] = ["cpu", "cuda"]

for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])


if __name__ == "__main__":
unittest.main()

0 comments on commit 1868c19

Please sign in to comment.