Skip to content

Commit

Permalink
fix default stride value (#6248)
Browse files Browse the repository at this point in the history
* fix default stride value

* rename

* Remove numpy

* fix autotest to default

* modify x to input
  • Loading branch information
MARD1NO authored Sep 13, 2021
1 parent 1868c19 commit 1f55e3a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 29 deletions.
12 changes: 6 additions & 6 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -638,23 +638,23 @@

- name: "max_pool1d"
signature:
'TensorTuple (Tensor x, Int32List[1] kernel_size, Int32List[1] stride=1,
'TensorTuple (Tensor input, Int32List[1] kernel_size, Int32List[1] stride=None,
Int32List[1] padding=0, Int32List[1] dilation=1,
Bool return_indices=True, Bool ceil_mode=False,
String data_format="channels_first") => Maxpool1D'
bind_python: True

- name: "max_pool2d"
signature:
'TensorTuple (Tensor x, Int32List[2] kernel_size, Int32List[2] stride=1,
'TensorTuple (Tensor input, Int32List[2] kernel_size, Int32List[2] stride=None,
Int32List[2] padding=0, Int32List[2] dilation=1,
Bool return_indices=True, Bool ceil_mode=False,
String data_format="channels_first") => Maxpool2D'
bind_python: True

- name: "max_pool3d"
signature:
'TensorTuple (Tensor x, Int32List[3] kernel_size, Int32List[3] stride=1,
'TensorTuple (Tensor input, Int32List[3] kernel_size, Int32List[3] stride=None,
Int32List[3] padding=0, Int32List[3] dilation=1,
Bool return_indices=True, Bool ceil_mode=False,
String data_format="channels_first") => Maxpool3D'
Expand Down Expand Up @@ -996,21 +996,21 @@

- name: "avg_pool1d"
signature:
'Tensor (Tensor x, Int32List[1] kernel_size, Int32List[1] stride=1,
'Tensor (Tensor input, Int32List[1] kernel_size, Int32List[1] stride=None,
Int32List[1] padding=0, Bool ceil_mode=False, Bool count_include_pad=True,
Int64 divisor_override=0, String data_format="channels_first") => Avgpool1D'
bind_python: True

- name: "avg_pool2d"
signature:
'Tensor (Tensor x, Int32List[2] kernel_size, Int32List[2] stride=1,
'Tensor (Tensor input, Int32List[2] kernel_size, Int32List[2] stride=None,
Int32List[2] padding=0, Bool ceil_mode=False, Bool count_include_pad=True,
Int64 divisor_override=0, String data_format="channels_first") => Avgpool2D'
bind_python: True

- name: "avg_pool3d"
signature:
'Tensor (Tensor x, Int32List[3] kernel_size, Int32List[3] stride=1,
'Tensor (Tensor input, Int32List[3] kernel_size, Int32List[3] stride=None,
Int32List[3] padding=0, Bool ceil_mode=False, Bool count_include_pad=True,
Int64 divisor_override=0, String data_format="channels_first") => Avgpool3D'
bind_python: True
Expand Down
23 changes: 17 additions & 6 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,20 @@ class PoolingNDFunctor {
virtual ~PoolingNDFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,
const std::vector<int32_t>& kernel_size,
const std::vector<int32_t>& stride,
const Optional<std::vector<int32_t>>& stride,
const std::vector<int32_t>& padding,
const std::vector<int32_t>& dilation, const bool& return_indices,
const bool& ceil_mode, const std::string& data_format) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("data_format", data_format));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding));
JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
if (stride.has_value()) {
JUST(attrs.SetAttr<std::vector<int32_t>>("stride", *JUST(stride.value())));
} else {
JUST(attrs.SetAttr<std::vector<int32_t>>(
"stride", kernel_size)); // If stride is None, we set it as kernel_size to align Pytorch.
}
JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
JUST(attrs.SetAttr<bool>("return_indices", return_indices));
JUST(attrs.SetAttr<bool>("ceil_mode", ceil_mode));
Expand Down Expand Up @@ -820,14 +825,20 @@ class AvgPoolingNDFunctor {
virtual ~AvgPoolingNDFunctor() = default;
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::vector<int32_t>& kernel_size,
const std::vector<int32_t>& stride, const std::vector<int32_t>& padding,
const bool& ceil_mode, const bool& count_include_pad,
const int64_t& divisor_override, const std::string& data_format) const {
const Optional<std::vector<int32_t>>& stride,
const std::vector<int32_t>& padding, const bool& ceil_mode,
const bool& count_include_pad, const int64_t& divisor_override,
const std::string& data_format) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("data_format", data_format));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding));
JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
if (stride.has_value()) {
JUST(attrs.SetAttr<std::vector<int32_t>>("stride", *JUST(stride.value())));
} else {
JUST(attrs.SetAttr<std::vector<int32_t>>(
"stride", kernel_size)); // If stride is None, we set it as kernel_size to align Pytorch.
}
JUST(attrs.SetAttr<bool>("ceil_mode", ceil_mode));
JUST(attrs.SetAttr<bool>("count_include_pad", count_include_pad));
JUST(attrs.SetAttr<int64_t>("divisor_override", divisor_override));
Expand Down
6 changes: 3 additions & 3 deletions python/oneflow/nn/functional/functional_maxpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def max_pool1d(
x,
kernel_size,
stride=1,
stride=None,
padding=0,
dilation=1,
return_indices=False,
Expand All @@ -46,7 +46,7 @@ def max_pool1d(
def max_pool2d(
x,
kernel_size,
stride=1,
stride=None,
padding=0,
dilation=1,
return_indices=False,
Expand All @@ -72,7 +72,7 @@ def max_pool2d(
def max_pool3d(
x,
kernel_size,
stride=1,
stride=None,
padding=0,
dilation=1,
return_indices=False,
Expand Down
56 changes: 53 additions & 3 deletions python/oneflow/test/modules/test_avgpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import unittest

import oneflow as flow
from oneflow.test_utils.automated_test_util.generators import constant, random_bool
import oneflow.unittest
from automated_test_util import *


@flow.unittest.skip_unless_1n1d()
class TestAvgPoolingModule(flow.unittest.TestCase):
@autotest(n=100)
@autotest()
def test_avgpool1d_with_random_data(test_case):
m = torch.nn.AvgPool1d(
kernel_size=random(4, 6),
Expand All @@ -38,7 +39,7 @@ def test_avgpool1d_with_random_data(test_case):
y = m(x)
return y

@autotest(n=100)
@autotest()
def test_avgpool2d_with_random_data(test_case):
m = torch.nn.AvgPool2d(
kernel_size=random(4, 6),
Expand All @@ -57,7 +58,7 @@ def test_avgpool2d_with_random_data(test_case):
y = m(x)
return y

@autotest(n=100)
@autotest()
def test_avgpool3d_with_random_data(test_case):
m = torch.nn.AvgPool3d(
kernel_size=random(4, 6),
Expand All @@ -77,5 +78,54 @@ def test_avgpool3d_with_random_data(test_case):
return y


@flow.unittest.skip_unless_1n1d()
class TestAvgPoolingFunctional(flow.unittest.TestCase):
@autotest()
def test_avgpool1d_functional(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=3, dim2=random(20, 22)).to(device)
y = torch.nn.functional.avg_pool1d(
x,
kernel_size=random(1, 6).to(int),
stride=random(1, 3).to(int) | nothing(),
padding=random(1, 3).to(int),
ceil_mode=random_bool(),
count_include_pad=random_bool(),
)
return y

@autotest()
def test_avgpool2d_functional(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(
device
)
y = torch.nn.functional.avg_pool2d(
x,
kernel_size=random(1, 6).to(int),
stride=random(1, 3).to(int) | nothing(),
padding=random(1, 3).to(int),
ceil_mode=random_bool(),
count_include_pad=random_bool(),
)
return y

@autotest()
def test_avgpool3d_functional(test_case):
device = random_device()
x = random_pytorch_tensor(
ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22)
).to(device)
y = torch.nn.functional.avg_pool3d(
x,
kernel_size=random(1, 6).to(int),
stride=random(1, 3).to(int) | nothing(),
padding=random(1, 3).to(int),
ceil_mode=random_bool(),
count_include_pad=random_bool(),
)
return y


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def unpack_indices(dual_object):

@flow.unittest.skip_unless_1n1d()
class TestMaxPooling(flow.unittest.TestCase):
@autotest(n=100, auto_backward=False)
@autotest(auto_backward=False)
def test_maxpool1d_with_random_data(test_case):
return_indices = random().to(bool).value()
m = torch.nn.MaxPool1d(
Expand All @@ -50,7 +50,7 @@ def test_maxpool1d_with_random_data(test_case):
else:
return y, y.sum().backward()

@autotest(n=100, auto_backward=False)
@autotest(auto_backward=False)
def test_maxpool2d_with_random_data(test_case):
return_indices = random().to(bool).value()
m = torch.nn.MaxPool2d(
Expand All @@ -74,7 +74,7 @@ def test_maxpool2d_with_random_data(test_case):
else:
return y, y.sum().backward()

@autotest(n=100, auto_backward=False)
@autotest(auto_backward=False)
def test_maxpool3d_with_random_data(test_case):
return_indices = random().to(bool).value()
m = torch.nn.MaxPool3d(
Expand All @@ -99,5 +99,72 @@ def test_maxpool3d_with_random_data(test_case):
return y, y.sum().backward()


@flow.unittest.skip_unless_1n1d()
class TestMaxPoolingFunctional(flow.unittest.TestCase):
@autotest(auto_backward=False)
def test_maxpool1d_with_random_data(test_case):
return_indices = random().to(bool).value()
device = random_device()
x = random_pytorch_tensor(ndim=3, dim2=random(20, 22)).to(device)
y = torch.nn.functional.max_pool1d(
x,
kernel_size=random(4, 6).to(int),
stride=random(1, 3).to(int) | nothing(),
padding=random(1, 3).to(int) | nothing(),
dilation=random(2, 4).to(int) | nothing(),
ceil_mode=random().to(bool),
return_indices=return_indices,
)

if return_indices:
return unpack_indices(y)
else:
return y, y.sum().backward()

@autotest(auto_backward=False)
def test_maxpool2d_with_random_data(test_case):
return_indices = random().to(bool).value()
device = random_device()
x = random_pytorch_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(
device
)
y = torch.nn.functional.max_pool2d(
x,
kernel_size=random(4, 6).to(int),
stride=random(1, 3).to(int) | nothing(),
padding=random(1, 3).to(int) | nothing(),
dilation=random(2, 4).to(int) | nothing(),
ceil_mode=random().to(bool),
return_indices=return_indices,
)

if return_indices:
return unpack_indices(y)
else:
return y, y.sum().backward()

@autotest(auto_backward=False)
def test_maxpool3d_with_random_data(test_case):
return_indices = random().to(bool).value()
device = random_device()
x = random_pytorch_tensor(
ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22)
).to(device)
y = torch.nn.functional.max_pool3d(
x,
kernel_size=random(4, 6).to(int),
stride=random(1, 3).to(int) | nothing(),
padding=random(1, 3).to(int) | nothing(),
dilation=random(2, 4).to(int) | nothing(),
ceil_mode=random().to(bool),
return_indices=return_indices,
)

if return_indices:
return unpack_indices(y)
else:
return y, y.sum().backward()


if __name__ == "__main__":
unittest.main()
17 changes: 9 additions & 8 deletions tools/functional/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,17 @@ def __init__(self, fmt, keyword_only=False):
sp = self._name.find("=")
if sp != -1:
self._default_value = _normalize(self._name[sp + 1 :])
if self._type.endswith("List"):
_value_list = [
self._default_value for i in range(self._size)
] # For int32List[2] = 1, _value_list will be [1, 1]
self._default_cpp_value = (
"{" + ", ".join(_value_list) + "}"
) # [1, 1] -> "{1, 1}"
elif self._default_value == "None":
if self._default_value == "None":
self._optional = True
self._default_cpp_value = ""
elif self._type.endswith("List"):
if self._default_value != "None":
_value_list = [
self._default_value for i in range(self._size)
] # For int32List[2] = 2, _value_list will be ["2", "2"]
self._default_cpp_value = (
"{" + ", ".join(_value_list) + "}"
) # ["2", "2"] -> "{2, 2}"
elif self._default_value in value_aliases:
self._default_cpp_value = value_aliases[self._default_value]
else:
Expand Down

0 comments on commit 1f55e3a

Please sign in to comment.