diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 5c90c9f6975..be8a4fee224 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -367,7 +367,7 @@ def test_saved_module_paths_and_class_names(self): ) self.assertEqual( config_dict["compile_config"]["loss"]["config"], - "my_mean_squared_error", + "my_custom_package>my_mean_squared_error", ) @pytest.mark.requires_trainable_backend diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index 3adc832884e..cf8eb327fb4 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -366,7 +366,7 @@ def _get_class_or_fn_config(obj): """Return the object's config depending on its type.""" # Functions / lambdas: if isinstance(obj, types.FunctionType): - return obj.__name__ + return object_registration.get_registered_name(obj) # All classes: if hasattr(obj, "get_config"): config = obj.get_config() @@ -781,15 +781,6 @@ def _retrieve_class_or_fn( if obj is not None: return obj - # Retrieval of registered custom function in a package - filtered_dict = { - k: v - for k, v in custom_objects.items() - if k.endswith(full_config["config"]) - } - if filtered_dict: - return next(iter(filtered_dict.values())) - # Otherwise, attempt to retrieve the class object given the `module` # and `class_name`. Import the module, find the class. try: