Skip to content

Commit

Permalink
Fix cloning of Sequential models w. input_tensors argument (#20550)
Browse files Browse the repository at this point in the history
* Fix cloning for Sequential w. input tensor

* Add missing test for input_tensor argument

* Add Sequential wo. Input to test, build model to ensure defined inputs
  • Loading branch information
Carbyne authored Nov 26, 2024
1 parent 553521e commit b6d305f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
9 changes: 7 additions & 2 deletions keras/src/models/cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
input_dtype = None
input_batch_shape = None

if input_tensors:
if input_tensors is not None:
if isinstance(input_tensors, (list, tuple)):
if len(input_tensors) != 1:
raise ValueError(
Expand All @@ -310,7 +310,12 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
"Argument `input_tensors` must be a KerasTensor. "
f"Received invalid value: input_tensors={input_tensors}"
)
inputs = Input(tensor=input_tensors, name=input_name)
inputs = Input(
tensor=input_tensors,
batch_shape=input_tensors.shape,
dtype=input_tensors.dtype,
name=input_name,
)
new_layers = [inputs] + new_layers
else:
if input_batch_shape is not None:
Expand Down
26 changes: 26 additions & 0 deletions keras/src/models/cloning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def get_sequential_model(explicit_input=True):
return model


def get_cnn_sequential_model(explicit_input=True):
model = models.Sequential()
if explicit_input:
model.add(layers.Input(shape=(7, 3)))
model.add(layers.Conv1D(2, 2, padding="same"))
model.add(layers.Conv1D(2, 2, padding="same"))
return model


def get_subclassed_model():
class ExampleModel(models.Model):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -124,6 +133,23 @@ def clone_function(layer):
if not isinstance(l1, layers.InputLayer):
self.assertEqual(l2.name, l1.name + "_custom")

@parameterized.named_parameters(
("cnn_functional", get_cnn_functional_model),
("cnn_sequential", get_cnn_sequential_model),
(
"cnn_sequential_noinputlayer",
lambda: get_cnn_sequential_model(explicit_input=False),
),
)
def test_input_tensors(self, model_fn):
ref_input = np.random.random((2, 7, 3))
model = model_fn()
model(ref_input) # Maybe needed to get model inputs if no Input layer
input_tensor = model.inputs[0]
new_model = clone_model(model, input_tensors=input_tensor)
tree.assert_same_structure(model.inputs, new_model.inputs)
tree.assert_same_structure(model.outputs, new_model.outputs)

def test_shared_layers_cloning(self):
model = get_mlp_functional_model(shared_layers=True)
new_model = clone_model(model)
Expand Down

0 comments on commit b6d305f

Please sign in to comment.