"""Client support for Adaptive Scheduler."""
from __future__ import annotations
import datetime
import json
import logging
import os
import socket
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable
import psutil
import structlog
import zmq
from adaptive_scheduler.utils import (
_deserialize,
_get_npoints,
_serialize,
fname_to_learner,
log_exception,
sleep_unless_task_is_done,
)
if TYPE_CHECKING:
import argparse
import asyncio
from pathlib import Path
from adaptive import AsyncRunner, BaseLearner
def _dumps(event_dict: dict[str, Any], **kwargs: Any) -> str:
"""Custom json.dumps to ensure 'event' key is always first in the JSON output."""
event = event_dict.pop("event", None)
return json.dumps({"event": event, **event_dict}, **kwargs)
ctx = zmq.Context()
logger = logging.getLogger("adaptive_scheduler.client")
logger.setLevel(logging.INFO)
log = structlog.wrap_logger(
logger,
processors=[
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M.%S", utc=False),
structlog.processors.JSONRenderer(serializer=_dumps),
],
)
[docs]
def add_log_file_handler(log_fname: str | Path) -> None: # pragma: no cover
"""Add a file handler to the logger."""
fh = logging.FileHandler(log_fname)
logger.addHandler(fh)
[docs]
def get_learner(
url: str,
log_fname: str,
job_id: str,
job_name: str,
) -> tuple[BaseLearner, str | list[str], Callable[[], None] | None]:
"""Get a learner from the database (running at `url`).
This learner's process will be logged in `log_fname`
and running under `job_id`.
Parameters
----------
url
The url of the database manager running via
(`adaptive_scheduler.server_support.manage_database`).
log_fname
The filename of the log-file. Should be passed in the job-script.
job_id
The job_id of the process the job. Should be passed in the job-script.
job_name
The name of the job. Should be passed in the job-script.
Returns
-------
learner
Learner that is chosen.
fname
The filename of the learner that was chosen.
initializer
A function that runs before the process is forked.
"""
log.info(
"trying to get learner",
job_id=job_id,
log_fname=log_fname,
job_name=job_name,
)
with ctx.socket(zmq.REQ) as socket:
socket.setsockopt(zmq.LINGER, 0)
socket.setsockopt(zmq.SNDTIMEO, 300_000) # timeout after 300s
socket.connect(url)
socket.send_serialized(("start", job_id, log_fname, job_name), _serialize)
log.info("sent start signal, going to wait 300s for a reply.")
socket.setsockopt(zmq.RCVTIMEO, 300_000) # timeout after 300s
reply = socket.recv_serialized(_deserialize)
log.info("got reply", reply=str(reply))
if reply is None:
msg = "No learners to be run."
exception = RuntimeError(msg)
log_exception(log, msg, exception)
raise exception
if isinstance(reply, Exception):
log_exception(log, "got an exception", exception=reply)
raise reply
fname = reply
learner, initializer = fname_to_learner(fname, return_initializer=True)
log.info("got fname and loaded learner")
log.info("picked a learner")
return learner, fname, initializer
[docs]
def tell_done(url: str, fname: str | list[str]) -> None:
"""Tell the database that the learner has reached it's goal.
Parameters
----------
url
The url of the database manager running via
(`adaptive_scheduler.server_support.manage_database`).
fname
The filename of the learner that is done.
"""
log.info("goal reached! 🎉🎊🥳")
with ctx.socket(zmq.REQ) as socket:
socket.setsockopt(zmq.LINGER, 0)
socket.connect(url)
socket.setsockopt(zmq.SNDTIMEO, 300_000) # timeout after 300s
socket.send_serialized(("stop", fname), _serialize)
socket.setsockopt(zmq.RCVTIMEO, 300_000) # timeout after 300s
log.info("sent stop signal, going to wait 300s for a reply", fname=fname)
socket.recv_serialized(_deserialize) # Needed because of socket type
def _get_log_entry(runner: AsyncRunner, npoints_start: int) -> dict[str, Any]:
learner = runner.learner
info: dict[str, float | str] = {}
Δt = datetime.timedelta(seconds=runner.elapsed_time()) # noqa: N806
info["elapsed_time"] = str(Δt)
info["overhead"] = runner.overhead()
npoints = _get_npoints(learner)
if npoints is not None:
info["npoints"] = npoints
Δnpoints = npoints - npoints_start # noqa: N806
with suppress(ZeroDivisionError):
# Δt.seconds could be zero if the job is done when starting
info["npoints/s"] = Δnpoints / Δt.seconds
with suppress(Exception):
info["latest_loss"] = learner._cache["loss"]
with suppress(AttributeError):
info["nlearners"] = len(learner.learners)
if "npoints" in info:
info["npoints/learner"] = info["npoints"] / info["nlearners"] # type: ignore[operator]
info["cpu_usage"] = psutil.cpu_percent()
info["mem_usage"] = psutil.virtual_memory().percent
for k, v in psutil.cpu_times()._asdict().items():
info[f"cputimes.{k}"] = v
return info
[docs]
def log_now(runner: AsyncRunner, npoints_start: int) -> None:
"""Create a log message now."""
info = _get_log_entry(runner, npoints_start)
log.info("current status", **info)
[docs]
def log_info(runner: AsyncRunner, interval: float = 300) -> asyncio.Task:
"""Log info in the job's logfile, similar to `runner.live_info`.
Parameters
----------
runner
Adaptive Runner instance.
interval
Time in seconds between log entries.
"""
async def coro(runner: AsyncRunner, interval: float) -> None:
log.info(f"started logger on hostname {socket.gethostname()}") # noqa: G004
learner = runner.learner
npoints_start = _get_npoints(learner)
assert npoints_start is not None
log.info("npoints at start", npoints=npoints_start)
while runner.status() == "running":
if await sleep_unless_task_is_done(runner.task, interval):
break
log_now(runner, npoints_start)
log.info("runner status changed", status=runner.status())
log.info("current status", **_get_log_entry(runner, npoints_start))
return runner.ioloop.create_task(coro(runner, interval))
[docs]
def args_to_env(args: argparse.Namespace, prefix: str = "ADAPTIVE_SCHEDULER_") -> None:
"""Convert parsed arguments to environment variables."""
env_vars = {}
for arg, value in vars(args).items():
if value is not None:
env_vars[f"{prefix}{arg.upper()}"] = str(value)
os.environ.update(env_vars)
log.info("set environment variables", **env_vars)