Source code for adaptive_scheduler._executor

from __future__ import annotations

import abc
import asyncio
import contextlib
import copy
import datetime
import functools
import os
import time
import uuid
from concurrent.futures import Executor, Future
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple

import cloudpickle
from adaptive import SequenceLearner

import adaptive_scheduler

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable

    from adaptive_scheduler.utils import (
        _DATAFRAME_FORMATS,
        EXECUTOR_TYPES,
        LOKY_START_METHODS,
        GoalTypes,
    )


class AdaptiveSchedulerExecutorBase(Executor):
    _run_manager: adaptive_scheduler.RunManager | None

    @abc.abstractmethod
    def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future:
        """Submit a task to the executor."""

    @abc.abstractmethod
    def finalize(self, *, start: bool = True) -> adaptive_scheduler.RunManager | None:
        """Finalize the executor and return the RunManager.

        Returns None if no learners were submitted.
        """

    def map(  # type: ignore[override]
        self,
        fn: Callable[..., Any],
        /,
        *iterables: Iterable[Any],
        timeout: float | None = None,
        chunksize: int = 1,
    ) -> list[Future]:
        tasks = []
        if timeout is not None:
            msg = "Timeout not implemented"
            raise NotImplementedError(msg)
        if chunksize != 1:
            msg = "Chunksize not implemented"
            raise NotImplementedError(msg)
        for args in zip(*iterables, strict=True):
            task = self.submit(fn, *args)
            tasks.append(task)
        return tasks

    def shutdown(
        self,
        wait: bool = True,  # noqa: FBT001, FBT002
        *,
        cancel_futures: bool = False,
    ) -> None:
        if not wait:
            msg = "Non-waiting shutdown not implemented"
            raise NotImplementedError(msg)
        if cancel_futures:
            msg = "Cancelling futures not implemented"
            raise NotImplementedError(msg)
        if self._run_manager is not None:
            self._run_manager.cancel()


class TaskID(NamedTuple):
    learner_index: int
    sequence_index: int


[docs] class SlurmTask(Future): """A `Future` that loads the result from a `SequenceLearner`.""" def __init__(self, executor: SlurmExecutor, task_id: TaskID) -> None: super().__init__() self.executor = executor self.task_id = task_id def _learner_index_and_local_index(self) -> tuple[int, int]: func_id, global_index = self.task_id try: learner_idx, local_index = self.executor._task_mapping[(func_id, global_index)] except KeyError as e: msg = "Task mapping not found; finalize() must be called first." raise RuntimeError(msg) from e return learner_idx, local_index def _get(self) -> Any | None: """Updates the state of the task and returns the result if the task is finished. This method is non-blocking and only checks in-memory data. The centralized file monitor is responsible for loading data from disk. """ if self.done(): return super().result(timeout=0) # Can't check anything if finalize() hasn't been called yet if self.executor._run_manager is None: return None learner_idx, local_index = self._learner_index_and_local_index() learner = self.executor._run_manager.learners[learner_idx] if local_index in learner.data: result = learner.data[local_index] self.set_result(result) return result return None @functools.cached_property def _learner_and_fname(self) -> tuple[SequenceLearner, str | Path]: idx_learner, _ = self.task_id run_manager = self.executor._run_manager assert run_manager is not None, "RunManager not initialized" learner: SequenceLearner = run_manager.learners[idx_learner] # type: ignore[index] fname = run_manager.fnames[idx_learner] return learner, fname
[docs] def result(self, timeout: float | None = None) -> Any: """Return the result of the future if available. Since this is an async task, this method will only return if the result is immediately available. Use `await task` to wait for the result. """ if timeout is not None: msg = "Timeout not implemented for SLURMTask" raise NotImplementedError(msg) if self.executor._run_manager is None: msg = "RunManager not initialized. Call finalize() first." raise RuntimeError(msg) # Do one check self._get() if not self.done(): msg = ( "Result not immediately available. " "Use 'await task' to wait for the result asynchronously." ) raise RuntimeError(msg) return super().result(timeout=0) # timeout=0 makes it non-blocking
def __repr__(self) -> str: if not self.done(): self._get() return f"SLURMTask(task_id={self.task_id}, state={self._state})" def __str__(self) -> str: return self.__repr__() def __await__(self) -> Any: """Allow using 'await task' to wait for the result.""" return asyncio.wrap_future(self).__await__()
def _uuid_with_datetime() -> str: """Return a UUID with the current datetime.""" # YYYYMMDD-HHMMSS-UUID return f"{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex}" # noqa: DTZ005 class _SerializableFunctionSplatter: def __init__(self, func: Callable[..., Any]) -> None: self.func = func def __call__(self, args: Any) -> Any: return self.func(*args) def __getstate__(self) -> bytes: return cloudpickle.dumps(self.func) def __setstate__(self, state: bytes) -> None: self.func = cloudpickle.loads(state)
[docs] @dataclass class SlurmExecutor(AdaptiveSchedulerExecutorBase): """An executor that runs jobs on SLURM. Similar to `concurrent.futures.Executor`, but for SLURM. A key difference is that ``submit()`` returns a `SLURMTask` instead of a `Future` and that ``finalize()`` must be called in order to start the jobs. Parameters ---------- name The name of the job. folder The folder to save the adaptive_scheduler files such as logs, database, ``.sbatch``, pickled tasks, and results files in. If the folder exists and has results, the results will be loaded! partition The partition to use. If None, then the default partition will be used. (The one marked with a * in `sinfo`). Use `adaptive_scheduler.scheduler.slurm_partitions` to see the available partitions. nodes The number of nodes to use. cores_per_node The number of cores per node to use. If None, then all cores on the partition will be used. memory Memory per job, e.g. ``"4GB"`` or ``"500MB"``. Adds ``--mem`` to the SBATCH options. num_threads The number of threads to use. exclusive Whether to use exclusive nodes, adds ``"--exclusive"`` if True. executor_type The executor that is used, by default `concurrent.futures.ProcessPoolExecutor` is used. One can use ``"ipyparallel"``, ``"dask-mpi"``, ``"mpi4py"``, ``"loky"``, ``"sequential"``, or ``"process-pool"``. extra_scheduler Extra ``#SLURM`` (depending on scheduler type) arguments, e.g. ``["--exclusive=user", "--time=1"]`` or a tuple of lists, e.g. ``(["--time=10"], ["--time=20"]])`` for two jobs. extra_env_vars Extra environment variables that are exported in the job script. e.g. ``["TMPDIR='/scratch'", "PYTHONPATH='my_dir:$PYTHONPATH'"]``. goal The goal passed to the `adaptive.Runner`. Note that this function will be serialized and pasted in the ``job_script``. Can be a smart-goal that accepts ``Callable[[adaptive.BaseLearner], bool] | float | datetime | timedelta | None``. See `adaptive_scheduler.utils.smart_goal` for more information. check_goal_on_start Checks whether a learner is already done. Only works if the learner is loaded. runner_kwargs Extra keyword argument to pass to the `adaptive.Runner`. Note that this dict will be serialized and pasted in the ``job_script``. url The url of the database manager, with the format ``tcp://ip_of_this_machine:allowed_port.``. If None, a correct url will be chosen. save_interval Time in seconds between saving of the learners. log_interval Time in seconds between log entries. job_manager_interval Time in seconds between checking and starting jobs. kill_interval Check for `kill_on_error` string inside the log-files every `kill_interval` seconds. kill_on_error If ``error`` is a string and is found in the log files, the job will be cancelled and restarted. If it is a callable, it is applied to the log text. Must take a single argument, a list of strings, and return True if the job has to be killed, or False if not. Set to None if no `KillManager` is needed. overwrite_db Overwrite the existing database. job_manager_kwargs Keyword arguments for the `JobManager` function that aren't set in ``__init__`` here. kill_manager_kwargs Keyword arguments for the `KillManager` function that aren't set in ``__init__`` here. loky_start_method Loky start method, by default "loky". cleanup_first Cancel all previous jobs generated by the same RunManager and clean logfiles. save_dataframe Whether to periodically save the learner's data as a `pandas.DataFame`. dataframe_format The format in which to save the `pandas.DataFame`. See the type hint for the options. max_log_lines The maximum number of lines to display in the log viewer widget. max_fails_per_job Maximum number of times that a job can fail. This is here as a fail switch because a job might fail instantly because of a bug inside your code. The job manager will stop when ``n_jobs * total_number_of_jobs_failed > max_fails_per_job`` is true. max_simultaneous_jobs Maximum number of simultaneously running jobs. By default no more than 500 jobs will be running. Keep in mind that if you do not specify a ``runner.goal``, jobs will run forever, resulting in the jobs that were not initially started (because of this `max_simultaneous_jobs` condition) to not ever start. quiet Whether to show a progress bar when creating learner files. extra_run_manager_kwargs Extra keyword arguments to pass to the `RunManager`. extra_scheduler_kwargs Extra keyword arguments to pass to the `adaptive_scheduler.scheduler.SLURM`. size_per_learner The size of each learner. If None, the whole sequence is passed to the learner. """ # Same as slurm_run, except it has no learners, fnames, dependencies and initializers. # slurm_run: Specific to slurm_run name: str = "adaptive-scheduler" folder: str | Path | None = None # `slurm_run` defaults to None # slurm_run: SLURM scheduler arguments partition: str | tuple[str | Callable[[], str], ...] | None = None nodes: int | tuple[int | None | Callable[[], int | None], ...] | None = 1 cores_per_node: int | tuple[int | None | Callable[[], int | None], ...] | None = ( 1 # `slurm_run` defaults to `None` ) memory: str | tuple[str | None | Callable[[], str | None], ...] | None = None num_threads: int | tuple[int | Callable[[], int], ...] = 1 exclusive: bool | tuple[bool | Callable[[], bool], ...] = False executor_type: EXECUTOR_TYPES | tuple[EXECUTOR_TYPES | Callable[[], EXECUTOR_TYPES], ...] = ( "process-pool" ) extra_scheduler: list[str] | tuple[list[str] | Callable[[], list[str]], ...] | None = None extra_env_vars: list[str] | tuple[list[str] | Callable[[], list[str]], ...] | None = None # slurm_run: Same as RunManager below (except dependencies and initializers) goal: GoalTypes | None = None check_goal_on_start: bool = True runner_kwargs: dict | None = None url: str | None = None save_interval: float = 300 log_interval: float = 300 job_manager_interval: float = 60 kill_interval: float = 60 kill_on_error: str | Callable[[list[str]], bool] | None = "srun: error:" overwrite_db: bool = True job_manager_kwargs: dict[str, Any] | None = None kill_manager_kwargs: dict[str, Any] | None = None loky_start_method: LOKY_START_METHODS = "loky" cleanup_first: bool = True save_dataframe: bool = False # `slurm_run` defaults to `True` dataframe_format: _DATAFRAME_FORMATS = "pickle" max_log_lines: int = 500 max_fails_per_job: int = 50 max_simultaneous_jobs: int = 100 quiet: bool = True # `slurm_run` defaults to `False` # slurm_run: RunManager arguments extra_run_manager_kwargs: dict[str, Any] | None = None extra_scheduler_kwargs: dict[str, Any] | None = None # Internal size_per_learner: int | None = None _sequences: dict[Callable[..., Any], list[Any]] = field(default_factory=dict) _sequence_mapping: dict[Callable[..., Any], int] = field(default_factory=dict) _disk_func_mapping: dict[Callable[..., Any], _DiskFunction] = field(default_factory=dict) _run_manager: adaptive_scheduler.RunManager | None = None _task_mapping: dict[tuple[int, int], tuple[int, int]] = field(default_factory=dict) _file_monitor_task: asyncio.Task | None = field(default=None, init=False, repr=False) _last_size: dict[int, float] = field(default_factory=dict, init=False, repr=False) _min_load_interval: dict[int, float] = field(default_factory=dict, init=False, repr=False) _pending_tasks: dict[int, list[SlurmTask]] = field(default_factory=dict, init=False, repr=False) _all_tasks: list[SlurmTask] = field(default_factory=list, init=False, repr=False) def __post_init__(self) -> None: if self.folder is None: self.folder = Path.cwd() / ".adaptive_scheduler" / _uuid_with_datetime() # type: ignore[operator] else: self.folder = Path(self.folder)
[docs] def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> SlurmTask: if kwargs: msg = "Keyword arguments are not supported" raise ValueError(msg) if fn not in self._sequence_mapping: self._sequence_mapping[fn] = len(self._sequence_mapping) assert fn not in self._disk_func_mapping assert isinstance(self.folder, Path) self._disk_func_mapping[fn] = _DiskFunction( fn, self.folder / f"{_name(fn)}-{uuid.uuid4().hex}.pickle", ) sequence = self._sequences.setdefault(fn, []) i = len(sequence) sequence.append(args) task_id = TaskID(self._sequence_mapping[fn], i) task = SlurmTask(self, task_id) self._all_tasks.append(task) return task
async def _monitor_files(self) -> None: """Single background task that monitors all learner files for changes.""" while self._run_manager is not None: with contextlib.suppress(Exception): await asyncio.sleep(1) if self._run_manager.task is not None and self._run_manager.task.cancelled(): break # Check each learner file that has pending tasks for learner_idx, learner in enumerate(self._run_manager.learners): await self._check_and_update_learner(learner_idx, learner) async def _check_and_update_learner(self, learner_idx: int, learner: SequenceLearner) -> None: """Check a learner file and update all pending tasks for that learner.""" if learner_idx not in self._pending_tasks: return # Filter out completed tasks and update with any results already in memory pending_tasks = _update_pending_tasks(self._pending_tasks[learner_idx], learner) self._pending_tasks[learner_idx] = pending_tasks if not pending_tasks: return # Check if we should load: timing, file existence, and size assert self._run_manager is not None last_load_time = self._run_manager._last_load_time.get(learner_idx, 0) min_interval = self._min_load_interval.get(learner_idx, 1.0) if time.monotonic() - last_load_time < min_interval: return fname = self._run_manager.fnames[learner_idx] try: size = await asyncio.to_thread(os.path.getsize, fname) except FileNotFoundError: return if self._last_size.get(learner_idx, 0) == size: return # Load file, update state, and update pending tasks load_start = time.monotonic() await asyncio.to_thread(learner.load, fname) load_time = time.monotonic() - load_start self._last_size[learner_idx] = size self._min_load_interval[learner_idx] = max(1.0, 20.0 * load_time) self._run_manager._last_load_time[learner_idx] = time.monotonic() self._pending_tasks[learner_idx] = _update_pending_tasks(pending_tasks, learner) def _register_task(self, task: SlurmTask) -> None: """Register a task to be monitored by the file monitor.""" learner_idx, _ = task._learner_index_and_local_index() if learner_idx not in self._pending_tasks: self._pending_tasks[learner_idx] = [] self._pending_tasks[learner_idx].append(task) def _to_learners( self, ) -> tuple[ list[SequenceLearner], list[Path], dict[tuple[int, int], tuple[int, int]], ]: learners = [] fnames = [] task_mapping = {} learner_idx = 0 for func, args_list in self._sequences.items(): func_id = self._sequence_mapping[func] # Chunk the sequence if size_per_learner is set; otherwise one chunk. if self.size_per_learner is not None: chunked_args = [ args_list[i : i + self.size_per_learner] for i in range(0, len(args_list), self.size_per_learner) ] else: chunked_args = [args_list] global_index = 0 # global index for tasks of this function for chunk in chunked_args: # Map each task in the chunk: global index -> (current learner, local index) for local_index in range(len(chunk)): task_mapping[(func_id, global_index)] = (learner_idx, local_index) global_index += 1 disk_func = self._disk_func_mapping[func] ser_func = _SerializableFunctionSplatter(disk_func) learner = SequenceLearner(ser_func, chunk) learners.append(learner) name = _name(func) assert isinstance(self.folder, Path) fnames.append(self.folder / f"{name}-{learner_idx}-{uuid.uuid4().hex}.pickle") learner_idx += 1 return learners, fnames, task_mapping
[docs] def finalize(self, *, start: bool = True) -> adaptive_scheduler.RunManager | None: if self._run_manager is not None: msg = "RunManager already initialized. Create a new SlurmExecutor instance." raise RuntimeError(msg) learners, fnames, self._task_mapping = self._to_learners() if not learners: return None assert self.folder is not None self._run_manager = adaptive_scheduler.slurm_run( learners=learners, fnames=fnames, # Specific to slurm_run name=self.name, folder=self.folder, # SLURM scheduler arguments partition=self.partition, nodes=self.nodes, cores_per_node=self.cores_per_node, memory=self.memory, num_threads=self.num_threads, exclusive=self.exclusive, executor_type=self.executor_type, extra_scheduler=self.extra_scheduler, extra_env_vars=self.extra_env_vars, # Same as RunManager below (except job_name, move_old_logs_to, and db_fname) goal=self.goal, check_goal_on_start=self.check_goal_on_start, runner_kwargs=self.runner_kwargs, url=self.url, save_interval=self.save_interval, log_interval=self.log_interval, job_manager_interval=self.job_manager_interval, kill_interval=self.kill_interval, kill_on_error=self.kill_on_error, overwrite_db=self.overwrite_db, job_manager_kwargs=self.job_manager_kwargs, kill_manager_kwargs=self.kill_manager_kwargs, loky_start_method=self.loky_start_method, cleanup_first=self.cleanup_first, save_dataframe=self.save_dataframe, dataframe_format=self.dataframe_format, max_log_lines=self.max_log_lines, max_fails_per_job=self.max_fails_per_job, max_simultaneous_jobs=self.max_simultaneous_jobs, quiet=self.quiet, # RunManager arguments extra_run_manager_kwargs=self.extra_run_manager_kwargs, extra_scheduler_kwargs=self.extra_scheduler_kwargs, ) # Register all tasks now that task mapping is available for task in self._all_tasks: self._register_task(task) # Start the file monitoring task self._file_monitor_task = asyncio.create_task(self._monitor_files()) if start: self._run_manager.start() return self._run_manager
[docs] def cleanup(self) -> None: assert self._run_manager is not None self._run_manager.cleanup(remove_old_logs_folder=True)
[docs] def shutdown( self, wait: bool = True, # noqa: FBT001, FBT002 *, cancel_futures: bool = False, ) -> None: if self._file_monitor_task is not None: self._file_monitor_task.cancel() self._file_monitor_task = None super().shutdown(wait, cancel_futures=cancel_futures)
[docs] def new(self, update: dict[str, Any] | None = None) -> SlurmExecutor: """Create a new SlurmExecutor with the same parameters.""" data = {} for key in SlurmExecutor.__dataclass_fields__: if key.startswith("_"): continue data[key] = copy.deepcopy(getattr(self, key)) if update is not None: data.update(update) return SlurmExecutor(**data)
def _update_pending_tasks( pending_tasks: list[SlurmTask], learner: SequenceLearner, ) -> list[SlurmTask]: """Update pending tasks by filtering out done tasks and setting results for completed ones.""" new_pending_tasks = [] for task in pending_tasks: if task.done(): continue # Filter out already done tasks _, local_idx = task._learner_index_and_local_index() if local_idx in learner.data: task.set_result(learner.data[local_idx]) else: new_pending_tasks.append(task) return new_pending_tasks def _name(func: Callable[..., Any]) -> str: return func.__name__ if hasattr(func, "__name__") else "func" class _DiskFunction: def __init__(self, func: Callable[..., Any], fname: str | Path) -> None: self.fname = Path(fname) self.fname.parent.mkdir(parents=True, exist_ok=True) with self.fname.open("wb") as f: cloudpickle.dump(func, f) def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.func(*args, **kwargs) @functools.cached_property def func(self) -> Callable[..., Any]: return cloudpickle.loads(self.fname.read_bytes())