From 6a4e10cd484092b150bdd518807634c36b730194 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 30 Mar 2023 09:12:28 +0800 Subject: [PATCH] Fix int8 training be more compatible with low-end gpus --- finetune.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/finetune.py b/finetune.py index 54bb1ca..3a98c9a 100644 --- a/finetune.py +++ b/finetune.py @@ -4,7 +4,7 @@ from modeling_chatglm import ChatGLMForConditionalGeneration import torch import torch.nn as nn -from peft import get_peft_model, LoraConfig, TaskType +from peft import get_peft_model, prepare_model_for_int8_training, LoraConfig, TaskType from dataclasses import dataclass, field import datasets import os @@ -20,11 +20,6 @@ class FinetuneArguments: lora_rank: int = field(default=8) -class CastOutputToFloat(nn.Sequential): - def forward(self, x): - return super().forward(x).to(torch.float32) - - def get_masks_and_position_ids( seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True ): @@ -126,10 +121,10 @@ def main(): model.enable_input_require_grads() model.is_parallelizable = True model.model_parallel = True - model.lm_head = CastOutputToFloat(model.lm_head) model.config.use_cache = ( False # silence the warnings. Please re-enable for inference! ) + model = prepare_model_for_int8_training(model) # setup peft peft_config = LoraConfig(