diff --git a/src/ibex_bluesky_core/callbacks/__init__.py b/src/ibex_bluesky_core/callbacks/__init__.py index 70fc0c96..bcfc37f4 100644 --- a/src/ibex_bluesky_core/callbacks/__init__.py +++ b/src/ibex_bluesky_core/callbacks/__init__.py @@ -24,6 +24,7 @@ from ibex_bluesky_core.callbacks._plotting import LivePlot, show_plot from ibex_bluesky_core.callbacks._utils import get_default_output_path from ibex_bluesky_core.fitting import FitMethod +from ibex_bluesky_core.utils import is_matplotlib_backend_qt logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ class ISISCallbacks: """ISIS standard callbacks for use within plans.""" - def __init__( + def __init__( # noqa: PLR0912 self, *, x: str, @@ -66,6 +67,8 @@ def __init__( live_fit_logger_output_dir: str | PathLike[str] | None = None, live_fit_logger_postfix: str = "", human_readable_file_postfix: str = "", + live_fit_update_every: int | None = 1, + live_plot_update_on_every_event: bool = True, ) -> None: """A collection of ISIS standard callbacks for use within plans. @@ -124,6 +127,8 @@ def _inner(): live_fit_logger_output_dir: the output directory for live fit logger. live_fit_logger_postfix: the postfix to add to live fit logger. human_readable_file_postfix: optional postfix to add to human-readable file logger. + live_fit_update_every: How often, in points, to recompute the fit. If None, do not compute until the end. + live_plot_update_on_every_event: whether to show the live plot on every event, or just at the end. """ # noqa self._subs = [] self._peak_stats = None @@ -164,31 +169,42 @@ def _inner(): if (add_plot_cb or show_fit_on_plot) and not ax: logger.debug("No axis provided, creating a new one") - fig, ax, exc, result = None, None, None, None - done_event = threading.Event() - - class _Cb(QtAwareCallback): - def start(self, doc: RunStart) -> None: - nonlocal result, exc, fig, ax - try: - plt.close("all") - fig, ax = plt.subplots() - finally: - done_event.set() - - cb = _Cb() - cb("start", {"time": 0, "uid": ""}) - done_event.wait(10.0) + fig, ax = None, None + + if is_matplotlib_backend_qt(): + done_event = threading.Event() + + # Note: not really a callback, this never gets attached to the runengine + class _Cb(QtAwareCallback): + def start(self, doc: RunStart) -> None: + nonlocal fig, ax + try: + plt.close("all") + fig, ax = plt.subplots() + finally: + done_event.set() + + cb = _Cb() + cb("start", {"time": 0, "uid": ""}) + done_event.wait(10.0) + else: + plt.close("all") + fig, ax = plt.subplots() if fit is not None: - self._live_fit = LiveFit(fit, y=y, x=x, yerr=yerr) + self._live_fit = LiveFit(fit, y=y, x=x, yerr=yerr, update_every=live_fit_update_every) - # Ideally this would append either livefitplot or livefit, not both, but there's a - # race condition if using the Qt backend where a fit result can be returned before - # the QtAwareCallback has had a chance to process it. - self._subs.append(self._live_fit) if show_fit_on_plot: + if is_matplotlib_backend_qt(): + # Ideally this would append either livefitplot + # or livefit, not both, but there's a + # race condition if using the Qt backend + # where a fit result can be returned before + # the QtAwareCallback has had a chance to process it. + self._subs.append(self._live_fit) self._subs.append(LiveFitPlot(livefit=self._live_fit, ax=ax)) + else: + self._subs.append(self._live_fit) if add_live_fit_logger: self._subs.append( @@ -211,6 +227,7 @@ def start(self, doc: RunStart) -> None: linestyle="none", ax=ax, yerr=yerr, + update_on_every_event=live_plot_update_on_every_event, ) ) diff --git a/src/ibex_bluesky_core/callbacks/_fitting.py b/src/ibex_bluesky_core/callbacks/_fitting.py index dcb581f3..fc255fdc 100644 --- a/src/ibex_bluesky_core/callbacks/_fitting.py +++ b/src/ibex_bluesky_core/callbacks/_fitting.py @@ -35,7 +35,13 @@ class LiveFit(_DefaultLiveFit): """LiveFit, customized for IBEX.""" def __init__( - self, method: FitMethod, y: str, x: str, *, update_every: int = 1, yerr: str | None = None + self, + method: FitMethod, + y: str, + x: str, + *, + update_every: int | None = 1, + yerr: str | None = None, ) -> None: """Call Bluesky LiveFit with assumption that there is only one independant variable. @@ -43,7 +49,7 @@ def __init__( method (FitMethod): The FitMethod (Model & Guess) to use when fitting. y (str): The name of the dependant variable. x (str): The name of the independant variable. - update_every (int, optional): How often to update the fit. (seconds) + update_every (int, optional): How often, in points, to update the fit. yerr (str or None, optional): Name of field in the Event document that provides standard deviation for each Y value. None meaning do not use uncertainties in fit. @@ -54,7 +60,10 @@ def __init__( self.weight_data = [] super().__init__( - model=method.model, y=y, independent_vars={"x": x}, update_every=update_every + model=method.model, + y=y, + independent_vars={"x": x}, + update_every=update_every, # type: ignore ) def event(self, doc: Event) -> None: @@ -110,7 +119,6 @@ def update_fit(self) -> None: self.result = self.model.fit( self.ydata, weights=None if self.yerr is None else self.weight_data, **kwargs ) - self.__stale = False class LiveFitLogger(CallbackBase): diff --git a/src/ibex_bluesky_core/callbacks/_plotting.py b/src/ibex_bluesky_core/callbacks/_plotting.py index c3d18b01..e3ddbd81 100644 --- a/src/ibex_bluesky_core/callbacks/_plotting.py +++ b/src/ibex_bluesky_core/callbacks/_plotting.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt from bluesky.callbacks import LivePlot as _DefaultLivePlot from bluesky.callbacks.core import get_obj_fields, make_class_safe +from event_model import RunStop from event_model.documents import Event, RunStart logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ def __init__( x: str | None = None, yerr: str | None = None, *args: Any, # noqa: ANN401 + update_on_every_event: bool = True, **kwargs: Any, # noqa: ANN401 ) -> None: """Initialise LivePlot. @@ -45,9 +47,12 @@ def __init__( yerr (str or None, optional): Name of uncertainties signal. Providing None means do not plot uncertainties. *args: As per mpl_plotting.py + update_on_every_event (bool, optional): Whether to update plot every event, + or just at the end. **kwargs: As per mpl_plotting.py """ + self.update_on_every_event = update_on_every_event super().__init__(y=y, x=x, *args, **kwargs) # noqa: B026 if yerr is not None: self.yerr, *_others = get_obj_fields([yerr]) @@ -55,18 +60,28 @@ def __init__( self.yerr = None self.yerr_data = [] + self._mpl_errorbar_container = None + def event(self, doc: Event) -> None: """Process an event document (delegate to superclass, then show the plot).""" new_yerr = None if self.yerr is None else doc["data"][self.yerr] self.update_yerr(new_yerr) super().event(doc) - show_plot() + if self.update_on_every_event: + show_plot() - def update_plot(self) -> None: + def update_plot(self, force: bool = False) -> None: """Create error bars if needed, then update plot.""" - if self.yerr is not None: - self.ax.errorbar(x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none") # type: ignore - super().update_plot() + if self.update_on_every_event or force: + if self.yerr is not None: + if self._mpl_errorbar_container is not None: + # Remove old error bars before drawing new ones + self._mpl_errorbar_container.remove() + self._mpl_errorbar_container = self.ax.errorbar( # type: ignore + x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none" + ) + + super().update_plot() def update_yerr(self, yerr: float | None) -> None: """Update uncertainties data.""" @@ -76,3 +91,10 @@ def start(self, doc: RunStart) -> None: """Process an start document (delegate to superclass, then show the plot).""" super().start(doc) show_plot() + + def stop(self, doc: RunStop) -> None: + """Process an start document (delegate to superclass, then show the plot).""" + super().stop(doc) + if not self.update_on_every_event: + self.update_plot(force=True) + show_plot() diff --git a/src/ibex_bluesky_core/plans/reflectometry/_autoalign.py b/src/ibex_bluesky_core/plans/reflectometry/_autoalign.py index 43183666..c84bb1f6 100644 --- a/src/ibex_bluesky_core/plans/reflectometry/_autoalign.py +++ b/src/ibex_bluesky_core/plans/reflectometry/_autoalign.py @@ -7,6 +7,8 @@ from bluesky.protocols import NamedMovable from bluesky.utils import Msg from lmfit.model import ModelResult +from ophyd_async.core import Device +from ophyd_async.plan_stubs import ensure_connected from ibex_bluesky_core.callbacks import ISISCallbacks from ibex_bluesky_core.devices.simpledae import SimpleDae @@ -206,6 +208,8 @@ def optimise_axis_against_intensity( # noqa: PLR0913 Instance of :obj:`ibex_bluesky_core.callbacks.ISISCallbacks`. """ + assert isinstance(alignment_param, Device) + yield from ensure_connected(dae, alignment_param) problem_found_plan = problem_found_plan or bps.null logger.info( diff --git a/src/ibex_bluesky_core/plans/reflectometry/_det_map_align.py b/src/ibex_bluesky_core/plans/reflectometry/_det_map_align.py index 2ad5fc15..66f998f2 100644 --- a/src/ibex_bluesky_core/plans/reflectometry/_det_map_align.py +++ b/src/ibex_bluesky_core/plans/reflectometry/_det_map_align.py @@ -87,6 +87,8 @@ def _angle_scan_callback_and_fit( ax=ax, live_fit_logger_postfix="_angle", human_readable_file_postfix="_angle", + live_fit_update_every=len(angle_map) - 1, + live_plot_update_on_every_event=False, ) for cb in angle_scan_callbacks.subs: angle_scan_ld.subscribe(cb) @@ -134,6 +136,7 @@ def angle_scan_plan( yield from ensure_connected(dae) yield from call_qt_aware(plt.close, "all") + yield from call_qt_aware(plt.show) _, ax = yield from call_qt_aware(plt.subplots) angle_cb, angle_fit = _angle_scan_callback_and_fit(reducer, angle_map, ax) diff --git a/src/ibex_bluesky_core/run_engine/__init__.py b/src/ibex_bluesky_core/run_engine/__init__.py index 391c2ce4..2727bda6 100644 --- a/src/ibex_bluesky_core/run_engine/__init__.py +++ b/src/ibex_bluesky_core/run_engine/__init__.py @@ -7,7 +7,6 @@ from threading import Event import bluesky.preprocessors as bpp -import matplotlib from bluesky.run_engine import RunEngine from bluesky.utils import DuringTask @@ -19,6 +18,7 @@ from ibex_bluesky_core.plan_stubs import CALL_QT_AWARE_MSG_KEY, CALL_SYNC_MSG_KEY from ibex_bluesky_core.run_engine._msg_handlers import call_qt_aware_handler, call_sync_handler +from ibex_bluesky_core.utils import is_matplotlib_backend_qt from ibex_bluesky_core.version import version logger = logging.getLogger(__name__) @@ -87,7 +87,7 @@ def get_run_engine() -> RunEngine: # See https://github.com/bluesky/bluesky/pull/1770 for details # We don't need to use our custom _DuringTask if matplotlib is # configured to use Qt. - dt = None if "qt" in matplotlib.get_backend() else _DuringTask() + dt = None if is_matplotlib_backend_qt() else _DuringTask() RE = RunEngine( loop=loop, diff --git a/src/ibex_bluesky_core/run_engine/_msg_handlers.py b/src/ibex_bluesky_core/run_engine/_msg_handlers.py index 8dcc3414..1965925d 100644 --- a/src/ibex_bluesky_core/run_engine/_msg_handlers.py +++ b/src/ibex_bluesky_core/run_engine/_msg_handlers.py @@ -100,7 +100,7 @@ def _wrapper() -> Any: # noqa: ANN401 async def call_qt_aware_handler(msg: Msg) -> Any: # noqa: ANN401 - """Handle ibex_bluesky_core.plan_stubs.call_sync.""" + """Handle ibex_bluesky_core.plan_stubs.call_qt_aware.""" func = msg.obj done_event = Event() result: Any = None @@ -121,7 +121,6 @@ def start(self, doc: RunStart) -> None: msg.kwargs, ) result = func(*msg.args, **msg.kwargs) - import matplotlib.pyplot as plt # noqa: PLC0415 plt.show() diff --git a/src/ibex_bluesky_core/utils.py b/src/ibex_bluesky_core/utils.py index 55dcbbd7..6c2e00dd 100644 --- a/src/ibex_bluesky_core/utils.py +++ b/src/ibex_bluesky_core/utils.py @@ -5,9 +5,15 @@ import os from typing import Any, Protocol +import matplotlib from bluesky.protocols import NamedMovable, Readable -__all__ = ["NamedReadableAndMovable", "centred_pixel", "get_pv_prefix"] +__all__ = ["NamedReadableAndMovable", "centred_pixel", "get_pv_prefix", "is_matplotlib_backend_qt"] + + +def is_matplotlib_backend_qt() -> bool: + """Return True if matplotlib is using a qt backend.""" + return "qt" in matplotlib.get_backend().lower() def centred_pixel(centre: int, pixel_range: int) -> list[int]: diff --git a/tests/callbacks/test_isis_callbacks.py b/tests/callbacks/test_isis_callbacks.py index fbdef1a6..4209311f 100644 --- a/tests/callbacks/test_isis_callbacks.py +++ b/tests/callbacks/test_isis_callbacks.py @@ -1,8 +1,10 @@ # pyright: reportMissingParameterType=false +from unittest.mock import patch + import bluesky.plan_stubs as bps import pytest -from bluesky.callbacks import LiveTable -from bluesky.callbacks.fitting import PeakStats +from bluesky.callbacks import LiveFitPlot, LiveTable +from bluesky.callbacks.fitting import LiveFit, PeakStats from ibex_bluesky_core.callbacks import ( HumanReadableFileCallback, @@ -154,28 +156,35 @@ def test_do_not_add_live_fit_logger_then_not_added(): assert not any([isinstance(i, LiveFitLogger) for i in icc.subs]) -def test_call_decorator(RE): - x = "X_signal" - y = "Y_signal" - icc = ISISCallbacks( - x=x, - y=y, - add_plot_cb=True, - add_table_cb=False, - add_peak_stats=False, - add_human_readable_file_cb=False, - show_fit_on_plot=False, - ) +@pytest.mark.parametrize("matplotlib_using_qt", [True, False]) +def test_call_decorator(RE, matplotlib_using_qt): + with patch( + "ibex_bluesky_core.callbacks.is_matplotlib_backend_qt", return_value=matplotlib_using_qt + ): + x = "X_signal" + y = "Y_signal" + icc = ISISCallbacks( + x=x, + y=y, + add_plot_cb=True, + add_table_cb=False, + add_peak_stats=False, + add_human_readable_file_cb=False, + show_fit_on_plot=True, + fit=Linear().fit(), + ) - def f(): - def _outer(): - @icc - def _inner(): - assert isinstance(icc.subs[0], LivePlot) - yield from bps.null() + def f(): + def _outer(): + @icc + def _inner(): + assert any(isinstance(sub, LivePlot) for sub in icc.subs) + assert any(isinstance(sub, LiveFitPlot) for sub in icc.subs) + assert any(isinstance(sub, LiveFit) for sub in icc.subs) == matplotlib_using_qt + yield from bps.null() - yield from _inner() + yield from _inner() - return (yield from _outer()) + return (yield from _outer()) - RE(f()) + RE(f()) diff --git a/tests/plans/test_reflectometry.py b/tests/plans/test_reflectometry.py index 8f26f446..e7b21ab9 100644 --- a/tests/plans/test_reflectometry.py +++ b/tests/plans/test_reflectometry.py @@ -106,6 +106,10 @@ def plan(mock) -> Generator[Msg, None, None]: mock = MagicMock() with ( + patch( + "ibex_bluesky_core.plans.reflectometry._autoalign.ensure_connected", + return_value=bps.null(), + ), patch("ibex_bluesky_core.devices.reflectometry.get_pv_prefix", return_value=prefix), patch("ibex_bluesky_core.plans.reflectometry._autoalign.scan", return_value=_fake_scan()), patch("ibex_bluesky_core.plans.reflectometry._autoalign.bps.mv", return_value=bps.null()), @@ -250,6 +254,10 @@ def counter(str: str): return "2" with ( + patch( + "ibex_bluesky_core.plans.reflectometry._autoalign.ensure_connected", + return_value=bps.null(), + ), patch("ibex_bluesky_core.devices.reflectometry.get_pv_prefix", return_value=prefix), patch( "ibex_bluesky_core.plans.reflectometry._autoalign._check_parameter", @@ -299,6 +307,10 @@ def counter(str: str): return "2" with ( + patch( + "ibex_bluesky_core.plans.reflectometry._autoalign.ensure_connected", + return_value=bps.null(), + ), patch("ibex_bluesky_core.devices.reflectometry.get_pv_prefix", return_value=prefix), patch( "ibex_bluesky_core.plans.reflectometry._autoalign._check_parameter", diff --git a/tests/test_utils.py b/tests/test_utils.py index ea4da143..4e43e4c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import pytest -from ibex_bluesky_core.utils import centred_pixel, get_pv_prefix +from ibex_bluesky_core.utils import centred_pixel, get_pv_prefix, is_matplotlib_backend_qt def test_get_pv_prefix(): @@ -20,3 +20,9 @@ def test_cannot_get_pv_prefix(): def test_centred_pixel(): assert centred_pixel(50, 3) == [47, 48, 49, 50, 51, 52, 53] + + +@pytest.mark.parametrize("mpl_backend", ["qt5Agg", "qt6Agg", "qtCairo", "something_else"]) +def test_is_matplotlib_backend_qt(mpl_backend: str): + with patch("ibex_bluesky_core.utils.matplotlib.get_backend", return_value=mpl_backend): + assert is_matplotlib_backend_qt() == ("qt" in mpl_backend)