Skip to content

Commit 353fc3a

Browse files
committed
Optimize
1 parent 9099a5c commit 353fc3a

File tree

11 files changed

+142
-58
lines changed

11 files changed

+142
-58
lines changed

src/ibex_bluesky_core/callbacks/__init__.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ibex_bluesky_core.callbacks._plotting import LivePlot, show_plot
2525
from ibex_bluesky_core.callbacks._utils import get_default_output_path
2626
from ibex_bluesky_core.fitting import FitMethod
27+
from ibex_bluesky_core.utils import is_matplotlib_backend_qt
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -45,7 +46,7 @@
4546
class ISISCallbacks:
4647
"""ISIS standard callbacks for use within plans."""
4748

48-
def __init__(
49+
def __init__( # noqa: PLR0912
4950
self,
5051
*,
5152
x: str,
@@ -66,6 +67,8 @@ def __init__(
6667
live_fit_logger_output_dir: str | PathLike[str] | None = None,
6768
live_fit_logger_postfix: str = "",
6869
human_readable_file_postfix: str = "",
70+
live_fit_update_every: int | None = 1,
71+
live_plot_show_every_event: bool = True,
6972
) -> None:
7073
"""A collection of ISIS standard callbacks for use within plans.
7174
@@ -124,6 +127,8 @@ def _inner():
124127
live_fit_logger_output_dir: the output directory for live fit logger.
125128
live_fit_logger_postfix: the postfix to add to live fit logger.
126129
human_readable_file_postfix: optional postfix to add to human-readable file logger.
130+
live_fit_update_every: How often to recompute the fit. If None, do not compute until the end.
131+
live_plot_show_every_event: whether to show the live plot on every event, or just at the end.
127132
""" # noqa
128133
self._subs = []
129134
self._peak_stats = None
@@ -164,31 +169,41 @@ def _inner():
164169

165170
if (add_plot_cb or show_fit_on_plot) and not ax:
166171
logger.debug("No axis provided, creating a new one")
167-
fig, ax, exc, result = None, None, None, None
168-
done_event = threading.Event()
169-
170-
class _Cb(QtAwareCallback):
171-
def start(self, doc: RunStart) -> None:
172-
nonlocal result, exc, fig, ax
173-
try:
174-
plt.close("all")
175-
fig, ax = plt.subplots()
176-
finally:
177-
done_event.set()
178-
179-
cb = _Cb()
180-
cb("start", {"time": 0, "uid": ""})
181-
done_event.wait(10.0)
172+
fig, ax = None, None
173+
174+
if is_matplotlib_backend_qt():
175+
done_event = threading.Event()
176+
177+
class _Cb(QtAwareCallback):
178+
def start(self, doc: RunStart) -> None:
179+
nonlocal fig, ax
180+
try:
181+
plt.close("all")
182+
fig, ax = plt.subplots()
183+
finally:
184+
done_event.set()
185+
186+
cb = _Cb()
187+
cb("start", {"time": 0, "uid": ""})
188+
done_event.wait(10.0)
189+
else:
190+
plt.close("all")
191+
fig, ax = plt.subplots()
182192

183193
if fit is not None:
184-
self._live_fit = LiveFit(fit, y=y, x=x, yerr=yerr)
194+
self._live_fit = LiveFit(fit, y=y, x=x, yerr=yerr, update_every=live_fit_update_every)
185195

186-
# Ideally this would append either livefitplot or livefit, not both, but there's a
187-
# race condition if using the Qt backend where a fit result can be returned before
188-
# the QtAwareCallback has had a chance to process it.
189-
self._subs.append(self._live_fit)
190196
if show_fit_on_plot:
197+
if is_matplotlib_backend_qt():
198+
# Ideally this would append either livefitplot
199+
# or livefit, not both, but there's a
200+
# race condition if using the Qt backend
201+
# where a fit result can be returned before
202+
# the QtAwareCallback has had a chance to process it.
203+
self._subs.append(self._live_fit)
191204
self._subs.append(LiveFitPlot(livefit=self._live_fit, ax=ax))
205+
else:
206+
self._subs.append(self._live_fit)
192207

193208
if add_live_fit_logger:
194209
self._subs.append(
@@ -211,6 +226,7 @@ def start(self, doc: RunStart) -> None:
211226
linestyle="none",
212227
ax=ax,
213228
yerr=yerr,
229+
show_every_event=live_plot_show_every_event,
214230
)
215231
)
216232

src/ibex_bluesky_core/callbacks/_fitting.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,21 @@ class LiveFit(_DefaultLiveFit):
3535
"""LiveFit, customized for IBEX."""
3636

3737
def __init__(
38-
self, method: FitMethod, y: str, x: str, *, update_every: int = 1, yerr: str | None = None
38+
self,
39+
method: FitMethod,
40+
y: str,
41+
x: str,
42+
*,
43+
update_every: int | None = 1,
44+
yerr: str | None = None,
3945
) -> None:
4046
"""Call Bluesky LiveFit with assumption that there is only one independant variable.
4147
4248
Args:
4349
method (FitMethod): The FitMethod (Model & Guess) to use when fitting.
4450
y (str): The name of the dependant variable.
4551
x (str): The name of the independant variable.
46-
update_every (int, optional): How often to update the fit. (seconds)
52+
update_every (int, optional): How often to update the fit.
4753
yerr (str or None, optional): Name of field in the Event document
4854
that provides standard deviation for each Y value. None meaning
4955
do not use uncertainties in fit.
@@ -54,7 +60,10 @@ def __init__(
5460
self.weight_data = []
5561

5662
super().__init__(
57-
model=method.model, y=y, independent_vars={"x": x}, update_every=update_every
63+
model=method.model,
64+
y=y,
65+
independent_vars={"x": x},
66+
update_every=update_every, # type: ignore
5867
)
5968

6069
def event(self, doc: Event) -> None:
@@ -110,7 +119,6 @@ def update_fit(self) -> None:
110119
self.result = self.model.fit(
111120
self.ydata, weights=None if self.yerr is None else self.weight_data, **kwargs
112121
)
113-
self.__stale = False
114122

115123

116124
class LiveFitLogger(CallbackBase):

src/ibex_bluesky_core/callbacks/_plotting.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import matplotlib.pyplot as plt
88
from bluesky.callbacks import LivePlot as _DefaultLivePlot
99
from bluesky.callbacks.core import get_obj_fields, make_class_safe
10+
from event_model import RunStop
1011
from event_model.documents import Event, RunStart
1112

1213
logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ def __init__(
3536
x: str | None = None,
3637
yerr: str | None = None,
3738
*args: Any, # noqa: ANN401
39+
show_every_event: bool = True,
3840
**kwargs: Any, # noqa: ANN401
3941
) -> None:
4042
"""Initialise LivePlot.
@@ -45,28 +47,40 @@ def __init__(
4547
yerr (str or None, optional): Name of uncertainties signal.
4648
Providing None means do not plot uncertainties.
4749
*args: As per mpl_plotting.py
50+
show_every_event (bool, optional): Whether to show plot every event, or just at the end.
4851
**kwargs: As per mpl_plotting.py
4952
5053
"""
54+
self.show_every_event = show_every_event
5155
super().__init__(y=y, x=x, *args, **kwargs) # noqa: B026
5256
if yerr is not None:
5357
self.yerr, *_others = get_obj_fields([yerr])
5458
else:
5559
self.yerr = None
5660
self.yerr_data = []
5761

62+
self._mpl_errorbar_container = None
63+
5864
def event(self, doc: Event) -> None:
5965
"""Process an event document (delegate to superclass, then show the plot)."""
6066
new_yerr = None if self.yerr is None else doc["data"][self.yerr]
6167
self.update_yerr(new_yerr)
6268
super().event(doc)
63-
show_plot()
69+
if self.show_every_event:
70+
show_plot()
6471

65-
def update_plot(self) -> None:
72+
def update_plot(self, force: bool = False) -> None:
6673
"""Create error bars if needed, then update plot."""
67-
if self.yerr is not None:
68-
self.ax.errorbar(x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none") # type: ignore
69-
super().update_plot()
74+
if self.show_every_event or force:
75+
if self.yerr is not None:
76+
if self._mpl_errorbar_container is not None:
77+
# Remove old error bars before drawing new ones
78+
self._mpl_errorbar_container.remove()
79+
self._mpl_errorbar_container = self.ax.errorbar( # type: ignore
80+
x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none"
81+
)
82+
83+
super().update_plot()
7084

7185
def update_yerr(self, yerr: float | None) -> None:
7286
"""Update uncertainties data."""
@@ -76,3 +90,10 @@ def start(self, doc: RunStart) -> None:
7690
"""Process an start document (delegate to superclass, then show the plot)."""
7791
super().start(doc)
7892
show_plot()
93+
94+
def stop(self, doc: RunStop) -> None:
95+
"""Process an start document (delegate to superclass, then show the plot)."""
96+
super().stop(doc)
97+
if not self.show_every_event:
98+
self.update_plot(force=True)
99+
show_plot()

src/ibex_bluesky_core/plans/reflectometry/_autoalign.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from bluesky.protocols import NamedMovable
88
from bluesky.utils import Msg
99
from lmfit.model import ModelResult
10+
from ophyd_async.core import Device
11+
from ophyd_async.plan_stubs import ensure_connected
1012

1113
from ibex_bluesky_core.callbacks import ISISCallbacks
1214
from ibex_bluesky_core.devices.simpledae import SimpleDae
@@ -206,6 +208,8 @@ def optimise_axis_against_intensity( # noqa: PLR0913
206208
Instance of :obj:`ibex_bluesky_core.callbacks.ISISCallbacks`.
207209
208210
"""
211+
assert isinstance(alignment_param, Device)
212+
yield from ensure_connected(dae, alignment_param)
209213
problem_found_plan = problem_found_plan or bps.null
210214

211215
logger.info(

src/ibex_bluesky_core/plans/reflectometry/_det_map_align.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def _angle_scan_callback_and_fit(
8787
ax=ax,
8888
live_fit_logger_postfix="_angle",
8989
human_readable_file_postfix="_angle",
90+
live_fit_update_every=len(angle_map) - 1,
91+
live_plot_show_every_event=False,
9092
)
9193
for cb in angle_scan_callbacks.subs:
9294
angle_scan_ld.subscribe(cb)
@@ -134,6 +136,7 @@ def angle_scan_plan(
134136
yield from ensure_connected(dae)
135137

136138
yield from call_qt_aware(plt.close, "all")
139+
yield from call_qt_aware(plt.show)
137140
_, ax = yield from call_qt_aware(plt.subplots)
138141

139142
angle_cb, angle_fit = _angle_scan_callback_and_fit(reducer, angle_map, ax)

src/ibex_bluesky_core/run_engine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from threading import Event
88

99
import bluesky.preprocessors as bpp
10-
import matplotlib
1110
from bluesky.run_engine import RunEngine
1211
from bluesky.utils import DuringTask
1312

@@ -19,6 +18,7 @@
1918

2019
from ibex_bluesky_core.plan_stubs import CALL_QT_AWARE_MSG_KEY, CALL_SYNC_MSG_KEY
2120
from ibex_bluesky_core.run_engine._msg_handlers import call_qt_aware_handler, call_sync_handler
21+
from ibex_bluesky_core.utils import is_matplotlib_backend_qt
2222
from ibex_bluesky_core.version import version
2323

2424
logger = logging.getLogger(__name__)
@@ -87,7 +87,7 @@ def get_run_engine() -> RunEngine:
8787
# See https://github.com/bluesky/bluesky/pull/1770 for details
8888
# We don't need to use our custom _DuringTask if matplotlib is
8989
# configured to use Qt.
90-
dt = None if "qt" in matplotlib.get_backend() else _DuringTask()
90+
dt = None if is_matplotlib_backend_qt() else _DuringTask()
9191

9292
RE = RunEngine(
9393
loop=loop,

src/ibex_bluesky_core/run_engine/_msg_handlers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _wrapper() -> Any: # noqa: ANN401
100100

101101

102102
async def call_qt_aware_handler(msg: Msg) -> Any: # noqa: ANN401
103-
"""Handle ibex_bluesky_core.plan_stubs.call_sync."""
103+
"""Handle ibex_bluesky_core.plan_stubs.call_qt_aware."""
104104
func = msg.obj
105105
done_event = Event()
106106
result: Any = None
@@ -121,7 +121,6 @@ def start(self, doc: RunStart) -> None:
121121
msg.kwargs,
122122
)
123123
result = func(*msg.args, **msg.kwargs)
124-
125124
import matplotlib.pyplot as plt # noqa: PLC0415
126125

127126
plt.show()

src/ibex_bluesky_core/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
import os
66
from typing import Any, Protocol
77

8+
import matplotlib
89
from bluesky.protocols import NamedMovable, Readable
910

1011
__all__ = ["NamedReadableAndMovable", "centred_pixel", "get_pv_prefix"]
1112

1213

14+
def is_matplotlib_backend_qt() -> bool:
15+
"""Return True if matplotlib is using a qt backend."""
16+
return "qt" in matplotlib.get_backend().lower()
17+
18+
1319
def centred_pixel(centre: int, pixel_range: int) -> list[int]:
1420
"""Given a centre and range, return a contiguous range of pixels around the centre, inclusive.
1521

tests/callbacks/test_isis_callbacks.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# pyright: reportMissingParameterType=false
2+
from unittest.mock import patch
3+
24
import bluesky.plan_stubs as bps
35
import pytest
4-
from bluesky.callbacks import LiveTable
5-
from bluesky.callbacks.fitting import PeakStats
6+
from bluesky.callbacks import LiveFitPlot, LiveTable
7+
from bluesky.callbacks.fitting import LiveFit, PeakStats
68

79
from ibex_bluesky_core.callbacks import (
810
HumanReadableFileCallback,
@@ -154,28 +156,35 @@ def test_do_not_add_live_fit_logger_then_not_added():
154156
assert not any([isinstance(i, LiveFitLogger) for i in icc.subs])
155157

156158

157-
def test_call_decorator(RE):
158-
x = "X_signal"
159-
y = "Y_signal"
160-
icc = ISISCallbacks(
161-
x=x,
162-
y=y,
163-
add_plot_cb=True,
164-
add_table_cb=False,
165-
add_peak_stats=False,
166-
add_human_readable_file_cb=False,
167-
show_fit_on_plot=False,
168-
)
159+
@pytest.mark.parametrize("matplotlib_using_qt", [True, False])
160+
def test_call_decorator(RE, matplotlib_using_qt):
161+
with patch(
162+
"ibex_bluesky_core.callbacks.is_matplotlib_backend_qt", return_value=matplotlib_using_qt
163+
):
164+
x = "X_signal"
165+
y = "Y_signal"
166+
icc = ISISCallbacks(
167+
x=x,
168+
y=y,
169+
add_plot_cb=True,
170+
add_table_cb=False,
171+
add_peak_stats=False,
172+
add_human_readable_file_cb=False,
173+
show_fit_on_plot=True,
174+
fit=Linear().fit(),
175+
)
169176

170-
def f():
171-
def _outer():
172-
@icc
173-
def _inner():
174-
assert isinstance(icc.subs[0], LivePlot)
175-
yield from bps.null()
177+
def f():
178+
def _outer():
179+
@icc
180+
def _inner():
181+
assert any(isinstance(sub, LivePlot) for sub in icc.subs)
182+
assert any(isinstance(sub, LiveFitPlot) for sub in icc.subs)
183+
assert any(isinstance(sub, LiveFit) for sub in icc.subs) == matplotlib_using_qt
184+
yield from bps.null()
176185

177-
yield from _inner()
186+
yield from _inner()
178187

179-
return (yield from _outer())
188+
return (yield from _outer())
180189

181-
RE(f())
190+
RE(f())

0 commit comments

Comments
 (0)