Skip to content

Commit

Permalink
remove destroy threads
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Aug 30, 2023
1 parent 1c1e6dc commit b2c92c1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
31 changes: 21 additions & 10 deletions pipegoose/nn/pipeline_parallel2/_worker.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
import threading
from queue import Queue
from threading import Thread
from typing import Callable, List

from pipegoose.constants import PIPELINE_MAX_WORKERS, PIPELINE_MIN_WORKERS
from pipegoose.nn.pipeline_parallel2._utils import sleep


class Worker(Thread):
class Worker(threading.Thread):
def __init__(self, selected_jobs: Queue, *args, **kwargs):
super().__init__(*args, **kwargs)
self._selected_jobs = selected_jobs
self._running = False
# self._stop_event = threading.Event()

@property
def running(self) -> bool:
return self._running

# def stop(self):
# self._stop_event.set()

# def stopped(self):
# return self._stop_event.is_set()

def run(self):
while True:
job = self._selected_jobs.get()
Expand All @@ -24,7 +31,7 @@ def run(self):
self._running = False


class WorkerPoolWatcher(Thread):
class WorkerPoolWatcher(threading.Thread):
def __init__(self, worker_pool: List[Worker], min_workers: int, max_workers: int, spawn_worker: Callable, *args, **kwargs):
super().__init__(*args, **kwargs)
self.worker_pool = worker_pool
Expand All @@ -49,7 +56,7 @@ def _num_working_workers(self) -> int:
return num_working


class JobSelector(Thread):
class JobSelector(threading.Thread):
def __init__(self, pending_jobs: Queue, selected_jobs: Queue, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pending_job = pending_jobs
Expand All @@ -69,7 +76,7 @@ def _select_job(self):
return job


class WorkerManager(Thread):
class WorkerManager(threading.Thread):
def __init__(
self,
num_workers: int = PIPELINE_MIN_WORKERS,
Expand Down Expand Up @@ -132,9 +139,13 @@ def spawn(self):
self._spawn_pool_watcher()

def destroy(self):
# TODO: why can't we join() here?
for worker in self.worker_pool:
# TODO: wait for workers finish their jobs
# before joining
# Create a copy of the worker pool to iterate over
worker_pool_copy = self.worker_pool.copy()

for worker in worker_pool_copy:
# Terminate the worker thread
worker.stop()
worker.join()
# self.worker_pool.remove(worker)

# Remove the worker from the original worker pool
# self.worker_pool.remove(worker)
3 changes: 3 additions & 0 deletions tests/nn/pipeline_parallel_2/test_worker2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ def test_worker_manager():
for worker in worker_manager.worker_pool:
assert worker.running is False
assert worker.is_alive() is True

# TODO: add this
# worker_manager.destroy()

0 comments on commit b2c92c1

Please sign in to comment.