diff --git a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py index 1efc8e4ce8..61190c42d7 100644 --- a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py +++ b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py @@ -2,7 +2,7 @@ import logging import os from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple from hydra.core.singleton import Singleton from hydra.core.utils import JobReturn, filter_overrides, run_job, setup_globals @@ -43,12 +43,8 @@ def setup( def __call__( self, - sweep_overrides: List[str], - job_dir_key: str, - job_num: int, - job_id: str, - singleton_state: Dict[type, Singleton], - ) -> JobReturn: + job_params: List[Tuple[List[str], str, int, str, Dict[type, Singleton]]], + ) -> Optional[JobReturn]: # lazy import to ensure plugin discovery remains fast import submitit @@ -56,6 +52,17 @@ def __call__( assert self.config is not None assert self.task_function is not None + job_env = submitit.JobEnvironment() + task_id = job_env.global_rank + if task_id >= len(job_params): + # May happen on the last job if the total number of tasks is not a multiple + # of `tasks_per_node`. + return None + + sweep_overrides, job_dir_key, job_num, job_id, singleton_state = job_params[ + task_id + ] + Singleton.set_state(singleton_state) setup_globals() sweep_config = self.hydra_context.config_loader.load_sweep_config( @@ -64,7 +71,7 @@ def __call__( with open_dict(sweep_config.hydra.job) as job: # Populate new job variables - job.id = submitit.JobEnvironment().job_id # type: ignore + job.id = f"{job_env.job_id}_{job_env.global_rank}" # type: ignore [attr-defined] sweep_config.hydra.job.num = job_num return run_job( @@ -141,8 +148,24 @@ def launch( ) ) - jobs = executor.map_array(self, *zip(*job_params)) - return [j.results()[0] for j in jobs] + # Create groups of parameters of size `tasks_per_node`, so that each task + # can get assigned its own set of parameters. + tasks_per_node = params.get("tasks_per_node", 1) + job_params = [ + job_params[start_idx : start_idx + tasks_per_node] + for start_idx in range(0, len(job_params), tasks_per_node) + ] + + # We need at least two jobs, otherwise submitit will create a single job instead + # of a job array, which will cause issues down the line. + # We create a new job with empty parameters (=> will terminate immediately). + if len(job_params) == 1: + job_params.append([]) + + jobs = executor.map_array(self, job_params) + return [ + result for job in jobs for result in job.results() if result is not None + ] class LocalLauncher(BaseSubmititLauncher):