From 9a9b0fa408dc7cd0a2b72350071e2190bb1e6d72 Mon Sep 17 00:00:00 2001 From: wejoncy <247153481@qq.com> Date: Fri, 22 Nov 2024 02:53:15 +0000 Subject: [PATCH 1/6] resolve 3060 --- src/accelerate/hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 14d57e33661..2da8796b164 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -436,7 +436,8 @@ def attach_execution_device_hook( return for child in module.children(): - attach_execution_device_hook(child, execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map) + attach_execution_device_hook(child, execution_device, skip_keys=skip_keys, + preload_module_classes=preload_module_classes, tied_params_map=tied_params_map) def attach_align_device_hook( From 1078cd230b30f4114595f0cd5bc7eb1893779617 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 26 Nov 2024 13:22:35 +0800 Subject: [PATCH 2/6] format --- src/accelerate/hooks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 2da8796b164..50098eaac01 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -436,8 +436,13 @@ def attach_execution_device_hook( return for child in module.children(): - attach_execution_device_hook(child, execution_device, skip_keys=skip_keys, - preload_module_classes=preload_module_classes, tied_params_map=tied_params_map) + attach_execution_device_hook( + child, + execution_device, + skip_keys=skip_keys, + preload_module_classes=preload_module_classes, + tied_params_map=tied_params_map, + ) def attach_align_device_hook( From 3213231104855419bbeadfedd2d98c4045fad945 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 26 Nov 2024 14:58:31 +0800 Subject: [PATCH 3/6] add tests --- tests/test_accelerator.py | 62 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 9b18fe5c909..f774b52ba5f 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -762,3 +762,65 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights assert torch.allclose(original_linear1, new_linear1) assert torch.allclose(original_batchnorm, new_batchnorm) assert torch.allclose(original_linear2, new_linear2) + + @require_cuda + def test_nested_hook(self, use_safetensors): + from transformers.modeling_utils import PretrainedConfig, PreTrainedModel + + class MyLinear(torch.nn.Module): + def __init__(self, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.centroid = torch.nn.Embedding(1, 2) + self.indices = torch.nn.parameter(torch.empty((1, 2, 2), **factory_kwargs)) + + def forward(self, x): + orig_shape = x.shape + x = torch.abs(x + self.indices).long() + x = x % 2 + x = x.sum(-1) + x = (self.centroid.weight + x).reshape(orig_shape) + return x + + class MySubModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = MyLinear() + + def forward(self, x): + return self.layer(x) + + class MyModel(PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layer = torch.nn.ModuleList([MySubModel() for i in range(4)]) + + def forward(self, x): + for layer in self.layer: + x = layer(x) + return x + + with tempfile.TemporaryDirectory() as tmpdirname: + check_point = tmpdirname + offload_folder = check_point + "/offload" + os.makedirs(offload_folder, exist_ok=True) + config = PretrainedConfig() + m = MyModel(config) + m.save_pretrained(check_point) + + with init_empty_weights(): + my_model = MyModel(config) + my_model = load_checkpoint_and_dispatch( + my_model, + checkpoint=check_point, + max_memory={"cpu": 60, 0: 60}, + device_map="auto", + no_split_module_classes=["MySubModel"], + offload_folder=offload_folder, + preload_module_classes=["VQuantLinear"], + ) + x = torch.randn(1, 2) + print(my_model(x)) + # before fix, this would raise an error + # weight is on the meta device, we need a `value` to put in on 0 + my_model(x) From a0a847c9fabc9299da262eed9f49275a68de4abd Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 26 Nov 2024 15:12:09 +0800 Subject: [PATCH 4/6] fix --- tests/test_accelerator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index f774b52ba5f..41475c57f52 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -819,8 +819,7 @@ def forward(self, x): offload_folder=offload_folder, preload_module_classes=["VQuantLinear"], ) - x = torch.randn(1, 2) - print(my_model(x)) # before fix, this would raise an error # weight is on the meta device, we need a `value` to put in on 0 + x = torch.randn(1, 2) my_model(x) From 2ebc1ab88863d3a03af2c8d698c852d33b560308 Mon Sep 17 00:00:00 2001 From: wejoncy <247153481@qq.com> Date: Wed, 27 Nov 2024 00:45:30 +0000 Subject: [PATCH 5/6] fix --- tests/test_accelerator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 41475c57f52..bb939e15021 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -32,6 +32,7 @@ require_bnb, require_multi_gpu, require_non_cpu, + require_huggingface_suite, require_transformer_engine, slow, torch_device, @@ -764,6 +765,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights assert torch.allclose(original_linear2, new_linear2) @require_cuda + @require_huggingface_suite def test_nested_hook(self, use_safetensors): from transformers.modeling_utils import PretrainedConfig, PreTrainedModel @@ -817,7 +819,7 @@ def forward(self, x): device_map="auto", no_split_module_classes=["MySubModel"], offload_folder=offload_folder, - preload_module_classes=["VQuantLinear"], + preload_module_classes=["MyLinear"], ) # before fix, this would raise an error # weight is on the meta device, we need a `value` to put in on 0 From 6642ecf0acea50d292afe3a1fb37622781428924 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 3 Dec 2024 11:01:41 +0800 Subject: [PATCH 6/6] format --- tests/test_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index bb939e15021..651dc17da5e 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -30,9 +30,9 @@ from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, + require_huggingface_suite, require_multi_gpu, require_non_cpu, - require_huggingface_suite, require_transformer_engine, slow, torch_device,