mle-infrastructure / mle-toolbox

Lightweight Tool to Manage Distributed ML Experiments 🛠
https://mle-infrastructure.github.io/mle_toolbox/toolbox/
MIT License
3 stars 1 forks source link

Asynchronous scheduling of jobs #8

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Currently experiments are scheduled in batches. Once an entire batch is completed, we submit a new batch of experiments. But some nodes are faster than others. The synchronicity slows us down. Add option to also run jobs async and schedule new ones as we go given a max budget of running jobs.

denisalevi commented 3 years ago

I once wrote a threading class that collected tasks in a queue and submitted jobs to the queue engine such that a fixed number of jobs was running at any time. Could possibly reuse / adapt it.

Here is the script, just copy pasted, but gives an idea what it does:

```python class SgeQrshThread(threading.Thread): def __init__( self, run_script, thread_idx, parameter_queue, running_queue, log_stream=None, **kwargs, ): """ Thread worker that spawns a single task using `qrsh` on a cluster running sun grid engine. Tasks are collected from `parameter_queue` and once started, task information including `thread_idx` and `hostname` are put into `running_queue` for update reports. """ threading.Thread.__init__(self, **kwargs) # the script that is run in `qrsh` self.run_script = run_script # the idx of this thread self.thread_idx = thread_idx # queue of paramters for next run self.parameter_queue = parameter_queue # queue of currently active simulations self.running_queue = running_queue # event stop executing this thread self.stop_request = threading.Event() if log_stream is None: log_stream = sys.stderr self.log_stream = log_stream # are we on a sun grid engine cluster? self.sge = False if hostname == "vnc" or hostname.startswith("cognition"): tqdm.write( "INFO detected cluster. Using qrsh to submit jobs", file=self.log_stream ) self.sge = True def run(self): while not self.stop_request.isSet(): try: # block=False raises Empty exception immediatly if queue is empty # set block=True and timeout for waiting if queue is empyt params = self.parameter_queue.get(block=False) except queue.Empty: # tasks are done return progress_file = params.pop("progress_file") model_name = params.pop("model_name") num_cycles = params.pop("num_consolidation_cycles") params_file = params.pop("params_file") self.running_queue.put( (self.thread_idx, progress_file, model_name, num_cycles) ) logfile = params.pop("logfile") # TODO: getcwd with os instead of PWD cmd = ( f"{sys.executable} {self.run_script} " f" --params_file {params_file}" f" --thread_idx {self.thread_idx}" ) if self.sge: cwd = None cmd = f"cd $PWD && conda activate tf && " + cmd args = ["qrsh", "-N", f"{model_name}", cmd] else: cwd = os.path.dirname(__file__) args = shlex.split(cmd) tqdm.write(f"DEBUG cwd {cwd}, args {args}", file=self.log_stream) stdout = open(logfile + "_stdout", "w+") stderr = open(logfile + "_stderr", "w+") p = subprocess.Popen(args, stdout=stdout, stderr=stderr, cwd=cwd) # poll the subprocess state every second terminated = 0 while p.poll() is None: # terminate subprocess when stop signal is sent from main thread if self.stop_request.isSet(): if terminated < 5: # try sending CTRL-C tqdm.write( "Stop request found: sending SIGINT", file=self.log_stream ) p.send_signal(signal.SIGINT) else: # kill after 5s tqdm.write( "Stop request found: sending SIGKILL", file=self.log_stream ) p.kill() terminated += 1 time.sleep(0.1) tqdm.write(f"Task finished thread {self.thread_idx}", file=self.log_stream) stdout.close() stderr.close() self.parameter_queue.task_done() def terminate(self): # sent request to stop reading queues before blocking self.stop_request.set() threading.Thread.join(self) ```

I also wrote a tool that reports the progress (with a bunch of stacked tqdm prgoress bars) on each of the tasks by watching progress files that they create. I felt like that was useful for me to see when specific tasks got stuck. But I guess that progress can be much easier reported using tensorboard for example. Still putting it here, in case you think it could be useful for something. :P

TqdmFileWatcher:

```python class TqdmFileWatch(FileSystemEventHandler): def __init__( self, path, num_bars, num_tasks, watched_files, task_queue, log_stream=None, **tqdm_kwargs, ): self.num_tasks = num_tasks self.main_progress = tqdm( total=num_tasks, position=1, leave=True, desc="Tasks completed" ) self.main_lastiter = 0 self.watched_files = watched_files self.task_queue = task_queue self.tqdm_kwargs = tqdm_kwargs if log_stream is None: log_stream = sys.stderr self.log_stream = log_stream assert os.path.isdir(path) self.path = path self.observer = PollingObserver(timeout=0.1) tqdm.write( f"INFO watching '{path}' for modified progress files", file=self.log_stream ) self.observer.schedule(self, path=path, recursive=False) self.observer.start() self.update_queue = queue.Queue() self.progress_bars = [] self.active_tasks = [] self.last_iterations = [] for i in range(num_bars): self.progress_bars.append(None) self.active_tasks.append(None) self.last_iterations.append(0) def on_any_event(self, event): tqdm.write(f"INFO detected file system event {event}", file=self.log_stream) def on_modified(self, event): """ This function is called in a separate observer thread when any file in `self.observer.path` is modified. It puts the filename int the `update_queue` for the main thread to update the tqdm progress bars. """ tqdm.write( f"INFO ON MODIFIED detected file system event {event}", file=self.log_stream ) file = event.src_path if file in self.watched_files: self.update_queue.put(file) tqdm.write(f"INFO UPDATED QUEUE {file}", file=self.log_stream) elif os.path.basename(file) != os.path.basename(self.path): tqdm.write( f"WARNING detected modification of {file}, which is not " f"in watched_files. ", file=self.log_stream, ) self.log_stream.flush() else: tqdm.write( f"WARNING ON MOD NOTHING DONE of {file}, which is not " f"in watched_files.", file=self.log_stream, ) def watch_for_updates(self, poll_interval=0.1): """ Watch for new updates by checking every `poll_interval` seconds if there is a new element in `update_queue`. If so, updates the progress bars. Terminates when `task_queue` is empty. """ while ( self.task_queue.unfinished_tasks > 0 or self.update_queue.unfinished_tasks > 0 ): try: # for CTRL-C to work, don't use block file = self.update_queue.get(block=False) self._update_next_bar(file) self.update_queue.task_done() except queue.Empty: time.sleep(poll_interval) pass def _update_next_bar(self, file): with open(file, "r", os.O_NONBLOCK) as f: new = f.read() try: # TODO get model_name, thread_idx from watche_files dics model_name, thread_idx, hostname, iteration, state = new.split(",") except ValueError: tqdm.write( f"ERROR Couldn't split '{new}'. Ignoring modified file" f"'{os.path.basename(file)}'.", file=self.log_stream, ) return if state not in ["DONE", "PLOTTING", "RUNNING"]: tqdm.write( f"ERROR state is '{state}', should be 'DONE', 'PLOTTING' or 'RUNNING'. " f"Ignoring modfified file {os.path.basename(file)}", file=self.log_stream, ) return thread_idx = int(thread_idx) iteration = int(iteration) if self.progress_bars[thread_idx] is not None: # UPDATE BAR if self.active_tasks[thread_idx] is None: self.active_tasks[thread_idx] = file if self.active_tasks[thread_idx] != file: tqdm.write( f"WARNING got new task progress from thread {thread_idx} for file " f"{os.path.basename(file)}, but prevoius file " f"{os.path.basename(self.active_tasks[thread_idx])} is not DONE yet " f"(last iteration {self.last_iterations[thread_idx]}). Maybe that task " f"had an error? Closing not DONE bar.", file=self.log_stream, ) self._reset_progress_bar(thread_idx) # TODO check how often the same situation came and then ignore it? # self.update_queue.put(file) # return last_iter = self.last_iterations[thread_idx] increment = iteration - last_iter if increment > 0: self.progress_bars[thread_idx].update(increment) self.last_iterations[thread_idx] = iteration tqdm.write( f"INFO updated {os.path.basename(file)} to {iteration}", file=self.log_stream, ) self.log_stream.flush() else: tqdm.write( f"INFO didn't update {os.path.basename(file)} because increment is " f"{increment}", file=self.log_stream, ) if state == "PLOTTING": self.progress_bars[thread_idx].set_postfix(s="plotting") # TODO doesn't work? # self.progress_bars[thread_idx].set_postfix_str("Plotting...") if state == "DONE": self.main_progress.update(1) self._reset_progress_bar(thread_idx) elif self.progress_bars[thread_idx] is None: # INIT BAR self.progress_bars[thread_idx] = tqdm( total=self.watched_files[file], position=thread_idx + 3, leave=True, postfix={"s": "running"}, desc=f"{hostname}: {model_name}", **self.tqdm_kwargs, ) def _reset_progress_bar(self, thread_idx): assert self.progress_bars[thread_idx] is not None # CLEAN UP BAR self.progress_bars[thread_idx].set_postfix(s="finished") self.progress_bars[thread_idx].close() self.progress_bars[thread_idx] = None self.active_tasks[thread_idx] = None self.last_iterations[thread_idx] = 0 def close(self): for i, bar in enumerate(self.progress_bars): if bar is None: tqdm.write( f"ERROR trying to close uninitialized progress bar of thread {i}", file=self.log_stream, ) else: bar.close() self.main_progress.close() self.observer.stop() self.observer.join() ```

RobertTLange commented 3 years ago

Finally got this sorted and cleaned up in ec317dfaa5eb4f46df577a6bac74353a9227ec85, 234c8cecb978578b3b3ff722ea54449630d5fa60 and a7ebf9e9e63bf73e730f5a567c272a7bec314ea6. All experiment launching and monitoring is now handled by the ExperimentQueue:

https://github.com/RobertTLange/mle-toolbox/blob/a7ebf9e9e63bf73e730f5a567c272a7bec314ea6/mle_toolbox/experiment/experiment_queue.py#L14-L28

The queue starts by launching jobs until the max_running_jobs limit is reached (or all jobs in the queue are scheduled). Afterwards the running jobs are monitored. Once a job finishes, the next one in the queue gets launched. If all seeds for one hyperparameter evaluation have finished, the queue will merge their .hdf5 logs together.

The queue is used for all core experiment types including both the async and sync search experiments. For the sync batch case, we simply instantiate and wait until completion for multiple sequentially launched queues.

Furthermore, these commits get rid of all the multiprocessing shenanigans, where each job was monitored in its own subprocess. This could cause problems when a process randomly died! Now in principle all of this can directly run on the head node of the cluster without taking up more than a single core.