Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions src/ibex_bluesky_core/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ibex_bluesky_core.callbacks._plotting import LivePlot, show_plot
from ibex_bluesky_core.callbacks._utils import get_default_output_path
from ibex_bluesky_core.fitting import FitMethod
from ibex_bluesky_core.utils import is_matplotlib_backend_qt

logger = logging.getLogger(__name__)

Expand All @@ -45,7 +46,7 @@
class ISISCallbacks:
"""ISIS standard callbacks for use within plans."""

def __init__(
def __init__( # noqa: PLR0912
self,
*,
x: str,
Expand All @@ -66,6 +67,8 @@ def __init__(
live_fit_logger_output_dir: str | PathLike[str] | None = None,
live_fit_logger_postfix: str = "",
human_readable_file_postfix: str = "",
live_fit_update_every: int | None = 1,
live_plot_update_on_every_event: bool = True,
) -> None:
"""A collection of ISIS standard callbacks for use within plans.

Expand Down Expand Up @@ -124,6 +127,8 @@ def _inner():
live_fit_logger_output_dir: the output directory for live fit logger.
live_fit_logger_postfix: the postfix to add to live fit logger.
human_readable_file_postfix: optional postfix to add to human-readable file logger.
live_fit_update_every: How often, in points, to recompute the fit. If None, do not compute until the end.
live_plot_update_on_every_event: whether to show the live plot on every event, or just at the end.
""" # noqa
self._subs = []
self._peak_stats = None
Expand Down Expand Up @@ -164,31 +169,42 @@ def _inner():

if (add_plot_cb or show_fit_on_plot) and not ax:
logger.debug("No axis provided, creating a new one")
fig, ax, exc, result = None, None, None, None
done_event = threading.Event()

class _Cb(QtAwareCallback):
def start(self, doc: RunStart) -> None:
nonlocal result, exc, fig, ax
try:
plt.close("all")
fig, ax = plt.subplots()
finally:
done_event.set()

cb = _Cb()
cb("start", {"time": 0, "uid": ""})
done_event.wait(10.0)
fig, ax = None, None

if is_matplotlib_backend_qt():
done_event = threading.Event()

# Note: not really a callback, this never gets attached to the runengine
class _Cb(QtAwareCallback):
def start(self, doc: RunStart) -> None:
nonlocal fig, ax
try:
plt.close("all")
fig, ax = plt.subplots()
finally:
done_event.set()

cb = _Cb()
cb("start", {"time": 0, "uid": ""})
done_event.wait(10.0)
else:
plt.close("all")
fig, ax = plt.subplots()

if fit is not None:
self._live_fit = LiveFit(fit, y=y, x=x, yerr=yerr)
self._live_fit = LiveFit(fit, y=y, x=x, yerr=yerr, update_every=live_fit_update_every)

# Ideally this would append either livefitplot or livefit, not both, but there's a
# race condition if using the Qt backend where a fit result can be returned before
# the QtAwareCallback has had a chance to process it.
self._subs.append(self._live_fit)
if show_fit_on_plot:
if is_matplotlib_backend_qt():
# Ideally this would append either livefitplot
# or livefit, not both, but there's a
# race condition if using the Qt backend
# where a fit result can be returned before
# the QtAwareCallback has had a chance to process it.
self._subs.append(self._live_fit)
self._subs.append(LiveFitPlot(livefit=self._live_fit, ax=ax))
else:
self._subs.append(self._live_fit)

if add_live_fit_logger:
self._subs.append(
Expand All @@ -211,6 +227,7 @@ def start(self, doc: RunStart) -> None:
linestyle="none",
ax=ax,
yerr=yerr,
update_on_every_event=live_plot_update_on_every_event,
)
)

Expand Down
16 changes: 12 additions & 4 deletions src/ibex_bluesky_core/callbacks/_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ class LiveFit(_DefaultLiveFit):
"""LiveFit, customized for IBEX."""

def __init__(
self, method: FitMethod, y: str, x: str, *, update_every: int = 1, yerr: str | None = None
self,
method: FitMethod,
y: str,
x: str,
*,
update_every: int | None = 1,
yerr: str | None = None,
) -> None:
"""Call Bluesky LiveFit with assumption that there is only one independant variable.

Args:
method (FitMethod): The FitMethod (Model & Guess) to use when fitting.
y (str): The name of the dependant variable.
x (str): The name of the independant variable.
update_every (int, optional): How often to update the fit. (seconds)
update_every (int, optional): How often, in points, to update the fit.
yerr (str or None, optional): Name of field in the Event document
that provides standard deviation for each Y value. None meaning
do not use uncertainties in fit.
Expand All @@ -54,7 +60,10 @@ def __init__(
self.weight_data = []

super().__init__(
model=method.model, y=y, independent_vars={"x": x}, update_every=update_every
model=method.model,
y=y,
independent_vars={"x": x},
update_every=update_every, # type: ignore
)

def event(self, doc: Event) -> None:
Expand Down Expand Up @@ -110,7 +119,6 @@ def update_fit(self) -> None:
self.result = self.model.fit(
self.ydata, weights=None if self.yerr is None else self.weight_data, **kwargs
)
self.__stale = False


class LiveFitLogger(CallbackBase):
Expand Down
32 changes: 27 additions & 5 deletions src/ibex_bluesky_core/callbacks/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
from bluesky.callbacks import LivePlot as _DefaultLivePlot
from bluesky.callbacks.core import get_obj_fields, make_class_safe
from event_model import RunStop
from event_model.documents import Event, RunStart

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
x: str | None = None,
yerr: str | None = None,
*args: Any, # noqa: ANN401
update_on_every_event: bool = True,
**kwargs: Any, # noqa: ANN401
) -> None:
"""Initialise LivePlot.
Expand All @@ -45,28 +47,41 @@ def __init__(
yerr (str or None, optional): Name of uncertainties signal.
Providing None means do not plot uncertainties.
*args: As per mpl_plotting.py
update_on_every_event (bool, optional): Whether to update plot every event,
or just at the end.
**kwargs: As per mpl_plotting.py

"""
self.update_on_every_event = update_on_every_event
super().__init__(y=y, x=x, *args, **kwargs) # noqa: B026
if yerr is not None:
self.yerr, *_others = get_obj_fields([yerr])
else:
self.yerr = None
self.yerr_data = []

self._mpl_errorbar_container = None

def event(self, doc: Event) -> None:
"""Process an event document (delegate to superclass, then show the plot)."""
new_yerr = None if self.yerr is None else doc["data"][self.yerr]
self.update_yerr(new_yerr)
super().event(doc)
show_plot()
if self.update_on_every_event:
show_plot()

def update_plot(self) -> None:
def update_plot(self, force: bool = False) -> None:
"""Create error bars if needed, then update plot."""
if self.yerr is not None:
self.ax.errorbar(x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none") # type: ignore
super().update_plot()
if self.update_on_every_event or force:
if self.yerr is not None:
if self._mpl_errorbar_container is not None:
# Remove old error bars before drawing new ones
self._mpl_errorbar_container.remove()
self._mpl_errorbar_container = self.ax.errorbar( # type: ignore
x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none"
)

super().update_plot()

def update_yerr(self, yerr: float | None) -> None:
"""Update uncertainties data."""
Expand All @@ -76,3 +91,10 @@ def start(self, doc: RunStart) -> None:
"""Process an start document (delegate to superclass, then show the plot)."""
super().start(doc)
show_plot()

def stop(self, doc: RunStop) -> None:
"""Process an start document (delegate to superclass, then show the plot)."""
super().stop(doc)
if not self.update_on_every_event:
self.update_plot(force=True)
show_plot()
4 changes: 4 additions & 0 deletions src/ibex_bluesky_core/plans/reflectometry/_autoalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from bluesky.protocols import NamedMovable
from bluesky.utils import Msg
from lmfit.model import ModelResult
from ophyd_async.core import Device
from ophyd_async.plan_stubs import ensure_connected

from ibex_bluesky_core.callbacks import ISISCallbacks
from ibex_bluesky_core.devices.simpledae import SimpleDae
Expand Down Expand Up @@ -206,6 +208,8 @@ def optimise_axis_against_intensity( # noqa: PLR0913
Instance of :obj:`ibex_bluesky_core.callbacks.ISISCallbacks`.

"""
assert isinstance(alignment_param, Device)
yield from ensure_connected(dae, alignment_param)
problem_found_plan = problem_found_plan or bps.null

logger.info(
Expand Down
3 changes: 3 additions & 0 deletions src/ibex_bluesky_core/plans/reflectometry/_det_map_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def _angle_scan_callback_and_fit(
ax=ax,
live_fit_logger_postfix="_angle",
human_readable_file_postfix="_angle",
live_fit_update_every=len(angle_map) - 1,
live_plot_update_on_every_event=False,
)
for cb in angle_scan_callbacks.subs:
angle_scan_ld.subscribe(cb)
Expand Down Expand Up @@ -134,6 +136,7 @@ def angle_scan_plan(
yield from ensure_connected(dae)

yield from call_qt_aware(plt.close, "all")
yield from call_qt_aware(plt.show)
_, ax = yield from call_qt_aware(plt.subplots)

angle_cb, angle_fit = _angle_scan_callback_and_fit(reducer, angle_map, ax)
Expand Down
4 changes: 2 additions & 2 deletions src/ibex_bluesky_core/run_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from threading import Event

import bluesky.preprocessors as bpp
import matplotlib
from bluesky.run_engine import RunEngine
from bluesky.utils import DuringTask

Expand All @@ -19,6 +18,7 @@

from ibex_bluesky_core.plan_stubs import CALL_QT_AWARE_MSG_KEY, CALL_SYNC_MSG_KEY
from ibex_bluesky_core.run_engine._msg_handlers import call_qt_aware_handler, call_sync_handler
from ibex_bluesky_core.utils import is_matplotlib_backend_qt
from ibex_bluesky_core.version import version

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_run_engine() -> RunEngine:
# See https://github.com/bluesky/bluesky/pull/1770 for details
# We don't need to use our custom _DuringTask if matplotlib is
# configured to use Qt.
dt = None if "qt" in matplotlib.get_backend() else _DuringTask()
dt = None if is_matplotlib_backend_qt() else _DuringTask()

RE = RunEngine(
loop=loop,
Expand Down
3 changes: 1 addition & 2 deletions src/ibex_bluesky_core/run_engine/_msg_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _wrapper() -> Any: # noqa: ANN401


async def call_qt_aware_handler(msg: Msg) -> Any: # noqa: ANN401
"""Handle ibex_bluesky_core.plan_stubs.call_sync."""
"""Handle ibex_bluesky_core.plan_stubs.call_qt_aware."""
func = msg.obj
done_event = Event()
result: Any = None
Expand All @@ -121,7 +121,6 @@ def start(self, doc: RunStart) -> None:
msg.kwargs,
)
result = func(*msg.args, **msg.kwargs)

import matplotlib.pyplot as plt # noqa: PLC0415

plt.show()
Expand Down
8 changes: 7 additions & 1 deletion src/ibex_bluesky_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
import os
from typing import Any, Protocol

import matplotlib
from bluesky.protocols import NamedMovable, Readable

__all__ = ["NamedReadableAndMovable", "centred_pixel", "get_pv_prefix"]
__all__ = ["NamedReadableAndMovable", "centred_pixel", "get_pv_prefix", "is_matplotlib_backend_qt"]


def is_matplotlib_backend_qt() -> bool:
"""Return True if matplotlib is using a qt backend."""
return "qt" in matplotlib.get_backend().lower()


def centred_pixel(centre: int, pixel_range: int) -> list[int]:
Expand Down
55 changes: 32 additions & 23 deletions tests/callbacks/test_isis_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# pyright: reportMissingParameterType=false
from unittest.mock import patch

import bluesky.plan_stubs as bps
import pytest
from bluesky.callbacks import LiveTable
from bluesky.callbacks.fitting import PeakStats
from bluesky.callbacks import LiveFitPlot, LiveTable
from bluesky.callbacks.fitting import LiveFit, PeakStats

from ibex_bluesky_core.callbacks import (
HumanReadableFileCallback,
Expand Down Expand Up @@ -154,28 +156,35 @@ def test_do_not_add_live_fit_logger_then_not_added():
assert not any([isinstance(i, LiveFitLogger) for i in icc.subs])


def test_call_decorator(RE):
x = "X_signal"
y = "Y_signal"
icc = ISISCallbacks(
x=x,
y=y,
add_plot_cb=True,
add_table_cb=False,
add_peak_stats=False,
add_human_readable_file_cb=False,
show_fit_on_plot=False,
)
@pytest.mark.parametrize("matplotlib_using_qt", [True, False])
def test_call_decorator(RE, matplotlib_using_qt):
with patch(
"ibex_bluesky_core.callbacks.is_matplotlib_backend_qt", return_value=matplotlib_using_qt
):
x = "X_signal"
y = "Y_signal"
icc = ISISCallbacks(
x=x,
y=y,
add_plot_cb=True,
add_table_cb=False,
add_peak_stats=False,
add_human_readable_file_cb=False,
show_fit_on_plot=True,
fit=Linear().fit(),
)

def f():
def _outer():
@icc
def _inner():
assert isinstance(icc.subs[0], LivePlot)
yield from bps.null()
def f():
def _outer():
@icc
def _inner():
assert any(isinstance(sub, LivePlot) for sub in icc.subs)
assert any(isinstance(sub, LiveFitPlot) for sub in icc.subs)
assert any(isinstance(sub, LiveFit) for sub in icc.subs) == matplotlib_using_qt
yield from bps.null()

yield from _inner()
yield from _inner()

return (yield from _outer())
return (yield from _outer())

RE(f())
RE(f())
Loading