"""The DatabaseManager."""
from __future__ import annotations
import asyncio
import json
import pickle
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
import pandas as pd
import zmq
import zmq.asyncio
import zmq.ssh
from adaptive_scheduler.utils import (
_deserialize,
_now,
_serialize,
cloudpickle_learners,
)
from .base_manager import BaseManager
from .common import log
if TYPE_CHECKING:
from collections.abc import Callable
import adaptive
from adaptive_scheduler.scheduler import BaseScheduler
ctx = zmq.asyncio.Context()
FnameType = str | Path | list[str] | list[Path]
FnamesTypes = list[str] | list[Path] | list[list[str]] | list[list[Path]]
class JobIDExistsInDbError(Exception):
"""Raised when a job id already exists in the database."""
def _ensure_str(
fnames: str | Path | FnamesTypes,
) -> str | list[str] | list[list[str]]:
"""Make sure that `pathlib.Path`s are converted to strings."""
if isinstance(fnames, str | Path):
return str(fnames)
if isinstance(fnames, list | tuple):
if len(fnames) == 0:
return [] # type: ignore[return-value]
if isinstance(fnames[0], str | Path):
return [str(f) for f in fnames]
if isinstance(fnames[0], list):
return [[str(f) for f in sublist] for sublist in fnames] # type: ignore[union-attr]
msg = (
"Invalid input: expected a string/Path, or list of"
" strings/Paths, a list of lists of strings/Paths."
)
raise ValueError(msg)
@dataclass
class _DBEntry:
fname: str | list[str]
job_id: str | None = None
is_pending: bool = False
is_done: bool = False
log_fname: str | None = None
job_name: str | None = None
output_logs: list[str] = field(default_factory=list)
start_time: float | None = None
depends_on: list[int] = field(default_factory=list)
def _asdict_fast(entry: _DBEntry) -> dict[str, Any]:
"""Fast version of `dataclasses.asdict` for `_DBEntry`.
About 10x faster than `asdict`, which surprisingly is a bottleneck.
"""
return {
"fname": entry.fname,
"job_id": entry.job_id,
"is_pending": entry.is_pending,
"is_done": entry.is_done,
"log_fname": entry.log_fname,
"job_name": entry.job_name,
"output_logs": entry.output_logs,
"start_time": entry.start_time,
"depends_on": entry.depends_on,
}
class SimpleDatabase:
def __init__(self, db_fname: str | Path, *, clear_existing: bool = False) -> None:
self.db_fname = Path(db_fname)
self._data: list[_DBEntry] = []
self._meta: dict[str, Any] = {}
self._save_task: asyncio.Task | None = None
self._last_save_time: float = 0.0
if self.db_fname.exists():
if clear_existing:
self.db_fname.unlink()
else:
with self.db_fname.open() as f:
raw_data = json.load(f)
self._data = [_DBEntry(**entry) for entry in raw_data["data"]]
def all(self) -> list[_DBEntry]:
return self._data
def insert_multiple(self, entries: list[_DBEntry]) -> None:
self._data.extend(entries)
self._save_now()
def update(self, update_dict: dict, indices: list[int] | None = None) -> None:
for index, entry in enumerate(self._data):
if indices is None or index in indices:
for key, value in update_dict.items():
assert hasattr(entry, key)
setattr(entry, key, value)
self._save_debounced()
def count(self, condition: Callable[[_DBEntry], bool]) -> int:
return sum(1 for entry in self._data if condition(entry))
def get_with_index(
self,
condition: Callable[[_DBEntry], bool],
) -> tuple[int, _DBEntry] | None:
for index, entry in enumerate(self._data):
if condition(entry):
return index, entry
return None
def get(self, condition: Callable[[_DBEntry], bool]) -> _DBEntry | None:
index_entry = self.get_with_index(condition)
if index_entry is None:
return None
_, entry = index_entry
return entry
def get_all(
self,
condition: Callable[[_DBEntry], bool],
) -> list[tuple[int, _DBEntry]]:
return [(i, entry) for i, entry in enumerate(self._data) if condition(entry)]
def contains(self, condition: Callable[[_DBEntry], bool]) -> bool:
return any(condition(entry) for entry in self._data)
def as_dicts(self) -> list[dict[str, Any]]:
return [_asdict_fast(entry) for entry in self._data]
def _cancel_save_task(self) -> None:
if self._save_task is not None:
self._save_task.cancel()
self._save_task = None
def _save_debounced(self, delay: float = 2.0) -> None:
"""Debounced save to prevent excessive disk I/O during rapid updates.
This implements a throttling mechanism that ensures saves don't happen more
frequently than once per `delay` seconds. The logic is:
1. If enough time has passed since the last save (> delay), save immediately
2. If the last save was recent (< delay), schedule a delayed save that will
execute after the remaining time in the delay period
3. Each new call cancels any pending save and reschedules, ensuring that
rapid successive calls result in only one final save
This is particularly important for database operations during high-load
scenarios where many jobs might be updating the database simultaneously.
Parameters
----------
delay : float
Minimum time in seconds between saves. Defaults to 2.0 seconds.
Notes
-----
Falls back to immediate save if no asyncio event loop is running,
making this method safe to call from both sync and async contexts.
"""
time_since_last_save = time.monotonic() - self._last_save_time
if time_since_last_save >= delay: # Enough time has passed, save immediately
self._save_now()
return
# Last save was recent, schedule a delayed save for the remaining time
remaining_delay = delay - time_since_last_save
async def delayed_save() -> None:
await asyncio.sleep(remaining_delay)
self._save_now()
self._cancel_save_task()
try:
self._save_task = asyncio.create_task(delayed_save())
except RuntimeError:
# No event loop running (e.g., in tests or sync context)
# Fall back to immediate save to ensure data isn't lost
self._save_now()
def _save_now(self) -> None:
"""Immediately save to disk."""
self._cancel_save_task()
self._last_save_time = time.monotonic()
with self.db_fname.open("w") as f:
json.dump({"data": self.as_dicts(), "meta": self._meta}, f, indent=4)
def close(self) -> None:
"""Clean up and save immediately."""
self._save_now() # Cancels any pending save task
def dependencies_satisfied(self, entry: _DBEntry) -> bool:
return all(self._data[i].is_done for i in entry.depends_on)
[docs]
class DatabaseManager(BaseManager):
"""Database manager.
Parameters
----------
url
The url of the database manager, with the format
``tcp://ip_of_this_machine:allowed_port.``. Use `get_allowed_url`
to get a `url` that will work.
scheduler
A scheduler instance from `adaptive_scheduler.scheduler`.
db_fname
Filename of the database, e.g. 'running.json'.
learners
List of `learners` corresponding to `fnames`.
fnames
List of `fnames` corresponding to `learners`.
dependencies
Dictionary of dependencies, e.g., ``{1: [0]}`` means that the ``learners[1]``
depends on the ``learners[0]``. This means that the ``learners[1]`` will only
start when the ``learners[0]`` is done.
overwrite_db
Overwrite the existing database upon starting.
initializers
List of functions that are called before the job starts, can populate
a cache.
Attributes
----------
failed : list
A list of entries that have failed and have been removed from the database.
"""
def __init__(
self,
url: str,
scheduler: BaseScheduler,
db_fname: str | Path,
learners: list[adaptive.BaseLearner],
fnames: FnamesTypes,
*,
dependencies: dict[int, list[int]] | None = None,
overwrite_db: bool = True,
initializers: list[Callable[[], None]] | None = None,
with_progress_bar: bool = True,
) -> None:
super().__init__()
self.url = url
self.scheduler = scheduler
self.db_fname = Path(db_fname)
self.learners = learners
self.fnames = fnames
self.dependencies = dependencies or {}
self.overwrite_db = overwrite_db
self.initializers = initializers
self.with_progress_bar = with_progress_bar
self._last_reply: str | list[str] | Exception | None = None
self._last_request: tuple[str, ...] | None = None
self.failed: list[dict[str, Any]] = []
self._pickling_time: float | None = None
self._total_learner_size: int | None = None
self._db: SimpleDatabase | None = None
def _setup(self) -> None:
if self.db_fname.exists() and not self.overwrite_db:
return
self.create_empty_db()
self._total_learner_size, self._pickling_time = cloudpickle_learners(
self.learners,
self.fnames,
initializers=self.initializers,
with_progress_bar=self.with_progress_bar,
)
[docs]
def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
"""If the ``job_id`` isn't running anymore, replace it with None."""
if self._db is None:
return
if queue is None:
queue = self.scheduler.queue(me_only=True)
job_names_in_queue = [x["job_name"] for x in queue.values()]
failed = self._db.get_all(
lambda e: e.job_name is not None and e.job_name not in job_names_in_queue, # type: ignore[operator]
)
self.failed.extend([_asdict_fast(entry) for _, entry in failed])
indices = [index for index, _ in failed]
self._db.update(
{"job_id": None, "job_name": None, "is_pending": False},
indices,
)
[docs]
def n_done(self) -> int:
"""Return the number of jobs that are done."""
if self._db is None:
return 0
return self._db.count(lambda e: e.is_done)
[docs]
def n_unscheduled(self) -> int:
"""Return the number of jobs that are not scheduled."""
if self._db is None:
return 0
return self._db.count(lambda e: not e.is_done and not e.is_pending)
[docs]
def is_done(self) -> bool:
"""Return True if all jobs are done."""
return self.n_done() == len(self.fnames)
[docs]
def create_empty_db(self) -> None:
"""Create an empty database.
It keeps track of ``fname -> (job_id, is_done, log_fname, job_name)``.
"""
deps = self.dependencies
entries: list[_DBEntry] = [
_DBEntry(fname=fname, depends_on=deps.get(i, [])) # type: ignore[arg-type]
for i, fname in enumerate(_ensure_str(self.fnames))
]
if self.db_fname.exists():
self.db_fname.unlink()
self._db = SimpleDatabase(self.db_fname)
self._db.insert_multiple(entries)
[docs]
def as_dicts(self) -> list[dict[str, str]]:
"""Return the database as a list of dictionaries."""
if self._db is None:
return []
return self._db.as_dicts()
[docs]
def as_df(self) -> pd.DataFrame:
"""Return the database as a `pandas.DataFrame`."""
return pd.DataFrame(self.as_dicts())
def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
job_id = self.scheduler.sanatize_job_id(job_id)
output_fnames = self.scheduler.output_fnames(job_name)
return [
f.with_name(f.name.replace(self.scheduler._JOB_ID_VARIABLE, job_id))
for f in output_fnames
]
def _choose_fname(self) -> tuple[int, str | list[str] | None]:
assert self._db is not None
entry = self._db.get(
lambda e: e.job_id is None
and not e.is_done
and not e.is_pending
and self._db.dependencies_satisfied(e), # type: ignore[union-attr]
)
if all(e.is_done for e in self._db.all()):
msg = "Requested a new job but no more learners to run in the database."
raise RuntimeError(msg)
if entry is None:
# Currently, we cannot schedule any more jobs, because we're waiting
# for dependencies to be satisfied.
return -1, None
log.debug("choose fname", entry=entry)
index = self._db.all().index(entry)
return index, _ensure_str(entry.fname) # type: ignore[return-value]
def _confirm_submitted(self, index: int, job_name: str) -> None:
assert self._db is not None
self._db.update(
{
"job_name": job_name,
"is_pending": True,
},
indices=[index],
)
def _start_request(
self,
job_id: str,
log_fname: str,
job_name: str,
) -> str | list[str] | None:
assert self._db is not None
if self._db.contains(lambda e: e.job_id == job_id):
entry = self._db.get(lambda e: e.job_id == job_id)
assert entry is not None
fname = entry.fname # already running
msg = (
f"The job_id {job_id} already exists in the database and "
f"runs {fname}. You might have forgotten to use the "
"`if __name__ == '__main__': ...` idiom in your code. Read the "
"warning in the [mpi4py](https://bit.ly/2HAk0GG) documentation.",
)
raise JobIDExistsInDbError(msg)
entry = self._db.get(lambda e: e.job_name == job_name and e.is_pending)
log.debug("choose fname", entry=entry)
if entry is None:
return None
index = self._db.all().index(entry)
self._db.update(
{
"job_id": job_id,
"log_fname": log_fname,
"output_logs": _ensure_str(self._output_logs(job_id, job_name)),
"start_time": _now(),
"is_pending": False,
},
indices=[index],
)
return _ensure_str(entry.fname) # type: ignore[return-value]
def _stop_request(self, fname: str | list[str] | Path | list[Path]) -> None:
fname_str = _ensure_str(fname)
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
assert self._db is not None
entry_indices = [index for index, _ in self._db.get_all(lambda e: e.fname == fname_str)]
self._db.update(reset, entry_indices)
def _stop_requests(self, fnames: FnamesTypes) -> None:
# Same as `_stop_request` but optimized for processing many `fnames` at once
assert self._db is not None
fnames_str = {str(fname) for fname in _ensure_str(fnames)}
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
entry_indices = [
index for index, _ in self._db.get_all(lambda e: str(e.fname) in fnames_str)
]
self._db.update(reset, entry_indices)
def _dispatch(
self,
request: tuple[str, str | list[str]] | tuple[str],
) -> str | list[str] | Exception | None:
request_type, *request_arg = request
log.debug("got a request", request=request)
try:
if request_type == "start":
# workers send us their slurm ID for us to fill in
job_id, log_fname, job_name = request_arg
# give the worker a job and send back the fname to the worker
fname = self._start_request(job_id, log_fname, job_name) # type: ignore[arg-type]
if fname is None:
# This should never happen because the _manage co-routine
# should have stopped the workers before this happens.
msg = "No more learners to run in the database."
raise RuntimeError(msg) # noqa: TRY301
log.debug(
"choose a fname",
fname=fname,
job_id=job_id,
log_fname=log_fname,
job_name=job_name,
)
return fname
if request_type == "stop":
fname = request_arg[0] # workers send us the fname they were given
log.debug("got a stop request", fname=fname)
self._stop_request(fname) # reset the job_id to None
return None
except Exception as e: # noqa: BLE001
return e
msg = f"Unknown request type: {request_type}"
raise ValueError(msg)
async def _manage(self) -> None:
"""Database manager co-routine.
Returns
-------
coroutine
"""
log.debug("started database")
socket = ctx.socket(zmq.REP)
socket.bind(self.url)
try:
while True:
try:
self._last_request = await socket.recv_serialized(_deserialize)
except zmq.error.Again:
log.exception(
"socket.recv_serialized failed in the DatabaseManager"
" with `zmq.error.Again`.",
)
except pickle.UnpicklingError as e:
if r"\x03" in str(e):
# Empty frame received.
# TODO: not sure why this happens
pass
else:
log.exception(
"socket.recv_serialized failed in the DatabaseManager"
" with `pickle.UnpicklingError` in _deserialize.",
)
else:
assert self._last_request is not None # for mypy
self._last_reply = self._dispatch(self._last_request) # type: ignore[arg-type]
await socket.send_serialized(self._last_reply, _serialize)
if self.is_done():
break
finally:
socket.close()
# Ensure any pending database changes are saved
if self._db is not None:
self._db.close()
[docs]
def replace_learner(self, index: int, new_learner: adaptive.BaseLearner) -> None:
"""Replace a learner and update the corresponding database entry and cloudpickled file.
Parameters
----------
index
The index of the learner to replace.
new_learner
The new learner to replace the old one.
"""
if index < 0 or index >= len(self.learners):
msg = "Index out of range"
raise IndexError(msg)
fname = self.fnames[index]
# Update the database entry
assert self._db is not None
index_entry = self._db.get_with_index(lambda e: e.fname == _ensure_str(fname))
assert index_entry is not None
index, entry = index_entry
if entry.is_done:
msg = f"Learner at index {index} is already done and cannot be replaced."
raise ValueError(msg)
assert not entry.is_pending
assert entry.job_id is None
# Replace the learner in the list
self.learners[index] = new_learner
# Cloudpickle the new learner
cloudpickle_learners(
[new_learner],
[fname], # type: ignore[arg-type]
)
# Note that self._total_learner_size and self._pickling_time are
# not updated now! But we don't care about that.
log.debug(f"Replaced learner at index {index} with a new learner")
[docs]
def cancel(self) -> bool | None:
"""Cancel the database manager and clean up resources."""
result = super().cancel()
if self._db is not None:
self._db.close()
return result