Skip to content

Commit

Permalink
Use tasks_per_node to split sweep across tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Apr 4, 2023
1 parent 744318b commit e119fcd
Showing 1 changed file with 37 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,31 @@ def setup(

def __call__(
self,
sweep_overrides: List[str],
job_dir_key: str,
job_num: int,
job_id: str,
singleton_state: Dict[type, Singleton],
) -> JobReturn:
job_params,
) -> Optional[JobReturn]:
# lazy import to ensure plugin discovery remains fast
import submitit

assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None

job_env = submitit.JobEnvironment()
task_id = job_env.global_rank
if task_id >= len(job_params):
# May happen on the last job if the total number of tasks is not a multiple
# of `tasks_per_node`.
return None

sweep_overrides: List[str]
job_dir_key: str
job_num: int
job_id: str
singleton_state: Dict[type, Singleton]
sweep_overrides, job_dir_key, job_num, job_id, singleton_state = job_params[
task_id
]

Singleton.set_state(singleton_state)
setup_globals()
sweep_config = self.hydra_context.config_loader.load_sweep_config(
Expand All @@ -64,7 +76,7 @@ def __call__(

with open_dict(sweep_config.hydra.job) as job:
# Populate new job variables
job.id = submitit.JobEnvironment().job_id # type: ignore
job.id = f"{job_env.job_id}_{job_env.global_rank}" # type: ignore [attr-defined]
sweep_config.hydra.job.num = job_num

return run_job(
Expand Down Expand Up @@ -141,8 +153,24 @@ def launch(
)
)

jobs = executor.map_array(self, *zip(*job_params))
return [j.results()[0] for j in jobs]
# Create groups of parameters of size `tasks_per_node`, so that each task
# can get assigned its own set of parameters.
tasks_per_node = params.get("tasks_per_node", 1)
job_params = [
job_params[start_idx : start_idx + tasks_per_node]
for start_idx in range(0, len(job_params), tasks_per_node)
]

# We need at least two jobs, otherwise submitit will create a single job instead
# of a job array, which will cause issues down the line.
# We create a new job with empty parameters (=> will terminate immediately).
if len(job_params) == 1:
job_params.append([])

jobs = executor.map_array(self, job_params)
return [
result for job in jobs for result in job.results() if result is not None
]


class LocalLauncher(BaseSubmititLauncher):
Expand Down

0 comments on commit e119fcd

Please sign in to comment.