"""The DatabaseManager."""
from __future__ import annotations
import json
import pickle
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Union
import pandas as pd
import zmq
import zmq.asyncio
import zmq.ssh
from adaptive_scheduler.utils import (
_deserialize,
_now,
_serialize,
cloudpickle_learners,
)
from .base_manager import BaseManager
from .common import log
if TYPE_CHECKING:
import adaptive
from adaptive_scheduler.scheduler import BaseScheduler
ctx = zmq.asyncio.Context()
FnameType = Union[str, Path, List[str], List[Path]]
FnamesTypes = Union[List[str], List[Path], List[List[str]], List[List[Path]]]
class JobIDExistsInDbError(Exception):
"""Raised when a job id already exists in the database."""
def _ensure_str(
fnames: str | Path | FnamesTypes,
) -> str | list[str] | list[list[str]]:
"""Make sure that `pathlib.Path`s are converted to strings."""
if isinstance(fnames, (str, Path)):
return str(fnames)
if isinstance(fnames, (list, tuple)):
if len(fnames) == 0:
return [] # type: ignore[return-value]
if isinstance(fnames[0], (str, Path)):
return [str(f) for f in fnames]
if isinstance(fnames[0], list):
return [[str(f) for f in sublist] for sublist in fnames] # type: ignore[union-attr]
msg = (
"Invalid input: expected a string/Path, or list of"
" strings/Paths, a list of lists of strings/Paths."
)
raise ValueError(msg)
@dataclass
class _DBEntry:
fname: str | list[str]
job_id: str | None = None
is_pending: bool = False
is_done: bool = False
log_fname: str | None = None
job_name: str | None = None
output_logs: list[str] = field(default_factory=list)
start_time: float | None = None
class SimpleDatabase:
def __init__(self, db_fname: str | Path, *, clear_existing: bool = False) -> None:
self.db_fname = Path(db_fname)
self._data: list[_DBEntry] = []
self._meta: dict[str, Any] = {}
if self.db_fname.exists():
if clear_existing:
self.db_fname.unlink()
else:
with self.db_fname.open() as f:
raw_data = json.load(f)
self._data = [_DBEntry(**entry) for entry in raw_data["data"]]
def all(self) -> list[_DBEntry]:
return self._data
def insert_multiple(self, entries: list[_DBEntry]) -> None:
self._data.extend(entries)
self._save()
def update(self, update_dict: dict, indices: list[int] | None = None) -> None:
for index, entry in enumerate(self._data):
if indices is None or index in indices:
for key, value in update_dict.items():
assert hasattr(entry, key)
setattr(entry, key, value)
self._save()
def count(self, condition: Callable[[_DBEntry], bool]) -> int:
return sum(1 for entry in self._data if condition(entry))
def get(self, condition: Callable[[_DBEntry], bool]) -> _DBEntry | None:
for entry in self._data:
if condition(entry):
return entry
return None
def get_all(
self,
condition: Callable[[_DBEntry], bool],
) -> list[tuple[int, _DBEntry]]:
return [(i, entry) for i, entry in enumerate(self._data) if condition(entry)]
def contains(self, condition: Callable[[_DBEntry], bool]) -> bool:
return any(condition(entry) for entry in self._data)
def as_dicts(self) -> list[dict[str, Any]]:
return [asdict(entry) for entry in self._data]
def _save(self) -> None:
with self.db_fname.open("w") as f:
json.dump({"data": self.as_dicts(), "meta": self._meta}, f, indent=4)
[docs]
class DatabaseManager(BaseManager):
"""Database manager.
Parameters
----------
url
The url of the database manager, with the format
``tcp://ip_of_this_machine:allowed_port.``. Use `get_allowed_url`
to get a `url` that will work.
scheduler
A scheduler instance from `adaptive_scheduler.scheduler`.
db_fname
Filename of the database, e.g. 'running.json'.
learners
List of `learners` corresponding to `fnames`.
fnames
List of `fnames` corresponding to `learners`.
overwrite_db
Overwrite the existing database upon starting.
initializers
List of functions that are called before the job starts, can populate
a cache.
Attributes
----------
failed : list
A list of entries that have failed and have been removed from the database.
"""
def __init__(
self,
url: str,
scheduler: BaseScheduler,
db_fname: str | Path,
learners: list[adaptive.BaseLearner],
fnames: FnamesTypes,
*,
overwrite_db: bool = True,
initializers: list[Callable[[], None]] | None = None,
) -> None:
super().__init__()
self.url = url
self.scheduler = scheduler
self.db_fname = Path(db_fname)
self.learners = learners
self.fnames = fnames
self.overwrite_db = overwrite_db
self.initializers = initializers
self._last_reply: str | list[str] | Exception | None = None
self._last_request: tuple[str, ...] | None = None
self.failed: list[dict[str, Any]] = []
self._pickling_time: float | None = None
self._total_learner_size: int | None = None
self._db: SimpleDatabase | None = None
def _setup(self) -> None:
if self.db_fname.exists() and not self.overwrite_db:
return
self.create_empty_db()
self._total_learner_size, self._pickling_time = cloudpickle_learners(
self.learners,
self.fnames,
initializers=self.initializers,
with_progress_bar=True,
)
[docs]
def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
"""If the ``job_id`` isn't running anymore, replace it with None."""
if self._db is None:
return
if queue is None:
queue = self.scheduler.queue(me_only=True)
job_names_in_queue = [x["job_name"] for x in queue.values()]
failed = self._db.get_all(
lambda e: e.job_name is not None and e.job_name not in job_names_in_queue, # type: ignore[operator]
)
self.failed.extend([asdict(entry) for _, entry in failed])
indices = [index for index, _ in failed]
self._db.update(
{"job_id": None, "job_name": None, "is_pending": False},
indices,
)
[docs]
def n_done(self) -> int:
"""Return the number of jobs that are done."""
if self._db is None:
return 0
return self._db.count(lambda e: e.is_done)
[docs]
def is_done(self) -> bool:
"""Return True if all jobs are done."""
return self.n_done() == len(self.fnames)
[docs]
def create_empty_db(self) -> None:
"""Create an empty database.
It keeps track of ``fname -> (job_id, is_done, log_fname, job_name)``.
"""
entries: list[_DBEntry] = [
_DBEntry(fname=fname) for fname in _ensure_str(self.fnames)
]
if self.db_fname.exists():
self.db_fname.unlink()
self._db = SimpleDatabase(self.db_fname)
self._db.insert_multiple(entries)
[docs]
def as_dicts(self) -> list[dict[str, str]]:
"""Return the database as a list of dictionaries."""
if self._db is None:
return []
return self._db.as_dicts()
[docs]
def as_df(self) -> pd.DataFrame:
"""Return the database as a `pandas.DataFrame`."""
return pd.DataFrame(self.as_dicts())
def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
job_id = self.scheduler.sanatize_job_id(job_id)
output_fnames = self.scheduler.output_fnames(job_name)
return [
f.with_name(f.name.replace(self.scheduler._JOB_ID_VARIABLE, job_id))
for f in output_fnames
]
def _choose_fname(self) -> tuple[int, str | list[str] | None]:
assert self._db is not None
entry = self._db.get(
lambda e: e.job_id is None and not e.is_done and not e.is_pending,
)
if entry is None:
msg = "Requested a new job but no more learners to run in the database."
raise RuntimeError(msg)
log.debug("choose fname", entry=entry)
index = self._db.all().index(entry)
return index, _ensure_str(entry.fname) # type: ignore[return-value]
def _confirm_submitted(self, index: int, job_name: str) -> None:
assert self._db is not None
self._db.update(
{
"job_name": job_name,
"is_pending": True,
},
indices=[index],
)
def _start_request(
self,
job_id: str,
log_fname: str,
job_name: str,
) -> str | list[str] | None:
assert self._db is not None
if self._db.contains(lambda e: e.job_id == job_id):
entry = self._db.get(lambda e: e.job_id == job_id)
assert entry is not None
fname = entry.fname # already running
msg = (
f"The job_id {job_id} already exists in the database and "
f"runs {fname}. You might have forgotten to use the "
"`if __name__ == '__main__': ...` idiom in your code. Read the "
"warning in the [mpi4py](https://bit.ly/2HAk0GG) documentation.",
)
raise JobIDExistsInDbError(msg)
entry = self._db.get(lambda e: e.job_name == job_name and e.is_pending)
log.debug("choose fname", entry=entry)
if entry is None:
return None
index = self._db.all().index(entry)
self._db.update(
{
"job_id": job_id,
"log_fname": log_fname,
"output_logs": _ensure_str(self._output_logs(job_id, job_name)),
"start_time": _now(),
"is_pending": False,
},
indices=[index],
)
return _ensure_str(entry.fname) # type: ignore[return-value]
def _stop_request(self, fname: str | list[str] | Path | list[Path]) -> None:
fname_str = _ensure_str(fname)
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
assert self._db is not None
entry_indices = [
index for index, _ in self._db.get_all(lambda e: e.fname == fname_str)
]
self._db.update(reset, entry_indices)
def _stop_requests(self, fnames: FnamesTypes) -> None:
# Same as `_stop_request` but optimized for processing many `fnames` at once
assert self._db is not None
fnames_str = {str(fname) for fname in _ensure_str(fnames)}
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
entry_indices = [
index for index, _ in self._db.get_all(lambda e: str(e.fname) in fnames_str)
]
self._db.update(reset, entry_indices)
def _dispatch(
self,
request: tuple[str, str | list[str]] | tuple[str],
) -> str | list[str] | Exception | None:
request_type, *request_arg = request
log.debug("got a request", request=request)
try:
if request_type == "start":
# workers send us their slurm ID for us to fill in
job_id, log_fname, job_name = request_arg
# give the worker a job and send back the fname to the worker
fname = self._start_request(job_id, log_fname, job_name) # type: ignore[arg-type]
if fname is None:
# This should never happen because the _manage co-routine
# should have stopped the workers before this happens.
msg = "No more learners to run in the database."
raise RuntimeError(msg) # noqa: TRY301
log.debug(
"choose a fname",
fname=fname,
job_id=job_id,
log_fname=log_fname,
job_name=job_name,
)
return fname
if request_type == "stop":
fname = request_arg[0] # workers send us the fname they were given
log.debug("got a stop request", fname=fname)
self._stop_request(fname) # reset the job_id to None
return None
except Exception as e: # noqa: BLE001
return e
msg = f"Unknown request type: {request_type}"
raise ValueError(msg)
async def _manage(self) -> None:
"""Database manager co-routine.
Returns
-------
coroutine
"""
log.debug("started database")
socket = ctx.socket(zmq.REP)
socket.bind(self.url)
try:
while True:
try:
self._last_request = await socket.recv_serialized(_deserialize)
except zmq.error.Again:
log.exception(
"socket.recv_serialized failed in the DatabaseManager"
" with `zmq.error.Again`.",
)
except pickle.UnpicklingError as e:
if r"\x03" in str(e):
# Empty frame received.
# TODO: not sure why this happens
pass
else:
log.exception(
"socket.recv_serialized failed in the DatabaseManager"
" with `pickle.UnpicklingError` in _deserialize.",
)
else:
assert self._last_request is not None # for mypy
self._last_reply = self._dispatch(self._last_request) # type: ignore[arg-type]
await socket.send_serialized(self._last_reply, _serialize)
if self.is_done():
break
finally:
socket.close()