From 30e8568eb49ba77c09e5cc06a04d54d5b0bc52a9 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 13 Oct 2022 11:27:46 +0200 Subject: [PATCH 01/23] Raise exception for different size placeholders --- returnn/exceptions/__init__.py | 0 returnn/exceptions/dimension_exception.py | 17 +++++++++++++++++ returnn/tf/util/data.py | 9 ++++----- 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 returnn/exceptions/__init__.py create mode 100644 returnn/exceptions/dimension_exception.py diff --git a/returnn/exceptions/__init__.py b/returnn/exceptions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/returnn/exceptions/dimension_exception.py b/returnn/exceptions/dimension_exception.py new file mode 100644 index 0000000000..a917376c20 --- /dev/null +++ b/returnn/exceptions/dimension_exception.py @@ -0,0 +1,17 @@ +class DimensionException(Exception): + """ Dimension Exception error, + Raised if the dimension of 2 values are not the same + before performing specific operation" + + Attributes: + dimA : dimensionA + dimB : dimensionB + message -- explanation of the error + """ + + def __init__(self, dimA,dimB, message="Dimensions are not equals"): + self.dimA = dimA + self.dimB = dimB + self.message = message + super().__init__(self.message) + \ No newline at end of file diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index e804844123..3b34f91b1b 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -12,10 +12,12 @@ import tensorflow as tf import traceback +from returnn.exceptions.dimension_exception import DimensionException from returnn.util.basic import NotSpecified, Entity import returnn.tf.compat as tf_compat + class Dim(object): """ This identifies one axis/dimension, like a time-dimension, etc. @@ -984,11 +986,8 @@ def declare_same_as(self, other): if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - # Note: Instead of making this a warning, we could also enforce this at some point. - # The user should be able to fix `extern_data` in the config such that this is correct in the first place. - # Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes. - print( - "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % ( + if not self.dyn_size == other_same_base: + raise DimensionException(self.dyn_size, other_same_base.dyn_size, "dim tags are same with different size placeholders: %r vs %r please check external_data" % ( self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. From f5df7897b5c10798d4babddf33b919f49fdf8477 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 13 Oct 2022 11:30:59 +0200 Subject: [PATCH 02/23] fixed typo --- returnn/tf/util/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 3b34f91b1b..b036e57a18 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -987,7 +987,7 @@ def declare_same_as(self, other): if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: if not self.dyn_size == other_same_base: - raise DimensionException(self.dyn_size, other_same_base.dyn_size, "dim tags are same with different size placeholders: %r vs %r please check external_data" % ( + raise DimensionException(self.dyn_size, other_same_base.dyn_size, "Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. From 5374abb6b0f57aa1fc3039a38ec3dbc5ea0f46a1 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 13 Oct 2022 15:24:42 +0200 Subject: [PATCH 03/23] fixed typo for pep8 --- returnn/exceptions/dimension_exception.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/returnn/exceptions/dimension_exception.py b/returnn/exceptions/dimension_exception.py index a917376c20..d718977847 100644 --- a/returnn/exceptions/dimension_exception.py +++ b/returnn/exceptions/dimension_exception.py @@ -1,7 +1,5 @@ class DimensionException(Exception): - """ Dimension Exception error, - Raised if the dimension of 2 values are not the same - before performing specific operation" + """ Dimension Exception error. Raised if the dimension are not the same" Attributes: dimA : dimensionA @@ -9,9 +7,8 @@ class DimensionException(Exception): message -- explanation of the error """ - def __init__(self, dimA,dimB, message="Dimensions are not equals"): + def __init__(self, dimA, dimB, message="Dimensions are not equals"): self.dimA = dimA self.dimB = dimB self.message = message super().__init__(self.message) - \ No newline at end of file From b191839117683c5682d73a282f0d8152050990ec Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 13 Oct 2022 15:30:06 +0200 Subject: [PATCH 04/23] removed dedicated exception --- returnn/exceptions/__init__.py | 0 returnn/exceptions/dimension_exception.py | 14 -------------- returnn/tf/util/data.py | 2 +- 3 files changed, 1 insertion(+), 15 deletions(-) delete mode 100644 returnn/exceptions/__init__.py delete mode 100644 returnn/exceptions/dimension_exception.py diff --git a/returnn/exceptions/__init__.py b/returnn/exceptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/returnn/exceptions/dimension_exception.py b/returnn/exceptions/dimension_exception.py deleted file mode 100644 index d718977847..0000000000 --- a/returnn/exceptions/dimension_exception.py +++ /dev/null @@ -1,14 +0,0 @@ -class DimensionException(Exception): - """ Dimension Exception error. Raised if the dimension are not the same" - - Attributes: - dimA : dimensionA - dimB : dimensionB - message -- explanation of the error - """ - - def __init__(self, dimA, dimB, message="Dimensions are not equals"): - self.dimA = dimA - self.dimB = dimB - self.message = message - super().__init__(self.message) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index b036e57a18..1f266abdb8 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -987,7 +987,7 @@ def declare_same_as(self, other): if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: if not self.dyn_size == other_same_base: - raise DimensionException(self.dyn_size, other_same_base.dyn_size, "Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( + raise Exception("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. From d5acc2e1089325df9e26cd9898bfb73e427c8138 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 13 Oct 2022 15:40:09 +0200 Subject: [PATCH 05/23] removed check for dimension --- returnn/tf/util/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 1f266abdb8..94ccf437c4 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -986,8 +986,7 @@ def declare_same_as(self, other): if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - if not self.dyn_size == other_same_base: - raise Exception("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( + raise Exception("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. From a14f1df08aae92e1ce715bf75e0f2add3c44e03e Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 13 Oct 2022 15:40:49 +0200 Subject: [PATCH 06/23] removed unused import --- returnn/tf/util/data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 94ccf437c4..33461ae13e 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -12,12 +12,10 @@ import tensorflow as tf import traceback -from returnn.exceptions.dimension_exception import DimensionException from returnn.util.basic import NotSpecified, Entity import returnn.tf.compat as tf_compat - class Dim(object): """ This identifies one axis/dimension, like a time-dimension, etc. From c6b50235bd820bc3e606d950165b2bd4c0c0e39e Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Sun, 16 Oct 2022 09:42:13 +0200 Subject: [PATCH 07/23] implemented behaviorversion exception --- returnn/tf/util/data.py | 454 ++++++++++++++++++++++------------------ 1 file changed, 245 insertions(+), 209 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 47c15e7c2b..fc91b5ea81 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -1,4 +1,3 @@ - """ Provides :class:`Data`, :class:`Dim`, :class:`SearchBeam`. @@ -12,7 +11,7 @@ import tensorflow as tf import traceback -from returnn.util.basic import NotSpecified, Entity +from returnn.util.basic import BehaviorVersion, NotSpecified, Entity import returnn.tf.compat as tf_compat @@ -55,16 +54,24 @@ class Types: _creation_counter = 0 - def __init__(self, kind=Types.Unspecified, description=None, + def __init__(self, + kind=Types.Unspecified, + description=None, dimension=None, vocab=None, - dyn_size=None, dyn_size_ext=None, - undefined=False, generic=False, special=False, + dyn_size=None, + dyn_size_ext=None, + undefined=False, + generic=False, + special=False, auto_generated=False, match_priority=0, - derived_from_tag=None, derived_from_op=None, - batch=None, control_flow_ctx=None, - src_data=None, src_axis=None): + derived_from_tag=None, + derived_from_op=None, + batch=None, + control_flow_ctx=None, + src_data=None, + src_axis=None): """ :param Entity|None kind: :param str|None description: the description should be unique @@ -130,7 +137,8 @@ def __init__(self, kind=Types.Unspecified, description=None, self.auto_generated = auto_generated # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx. # They each have same_as = self. The same_base should have the base (global) batch info. - self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],Dim] # nopep8 + self._same_for_batch_ctx = { + } # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],Dim] # nopep8 if dyn_size is not None: assert not dyn_size_ext self.dyn_size = dyn_size @@ -154,9 +162,8 @@ def short_repr(self): desc += "(%i%s)" % (self.dimension, "*" if self.generic else "") else: if self.dyn_size_ext: - desc += "[%s%s]" % ( - ",".join(self.dyn_size_ext.get_batch_axes_short_description(special_axes=False)), - "*" if self.generic else "") + desc += "[%s%s]" % (",".join( + self.dyn_size_ext.get_batch_axes_short_description(special_axes=False)), "*" if self.generic else "") else: desc += "[*]" if self.generic else "[?]" if self.control_flow_ctx: @@ -217,11 +224,14 @@ def copy(self, same_as_self=True, description=None, kind=None, match_priority=No if not same_as_self: assert description is not None, "%s copy with not same_as_self should have a new description" % self tag = Dim( - kind=kind or self.kind, description=description or self.description, + kind=kind or self.kind, + description=description or self.description, match_priority=match_priority if match_priority is not None else self.match_priority, - dimension=self.dimension, dyn_size_ext=self.dyn_size_ext, + dimension=self.dimension, + dyn_size_ext=self.dyn_size_ext, batch=self.batch, - src_data=self.src_data, src_axis=self.src_axis) + src_data=self.src_data, + src_axis=self.src_axis) if same_as_self: tag.same_as = self # not declare_same_as, none of the extra checks needed tag._same_as_tb = traceback.extract_stack() @@ -391,9 +401,12 @@ def get_for_batch_ctx(self, batch, ctx, allow_none=False): if not dyn_size_ext and allow_none: return None dim_tag = Dim( - kind=self.kind, description=self.description, dimension=self.dimension, + kind=self.kind, + description=self.description, + dimension=self.dimension, auto_generated=self.auto_generated, - batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, + batch=batch, + control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, dyn_size_ext=dyn_size_ext) dim_tag.same_as = same_base dim_tag._same_as_tb = traceback.extract_stack() @@ -466,8 +479,13 @@ def dyn_size(self, dyn_size): beam = getattr(dyn_size, "_RETURNN_dyn_size_beam", None) self.dyn_size_ext = Data( name=("%s:dyn_size" % self.description) if self.description else dyn_size.op.name, - dtype=Data.size_dtype, placeholder=dyn_size, shape=(), batch_dim_axis=0, - batch=self.batch, beam=beam, control_flow_ctx=self.control_flow_ctx) + dtype=Data.size_dtype, + placeholder=dyn_size, + shape=(), + batch_dim_axis=0, + batch=self.batch, + beam=beam, + control_flow_ctx=self.control_flow_ctx) other = Dim.get_tag_from_size_tensor(dyn_size) if other: self.declare_same_as(other) @@ -587,12 +605,12 @@ def set_tag_on_size_tensor(self, x, batch=None, same_as_before=False): # So for now, just error. from .basic import format_graph_output raise Exception("\n".join([ - "%r (%r) already has size %r, and another incompatible size %r (batch %r) is being assigned." % ( - self, self.description, self.dyn_size, x, batch), - "\nNew size computation graph:", + "%r (%r) already has size %r, and another incompatible size %r (batch %r) is being assigned." % + (self, self.description, self.dyn_size, x, batch), "\nNew size computation graph:", format_graph_output(x, max_depth=3), "\nThis is maybe the result of an incorrect declare_same_as. Traceback of declare_same_as:", - "".join(self._same_as_tb.format()) if self._same_as_tb else ("same_as = %s" % self.same_as)])) + "".join(self._same_as_tb.format()) if self._same_as_tb else ("same_as = %s" % self.same_as) + ])) if batch and getattr(x, "_RETURNN_dyn_size_beam", None): assert batch.beam == getattr(x, "_RETURNN_dyn_size_beam") if self.batch and batch: @@ -674,9 +692,7 @@ def _bin_op(a, b): if x.dimension is not None: if y is None: with tf.control_dependencies(None): # this will reset the context - y = Data( - name=y_name, dim_tags=[], dtype="int32", - placeholder=tf.constant(x.dimension)) + y = Data(name=y_name, dim_tags=[], dtype="int32", placeholder=tf.constant(x.dimension)) continue y.placeholder = _bin_op(y.placeholder, x.dimension) continue @@ -709,9 +725,16 @@ def _bin_op(a, b): if y.placeholder is not None: self.set_tag_on_size_tensor(y.placeholder) - def is_equal(self, other, ignore_feature_dim=False, allow_same_feature_dim=False, allow_same_spatial_dim=None, - treat_feature_as_spatial=False, broadcast_matches=False, unknown_spatial_matches=False, - undefined_matches=False, derived_matches=False): + def is_equal(self, + other, + ignore_feature_dim=False, + allow_same_feature_dim=False, + allow_same_spatial_dim=None, + treat_feature_as_spatial=False, + broadcast_matches=False, + unknown_spatial_matches=False, + undefined_matches=False, + derived_matches=False): """ Compares self to other for equality. @@ -826,10 +849,9 @@ def __hash__(self): # This must match the behavior in __eq__, which is is_equal with default options. # I.e. different hash implies not equal (but same hash not necessarily equal). if self.generic: - raise ValueError( - "Hash for generic dim tag %s is not well defined. " % self + - "The generic flag invalidates the transitive property of equivalence relations. " - "Explicitly go through the set or dict of dim tags and check each for equality instead.") + raise ValueError("Hash for generic dim tag %s is not well defined. " % self + + "The generic flag invalidates the transitive property of equivalence relations. " + "Explicitly go through the set or dict of dim tags and check each for equality instead.") if self.special: return hash(id(self)) if self.is_batch_dim(): @@ -984,8 +1006,9 @@ def declare_same_as(self, other): if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - raise Exception("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( - self.dyn_size, other_same_base.dyn_size)) + BehaviorVersion.require( + False, ("Dim tags are same with different size placeholders: %r vs %r please check external_data" % + (self.dyn_size, other_same_base.dyn_size)), 13) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: @@ -1000,15 +1023,15 @@ def declare_same_as(self, other): # Could be unset if it comes from the config, or from prev graph creation. # This is important such that self.can_compare() is sane. if other_same_base.dyn_size is None or not other_same_base._validate_in_current_graph(): - other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( - other_same_base.batch, other_same_base.control_flow_ctx) + other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(other_same_base.batch, + other_same_base.control_flow_ctx) other_same_base._maybe_update() if not self.dyn_size_ext or not self._validate_in_current_graph(): self.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) self._maybe_update() elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph(): - other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( - other_same_base.batch, other_same_base.control_flow_ctx) + other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(other_same_base.batch, + other_same_base.control_flow_ctx) other_same_base._maybe_update() if self.is_dim_known() and other.is_dim_known(): assert self.dimension == other.dimension @@ -1322,6 +1345,7 @@ class Op: """ Op on :class:`Dim` which results in a derived :class:`Dim`. """ + def __init__(self, kind, inputs, attribs=None): """ :param str kind: "add", "sub", "mul", "ceildiv" @@ -1355,6 +1379,7 @@ class _OpMultTerm: """ represents sth like a * b * c """ + @classmethod def from_dim(cls, dim): """ @@ -1570,12 +1595,10 @@ def extend_mul_div_(self, other, kind, right): if term.dimension * other.dimension == 1: self.terms.pop(idx) return - self.terms[idx] = Dim._make_constant_static_dim( - term.dimension * other.dimension, kind=term.kind) + self.terms[idx] = Dim._make_constant_static_dim(term.dimension * other.dimension, kind=term.kind) return if kind.endswith("div") and term.dimension % other.dimension == 0: - self.terms[idx] = Dim._make_constant_static_dim( - term.dimension // other.dimension, kind=term.kind) + self.terms[idx] = Dim._make_constant_static_dim(term.dimension // other.dimension, kind=term.kind) return # Fallback with generic handling. if kind.endswith("div"): @@ -1610,8 +1633,7 @@ def new_div_dim(cls, numerator, denominator, kind, right): kind = "floordiv" # for nicer description, and does not matter elif kind == "truediv": if a % b != 0: - raise ValueError( - "%s truediv %s only allowed if the result is an integer" % (numerator, denominator)) + raise ValueError("%s truediv %s only allowed if the result is an integer" % (numerator, denominator)) dim_value = a // b if right: kind = "floordiv" # for nicer description, and does not matter @@ -1620,9 +1642,9 @@ def new_div_dim(cls, numerator, denominator, kind, right): if kind == "floordiv" and right: description = "%s//%s" % (Dim._get_description(numerator), Dim._get_description(denominator)) else: - description = "%s_%s(%s, %s)" % ( - kind, "right" if right else "left", - Dim._get_description(numerator, brackets=False), Dim._get_description(denominator, brackets=False)) + description = "%s_%s(%s, %s)" % (kind, "right" if right else "left", + Dim._get_description(numerator, brackets=False), + Dim._get_description(denominator, brackets=False)) op_kind = kind if a is not None and b is not None and a % b == 0: op_kind = "truediv" # makes some other checks simpler @@ -1644,7 +1666,8 @@ def as_dim(self): return self.terms[0] dim_kind = _get_merged_dim_kind(self.terms) return Dim( - kind=dim_kind, description="*".join(map(Dim._get_description, self.terms)), + kind=dim_kind, + description="*".join(map(Dim._get_description, self.terms)), dimension=self.dimension, derived_from_op=Dim.Op(kind="mul", inputs=list(self.terms))) @@ -1669,6 +1692,7 @@ class _OpLinearTerm: """ represents sth like a * b + c """ + @classmethod def from_dim(cls, dim): """ @@ -1797,8 +1821,9 @@ def extend_mul_div_(self, other, kind, right): return if kind.endswith("div"): if any(not term.divisible(other, right=right) for term in self.terms): - self.terms = [Dim._OpMultTerm.from_dim( - Dim._OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right))] + self.terms = [ + Dim._OpMultTerm.from_dim(Dim._OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right)) + ] return for term in self.terms: term.extend_mul_div_(other, kind=kind, right=right) @@ -1844,10 +1869,10 @@ def representative_tag(self): # Global dim tag placeholders. batch_dim = Dim(kind=Dim.Types.Batch, description="global batch") - # Provide some simple wrappers. https://github.com/rwth-i6/returnn/issues/782 # Use CamelCase function names (invalidates PEP8) to make it look like a class instance. + # noinspection PyPep8Naming def FeatureDim(description, dimension, **kwargs): """ @@ -1872,12 +1897,12 @@ def SpatialDim(description, dimension=None, **kwargs): any_feature_dim = FeatureDim("any-feature-dim", None, generic=True) any_spatial_dim = SpatialDim("any-spatial-dim", None, generic=True) - # This indicates to perform a single step execution of some layer which can potentially have recurrent state. single_step_dim = Dim(description="single-step", kind=Dim.Types.Spatial, special=True, dimension=1) class _MarkedDim: + def __init__(self, tag): """ :param Dim tag: @@ -1993,6 +2018,7 @@ class VirtualDimBase(object): """ Represents one virtual dim, flattened into the batch dim. """ + def short_repr(self): """ :rtype: str @@ -2006,6 +2032,7 @@ class FixedDim(VirtualDimBase): """ Represents a dim with fixed size. """ + def __init__(self, size, dim_tag=None): """ :param tf.Tensor|int size: @@ -2026,6 +2053,7 @@ class GlobalBatchDim(FixedDim): """ Represents the global batch dim by the network (minibatch construction from the dataset). """ + def short_repr(self): """ :rtype: str @@ -2038,6 +2066,7 @@ class BeamDim(FixedDim): """ Represents a search beam. """ + def __init__(self, beam): """ :param SearchBeam beam: @@ -2055,6 +2084,7 @@ class PaddedDim(FixedDim): """ Represents a dim with variable size, which is flattened with padding (not packed) into the batch. """ + def __init__(self, dim_tag): """ :param Dim dim_tag: @@ -2073,6 +2103,7 @@ class PackedDim(VirtualDimBase): Represents a dim with variable sizes, which is packed (un-padded) into the batch. Variable w.r.t. other dims (must be per batch entry). """ + def __init__(self, dim_tag, key_axes): """ :param Dim dim_tag: @@ -2131,7 +2162,8 @@ def __init__(self, base, new_dim, new_dim_index=None): self._packed_dims_by_dim_tag = {} # type: typing.Dict[Dim,BatchInfo.PackedDim] self.descendants = [] # type: typing.List[BatchInfo] self._descendants_by_beam_name = {} # type: typing.Dict[str,BatchInfo] - self._global_descendants_by_virtual_dims = {} # type: typing.Dict[typing.Tuple[BatchInfo.VirtualDimBase,...],BatchInfo] # noqa + self._global_descendants_by_virtual_dims = { + } # type: typing.Dict[typing.Tuple[BatchInfo.VirtualDimBase,...],BatchInfo] # noqa if base: base.descendants.append(self) if isinstance(new_dim, BatchInfo.BeamDim): @@ -2565,8 +2597,7 @@ def __repr__(self): keys = ["name", "beam_size"] if self.dependency is not NotSpecified: keys.append("dependency") - return "%s(%s)" % ( - self.__class__.__name__, ", ".join(["%s=%r" % (key, getattr(self, key)) for key in keys])) + return "%s(%s)" % (self.__class__.__name__, ", ".join(["%s=%r" % (key, getattr(self, key)) for key in keys])) def __eq__(self, other): """ @@ -2657,15 +2688,15 @@ def get_combined_beam(cls, beam1, beam2=None, *beams): return beam1 if beam2 in l1: return beam2 - raise Exception( - "\n".join([ - "Cannot combine beams:", - " 1: %s (deps: %s, next %s, next deps %s)" % ( - beam1, beam1._get_dependency_list(), - beam1._next_frame, beam1._next_frame._get_dependency_list() if beam1._next_frame else None), - " 2: %s (deps: %s, next %s, next deps %s)" % ( - beam2, beam2._get_dependency_list(), beam2._next_frame, - beam2._next_frame._get_dependency_list() if beam2._next_frame else None)])) + raise Exception("\n".join([ + "Cannot combine beams:", + " 1: %s (deps: %s, next %s, next deps %s)" % + (beam1, beam1._get_dependency_list(), beam1._next_frame, + beam1._next_frame._get_dependency_list() if beam1._next_frame else None), + " 2: %s (deps: %s, next %s, next deps %s)" % + (beam2, beam2._get_dependency_list(), beam2._next_frame, + beam2._next_frame._get_dependency_list() if beam2._next_frame else None) + ])) class Data(object): @@ -2683,8 +2714,10 @@ class Data(object): size_dtype = "int32" - def __init__(self, name, - shape=None, dtype=None, + def __init__(self, + name, + shape=None, + dtype=None, placeholder=None, sparse=None, sparse_dim=NotSpecified, @@ -2799,16 +2832,24 @@ def __init__(self, name, time_dim_axis = _default_time_dim_axis_no_shape( batch_dim_axis=batch_dim_axis, feature_dim_axis=feature_dim_axis) shape, time_dim_axis = _infer_default_shape_and_time( - batch_dim_axis=batch_dim_axis, feature_dim_axis=feature_dim_axis, time_dim_axis=time_dim_axis, - sparse=sparse, dim=dim) + batch_dim_axis=batch_dim_axis, + feature_dim_axis=feature_dim_axis, + time_dim_axis=time_dim_axis, + sparse=sparse, + dim=dim) else: if time_dim_axis is NotSpecified: time_dim_axis = _default_time_dim_axis(batch_dim_axis=batch_dim_axis, shape=shape) dim_tags = _infer_dim_tags_tuple_from_shape( - shape, batch_dim_axis=batch_dim_axis, time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis, - size_placeholder=size_placeholder, name=name, + shape, + batch_dim_axis=batch_dim_axis, + time_dim_axis=time_dim_axis, + feature_dim_axis=feature_dim_axis, + size_placeholder=size_placeholder, + name=name, auto_create_placeholders=auto_create_placeholders, - dim_tags=dim_tags, sparse=sparse) + dim_tags=dim_tags, + sparse=sparse) del batch_dim_axis del shape self._dim_tags = dim_tags # type: typing.Tuple[Dim] @@ -2913,7 +2954,8 @@ def template_from_constant(cls, x, name, dtype=None, shape=None, with_batch_dim= assert d == d_ d = Dim( kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature, - description="%s:static:%i" % (name, i), auto_generated=True, + description="%s:static:%i" % (name, i), + auto_generated=True, dimension=d) else: raise TypeError("%r shape[%i] invalid type %r in shape %r" % (name, i, type(d), shape)) @@ -2997,13 +3039,13 @@ def get_runtime_sanity_check_op(self): data = ["Data.get_runtime_sanity_check_op:", str(self), "shape", shape] for i, tag in enumerate(self.dim_tags): if tag.dyn_size is not None: - data += [ - "dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)] + data += ["dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)] checks += [tf.Assert(tf.equal(rank, self.batch_ndim), data + ["-> invalid rank"])] if self.have_batch_axis(): batch_dim_via_info = self.get_batch_dim() checks += [ - tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info])] + tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info]) + ] for i in range(self.batch_ndim): if self.batch_shape[i] is not None: checks += [tf.Assert(tf.equal(shape[i], self.batch_shape[i]), data + ["-> invalid shape[%i]" % i])] @@ -3011,21 +3053,25 @@ def get_runtime_sanity_check_op(self): if dyn_size_ext and dyn_size_ext.placeholder is not None: dyn_size = dyn_size_ext.placeholder if dyn_size_ext.have_batch_axis() and self.have_batch_axis(): - checks += [tf.Assert( - tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim), - data + ["-> invalid axis %i tag dyn size batch dim" % i])] - checks += [tf.Assert( - # Note: in almost all cases, we have equality here. - # However, not strictly in all cases, e.g. DecideLayer, maybe some others... - # But that should not be more than 1 less. - tf.logical_or( - tf.logical_and( - tf.less_equal(tf.reduce_max(dyn_size), shape[i]), - tf.greater_equal(tf.reduce_max(dyn_size), shape[i] - 1)), - # In other rare cases, this might be a broadcast dim - # (e.g. as initial values of att weights for a rec loop). - tf.equal(1, shape[i])), - data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)])] + checks += [ + tf.Assert( + tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim), + data + ["-> invalid axis %i tag dyn size batch dim" % i]) + ] + checks += [ + tf.Assert( + # Note: in almost all cases, we have equality here. + # However, not strictly in all cases, e.g. DecideLayer, maybe some others... + # But that should not be more than 1 less. + tf.logical_or( + tf.logical_and( + tf.less_equal(tf.reduce_max(dyn_size), shape[i]), + tf.greater_equal(tf.reduce_max(dyn_size), shape[i] - 1)), + # In other rare cases, this might be a broadcast dim + # (e.g. as initial values of att weights for a rec loop). + tf.equal(1, shape[i])), + data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)]) + ] checks += [dyn_size_ext.get_runtime_sanity_check_op()] return tf.group(*checks) @@ -3043,8 +3089,8 @@ def verify_out_shape(self, out_shape): self_dim_tags_implicit_only = self.dim_tags_set_implicit_only_wrapped if not out_shape: if self_dim_tags: - raise VerifyOutShapeException( - "%s verify_out_shape, with dims %s, does not match empty out_shape %r" % (self, self_dim_tags, out_shape)) + raise VerifyOutShapeException("%s verify_out_shape, with dims %s, does not match empty out_shape %r" % + (self, self_dim_tags, out_shape)) return if not isinstance(out_shape, set): raise TypeError("%s verify_out_shape: expects a set but got %s" % (self, type(out_shape))) @@ -3056,27 +3102,26 @@ def verify_out_shape(self, out_shape): dim_tag = dim.tag if dim not in self_dim_tags_implicit_only: raise VerifyOutShapeException( - "%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % ( - self, self_dim_tags, out_shape, dim)) + "%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % + (self, self_dim_tags, out_shape, dim)) elif isinstance(dim, OptionalDim): dim_tag = dim.tag if dim_tag not in remaining: continue else: - raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % ( - self, out_shape, type(dim))) + raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % + (self, out_shape, type(dim))) if dim_tag not in remaining: if dim_tag in self_dim_tags: # can happen e.g. if specified once as implicit dim and then also as explicit raise VerifyOutShapeException( - "%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % ( - self, self_dim_tags, out_shape, dim)) - raise VerifyOutShapeException( - "%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % ( - self, self_dim_tags, out_shape, dim)) + "%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % + (self, self_dim_tags, out_shape, dim)) + raise VerifyOutShapeException("%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % + (self, self_dim_tags, out_shape, dim)) remaining.discard(dim_tag) if remaining: - raise VerifyOutShapeException( - "%s verify_out_shape, dims %s are not specified in out_shape %s" % (self, remaining, out_shape)) + raise VerifyOutShapeException("%s verify_out_shape, dims %s are not specified in out_shape %s" % + (self, remaining, out_shape)) def get_placeholder_kwargs(self, with_batch=True): """ @@ -3212,12 +3257,8 @@ def get_compare_key(self): Note that this order is not totally fixed, and might change. :rtype: object """ - return ( - self.dtype, - self.shape, - self.batch_dim_axis, self.feature_dim_axis, self.time_dim_axis, - self.dim_tags, - self.batch, self.beam) + return (self.dtype, self.shape, self.batch_dim_axis, self.feature_dim_axis, self.time_dim_axis, self.dim_tags, + self.batch, self.beam) def __repr__(self): return self.get_description(catch_exceptions=True) @@ -3234,8 +3275,7 @@ def __getstate__(self): def _adapt_batch_consistent_dim_tags(self): if not self.batch: # uninitialized return - self._dim_tags = tuple( - tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) for tag in self._dim_tags) + self._dim_tags = tuple(tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) for tag in self._dim_tags) def copy(self, name=None): """ @@ -3467,8 +3507,7 @@ def copy_add_batch_dim(self, batch_dim_axis, batch=None, dim_tag=None): assert dim_tag.dimension == batch.static_dim or dim_tag.dimension is None assert dim_tag.batch == batch else: - dim_tag = Dim( - kind=Dim.Types.Batch, description="batch", dimension=batch.static_dim, batch=batch) + dim_tag = Dim(kind=Dim.Types.Batch, description="batch", dimension=batch.static_dim, batch=batch) dim_tags.insert(batch_dim_axis, dim_tag) data_opts["dim_tags"] = dim_tags data_opts["batch"] = batch @@ -3566,7 +3605,8 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None): else: batch_info = BatchInfo.make_global_broadcast_batch_info() return self.copy_add_batch_dim( - batch_dim_axis=axis, batch=batch_info, + batch_dim_axis=axis, + batch=batch_info, dim_tag=dim_tag if (dim_tag.dimension == 1 and dim_tag.batch == batch_info) else None) data_opts = self.get_kwargs() @@ -3575,7 +3615,9 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None): dim_tag = dim_tag.copy(same_as_self=True, kind=Dim.Types.Spatial) if not unbroadcast and dim_tag.dimension != 1: dim_tag = Dim( - kind=dim_tag.kind, description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), dimension=1, + kind=dim_tag.kind, + description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), + dimension=1, auto_generated=True) data_opts["dim_tags"] = self.dim_tags[:axis] + (dim_tag,) + self.dim_tags[axis:] other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True) @@ -3611,15 +3653,17 @@ def copy_split_feature_dim(self, new_feature_dim): new_feature_dim_axis = self.feature_dim_axis + 1 data_opts = self.get_kwargs(include_special_axes=False) dim_tag_split_rem = Dim( - kind=Dim.Types.Spatial, description="feature_split_rem_%i" % feature_dim_rem, auto_generated=True, + kind=Dim.Types.Spatial, + description="feature_split_rem_%i" % feature_dim_rem, + auto_generated=True, dimension=feature_dim_rem) dim_tag_new = Dim( kind=self.dim_tags[self.feature_dim_axis].kind, - description="feature_split_new_%i" % new_feature_dim, auto_generated=True, + description="feature_split_new_%i" % new_feature_dim, + auto_generated=True, dimension=new_feature_dim) dim_tags = ( - self.dim_tags[:self.feature_dim_axis] + - (dim_tag_split_rem, dim_tag_new) + + self.dim_tags[:self.feature_dim_axis] + (dim_tag_split_rem, dim_tag_new) + self.dim_tags[self.feature_dim_axis + 1:]) data_opts["dim_tags"] = dim_tags other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True) @@ -3630,9 +3674,7 @@ def copy_split_feature_dim(self, new_feature_dim): self.placeholder.set_shape(self.batch_shape) old_shape = get_shape(self.placeholder) new_shape = ( - old_shape[:self.feature_dim_axis] + - [feature_dim_rem, new_feature_dim] + - old_shape[self.feature_dim_axis + 1:]) + old_shape[:self.feature_dim_axis] + [feature_dim_rem, new_feature_dim] + old_shape[self.feature_dim_axis + 1:]) data_opts["placeholder"] = tf.reshape(self.placeholder, new_shape, name="copy_split_feature_dim") return Data(**data_opts) @@ -3693,8 +3735,14 @@ def copy_extend_batch(self, batch): data.placeholder = x return data - def copy_compatible_to(self, data, add_dims=True, unbroadcast=False, except_feature=False, except_axis=None, - check_sparse=True, check_dtype=True): + def copy_compatible_to(self, + data, + add_dims=True, + unbroadcast=False, + except_feature=False, + except_axis=None, + check_sparse=True, + check_dtype=True): """ :param Data data: other data which the returned tensor should be compatible to It would add any missing axes with a dim 1 axis for automatic broadcasting (with add_dims=True). @@ -3730,12 +3778,10 @@ def copy_compatible_to(self, data, add_dims=True, unbroadcast=False, except_feat new_v_axis = min(target_axis, v.batch_ndim) if target_axis not in mapped_axes.values(): if not add_dims: - raise ValueError( - "%s.copy_compatible_to(%s) not allowed, axis %i (%s) not in source" % ( - self, data, target_axis, data.dim_tags[target_axis])) + raise ValueError("%s.copy_compatible_to(%s) not allowed, axis %i (%s) not in source" % + (self, data, target_axis, data.dim_tags[target_axis])) # Dim in data, but not in v - unbroadcast_axis = unbroadcast and not ( - except_feature and data.feature_dim_axis == target_axis) and not ( + unbroadcast_axis = unbroadcast and not (except_feature and data.feature_dim_axis == target_axis) and not ( except_axis_int is not None and except_axis_int == target_axis) v = v.copy_add_dim_by_tag(data.get_dim_tag(target_axis), axis=new_v_axis, unbroadcast=unbroadcast_axis) # Keep mapped_axes consistent @@ -3840,8 +3886,7 @@ def copy_merge_into_batch(self, axes): if axis == data.batch_dim_axis: batch_idx = len(batch.virtual_dims) # add all remaining axes behind continue - batch = batch.copy_extend_with_padded_or_fixed_dim_tag( - dim_tag=data.dim_tags[axis], new_dim_idx=batch_idx) + batch = batch.copy_extend_with_padded_or_fixed_dim_tag(dim_tag=data.dim_tags[axis], new_dim_idx=batch_idx) batch_idx += 1 for axis in reversed(sorted(axes)): if axis != data.batch_dim_axis: @@ -3869,8 +3914,7 @@ def copy_squeeze_axes(self, axes): data_opts = self.get_kwargs(include_special_axes=False) if self.placeholder is not None: data_opts["placeholder"] = tf.squeeze( - self.placeholder, axes, - name="%s_squeeze_axes" % get_valid_scope_name_from_str(self.name)) + self.placeholder, axes, name="%s_squeeze_axes" % get_valid_scope_name_from_str(self.name)) data_opts["dim_tags"] = [tag for (i, tag) in enumerate(self.dim_tags) if i not in axes] if self.time_dim_axis is not None: if self.time_dim_axis in axes: @@ -4037,9 +4081,11 @@ def copy_template_replace_dim(self, axis, new_dim, new_size=None): assert new_dim is None return self.copy_template() # nothing to do dim_tag = Dim( - kind=dim_tag.kind, description="%s_replaced" % (dim_tag.description or "unnamed"), + kind=dim_tag.kind, + description="%s_replaced" % (dim_tag.description or "unnamed"), auto_generated=True, - dimension=new_dim, dyn_size=new_size) + dimension=new_dim, + dyn_size=new_size) return self.copy_template_replace_dim_tag(axis=axis, new_dim_tag=dim_tag) def copy_template_new_dim_tags(self, new_dim_tags, name=None, keep_special_axes=False): @@ -4362,8 +4408,10 @@ def _default_feature_dim_axis(self): :rtype: int|None """ return _default_feature_dim_axis( - batch_dim_axis=self.batch_dim_axis, time_dim_axis=self.time_dim_axis, - batch_shape=self.batch_shape, sparse=self.sparse) + batch_dim_axis=self.batch_dim_axis, + time_dim_axis=self.time_dim_axis, + batch_shape=self.batch_shape, + sparse=self.sparse) @property def feature_dim_axis(self): @@ -4597,8 +4645,8 @@ def get_placeholder_time_flattened(self): # flatten_with_seq_len_mask only works if either time_dim_axis or batch_dim_axis is 0: assert 0 in [self.time_dim_axis, self.batch_dim_axis] seq_lens = self.size_placeholder[self.time_dim_axis_excluding_batch] - return flatten_with_seq_len_mask(self.placeholder, seq_lens, batch_dim_axis=self.batch_dim_axis, - time_dim_axis=self.time_dim_axis) + return flatten_with_seq_len_mask( + self.placeholder, seq_lens, batch_dim_axis=self.batch_dim_axis, time_dim_axis=self.time_dim_axis) def get_placeholder_flattened(self, keepdims=False): """ @@ -4623,15 +4671,12 @@ def get_placeholder_flattened(self, keepdims=False): x = self.get_placeholder_time_flattened() removed_axis = max(self.time_dim_axis, self.batch_dim_axis) dyn_axes.remove(removed_axis) - dyn_axes = [(i if (i < removed_axis) else (i - 1)) - for i in dyn_axes] + dyn_axes = [(i if (i < removed_axis) else (i - 1)) for i in dyn_axes] ndim -= 1 if len(dyn_axes) > 1: shape = tf.shape(x) - x = tf.reshape( - x, - [tf.reduce_prod([shape[i] for i in dyn_axes])] + - [shape[i] for i in range(ndim) if i not in dyn_axes]) + x = tf.reshape(x, [tf.reduce_prod([shape[i] for i in dyn_axes])] + + [shape[i] for i in range(ndim) if i not in dyn_axes]) dyn_axes = [0] assert dyn_axes == [0] if keepdims and orig_num_dyn_axes >= 2: @@ -4684,9 +4729,7 @@ def _verify_axis_order_dependent(cls): """ from returnn.util import BehaviorVersion BehaviorVersion.require( - condition=False, - message="Do not specify axis or axes in a way that depends on the order of the axes.", - version=7) + condition=False, message="Do not specify axis or axes in a way that depends on the order of the axes.", version=7) def _make_valid_int_axis(self, axis): """ @@ -4721,9 +4764,9 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified): if len(dims) > 1: max_match_priority = max(self.dim_tags[i].match_priority for i in dims) dims = [i for i in dims if self.dim_tags[i].match_priority == max_match_priority] - assert len(dims) <= 1, ( - "%s: matching dim %s must be unique," - " use `match_priority` to resolve the matching order of ambiguous dimensions" % (self, axes)) + assert len(dims) <= 1, ("%s: matching dim %s must be unique," + " use `match_priority` to resolve the matching order of ambiguous dimensions" % + (self, axes)) return dims if isinstance(axes, int): self._verify_axis_int_from_description(allow_int=allow_int) @@ -4883,9 +4926,8 @@ def get_description_from_axis(self, axis): name = dim_tag.description matching_axes = self.get_axes_by_tag_name(name, spatial_only=True) assert axis in matching_axes - return ( - "stag-single:%i:%s" % ( - matching_axes.index(axis) - len(matching_axes), name)) # negative because this is likely more robust + return ("stag-single:%i:%s" % (matching_axes.index(axis) - len(matching_axes), name) + ) # negative because this is likely more robust def has_axis(self, axis): """ @@ -4904,13 +4946,15 @@ def get_axes_by_tag_name(self, name, spatial_only=False): """ dim_tags = self.get_batch_shape_dim_tags() matching_dim_tags = [ - (axis, tag) for axis, tag in enumerate(dim_tags) - if name.lower() in tag.description.lower() - or name.lower() in tag.get_same_base().description.lower()] + (axis, tag) + for axis, tag in enumerate(dim_tags) + if name.lower() in tag.description.lower() or name.lower() in tag.get_same_base().description.lower() + ] if spatial_only: spatial_axes = self.get_spatial_batch_axes() matching_dim_tags = [ - (axis, tag) for axis, tag in matching_dim_tags if axis in spatial_axes or tag.is_spatial_dim()] + (axis, tag) for axis, tag in matching_dim_tags if axis in spatial_axes or tag.is_spatial_dim() + ] return [ax for ax, _ in matching_dim_tags] def get_axis_by_tag_name(self, name, spatial_only=False): @@ -4920,8 +4964,8 @@ def get_axis_by_tag_name(self, name, spatial_only=False): :rtype: int """ matching_dim_tags = self.get_axes_by_tag_name(name, spatial_only) - assert len(matching_dim_tags) > 0, "%r: no %stag found with name %r" % ( - self, "spatial " if spatial_only else "", name) + assert len(matching_dim_tags) > 0, "%r: no %stag found with name %r" % (self, "spatial " if spatial_only else "", + name) assert len(matching_dim_tags) == 1, "%r: tag name %r is not unique in dim tags %r" % ( self, name, self.get_batch_shape_dim_tags()) return matching_dim_tags[0] @@ -4984,9 +5028,9 @@ def is_time_axis_dynamic(self): return self.batch_shape[self.time_dim_axis_excluding_batch] is None if self.time_dim_axis_excluding_batch in self.size_placeholder: return True - assert isinstance(self.shape[self.time_dim_axis_excluding_batch], int), ( - "%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" % ( - self, self.time_dim_axis, self.size_placeholder)) + assert isinstance(self.shape[self.time_dim_axis_excluding_batch], + int), ("%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" % + (self, self.time_dim_axis, self.size_placeholder)) return False def is_axis_dynamic(self, axis): @@ -5069,16 +5113,14 @@ def get_dynamic_axes(self): :return: list of axes, counted with batch-dim axis (but we exclude the batch dim axis itself) :rtype: list[int] """ - return [axis for axis, dim in enumerate(self.batch_shape) - if axis != self.batch_dim_axis and dim is None] + return [axis for axis, dim in enumerate(self.batch_shape) if axis != self.batch_dim_axis and dim is None] def get_static_axes(self): """ :return: list of axes, counted with batch-dim axis (but we exclude the batch dim axis itself) :rtype: list[int] """ - return [axis for axis, dim in enumerate(self.batch_shape) - if axis != self.batch_dim_axis and dim is not None] + return [axis for axis, dim in enumerate(self.batch_shape) if axis != self.batch_dim_axis and dim is not None] def mark_same_time(self, tags, must_match=False): """ @@ -5125,8 +5167,7 @@ def get_sequence_lengths(self): return self.size_placeholder[self.time_dim_axis_excluding_batch] assert self.shape[self.time_dim_axis_excluding_batch] is not None with same_control_flow_ctx(self.placeholder), tf.name_scope("fixed_seq_len"): - return expand_dims_unbroadcast( - self.shape[self.time_dim_axis_excluding_batch], axis=0, dim=self.get_batch_dim()) + return expand_dims_unbroadcast(self.shape[self.time_dim_axis_excluding_batch], axis=0, dim=self.get_batch_dim()) def get_sequence_mask(self): """ @@ -5272,12 +5313,9 @@ def get_spatial_batch_axes(self): counted with batch-dim. """ return [ - axis - for axis in range(self.batch_ndim) - if axis != self.batch_dim_axis - and (axis != self.feature_dim_axis or - axis == self.time_dim_axis or - self.batch_shape[axis] is None)] + axis for axis in range(self.batch_ndim) if axis != self.batch_dim_axis and + (axis != self.feature_dim_axis or axis == self.time_dim_axis or self.batch_shape[axis] is None) + ] def get_spatial_axes(self): """ @@ -5315,8 +5353,7 @@ def get_special_axes_dict(self, counted_with_batch_dim=True, only_available=Fals axes = list(self.SpecialAxesNames) d = {k: getattr(self, k) for k in axes} if not counted_with_batch_dim: - d = {k: self.get_batch_axis_excluding_batch(v) if (v is not None) else None - for (k, v) in d.items()} + d = {k: self.get_batch_axis_excluding_batch(v) if (v is not None) else None for (k, v) in d.items()} if only_available: d = {k: v for (k, v) in d.items() if v is not None} if self._feature_dim_axis is NotSpecified: # special rule @@ -5331,8 +5368,7 @@ def get_bc_spatial_batch_shape(self): dyn_axes = self.get_spatial_batch_axes() if self.batch_dim_axis is not None: dyn_axes += [self.batch_dim_axis] - return tuple([1 if (axis in dyn_axes) else dim - for axis, dim in enumerate(self.batch_shape)]) + return tuple([1 if (axis in dyn_axes) else dim for axis, dim in enumerate(self.batch_shape)]) def get_bc_shape(self, opts=None): """ @@ -5358,8 +5394,8 @@ def get_bc_shape(self, opts=None): value = None key_axes = self.get_axes_from_description(key) for key_axis in key_axes: - assert key_axis not in axes_map, ( - "%r get_bc_shape: axis %i is defined multiple times in opts %r" % (self, key_axis, opts)) + assert key_axis not in axes_map, ("%r get_bc_shape: axis %i is defined multiple times in opts %r" % + (self, key_axis, opts)) assert 0 <= key_axis < self.batch_ndim, "%r get_bc_shape: invalid axis %i in opts %r" % (self, key_axis, opts) (axes_map if key != "*" else default_axes_map)[key_axis] = ( self.batch_shape[key_axis] if value is None else value) @@ -5460,9 +5496,11 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_ # Note: we don't use copy_extend_with_beam because we don't want to create any ops in the TF graph at this point. common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources]) is_equal_opts = dict( - ignore_feature_dim=ignore_feature_dim, treat_feature_as_spatial=True, + ignore_feature_dim=ignore_feature_dim, + treat_feature_as_spatial=True, allow_same_spatial_dim=True, - undefined_matches=True, derived_matches=True) + undefined_matches=True, + derived_matches=True) if BehaviorVersion.get() < 11: is_equal_opts["broadcast_matches"] = True all_dim_tags, tags_dict = Dim.get_all_dimension_tags(sources, is_equal_opts=is_equal_opts) @@ -5506,8 +5544,7 @@ def find_matching_dim_map(self, other, other_axes, is_equal_opts=None): :rtype: dict[int,int] """ if is_equal_opts is None: - is_equal_opts = dict( - allow_same_feature_dim=True, allow_same_spatial_dim=True, treat_feature_as_spatial=True) + is_equal_opts = dict(allow_same_feature_dim=True, allow_same_spatial_dim=True, treat_feature_as_spatial=True) def map_other_axis_to_self(other_axis, taken_self_axes): """ @@ -5532,14 +5569,15 @@ def map_other_axis_to_self(other_axis, taken_self_axes): is_equal_opts_[opt] = True matching = [ self_axis for self_axis in self.find_matching_dims(other_axis_dim_tag, is_equal_opts_) - if self_axis not in taken_self_axes] + if self_axis not in taken_self_axes + ] if opt == "unknown_spatial_matches": assert len(matching) <= 1, 'cannot match axes %s from %s to %s, failed at other %s, not unique after %s' % ( other_axes, other, self, other_axis, opt) if matching: break - assert matching, 'cannot match the axes %s from %s to %s. Failing at axis %s' % ( - other_axes, other, self, other_axis) + assert matching, 'cannot match the axes %s from %s to %s. Failing at axis %s' % (other_axes, other, self, + other_axis) # If there are multiple matches (e.g. because two axes have the same feature dim), leave their order intact. # We do this by always choosing the first unused match which is the smallest axes return matching[0] @@ -5705,7 +5743,8 @@ def _create_size_placeholder(name, axis_wo_b, tag, batch_dim): from .basic import reuse_name_scope with reuse_name_scope("extern_data/placeholders/%s" % name, absolute=True): dyn_size_ext = Data( - "%s_dim%i_size" % (name, axis_wo_b), dtype=Data.size_dtype, + "%s_dim%i_size" % (name, axis_wo_b), + dtype=Data.size_dtype, dim_tags=[batch_dim] if batch_dim else [], batch=batch_dim.batch if batch_dim else None) dyn_size = tf_compat.v1.placeholder( @@ -5721,15 +5760,8 @@ def _create_size_placeholder(name, axis_wo_b, tag, batch_dim): tag.set_tag_on_size_tensor(dyn_size) -def _infer_dim_tags_tuple_from_shape( - shape, - batch_dim_axis, time_dim_axis, feature_dim_axis, - sparse, - size_placeholder, - dim_tags, - name, - auto_create_placeholders -): +def _infer_dim_tags_tuple_from_shape(shape, batch_dim_axis, time_dim_axis, feature_dim_axis, sparse, size_placeholder, + dim_tags, name, auto_create_placeholders): """ :param tuple[int|None]|list[int|None] shape: this is without batch-dim-axis :param int|None batch_dim_axis: @@ -5770,12 +5802,9 @@ def _infer_dim_tags_tuple_from_shape( dim_tags[axis] = tag # See Data.get_spatial_batch_axes spatial_axes = [ - axis - for axis in range(len(batch_shape)) - if axis != batch_dim_axis - and (axis != feature_dim_axis or - axis == time_dim_axis or - batch_shape[axis] is None)] + axis for axis in range(len(batch_shape)) + if axis != batch_dim_axis and (axis != feature_dim_axis or axis == time_dim_axis or batch_shape[axis] is None) + ] for axis in range(len(batch_shape)): tag = dim_tags.get(axis) axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis) @@ -5803,8 +5832,11 @@ def _infer_dim_tags_tuple_from_shape( continue if axis == feature_dim_axis and dyn_size is None and axis != time_dim_axis: tag = Dim( - kind=Dim.Types.Feature, dimension=dim, description="feature:%s" % name, - undefined=dim is None, auto_generated=True) + kind=Dim.Types.Feature, + dimension=dim, + description="feature:%s" % name, + undefined=dim is None, + auto_generated=True) else: assert axis in spatial_axes description = "time" if axis == time_dim_axis else "spatial%i" % spatial_axes.index(axis) @@ -5818,8 +5850,12 @@ def _infer_dim_tags_tuple_from_shape( description += ":static%i" % dim description += ":%s" % name tag = Dim( - kind=Dim.Types.Spatial, description=description, dimension=dim, dyn_size=dyn_size, - undefined=dim is None and dyn_size is None, auto_generated=True) + kind=Dim.Types.Spatial, + description=description, + dimension=dim, + dyn_size=dyn_size, + undefined=dim is None and dyn_size is None, + auto_generated=True) dim_tags[axis] = tag assert sorted(dim_tags.keys()) == list(range(len(batch_shape))) return tuple(dim_tags[axis] for axis in range(len(batch_shape))) From 28b398585cc50eb5b72570b4261f3e320755c0c3 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Sun, 16 Oct 2022 09:44:34 +0200 Subject: [PATCH 08/23] Revert "implemented behaviorversion exception" This reverts commit c6b50235bd820bc3e606d950165b2bd4c0c0e39e. --- returnn/tf/util/data.py | 454 ++++++++++++++++++---------------------- 1 file changed, 209 insertions(+), 245 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index fc91b5ea81..47c15e7c2b 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -1,3 +1,4 @@ + """ Provides :class:`Data`, :class:`Dim`, :class:`SearchBeam`. @@ -11,7 +12,7 @@ import tensorflow as tf import traceback -from returnn.util.basic import BehaviorVersion, NotSpecified, Entity +from returnn.util.basic import NotSpecified, Entity import returnn.tf.compat as tf_compat @@ -54,24 +55,16 @@ class Types: _creation_counter = 0 - def __init__(self, - kind=Types.Unspecified, - description=None, + def __init__(self, kind=Types.Unspecified, description=None, dimension=None, vocab=None, - dyn_size=None, - dyn_size_ext=None, - undefined=False, - generic=False, - special=False, + dyn_size=None, dyn_size_ext=None, + undefined=False, generic=False, special=False, auto_generated=False, match_priority=0, - derived_from_tag=None, - derived_from_op=None, - batch=None, - control_flow_ctx=None, - src_data=None, - src_axis=None): + derived_from_tag=None, derived_from_op=None, + batch=None, control_flow_ctx=None, + src_data=None, src_axis=None): """ :param Entity|None kind: :param str|None description: the description should be unique @@ -137,8 +130,7 @@ def __init__(self, self.auto_generated = auto_generated # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx. # They each have same_as = self. The same_base should have the base (global) batch info. - self._same_for_batch_ctx = { - } # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],Dim] # nopep8 + self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],Dim] # nopep8 if dyn_size is not None: assert not dyn_size_ext self.dyn_size = dyn_size @@ -162,8 +154,9 @@ def short_repr(self): desc += "(%i%s)" % (self.dimension, "*" if self.generic else "") else: if self.dyn_size_ext: - desc += "[%s%s]" % (",".join( - self.dyn_size_ext.get_batch_axes_short_description(special_axes=False)), "*" if self.generic else "") + desc += "[%s%s]" % ( + ",".join(self.dyn_size_ext.get_batch_axes_short_description(special_axes=False)), + "*" if self.generic else "") else: desc += "[*]" if self.generic else "[?]" if self.control_flow_ctx: @@ -224,14 +217,11 @@ def copy(self, same_as_self=True, description=None, kind=None, match_priority=No if not same_as_self: assert description is not None, "%s copy with not same_as_self should have a new description" % self tag = Dim( - kind=kind or self.kind, - description=description or self.description, + kind=kind or self.kind, description=description or self.description, match_priority=match_priority if match_priority is not None else self.match_priority, - dimension=self.dimension, - dyn_size_ext=self.dyn_size_ext, + dimension=self.dimension, dyn_size_ext=self.dyn_size_ext, batch=self.batch, - src_data=self.src_data, - src_axis=self.src_axis) + src_data=self.src_data, src_axis=self.src_axis) if same_as_self: tag.same_as = self # not declare_same_as, none of the extra checks needed tag._same_as_tb = traceback.extract_stack() @@ -401,12 +391,9 @@ def get_for_batch_ctx(self, batch, ctx, allow_none=False): if not dyn_size_ext and allow_none: return None dim_tag = Dim( - kind=self.kind, - description=self.description, - dimension=self.dimension, + kind=self.kind, description=self.description, dimension=self.dimension, auto_generated=self.auto_generated, - batch=batch, - control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, + batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, dyn_size_ext=dyn_size_ext) dim_tag.same_as = same_base dim_tag._same_as_tb = traceback.extract_stack() @@ -479,13 +466,8 @@ def dyn_size(self, dyn_size): beam = getattr(dyn_size, "_RETURNN_dyn_size_beam", None) self.dyn_size_ext = Data( name=("%s:dyn_size" % self.description) if self.description else dyn_size.op.name, - dtype=Data.size_dtype, - placeholder=dyn_size, - shape=(), - batch_dim_axis=0, - batch=self.batch, - beam=beam, - control_flow_ctx=self.control_flow_ctx) + dtype=Data.size_dtype, placeholder=dyn_size, shape=(), batch_dim_axis=0, + batch=self.batch, beam=beam, control_flow_ctx=self.control_flow_ctx) other = Dim.get_tag_from_size_tensor(dyn_size) if other: self.declare_same_as(other) @@ -605,12 +587,12 @@ def set_tag_on_size_tensor(self, x, batch=None, same_as_before=False): # So for now, just error. from .basic import format_graph_output raise Exception("\n".join([ - "%r (%r) already has size %r, and another incompatible size %r (batch %r) is being assigned." % - (self, self.description, self.dyn_size, x, batch), "\nNew size computation graph:", + "%r (%r) already has size %r, and another incompatible size %r (batch %r) is being assigned." % ( + self, self.description, self.dyn_size, x, batch), + "\nNew size computation graph:", format_graph_output(x, max_depth=3), "\nThis is maybe the result of an incorrect declare_same_as. Traceback of declare_same_as:", - "".join(self._same_as_tb.format()) if self._same_as_tb else ("same_as = %s" % self.same_as) - ])) + "".join(self._same_as_tb.format()) if self._same_as_tb else ("same_as = %s" % self.same_as)])) if batch and getattr(x, "_RETURNN_dyn_size_beam", None): assert batch.beam == getattr(x, "_RETURNN_dyn_size_beam") if self.batch and batch: @@ -692,7 +674,9 @@ def _bin_op(a, b): if x.dimension is not None: if y is None: with tf.control_dependencies(None): # this will reset the context - y = Data(name=y_name, dim_tags=[], dtype="int32", placeholder=tf.constant(x.dimension)) + y = Data( + name=y_name, dim_tags=[], dtype="int32", + placeholder=tf.constant(x.dimension)) continue y.placeholder = _bin_op(y.placeholder, x.dimension) continue @@ -725,16 +709,9 @@ def _bin_op(a, b): if y.placeholder is not None: self.set_tag_on_size_tensor(y.placeholder) - def is_equal(self, - other, - ignore_feature_dim=False, - allow_same_feature_dim=False, - allow_same_spatial_dim=None, - treat_feature_as_spatial=False, - broadcast_matches=False, - unknown_spatial_matches=False, - undefined_matches=False, - derived_matches=False): + def is_equal(self, other, ignore_feature_dim=False, allow_same_feature_dim=False, allow_same_spatial_dim=None, + treat_feature_as_spatial=False, broadcast_matches=False, unknown_spatial_matches=False, + undefined_matches=False, derived_matches=False): """ Compares self to other for equality. @@ -849,9 +826,10 @@ def __hash__(self): # This must match the behavior in __eq__, which is is_equal with default options. # I.e. different hash implies not equal (but same hash not necessarily equal). if self.generic: - raise ValueError("Hash for generic dim tag %s is not well defined. " % self + - "The generic flag invalidates the transitive property of equivalence relations. " - "Explicitly go through the set or dict of dim tags and check each for equality instead.") + raise ValueError( + "Hash for generic dim tag %s is not well defined. " % self + + "The generic flag invalidates the transitive property of equivalence relations. " + "Explicitly go through the set or dict of dim tags and check each for equality instead.") if self.special: return hash(id(self)) if self.is_batch_dim(): @@ -1006,9 +984,8 @@ def declare_same_as(self, other): if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - BehaviorVersion.require( - False, ("Dim tags are same with different size placeholders: %r vs %r please check external_data" % - (self.dyn_size, other_same_base.dyn_size)), 13) + raise Exception("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( + self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: @@ -1023,15 +1000,15 @@ def declare_same_as(self, other): # Could be unset if it comes from the config, or from prev graph creation. # This is important such that self.can_compare() is sane. if other_same_base.dyn_size is None or not other_same_base._validate_in_current_graph(): - other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(other_same_base.batch, - other_same_base.control_flow_ctx) + other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( + other_same_base.batch, other_same_base.control_flow_ctx) other_same_base._maybe_update() if not self.dyn_size_ext or not self._validate_in_current_graph(): self.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) self._maybe_update() elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph(): - other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(other_same_base.batch, - other_same_base.control_flow_ctx) + other_same_base.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( + other_same_base.batch, other_same_base.control_flow_ctx) other_same_base._maybe_update() if self.is_dim_known() and other.is_dim_known(): assert self.dimension == other.dimension @@ -1345,7 +1322,6 @@ class Op: """ Op on :class:`Dim` which results in a derived :class:`Dim`. """ - def __init__(self, kind, inputs, attribs=None): """ :param str kind: "add", "sub", "mul", "ceildiv" @@ -1379,7 +1355,6 @@ class _OpMultTerm: """ represents sth like a * b * c """ - @classmethod def from_dim(cls, dim): """ @@ -1595,10 +1570,12 @@ def extend_mul_div_(self, other, kind, right): if term.dimension * other.dimension == 1: self.terms.pop(idx) return - self.terms[idx] = Dim._make_constant_static_dim(term.dimension * other.dimension, kind=term.kind) + self.terms[idx] = Dim._make_constant_static_dim( + term.dimension * other.dimension, kind=term.kind) return if kind.endswith("div") and term.dimension % other.dimension == 0: - self.terms[idx] = Dim._make_constant_static_dim(term.dimension // other.dimension, kind=term.kind) + self.terms[idx] = Dim._make_constant_static_dim( + term.dimension // other.dimension, kind=term.kind) return # Fallback with generic handling. if kind.endswith("div"): @@ -1633,7 +1610,8 @@ def new_div_dim(cls, numerator, denominator, kind, right): kind = "floordiv" # for nicer description, and does not matter elif kind == "truediv": if a % b != 0: - raise ValueError("%s truediv %s only allowed if the result is an integer" % (numerator, denominator)) + raise ValueError( + "%s truediv %s only allowed if the result is an integer" % (numerator, denominator)) dim_value = a // b if right: kind = "floordiv" # for nicer description, and does not matter @@ -1642,9 +1620,9 @@ def new_div_dim(cls, numerator, denominator, kind, right): if kind == "floordiv" and right: description = "%s//%s" % (Dim._get_description(numerator), Dim._get_description(denominator)) else: - description = "%s_%s(%s, %s)" % (kind, "right" if right else "left", - Dim._get_description(numerator, brackets=False), - Dim._get_description(denominator, brackets=False)) + description = "%s_%s(%s, %s)" % ( + kind, "right" if right else "left", + Dim._get_description(numerator, brackets=False), Dim._get_description(denominator, brackets=False)) op_kind = kind if a is not None and b is not None and a % b == 0: op_kind = "truediv" # makes some other checks simpler @@ -1666,8 +1644,7 @@ def as_dim(self): return self.terms[0] dim_kind = _get_merged_dim_kind(self.terms) return Dim( - kind=dim_kind, - description="*".join(map(Dim._get_description, self.terms)), + kind=dim_kind, description="*".join(map(Dim._get_description, self.terms)), dimension=self.dimension, derived_from_op=Dim.Op(kind="mul", inputs=list(self.terms))) @@ -1692,7 +1669,6 @@ class _OpLinearTerm: """ represents sth like a * b + c """ - @classmethod def from_dim(cls, dim): """ @@ -1821,9 +1797,8 @@ def extend_mul_div_(self, other, kind, right): return if kind.endswith("div"): if any(not term.divisible(other, right=right) for term in self.terms): - self.terms = [ - Dim._OpMultTerm.from_dim(Dim._OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right)) - ] + self.terms = [Dim._OpMultTerm.from_dim( + Dim._OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right))] return for term in self.terms: term.extend_mul_div_(other, kind=kind, right=right) @@ -1869,10 +1844,10 @@ def representative_tag(self): # Global dim tag placeholders. batch_dim = Dim(kind=Dim.Types.Batch, description="global batch") + # Provide some simple wrappers. https://github.com/rwth-i6/returnn/issues/782 # Use CamelCase function names (invalidates PEP8) to make it look like a class instance. - # noinspection PyPep8Naming def FeatureDim(description, dimension, **kwargs): """ @@ -1897,12 +1872,12 @@ def SpatialDim(description, dimension=None, **kwargs): any_feature_dim = FeatureDim("any-feature-dim", None, generic=True) any_spatial_dim = SpatialDim("any-spatial-dim", None, generic=True) + # This indicates to perform a single step execution of some layer which can potentially have recurrent state. single_step_dim = Dim(description="single-step", kind=Dim.Types.Spatial, special=True, dimension=1) class _MarkedDim: - def __init__(self, tag): """ :param Dim tag: @@ -2018,7 +1993,6 @@ class VirtualDimBase(object): """ Represents one virtual dim, flattened into the batch dim. """ - def short_repr(self): """ :rtype: str @@ -2032,7 +2006,6 @@ class FixedDim(VirtualDimBase): """ Represents a dim with fixed size. """ - def __init__(self, size, dim_tag=None): """ :param tf.Tensor|int size: @@ -2053,7 +2026,6 @@ class GlobalBatchDim(FixedDim): """ Represents the global batch dim by the network (minibatch construction from the dataset). """ - def short_repr(self): """ :rtype: str @@ -2066,7 +2038,6 @@ class BeamDim(FixedDim): """ Represents a search beam. """ - def __init__(self, beam): """ :param SearchBeam beam: @@ -2084,7 +2055,6 @@ class PaddedDim(FixedDim): """ Represents a dim with variable size, which is flattened with padding (not packed) into the batch. """ - def __init__(self, dim_tag): """ :param Dim dim_tag: @@ -2103,7 +2073,6 @@ class PackedDim(VirtualDimBase): Represents a dim with variable sizes, which is packed (un-padded) into the batch. Variable w.r.t. other dims (must be per batch entry). """ - def __init__(self, dim_tag, key_axes): """ :param Dim dim_tag: @@ -2162,8 +2131,7 @@ def __init__(self, base, new_dim, new_dim_index=None): self._packed_dims_by_dim_tag = {} # type: typing.Dict[Dim,BatchInfo.PackedDim] self.descendants = [] # type: typing.List[BatchInfo] self._descendants_by_beam_name = {} # type: typing.Dict[str,BatchInfo] - self._global_descendants_by_virtual_dims = { - } # type: typing.Dict[typing.Tuple[BatchInfo.VirtualDimBase,...],BatchInfo] # noqa + self._global_descendants_by_virtual_dims = {} # type: typing.Dict[typing.Tuple[BatchInfo.VirtualDimBase,...],BatchInfo] # noqa if base: base.descendants.append(self) if isinstance(new_dim, BatchInfo.BeamDim): @@ -2597,7 +2565,8 @@ def __repr__(self): keys = ["name", "beam_size"] if self.dependency is not NotSpecified: keys.append("dependency") - return "%s(%s)" % (self.__class__.__name__, ", ".join(["%s=%r" % (key, getattr(self, key)) for key in keys])) + return "%s(%s)" % ( + self.__class__.__name__, ", ".join(["%s=%r" % (key, getattr(self, key)) for key in keys])) def __eq__(self, other): """ @@ -2688,15 +2657,15 @@ def get_combined_beam(cls, beam1, beam2=None, *beams): return beam1 if beam2 in l1: return beam2 - raise Exception("\n".join([ - "Cannot combine beams:", - " 1: %s (deps: %s, next %s, next deps %s)" % - (beam1, beam1._get_dependency_list(), beam1._next_frame, - beam1._next_frame._get_dependency_list() if beam1._next_frame else None), - " 2: %s (deps: %s, next %s, next deps %s)" % - (beam2, beam2._get_dependency_list(), beam2._next_frame, - beam2._next_frame._get_dependency_list() if beam2._next_frame else None) - ])) + raise Exception( + "\n".join([ + "Cannot combine beams:", + " 1: %s (deps: %s, next %s, next deps %s)" % ( + beam1, beam1._get_dependency_list(), + beam1._next_frame, beam1._next_frame._get_dependency_list() if beam1._next_frame else None), + " 2: %s (deps: %s, next %s, next deps %s)" % ( + beam2, beam2._get_dependency_list(), beam2._next_frame, + beam2._next_frame._get_dependency_list() if beam2._next_frame else None)])) class Data(object): @@ -2714,10 +2683,8 @@ class Data(object): size_dtype = "int32" - def __init__(self, - name, - shape=None, - dtype=None, + def __init__(self, name, + shape=None, dtype=None, placeholder=None, sparse=None, sparse_dim=NotSpecified, @@ -2832,24 +2799,16 @@ def __init__(self, time_dim_axis = _default_time_dim_axis_no_shape( batch_dim_axis=batch_dim_axis, feature_dim_axis=feature_dim_axis) shape, time_dim_axis = _infer_default_shape_and_time( - batch_dim_axis=batch_dim_axis, - feature_dim_axis=feature_dim_axis, - time_dim_axis=time_dim_axis, - sparse=sparse, - dim=dim) + batch_dim_axis=batch_dim_axis, feature_dim_axis=feature_dim_axis, time_dim_axis=time_dim_axis, + sparse=sparse, dim=dim) else: if time_dim_axis is NotSpecified: time_dim_axis = _default_time_dim_axis(batch_dim_axis=batch_dim_axis, shape=shape) dim_tags = _infer_dim_tags_tuple_from_shape( - shape, - batch_dim_axis=batch_dim_axis, - time_dim_axis=time_dim_axis, - feature_dim_axis=feature_dim_axis, - size_placeholder=size_placeholder, - name=name, + shape, batch_dim_axis=batch_dim_axis, time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis, + size_placeholder=size_placeholder, name=name, auto_create_placeholders=auto_create_placeholders, - dim_tags=dim_tags, - sparse=sparse) + dim_tags=dim_tags, sparse=sparse) del batch_dim_axis del shape self._dim_tags = dim_tags # type: typing.Tuple[Dim] @@ -2954,8 +2913,7 @@ def template_from_constant(cls, x, name, dtype=None, shape=None, with_batch_dim= assert d == d_ d = Dim( kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature, - description="%s:static:%i" % (name, i), - auto_generated=True, + description="%s:static:%i" % (name, i), auto_generated=True, dimension=d) else: raise TypeError("%r shape[%i] invalid type %r in shape %r" % (name, i, type(d), shape)) @@ -3039,13 +2997,13 @@ def get_runtime_sanity_check_op(self): data = ["Data.get_runtime_sanity_check_op:", str(self), "shape", shape] for i, tag in enumerate(self.dim_tags): if tag.dyn_size is not None: - data += ["dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)] + data += [ + "dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)] checks += [tf.Assert(tf.equal(rank, self.batch_ndim), data + ["-> invalid rank"])] if self.have_batch_axis(): batch_dim_via_info = self.get_batch_dim() checks += [ - tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info]) - ] + tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info])] for i in range(self.batch_ndim): if self.batch_shape[i] is not None: checks += [tf.Assert(tf.equal(shape[i], self.batch_shape[i]), data + ["-> invalid shape[%i]" % i])] @@ -3053,25 +3011,21 @@ def get_runtime_sanity_check_op(self): if dyn_size_ext and dyn_size_ext.placeholder is not None: dyn_size = dyn_size_ext.placeholder if dyn_size_ext.have_batch_axis() and self.have_batch_axis(): - checks += [ - tf.Assert( - tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim), - data + ["-> invalid axis %i tag dyn size batch dim" % i]) - ] - checks += [ - tf.Assert( - # Note: in almost all cases, we have equality here. - # However, not strictly in all cases, e.g. DecideLayer, maybe some others... - # But that should not be more than 1 less. - tf.logical_or( - tf.logical_and( - tf.less_equal(tf.reduce_max(dyn_size), shape[i]), - tf.greater_equal(tf.reduce_max(dyn_size), shape[i] - 1)), - # In other rare cases, this might be a broadcast dim - # (e.g. as initial values of att weights for a rec loop). - tf.equal(1, shape[i])), - data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)]) - ] + checks += [tf.Assert( + tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim), + data + ["-> invalid axis %i tag dyn size batch dim" % i])] + checks += [tf.Assert( + # Note: in almost all cases, we have equality here. + # However, not strictly in all cases, e.g. DecideLayer, maybe some others... + # But that should not be more than 1 less. + tf.logical_or( + tf.logical_and( + tf.less_equal(tf.reduce_max(dyn_size), shape[i]), + tf.greater_equal(tf.reduce_max(dyn_size), shape[i] - 1)), + # In other rare cases, this might be a broadcast dim + # (e.g. as initial values of att weights for a rec loop). + tf.equal(1, shape[i])), + data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)])] checks += [dyn_size_ext.get_runtime_sanity_check_op()] return tf.group(*checks) @@ -3089,8 +3043,8 @@ def verify_out_shape(self, out_shape): self_dim_tags_implicit_only = self.dim_tags_set_implicit_only_wrapped if not out_shape: if self_dim_tags: - raise VerifyOutShapeException("%s verify_out_shape, with dims %s, does not match empty out_shape %r" % - (self, self_dim_tags, out_shape)) + raise VerifyOutShapeException( + "%s verify_out_shape, with dims %s, does not match empty out_shape %r" % (self, self_dim_tags, out_shape)) return if not isinstance(out_shape, set): raise TypeError("%s verify_out_shape: expects a set but got %s" % (self, type(out_shape))) @@ -3102,26 +3056,27 @@ def verify_out_shape(self, out_shape): dim_tag = dim.tag if dim not in self_dim_tags_implicit_only: raise VerifyOutShapeException( - "%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % - (self, self_dim_tags, out_shape, dim)) + "%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % ( + self, self_dim_tags, out_shape, dim)) elif isinstance(dim, OptionalDim): dim_tag = dim.tag if dim_tag not in remaining: continue else: - raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % - (self, out_shape, type(dim))) + raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % ( + self, out_shape, type(dim))) if dim_tag not in remaining: if dim_tag in self_dim_tags: # can happen e.g. if specified once as implicit dim and then also as explicit raise VerifyOutShapeException( - "%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % - (self, self_dim_tags, out_shape, dim)) - raise VerifyOutShapeException("%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % - (self, self_dim_tags, out_shape, dim)) + "%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % ( + self, self_dim_tags, out_shape, dim)) + raise VerifyOutShapeException( + "%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % ( + self, self_dim_tags, out_shape, dim)) remaining.discard(dim_tag) if remaining: - raise VerifyOutShapeException("%s verify_out_shape, dims %s are not specified in out_shape %s" % - (self, remaining, out_shape)) + raise VerifyOutShapeException( + "%s verify_out_shape, dims %s are not specified in out_shape %s" % (self, remaining, out_shape)) def get_placeholder_kwargs(self, with_batch=True): """ @@ -3257,8 +3212,12 @@ def get_compare_key(self): Note that this order is not totally fixed, and might change. :rtype: object """ - return (self.dtype, self.shape, self.batch_dim_axis, self.feature_dim_axis, self.time_dim_axis, self.dim_tags, - self.batch, self.beam) + return ( + self.dtype, + self.shape, + self.batch_dim_axis, self.feature_dim_axis, self.time_dim_axis, + self.dim_tags, + self.batch, self.beam) def __repr__(self): return self.get_description(catch_exceptions=True) @@ -3275,7 +3234,8 @@ def __getstate__(self): def _adapt_batch_consistent_dim_tags(self): if not self.batch: # uninitialized return - self._dim_tags = tuple(tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) for tag in self._dim_tags) + self._dim_tags = tuple( + tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) for tag in self._dim_tags) def copy(self, name=None): """ @@ -3507,7 +3467,8 @@ def copy_add_batch_dim(self, batch_dim_axis, batch=None, dim_tag=None): assert dim_tag.dimension == batch.static_dim or dim_tag.dimension is None assert dim_tag.batch == batch else: - dim_tag = Dim(kind=Dim.Types.Batch, description="batch", dimension=batch.static_dim, batch=batch) + dim_tag = Dim( + kind=Dim.Types.Batch, description="batch", dimension=batch.static_dim, batch=batch) dim_tags.insert(batch_dim_axis, dim_tag) data_opts["dim_tags"] = dim_tags data_opts["batch"] = batch @@ -3605,8 +3566,7 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None): else: batch_info = BatchInfo.make_global_broadcast_batch_info() return self.copy_add_batch_dim( - batch_dim_axis=axis, - batch=batch_info, + batch_dim_axis=axis, batch=batch_info, dim_tag=dim_tag if (dim_tag.dimension == 1 and dim_tag.batch == batch_info) else None) data_opts = self.get_kwargs() @@ -3615,9 +3575,7 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None): dim_tag = dim_tag.copy(same_as_self=True, kind=Dim.Types.Spatial) if not unbroadcast and dim_tag.dimension != 1: dim_tag = Dim( - kind=dim_tag.kind, - description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), - dimension=1, + kind=dim_tag.kind, description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), dimension=1, auto_generated=True) data_opts["dim_tags"] = self.dim_tags[:axis] + (dim_tag,) + self.dim_tags[axis:] other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True) @@ -3653,17 +3611,15 @@ def copy_split_feature_dim(self, new_feature_dim): new_feature_dim_axis = self.feature_dim_axis + 1 data_opts = self.get_kwargs(include_special_axes=False) dim_tag_split_rem = Dim( - kind=Dim.Types.Spatial, - description="feature_split_rem_%i" % feature_dim_rem, - auto_generated=True, + kind=Dim.Types.Spatial, description="feature_split_rem_%i" % feature_dim_rem, auto_generated=True, dimension=feature_dim_rem) dim_tag_new = Dim( kind=self.dim_tags[self.feature_dim_axis].kind, - description="feature_split_new_%i" % new_feature_dim, - auto_generated=True, + description="feature_split_new_%i" % new_feature_dim, auto_generated=True, dimension=new_feature_dim) dim_tags = ( - self.dim_tags[:self.feature_dim_axis] + (dim_tag_split_rem, dim_tag_new) + + self.dim_tags[:self.feature_dim_axis] + + (dim_tag_split_rem, dim_tag_new) + self.dim_tags[self.feature_dim_axis + 1:]) data_opts["dim_tags"] = dim_tags other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True) @@ -3674,7 +3630,9 @@ def copy_split_feature_dim(self, new_feature_dim): self.placeholder.set_shape(self.batch_shape) old_shape = get_shape(self.placeholder) new_shape = ( - old_shape[:self.feature_dim_axis] + [feature_dim_rem, new_feature_dim] + old_shape[self.feature_dim_axis + 1:]) + old_shape[:self.feature_dim_axis] + + [feature_dim_rem, new_feature_dim] + + old_shape[self.feature_dim_axis + 1:]) data_opts["placeholder"] = tf.reshape(self.placeholder, new_shape, name="copy_split_feature_dim") return Data(**data_opts) @@ -3735,14 +3693,8 @@ def copy_extend_batch(self, batch): data.placeholder = x return data - def copy_compatible_to(self, - data, - add_dims=True, - unbroadcast=False, - except_feature=False, - except_axis=None, - check_sparse=True, - check_dtype=True): + def copy_compatible_to(self, data, add_dims=True, unbroadcast=False, except_feature=False, except_axis=None, + check_sparse=True, check_dtype=True): """ :param Data data: other data which the returned tensor should be compatible to It would add any missing axes with a dim 1 axis for automatic broadcasting (with add_dims=True). @@ -3778,10 +3730,12 @@ def copy_compatible_to(self, new_v_axis = min(target_axis, v.batch_ndim) if target_axis not in mapped_axes.values(): if not add_dims: - raise ValueError("%s.copy_compatible_to(%s) not allowed, axis %i (%s) not in source" % - (self, data, target_axis, data.dim_tags[target_axis])) + raise ValueError( + "%s.copy_compatible_to(%s) not allowed, axis %i (%s) not in source" % ( + self, data, target_axis, data.dim_tags[target_axis])) # Dim in data, but not in v - unbroadcast_axis = unbroadcast and not (except_feature and data.feature_dim_axis == target_axis) and not ( + unbroadcast_axis = unbroadcast and not ( + except_feature and data.feature_dim_axis == target_axis) and not ( except_axis_int is not None and except_axis_int == target_axis) v = v.copy_add_dim_by_tag(data.get_dim_tag(target_axis), axis=new_v_axis, unbroadcast=unbroadcast_axis) # Keep mapped_axes consistent @@ -3886,7 +3840,8 @@ def copy_merge_into_batch(self, axes): if axis == data.batch_dim_axis: batch_idx = len(batch.virtual_dims) # add all remaining axes behind continue - batch = batch.copy_extend_with_padded_or_fixed_dim_tag(dim_tag=data.dim_tags[axis], new_dim_idx=batch_idx) + batch = batch.copy_extend_with_padded_or_fixed_dim_tag( + dim_tag=data.dim_tags[axis], new_dim_idx=batch_idx) batch_idx += 1 for axis in reversed(sorted(axes)): if axis != data.batch_dim_axis: @@ -3914,7 +3869,8 @@ def copy_squeeze_axes(self, axes): data_opts = self.get_kwargs(include_special_axes=False) if self.placeholder is not None: data_opts["placeholder"] = tf.squeeze( - self.placeholder, axes, name="%s_squeeze_axes" % get_valid_scope_name_from_str(self.name)) + self.placeholder, axes, + name="%s_squeeze_axes" % get_valid_scope_name_from_str(self.name)) data_opts["dim_tags"] = [tag for (i, tag) in enumerate(self.dim_tags) if i not in axes] if self.time_dim_axis is not None: if self.time_dim_axis in axes: @@ -4081,11 +4037,9 @@ def copy_template_replace_dim(self, axis, new_dim, new_size=None): assert new_dim is None return self.copy_template() # nothing to do dim_tag = Dim( - kind=dim_tag.kind, - description="%s_replaced" % (dim_tag.description or "unnamed"), + kind=dim_tag.kind, description="%s_replaced" % (dim_tag.description or "unnamed"), auto_generated=True, - dimension=new_dim, - dyn_size=new_size) + dimension=new_dim, dyn_size=new_size) return self.copy_template_replace_dim_tag(axis=axis, new_dim_tag=dim_tag) def copy_template_new_dim_tags(self, new_dim_tags, name=None, keep_special_axes=False): @@ -4408,10 +4362,8 @@ def _default_feature_dim_axis(self): :rtype: int|None """ return _default_feature_dim_axis( - batch_dim_axis=self.batch_dim_axis, - time_dim_axis=self.time_dim_axis, - batch_shape=self.batch_shape, - sparse=self.sparse) + batch_dim_axis=self.batch_dim_axis, time_dim_axis=self.time_dim_axis, + batch_shape=self.batch_shape, sparse=self.sparse) @property def feature_dim_axis(self): @@ -4645,8 +4597,8 @@ def get_placeholder_time_flattened(self): # flatten_with_seq_len_mask only works if either time_dim_axis or batch_dim_axis is 0: assert 0 in [self.time_dim_axis, self.batch_dim_axis] seq_lens = self.size_placeholder[self.time_dim_axis_excluding_batch] - return flatten_with_seq_len_mask( - self.placeholder, seq_lens, batch_dim_axis=self.batch_dim_axis, time_dim_axis=self.time_dim_axis) + return flatten_with_seq_len_mask(self.placeholder, seq_lens, batch_dim_axis=self.batch_dim_axis, + time_dim_axis=self.time_dim_axis) def get_placeholder_flattened(self, keepdims=False): """ @@ -4671,12 +4623,15 @@ def get_placeholder_flattened(self, keepdims=False): x = self.get_placeholder_time_flattened() removed_axis = max(self.time_dim_axis, self.batch_dim_axis) dyn_axes.remove(removed_axis) - dyn_axes = [(i if (i < removed_axis) else (i - 1)) for i in dyn_axes] + dyn_axes = [(i if (i < removed_axis) else (i - 1)) + for i in dyn_axes] ndim -= 1 if len(dyn_axes) > 1: shape = tf.shape(x) - x = tf.reshape(x, [tf.reduce_prod([shape[i] for i in dyn_axes])] + - [shape[i] for i in range(ndim) if i not in dyn_axes]) + x = tf.reshape( + x, + [tf.reduce_prod([shape[i] for i in dyn_axes])] + + [shape[i] for i in range(ndim) if i not in dyn_axes]) dyn_axes = [0] assert dyn_axes == [0] if keepdims and orig_num_dyn_axes >= 2: @@ -4729,7 +4684,9 @@ def _verify_axis_order_dependent(cls): """ from returnn.util import BehaviorVersion BehaviorVersion.require( - condition=False, message="Do not specify axis or axes in a way that depends on the order of the axes.", version=7) + condition=False, + message="Do not specify axis or axes in a way that depends on the order of the axes.", + version=7) def _make_valid_int_axis(self, axis): """ @@ -4764,9 +4721,9 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified): if len(dims) > 1: max_match_priority = max(self.dim_tags[i].match_priority for i in dims) dims = [i for i in dims if self.dim_tags[i].match_priority == max_match_priority] - assert len(dims) <= 1, ("%s: matching dim %s must be unique," - " use `match_priority` to resolve the matching order of ambiguous dimensions" % - (self, axes)) + assert len(dims) <= 1, ( + "%s: matching dim %s must be unique," + " use `match_priority` to resolve the matching order of ambiguous dimensions" % (self, axes)) return dims if isinstance(axes, int): self._verify_axis_int_from_description(allow_int=allow_int) @@ -4926,8 +4883,9 @@ def get_description_from_axis(self, axis): name = dim_tag.description matching_axes = self.get_axes_by_tag_name(name, spatial_only=True) assert axis in matching_axes - return ("stag-single:%i:%s" % (matching_axes.index(axis) - len(matching_axes), name) - ) # negative because this is likely more robust + return ( + "stag-single:%i:%s" % ( + matching_axes.index(axis) - len(matching_axes), name)) # negative because this is likely more robust def has_axis(self, axis): """ @@ -4946,15 +4904,13 @@ def get_axes_by_tag_name(self, name, spatial_only=False): """ dim_tags = self.get_batch_shape_dim_tags() matching_dim_tags = [ - (axis, tag) - for axis, tag in enumerate(dim_tags) - if name.lower() in tag.description.lower() or name.lower() in tag.get_same_base().description.lower() - ] + (axis, tag) for axis, tag in enumerate(dim_tags) + if name.lower() in tag.description.lower() + or name.lower() in tag.get_same_base().description.lower()] if spatial_only: spatial_axes = self.get_spatial_batch_axes() matching_dim_tags = [ - (axis, tag) for axis, tag in matching_dim_tags if axis in spatial_axes or tag.is_spatial_dim() - ] + (axis, tag) for axis, tag in matching_dim_tags if axis in spatial_axes or tag.is_spatial_dim()] return [ax for ax, _ in matching_dim_tags] def get_axis_by_tag_name(self, name, spatial_only=False): @@ -4964,8 +4920,8 @@ def get_axis_by_tag_name(self, name, spatial_only=False): :rtype: int """ matching_dim_tags = self.get_axes_by_tag_name(name, spatial_only) - assert len(matching_dim_tags) > 0, "%r: no %stag found with name %r" % (self, "spatial " if spatial_only else "", - name) + assert len(matching_dim_tags) > 0, "%r: no %stag found with name %r" % ( + self, "spatial " if spatial_only else "", name) assert len(matching_dim_tags) == 1, "%r: tag name %r is not unique in dim tags %r" % ( self, name, self.get_batch_shape_dim_tags()) return matching_dim_tags[0] @@ -5028,9 +4984,9 @@ def is_time_axis_dynamic(self): return self.batch_shape[self.time_dim_axis_excluding_batch] is None if self.time_dim_axis_excluding_batch in self.size_placeholder: return True - assert isinstance(self.shape[self.time_dim_axis_excluding_batch], - int), ("%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" % - (self, self.time_dim_axis, self.size_placeholder)) + assert isinstance(self.shape[self.time_dim_axis_excluding_batch], int), ( + "%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" % ( + self, self.time_dim_axis, self.size_placeholder)) return False def is_axis_dynamic(self, axis): @@ -5113,14 +5069,16 @@ def get_dynamic_axes(self): :return: list of axes, counted with batch-dim axis (but we exclude the batch dim axis itself) :rtype: list[int] """ - return [axis for axis, dim in enumerate(self.batch_shape) if axis != self.batch_dim_axis and dim is None] + return [axis for axis, dim in enumerate(self.batch_shape) + if axis != self.batch_dim_axis and dim is None] def get_static_axes(self): """ :return: list of axes, counted with batch-dim axis (but we exclude the batch dim axis itself) :rtype: list[int] """ - return [axis for axis, dim in enumerate(self.batch_shape) if axis != self.batch_dim_axis and dim is not None] + return [axis for axis, dim in enumerate(self.batch_shape) + if axis != self.batch_dim_axis and dim is not None] def mark_same_time(self, tags, must_match=False): """ @@ -5167,7 +5125,8 @@ def get_sequence_lengths(self): return self.size_placeholder[self.time_dim_axis_excluding_batch] assert self.shape[self.time_dim_axis_excluding_batch] is not None with same_control_flow_ctx(self.placeholder), tf.name_scope("fixed_seq_len"): - return expand_dims_unbroadcast(self.shape[self.time_dim_axis_excluding_batch], axis=0, dim=self.get_batch_dim()) + return expand_dims_unbroadcast( + self.shape[self.time_dim_axis_excluding_batch], axis=0, dim=self.get_batch_dim()) def get_sequence_mask(self): """ @@ -5313,9 +5272,12 @@ def get_spatial_batch_axes(self): counted with batch-dim. """ return [ - axis for axis in range(self.batch_ndim) if axis != self.batch_dim_axis and - (axis != self.feature_dim_axis or axis == self.time_dim_axis or self.batch_shape[axis] is None) - ] + axis + for axis in range(self.batch_ndim) + if axis != self.batch_dim_axis + and (axis != self.feature_dim_axis or + axis == self.time_dim_axis or + self.batch_shape[axis] is None)] def get_spatial_axes(self): """ @@ -5353,7 +5315,8 @@ def get_special_axes_dict(self, counted_with_batch_dim=True, only_available=Fals axes = list(self.SpecialAxesNames) d = {k: getattr(self, k) for k in axes} if not counted_with_batch_dim: - d = {k: self.get_batch_axis_excluding_batch(v) if (v is not None) else None for (k, v) in d.items()} + d = {k: self.get_batch_axis_excluding_batch(v) if (v is not None) else None + for (k, v) in d.items()} if only_available: d = {k: v for (k, v) in d.items() if v is not None} if self._feature_dim_axis is NotSpecified: # special rule @@ -5368,7 +5331,8 @@ def get_bc_spatial_batch_shape(self): dyn_axes = self.get_spatial_batch_axes() if self.batch_dim_axis is not None: dyn_axes += [self.batch_dim_axis] - return tuple([1 if (axis in dyn_axes) else dim for axis, dim in enumerate(self.batch_shape)]) + return tuple([1 if (axis in dyn_axes) else dim + for axis, dim in enumerate(self.batch_shape)]) def get_bc_shape(self, opts=None): """ @@ -5394,8 +5358,8 @@ def get_bc_shape(self, opts=None): value = None key_axes = self.get_axes_from_description(key) for key_axis in key_axes: - assert key_axis not in axes_map, ("%r get_bc_shape: axis %i is defined multiple times in opts %r" % - (self, key_axis, opts)) + assert key_axis not in axes_map, ( + "%r get_bc_shape: axis %i is defined multiple times in opts %r" % (self, key_axis, opts)) assert 0 <= key_axis < self.batch_ndim, "%r get_bc_shape: invalid axis %i in opts %r" % (self, key_axis, opts) (axes_map if key != "*" else default_axes_map)[key_axis] = ( self.batch_shape[key_axis] if value is None else value) @@ -5496,11 +5460,9 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_ # Note: we don't use copy_extend_with_beam because we don't want to create any ops in the TF graph at this point. common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources]) is_equal_opts = dict( - ignore_feature_dim=ignore_feature_dim, - treat_feature_as_spatial=True, + ignore_feature_dim=ignore_feature_dim, treat_feature_as_spatial=True, allow_same_spatial_dim=True, - undefined_matches=True, - derived_matches=True) + undefined_matches=True, derived_matches=True) if BehaviorVersion.get() < 11: is_equal_opts["broadcast_matches"] = True all_dim_tags, tags_dict = Dim.get_all_dimension_tags(sources, is_equal_opts=is_equal_opts) @@ -5544,7 +5506,8 @@ def find_matching_dim_map(self, other, other_axes, is_equal_opts=None): :rtype: dict[int,int] """ if is_equal_opts is None: - is_equal_opts = dict(allow_same_feature_dim=True, allow_same_spatial_dim=True, treat_feature_as_spatial=True) + is_equal_opts = dict( + allow_same_feature_dim=True, allow_same_spatial_dim=True, treat_feature_as_spatial=True) def map_other_axis_to_self(other_axis, taken_self_axes): """ @@ -5569,15 +5532,14 @@ def map_other_axis_to_self(other_axis, taken_self_axes): is_equal_opts_[opt] = True matching = [ self_axis for self_axis in self.find_matching_dims(other_axis_dim_tag, is_equal_opts_) - if self_axis not in taken_self_axes - ] + if self_axis not in taken_self_axes] if opt == "unknown_spatial_matches": assert len(matching) <= 1, 'cannot match axes %s from %s to %s, failed at other %s, not unique after %s' % ( other_axes, other, self, other_axis, opt) if matching: break - assert matching, 'cannot match the axes %s from %s to %s. Failing at axis %s' % (other_axes, other, self, - other_axis) + assert matching, 'cannot match the axes %s from %s to %s. Failing at axis %s' % ( + other_axes, other, self, other_axis) # If there are multiple matches (e.g. because two axes have the same feature dim), leave their order intact. # We do this by always choosing the first unused match which is the smallest axes return matching[0] @@ -5743,8 +5705,7 @@ def _create_size_placeholder(name, axis_wo_b, tag, batch_dim): from .basic import reuse_name_scope with reuse_name_scope("extern_data/placeholders/%s" % name, absolute=True): dyn_size_ext = Data( - "%s_dim%i_size" % (name, axis_wo_b), - dtype=Data.size_dtype, + "%s_dim%i_size" % (name, axis_wo_b), dtype=Data.size_dtype, dim_tags=[batch_dim] if batch_dim else [], batch=batch_dim.batch if batch_dim else None) dyn_size = tf_compat.v1.placeholder( @@ -5760,8 +5721,15 @@ def _create_size_placeholder(name, axis_wo_b, tag, batch_dim): tag.set_tag_on_size_tensor(dyn_size) -def _infer_dim_tags_tuple_from_shape(shape, batch_dim_axis, time_dim_axis, feature_dim_axis, sparse, size_placeholder, - dim_tags, name, auto_create_placeholders): +def _infer_dim_tags_tuple_from_shape( + shape, + batch_dim_axis, time_dim_axis, feature_dim_axis, + sparse, + size_placeholder, + dim_tags, + name, + auto_create_placeholders +): """ :param tuple[int|None]|list[int|None] shape: this is without batch-dim-axis :param int|None batch_dim_axis: @@ -5802,9 +5770,12 @@ def _infer_dim_tags_tuple_from_shape(shape, batch_dim_axis, time_dim_axis, featu dim_tags[axis] = tag # See Data.get_spatial_batch_axes spatial_axes = [ - axis for axis in range(len(batch_shape)) - if axis != batch_dim_axis and (axis != feature_dim_axis or axis == time_dim_axis or batch_shape[axis] is None) - ] + axis + for axis in range(len(batch_shape)) + if axis != batch_dim_axis + and (axis != feature_dim_axis or + axis == time_dim_axis or + batch_shape[axis] is None)] for axis in range(len(batch_shape)): tag = dim_tags.get(axis) axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis) @@ -5832,11 +5803,8 @@ def _infer_dim_tags_tuple_from_shape(shape, batch_dim_axis, time_dim_axis, featu continue if axis == feature_dim_axis and dyn_size is None and axis != time_dim_axis: tag = Dim( - kind=Dim.Types.Feature, - dimension=dim, - description="feature:%s" % name, - undefined=dim is None, - auto_generated=True) + kind=Dim.Types.Feature, dimension=dim, description="feature:%s" % name, + undefined=dim is None, auto_generated=True) else: assert axis in spatial_axes description = "time" if axis == time_dim_axis else "spatial%i" % spatial_axes.index(axis) @@ -5850,12 +5818,8 @@ def _infer_dim_tags_tuple_from_shape(shape, batch_dim_axis, time_dim_axis, featu description += ":static%i" % dim description += ":%s" % name tag = Dim( - kind=Dim.Types.Spatial, - description=description, - dimension=dim, - dyn_size=dyn_size, - undefined=dim is None and dyn_size is None, - auto_generated=True) + kind=Dim.Types.Spatial, description=description, dimension=dim, dyn_size=dyn_size, + undefined=dim is None and dyn_size is None, auto_generated=True) dim_tags[axis] = tag assert sorted(dim_tags.keys()) == list(range(len(batch_shape))) return tuple(dim_tags[axis] for axis in range(len(batch_shape))) From 6ea93f1b8cfd546a6e71750dfd9674fca092bfe8 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Sun, 16 Oct 2022 09:47:33 +0200 Subject: [PATCH 09/23] Behavioral implementation for exception --- returnn/tf/util/data.py | 6 +++--- returnn/util/basic.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 47c15e7c2b..1ba8dd089e 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -12,7 +12,7 @@ import tensorflow as tf import traceback -from returnn.util.basic import NotSpecified, Entity +from returnn.util.basic import BehaviorVersion, NotSpecified, Entity import returnn.tf.compat as tf_compat @@ -984,8 +984,8 @@ def declare_same_as(self, other): if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - raise Exception("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( - self.dyn_size, other_same_base.dyn_size)) + BehaviorVersion.require(False,("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( + self.dyn_size, other_same_base.dyn_size)),14) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: diff --git a/returnn/util/basic.py b/returnn/util/basic.py index b84deafec4..c9f1fe83ca 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -238,7 +238,7 @@ class BehaviorVersion: The version will be set after the config is defined at __main__.init_config() or Engine.__init__() """ - _latest_behavior_version = 13 + _latest_behavior_version = 14 _behavior_version = None # type: typing.Optional[int] @classmethod From da2fed7496d369d6f795248671540a4a7c2c76af Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Mon, 17 Oct 2022 09:13:39 +0200 Subject: [PATCH 10/23] added documentation for behavior --- docs/configuration_reference/behavior_version.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/configuration_reference/behavior_version.rst b/docs/configuration_reference/behavior_version.rst index 79285bd2bc..31b3820cf9 100644 --- a/docs/configuration_reference/behavior_version.rst +++ b/docs/configuration_reference/behavior_version.rst @@ -22,6 +22,13 @@ and not listing legacy/deprecated parameters. Version History --------------- +Behavior version 14 (2022-10-17) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Dimensions tags were previously assumed to be the same with different size place_holders when processing matrixes. From now on, The dimension tags are now not allowed to be different and it will raise an exception. + +See issue `#1141 `__. + Behavior version 13 (2022-10-13) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 59bae84fed78c1bf059ca84149a945ad5937f16e Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Mon, 17 Oct 2022 09:24:59 +0200 Subject: [PATCH 11/23] formating contribution --- returnn/tf/util/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 1ba8dd089e..66c2e77154 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -1,4 +1,3 @@ - """ Provides :class:`Data`, :class:`Dim`, :class:`SearchBeam`. @@ -984,8 +983,12 @@ def declare_same_as(self, other): if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - BehaviorVersion.require(False,("Dim tags are same with different size placeholders: %r vs %r please check external_data" % ( - self.dyn_size, other_same_base.dyn_size)),14) + BehaviorVersion.require( + False, + ("Dim tags are same with different size placeholders please check external_data" % + (self.dyn_size, other_same_base.dyn_size)), + 14 + ) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: From 49d996067de4ba84155a938b2dbb90a1d007fa88 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 17 Oct 2022 09:33:28 +0200 Subject: [PATCH 12/23] Update returnn/tf/util/data.py --- returnn/tf/util/data.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 66c2e77154..fe2efef90c 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -985,10 +985,9 @@ def declare_same_as(self, other): if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: BehaviorVersion.require( False, - ("Dim tags are same with different size placeholders please check external_data" % - (self.dyn_size, other_same_base.dyn_size)), - 14 - ) + "Dim tags are same with different size placeholders (%r vs %r), please check external_data" % ( + self.dyn_size, other_same_base.dyn_size), + 14) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: From 3820bc7ba5c8fed4653189d427fe0a6880c9d78d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 17 Oct 2022 09:33:36 +0200 Subject: [PATCH 13/23] Update docs/configuration_reference/behavior_version.rst --- docs/configuration_reference/behavior_version.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/configuration_reference/behavior_version.rst b/docs/configuration_reference/behavior_version.rst index 31b3820cf9..ef6e2d1893 100644 --- a/docs/configuration_reference/behavior_version.rst +++ b/docs/configuration_reference/behavior_version.rst @@ -25,7 +25,11 @@ Version History Behavior version 14 (2022-10-17) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Dimensions tags were previously assumed to be the same with different size place_holders when processing matrixes. From now on, The dimension tags are now not allowed to be different and it will raise an exception. +Dimension tags with different dynamic size tensors +can not be merged anymore via `declare_same_as`. +This can happen when the user did not set the dim tags +correctly in `extern_data`. +Otherwise it is likely a bug. See issue `#1141 `__. From 8ae0cbd6aa781a1dcdac56c31db1d46b74d74e3d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 17 Oct 2022 09:34:14 +0200 Subject: [PATCH 14/23] Update returnn/tf/util/data.py --- returnn/tf/util/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index fe2efef90c..0a366f51d4 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -1,3 +1,4 @@ + """ Provides :class:`Data`, :class:`Dim`, :class:`SearchBeam`. From 2ebc9fcc6d207b77af0a7690eaf217e1d8f5b9f1 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Tue, 18 Oct 2022 11:28:54 +0200 Subject: [PATCH 15/23] removed auto_create_placeholders --- tests/test_TFUtil.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_TFUtil.py b/tests/test_TFUtil.py index 8acacf5ad3..c41eb981ee 100644 --- a/tests/test_TFUtil.py +++ b/tests/test_TFUtil.py @@ -663,8 +663,8 @@ def test_Data_copy_compatible_to_bias_to_batch_time_spatial_feature(): def test_Data_get_common_data_extra_static_spatial(): - d1 = Data(name='t', shape=(None, 32, 128), dtype='float32', auto_create_placeholders=True) - d2 = Data(name='r', shape=(None, 32, 128), dtype='float32', auto_create_placeholders=True) + d1 = Data(name='t', shape=(None, 32, 128), dtype='float32') + d2 = Data(name='r', shape=(None, 32, 128), dtype='float32') d2.get_size_dim_tag(0).declare_same_as(d1.get_size_dim_tag(0)) common = Data.get_common_data([d1, d2]) assert d1.shape == common.shape @@ -678,8 +678,8 @@ def test_Data_get_common_data_broadcast_multiple(): def test_Data_get_common_data_extra2_static_spatial(): - d1 = Data(name='t', shape=(None, 32, 32, 128), dtype='float32', auto_create_placeholders=True) - d2 = Data(name='r', shape=(None, 32, 32, 128), dtype='float32', auto_create_placeholders=True) + d1 = Data(name='t', shape=(None, 32, 32, 128), dtype='float32') + d2 = Data(name='r', shape=(None, 32, 32, 128), dtype='float32') d2.get_size_dim_tag(0).declare_same_as(d1.get_size_dim_tag(0)) common = Data.get_common_data([d1, d2]) assert d1.shape == common.shape From ac940573b9cd6d28babbe17ad7f30e86165464a7 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Wed, 19 Oct 2022 16:05:53 +0200 Subject: [PATCH 16/23] corrected documentation for behavior --- .../behavior_version.rst | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/configuration_reference/behavior_version.rst b/docs/configuration_reference/behavior_version.rst index aa0061eace..138db1ad50 100644 --- a/docs/configuration_reference/behavior_version.rst +++ b/docs/configuration_reference/behavior_version.rst @@ -25,17 +25,6 @@ Version History Behavior version 15 (2022-10-19) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Dimension tags with different dynamic size tensors -can not be merged anymore via `declare_same_as`. -This can happen when the user did not set the dim tags -correctly in `extern_data`. -Otherwise it is likely a bug. - -See issue `#1141 `__. - -Behavior version 14 (2022-10-19) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - The dim matching in :class:`DotLayer` is now more strict for the case that ``var1`` and ``var2`` are not provided, to figure out the common dims. @@ -46,6 +35,17 @@ then just specify ``var1`` and ``var2`` explicitly. See issue `#1154 `__. +Behavior version 14 (2022-10-17) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Dimension tags with different dynamic size tensors +can not be merged anymore via `declare_same_as`. +This can happen when the user did not set the dim tags +correctly in `extern_data`. +Otherwise it is likely a bug. + +See issue `#1141 `__. + Behavior version 13 (2022-10-13) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From b6883e255e524c9b3a21047997dccbf5f255446f Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Wed, 19 Oct 2022 16:42:45 +0200 Subject: [PATCH 17/23] moved dim tags to behavior version 15 --- .../behavior_version.rst | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/configuration_reference/behavior_version.rst b/docs/configuration_reference/behavior_version.rst index 138db1ad50..aa0061eace 100644 --- a/docs/configuration_reference/behavior_version.rst +++ b/docs/configuration_reference/behavior_version.rst @@ -25,6 +25,17 @@ Version History Behavior version 15 (2022-10-19) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Dimension tags with different dynamic size tensors +can not be merged anymore via `declare_same_as`. +This can happen when the user did not set the dim tags +correctly in `extern_data`. +Otherwise it is likely a bug. + +See issue `#1141 `__. + +Behavior version 14 (2022-10-19) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + The dim matching in :class:`DotLayer` is now more strict for the case that ``var1`` and ``var2`` are not provided, to figure out the common dims. @@ -35,17 +46,6 @@ then just specify ``var1`` and ``var2`` explicitly. See issue `#1154 `__. -Behavior version 14 (2022-10-17) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Dimension tags with different dynamic size tensors -can not be merged anymore via `declare_same_as`. -This can happen when the user did not set the dim tags -correctly in `extern_data`. -Otherwise it is likely a bug. - -See issue `#1141 `__. - Behavior version 13 (2022-10-13) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 33d7ebf6dae0112a56eba3cbd69d69d748ee28c2 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 20 Oct 2022 09:56:40 +0200 Subject: [PATCH 18/23] tflayernetwork test compatibility warning-to-exception --- tests/test_TFNetworkLayer.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 61dd1fcf7f..25a61d7643 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -1178,15 +1178,13 @@ def test_CombineLayer_two_time_dims(): name="in0", shape=(None, None, n_dim), batch_dim_axis=1, auto_create_placeholders=True) in1 = Data( # same time as first in in0 - name="in1", shape=(None, n_dim), auto_create_placeholders=True) + name="in1", dim_tags=[in0.dim_tags[i] for i in (1, 0, 3)], auto_create_placeholders=True) in2 = Data( # same time as in second in in0 - name="in2", shape=(None, n_dim), batch_dim_axis=1, auto_create_placeholders=True) + name="in2", dim_tags=[in0.dim_tags[i] for i in (2, 1, 3)], auto_create_placeholders=True) extern_data.register_data(in0) extern_data.register_data(in1) extern_data.register_data(in2) - in1.get_size_dim_tag(0).declare_same_as(in0.get_size_dim_tag(0)) - in2.get_size_dim_tag(0).declare_same_as(in0.get_size_dim_tag(1)) print("ExternData all dimension tags (allow_same_feature_dim=True):") pprint(extern_data.get_all_dimension_tags(allow_same_feature_dim=True)) network = TFNetwork(config=config, extern_data=extern_data, train_flag=True) @@ -1232,15 +1230,13 @@ def test_CombineLayer_two_time_dims_first_not_most_generic(): name="in0", shape=(None, None, n_dim), batch_dim_axis=1, auto_create_placeholders=True) in1 = Data( # same time as first in in0 - name="in1", shape=(None, n_dim), auto_create_placeholders=True) + name="in1", dim_tags=[in0.dim_tags[i] for i in (1, 0, 3)], auto_create_placeholders=True) in2 = Data( # same time as in second in in0 - name="in2", shape=(None, n_dim), batch_dim_axis=1, auto_create_placeholders=True) + name="in2", dim_tags=[in0.dim_tags[i] for i in (2, 1, 3)], auto_create_placeholders=True) extern_data.register_data(in0) extern_data.register_data(in1) extern_data.register_data(in2) - in1.get_size_dim_tag(0).declare_same_as(in0.get_size_dim_tag(0)) - in2.get_size_dim_tag(0).declare_same_as(in0.get_size_dim_tag(1)) print("ExternData all dimension tags (allow_same_feature_dim=True):") pprint(extern_data.get_all_dimension_tags(allow_same_feature_dim=True)) network = TFNetwork(config=config, extern_data=extern_data, train_flag=True) @@ -1286,15 +1282,13 @@ def test_CombineLayer_two_time_dims_first_not_most_generic_with_n_out(): name="in0", shape=(None, None, n_dim), batch_dim_axis=1, auto_create_placeholders=True) in1 = Data( # same time as first in in0 - name="in1", shape=(None, n_dim), auto_create_placeholders=True) + name="in1", dim_tags=[in0.dim_tags[i] for i in (1, 0, 3)], auto_create_placeholders=True) in2 = Data( # same time as in second in in0 - name="in2", shape=(None, n_dim), batch_dim_axis=1, auto_create_placeholders=True) + name="in2", dim_tags=[in0.dim_tags[i] for i in (2, 1, 3)], auto_create_placeholders=True) extern_data.register_data(in0) extern_data.register_data(in1) extern_data.register_data(in2) - in1.get_size_dim_tag(0).declare_same_as(in0.get_size_dim_tag(0)) - in2.get_size_dim_tag(0).declare_same_as(in0.get_size_dim_tag(1)) print("ExternData all dimension tags (allow_same_feature_dim=True):") pprint(extern_data.get_all_dimension_tags(allow_same_feature_dim=True)) network = TFNetwork(config=config, extern_data=extern_data, train_flag=True) From 41608dbb7e3cda429a498b1259285d425bf02c40 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 20 Oct 2022 16:09:22 +0200 Subject: [PATCH 19/23] changed the order of the if statement --- returnn/tf/util/data.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 2b3329b5a7..91da82a9ad 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -983,6 +983,16 @@ def declare_same_as(self, other): if self_derived_bases.issubset(other_derived_bases): # Avoid cycles on derived_from_tag. https://github.com/rwth-i6/returnn/issues/1054 return other.declare_same_as(self) + if self.dyn_size is not None and other_same_base.dyn_size is not None: + if self.dyn_size is not other_same_base.dyn_size: + if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: + BehaviorVersion.require( + False, + "Dim tags are same with different size placeholders (%r vs %r), please check external_data" % ( + self.dyn_size, other_same_base.dyn_size), + 15) + # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, + # maybe we can overtake the size_placeholder now. if self_same_as is not self: assert not self_same_as.same_as if self_same_as is other_same_base: @@ -995,16 +1005,6 @@ def declare_same_as(self, other): self.same_as = other_same_base self._same_as_tb = traceback.extract_stack() self._maybe_update() - if self.dyn_size is not None and other_same_base.dyn_size is not None: - if self.dyn_size is not other_same_base.dyn_size: - if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - BehaviorVersion.require( - False, - "Dim tags are same with different size placeholders (%r vs %r), please check external_data" % ( - self.dyn_size, other_same_base.dyn_size), - 15) - # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, - # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: assert isinstance(self.src_axis, int) # Maybe it changed in the meanwhile, so check. From 7aa2303eba2ab05f4aeb6e8be772b92612da9ad2 Mon Sep 17 00:00:00 2001 From: JeremyNgu108 Date: Thu, 20 Oct 2022 16:39:09 +0200 Subject: [PATCH 20/23] restored comment to it's rightful place --- returnn/tf/util/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 91da82a9ad..ac1b845216 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -991,8 +991,6 @@ def declare_same_as(self, other): "Dim tags are same with different size placeholders (%r vs %r), please check external_data" % ( self.dyn_size, other_same_base.dyn_size), 15) - # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, - # maybe we can overtake the size_placeholder now. if self_same_as is not self: assert not self_same_as.same_as if self_same_as is other_same_base: @@ -1005,6 +1003,8 @@ def declare_same_as(self, other): self.same_as = other_same_base self._same_as_tb = traceback.extract_stack() self._maybe_update() + # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, + # maybe we can overtake the size_placeholder now. if other_same_base.dyn_size is not None and self.src_data: assert isinstance(self.src_axis, int) # Maybe it changed in the meanwhile, so check. From c407835cdede707278bf9bf81eaf2e25c8daccb9 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 20 Oct 2022 17:48:08 +0200 Subject: [PATCH 21/23] Update returnn/tf/util/data.py --- returnn/tf/util/data.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index ac1b845216..e4e7728a36 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -983,14 +983,16 @@ def declare_same_as(self, other): if self_derived_bases.issubset(other_derived_bases): # Avoid cycles on derived_from_tag. https://github.com/rwth-i6/returnn/issues/1054 return other.declare_same_as(self) - if self.dyn_size is not None and other_same_base.dyn_size is not None: - if self.dyn_size is not other_same_base.dyn_size: - if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: - BehaviorVersion.require( - False, - "Dim tags are same with different size placeholders (%r vs %r), please check external_data" % ( - self.dyn_size, other_same_base.dyn_size), - 15) + if any( + self._same_for_batch_ctx[key].dyn_size is not None and + other._same_for_batch_ctx[key].dyn_size is not None and + self._same_for_batch_ctx[key].dyn_size is not other._same_for_batch_ctx[key].dyn_size + for key in set(self._same_for_batch_ctx.keys()).intersection(other._same_for_batch_ctx.keys())): + BehaviorVersion.require( + False, + "%s declare_same_as %s: Invalid with different size placeholders, please check external_data" % ( + self, other_same_base), + 15) if self_same_as is not self: assert not self_same_as.same_as if self_same_as is other_same_base: From 353810845c9886b2269d80340b00c4359e8afdf2 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 20 Oct 2022 23:50:28 +0200 Subject: [PATCH 22/23] Update returnn/tf/util/data.py --- returnn/tf/util/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index e4e7728a36..dafcc8964a 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -983,6 +983,8 @@ def declare_same_as(self, other): if self_derived_bases.issubset(other_derived_bases): # Avoid cycles on derived_from_tag. https://github.com/rwth-i6/returnn/issues/1054 return other.declare_same_as(self) + self._maybe_update() + other._maybe_update() if any( self._same_for_batch_ctx[key].dyn_size is not None and other._same_for_batch_ctx[key].dyn_size is not None and From 7df2a58a0adffd2a06ec2414376db919af7ef4d0 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 21 Oct 2022 00:08:34 +0200 Subject: [PATCH 23/23] Update returnn/tf/util/data.py --- returnn/tf/util/data.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index dafcc8964a..2bb8eb9916 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -985,15 +985,17 @@ def declare_same_as(self, other): return other.declare_same_as(self) self._maybe_update() other._maybe_update() - if any( - self._same_for_batch_ctx[key].dyn_size is not None and - other._same_for_batch_ctx[key].dyn_size is not None and - self._same_for_batch_ctx[key].dyn_size is not other._same_for_batch_ctx[key].dyn_size - for key in set(self._same_for_batch_ctx.keys()).intersection(other._same_for_batch_ctx.keys())): + for key in set(self._same_for_batch_ctx.keys()).intersection(other._same_for_batch_ctx.keys()): + self_ = self._same_for_batch_ctx[key] + other_ = other._same_for_batch_ctx[key] + if not self_._validate_in_current_graph() or not other_._validate_in_current_graph(): + continue + if self_.dyn_size is None or other_.dyn_size is None: + continue BehaviorVersion.require( - False, - "%s declare_same_as %s: Invalid with different size placeholders, please check external_data" % ( - self, other_same_base), + self_.dyn_size is other_.dyn_size, + "%s declare_same_as %s: Invalid with different size placeholders (%r vs %r), please check external_data" % ( + self, other, self_.dyn_size, other_.dyn_size), 15) if self_same_as is not self: assert not self_same_as.same_as