Skip to content

Commit

Permalink
"|" was replaced with Union to support Python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
ZFTurbo committed Dec 15, 2024
1 parent 343fd7f commit cc01264
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
warnings.filterwarnings("ignore")


def parse_args(args: List[str] | None) -> argparse.Namespace:
def parse_args(args: Union[List[str], None]) -> argparse.Namespace:
"""
Parse command-line arguments for configuring the model, dataset, and training parameters.
Expand Down Expand Up @@ -237,7 +237,7 @@ def load_start_checkpoint(args: argparse.Namespace, model: torch.nn.Module) -> N
model.load_state_dict(torch.load(args.start_check_point))


def initialize_model_and_device(model: torch.nn.Module, device_ids: List[int]) -> Tuple[torch.device | str, torch.nn.Module]:
def initialize_model_and_device(model: torch.nn.Module, device_ids: List[int]) -> Tuple[Union[torch.device, str], torch.nn.Module]:
"""
Initialize the model and assign it to the appropriate device (GPU or CPU).
Expand Down Expand Up @@ -607,10 +607,10 @@ def train_model(args: argparse.Namespace) -> None:
print(
f"Instruments: {config.training.instruments}\n"
f"Metrics for training: {args.metrics}. Metric for scheduler: {args.metric_for_scheduler}\n"
f"Patience: {config.training.patience}\n"
f"Patience: {config.training.patience} "
f"Reduce factor: {config.training.reduce_factor}\n"
f"Batch size: {batch_size}\n"
f"Grad accum steps: {gradient_accumulation_steps}\n"
f"Batch size: {batch_size} "
f"Grad accum steps: {gradient_accumulation_steps} "
f"Effective batch size: {batch_size * gradient_accumulation_steps}\n"
f"Optimizer: {config.training.optimizer}"
)
Expand Down

0 comments on commit cc01264

Please sign in to comment.