diff --git a/doc/devices/dae.md b/doc/devices/dae.md index 76f68581..f2e38cd7 100644 --- a/doc/devices/dae.md +++ b/doc/devices/dae.md @@ -13,7 +13,7 @@ This means that [`SimpleDae`](ibex_bluesky_core.devices.simpledae.SimpleDae) is example running using either one DAE run per scan point, or one DAE period per scan point. For complex use-cases, particularly those where the DAE may need to start and stop multiple -acquisitions per scan point (e.g. polarization measurements), [`SimpleDae`](ibex_bluesky_core.devices.simpledae.SimpleDae) is unlikely to be +acquisitions per scan point (e.g. Polarisation measurements), [`SimpleDae`](ibex_bluesky_core.devices.simpledae.SimpleDae) is unlikely to be suitable; instead the [`Dae`](ibex_bluesky_core.devices.dae.Dae) class should be subclassed directly to allow for finer control. ## Example configurations @@ -285,10 +285,10 @@ as a list, and `units` (μs/microseconds for time of flight bounding, and angstr If you don't specify either of these options, they will default to summing over the entire spectrum. -### Polarization/Asymmetry +### Polarisation/Asymmetry ibex_bluesky_core provides a helper method, -{py:obj}`ibex_bluesky_core.devices.simpledae.polarization`, for calculating the quantity (a-b)/(a+b). This quantity is used, for example, in neutron polarization measurements, and in calculating asymmetry for muon measurements. +{py:obj}`ibex_bluesky_core.utils.calculate_polarisation`, for calculating the quantity (a-b)/(a+b). This quantity is used, for example, in neutron polarisation measurements, and in calculating asymmetry for muon measurements. For this expression, scipp's default uncertainty propagation rules cannot be used as the uncertainties on (a-b) are correlated with those of (a+b) in the division step - but scipp assumes uncorrelated data. This helper method calculates the uncertainties following linear error propagation theory, using the partial derivatives of the above expression. @@ -304,7 +304,7 @@ Which then means the variances computed by this helper function are: $ Variance = (\frac{\delta}{\delta a}^2 * variance_a) + (\frac{\delta}{\delta b}^2 * variance_b) $ -The polarization funtion provided will calculate the polarization between two values, A and B, which +The polarisation function provided will calculate the polarisation between two values, A and B, which have different definitions based on the instrument context. Instrument-Specific Interpretations @@ -318,6 +318,10 @@ Similar to LARMOR, A and B represent intensities before and after flipper switch Muon Instruments A and B refer to Measurements from different detector banks. +{py:obj}`ibex_bluesky_core.utils.calculate_polarisation` + +See [`PolarisationReducer`](#PolarisationReducer) for how this is integrated into DAE behaviour. + ## Waiters A [`waiter`](ibex_bluesky_core.devices.simpledae.Waiter) defines an arbitrary strategy for how long to count at each point. @@ -361,6 +365,63 @@ Waits for a user-specified time duration, irrespective of DAE state. Does not publish any additional signals. +## Polarising DAE + +The polarising DAE provides specialised functionality for taking data whilst taking into account the polarity of the beam. + +### DualRunDae + +[`DualRunDae`](ibex_bluesky_core.devices.polarisingdae.DualRunDae) is a more complex version of [`SimpleDae`](ibex_bluesky_core.devices.simpledae.SimpleDae), designed specifically for taking polarisation measurements. It requires a flipper device and uses it to flip from one neutron state to the other between runs. + +Key features: +- Controls a flipper device to switch between neutron states +- Handles three separate reduction strategies + - Up & Down Reducers, which run after each run + - Main reducer, which runs after everything else + +### polarising_dae + +[`polarising_dae`](ibex_bluesky_core.devices.polarisingdae.polarising_dae) is a helper function that creates a configured `PolarisingDae` instance with wavelength binning based normalisation and polarisation calculation capabilities. + +The following is how you may want to use `polarising_dae`: +```python +import scipp + +flipper = block_rw(float, "alice") +wavelength_interval = scipp.array(dims=["tof"], values=[0, 9999999999.0], unit=scipp.units.angstrom, dtype="float64") # Creates a wavelength interval of the whole sprectrum +total_flight_path_length = sc.scalar(value=10, unit=sc.units.m) + +dae = polarising_dae(det_pixels=[1], frames=500, flipper=flipper, flipper_states=(0.0, 1.0), intervals=[wavelength_interval], total_flight_path_length=total_flight_path_length, monitor=2) +``` + +:::{note} + Notice how you must define what the `flipper_states` are to the polarising dae. This is so that it knows what to assign to the `flipper` device to move it to the "up state" and "down state" + . +::: + +### Polarising Reducers + +#### MultiWavelengthBandNormalizer + +[`MultiWavelengthBandNormalizer`](ibex_bluesky_core.devices.polarisingdae.MultiWavelengthBandNormalizer) sums wavelength-bounded spectra and normalises by monitor intensity. + +Published signals: +- `wavelength_bands`: DeviceVector containing wavelength band measurements + - `det_counts`: detector counts in the wavelength band + - `mon_counts`: monitor counts in the wavelength band + - `intensity`: normalised intensity in the wavelength band + - Associated uncertainty measurements for each value + +#### PolarisingReducer + +[`PolarisationReducer`](ibex_bluesky_core.devices.polarisingdae.PolarisationReducer) calculates polarisation from 'spin-up' and 'spin-down' states of a polarising DAE. Uses the [`Polarisation`](#polarisationasymmetry) algorithm. + +Published signals: +- `wavelength_bands`: DeviceVector containing polarisation measurements + - `polarisation`: The calculated polarisation value for that wavelength band + - `polarisation_ratio`: Ratio between up and down states for that wavelength band + - Associated uncertainty measurements for each value + --- ## `Dae` (base class, advanced) @@ -479,9 +540,6 @@ A [`DaeSpectra`](ibex_bluesky_core.devices.dae.DaeSpectra) object provides 3 arr The [`Dae`](ibex_bluesky_core.devices.dae) base class does not provide any spectra by default. User-level classes should specify the set of spectra which they are interested in. - - - Spectra can be summed between two bounds based on time of flight bounds, or wavelength bounds, for both detector and monitor normalizers. Both Scalar Normalizers (PeriodGoodFramesNormalizer, GoodFramesNormalizer) and MonitorNormalizers diff --git a/doc/fitting/fitting.md b/doc/fitting/fitting.md index 6c6eef0f..e8d43cd2 100644 --- a/doc/fitting/fitting.md +++ b/doc/fitting/fitting.md @@ -1,4 +1,4 @@ -# Fitting Callback +# Fitting Callbacks Similar to [`LivePlot`](../callbacks/plotting.md), [`ibex_bluesky_core`](ibex_bluesky_core) provides a thin wrapper around Bluesky's [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) class, enhancing it with additional functionality to better support real-time data fitting. This wrapper not only offers a wide selection of models to fit your data on, but also introduces guess generation for fit parameters. As new data points are acquired, the wrapper refines these guesses dynamically, improving the accuracy of the fit with each additional piece of data, allowing for more efficient and adaptive real-time fitting workflows. @@ -7,7 +7,8 @@ In order to use the wrapper, import[`LiveFit`](ibex_bluesky_core.callbacks.LiveF ```py from ibex_bluesky_core.callbacks.fitting import LiveFit ``` -**Note:** that you do not *need* [`LivePlot`](ibex_bluesky_core.callbacks.LivePlot) for [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) to work but it may be useful to know visaully how well the model fits to the raw data. +.. note:: + that you do not *need* [`LivePlot`](ibex_bluesky_core.callbacks.LivePlot) for [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) to work but it may be useful to know visaully how well the model fits to the raw data. ## Configuration @@ -31,7 +32,8 @@ fit_callback = LiveFit(Gaussian.fit(), y="y_signal", x="x_signal", yerr="yerr_si fit_plot_callback = LiveFitPlot(fit_callback, ax=ax, color="r") ``` -**Note:** that the [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) callback doesn't directly do the plotting, it will return function parameters of the model its trying to fit to; a [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) object must be passed to `LiveFitPlot` which can then be subscribed to the `RunEngine`. See the [Bluesky Documentation](https://blueskyproject.io/bluesky/main/callbacks.html#livefitplot) for information on the various arguments that can be passed to the `LiveFitPlot` class. +.. note:: + that the [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) callback doesn't directly do the plotting, it will return function parameters of the model its trying to fit to; a [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) object must be passed to `LiveFitPlot` which can then be subscribed to the `RunEngine`. See the [Bluesky Documentation](https://blueskyproject.io/bluesky/main/callbacks.html#livefitplot) for information on the various arguments that can be passed to the `LiveFitPlot` class. Using the `yerr` argument allows you to pass uncertainties via a signal to LiveFit, so that the "weight" of each point influences the fit produced. By not providing a signal name you choose not to use uncertainties/weighting in the fitting calculation. Each weight is computed as `1/(standard deviation at point)` and is taken into account to determine how much a point affects the overall fit of the data. Same as the rest of [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit), the fit will be updated after every new point collected now taking into account the weights of each point. Uncertainty data is collected from Bluesky event documents after each new point. @@ -89,7 +91,8 @@ lf = LiveFit([FIT].fit(), y="y_signal", x="x_signal", update_every=0.5) The `[FIT].fit()` function will pass the [`FitMethod`](ibex_bluesky_core.fitting.FitMethod) object straight to the [`LiveFit`](ibex_bluesky_core.callbacks.LiveFit) class. -**Note:** that for the fits in the above table that require parameters, you will need to pass value(s) to their `.fit` method. For example Polynomial fitting: +.. note:: + that for the fits in the above table that require parameters, you will need to pass value(s) to their `.fit` method. For example Polynomial fitting: ```py lf = LiveFit(Polynomial.fit(3), y="y_signal", x="x_signal", update_every=0.5) @@ -146,7 +149,8 @@ lf = LiveFit(fit_method, y="y_signal", x="x_signal", update_every=0.5) # Then subscribe to LiveFitPlot(lf, ...) ``` -**Note:** that the parameters returned from the guess function must allocate to the arguments to the model function, ignoring the independant variable e.g `x` in this case. Array-like structures are not allowed. See the [lmfit documentation](https://lmfit.github.io/lmfit-py/parameters.html) for more information. +.. note:: + that the parameters returned from the guess function must allocate to the arguments to the model function, ignoring the independant variable e.g `x` in this case. Array-like structures are not allowed. See the [lmfit documentation](https://lmfit.github.io/lmfit-py/parameters.html) for more information. #### Option 2: Continued @@ -201,10 +205,60 @@ lf = LiveFit(fit_method, y="y_signal", x="x_signal", update_every=0.5) Or you can create a completely user-defined fitting method. -**Note:** that for fits that require arguments, you will need to pass values to their respecitive `.model` and `.guess` functions. E.g for `Polynomial` fitting: +.. note:: + that for fits that require arguments, you will need to pass values to their respecitive `.model` and `.guess` functions. E.g for `Polynomial` fitting: ```py fit_method = FitMethod(Polynomial.model(3), different_guess) # If using a custom guess function lf = LiveFit(fit_method, ...) ``` See the [standard fits](#models) list above for standard fits which require parameters. It gets more complicated if you want to define your own custom model or guess which you want to pass parameters to. You will have to define a function that takes these parameters and returns the model / guess function with the subsituted values. + +## Chained Fitting + +[`ChainedLiveFit`](ibex_bluesky_core.callbacks.ChainedLiveFit) is a specialised callback that manages multiple LiveFit instances in a chain, where each fit's results inform the next fit's initial parameters. This is particularly useful when dealing with complex data sets where subsequent fits depend on the parameters obtained from previous fits. + +This is useful for when you need to be careful with your curve fitting due to the presence of noisy data. It allows you to fit your widest (full) wavelength band first and then using its fit parameters as the initial guess of the parameters for the next fit + +# Usage +To show how we expect this to be used we will use the PolarisingDae and wavelength bands to highlight the need for the carry over of fitting parameters. Below shows two wavelength bands, first bigger than the second, we will fit to the data in the first and carry it over to the data in the second to improve its, otherwise worse, fit. + +```python +# Needed for PolarisingDae +flipper = block_rw(float, "flipper") +total_flight_path_length = sc.scalar(value=10, unit=sc.units.m) + +x_axis = block_rw(float, "x_axis", write_config=BlockWriteConfig(settle_time_s=0.5)) +wavelength_band_0 = sc.array(dims=["tof"], values=[0, 9999999999.0], unit=sc.units.angstrom, dtype="float64") +wavelength_band_1 = sc.array(dims=["tof"], values=[0.0, 0.07], unit=sc.units.angstrom, dtype="float64") + +dae = polarising_dae(det_pixels=[1], frames=50, flipper=flipper, flipper_states=(0.0, 1.0), + intervals=[wavelength_band_0, wavelength_band_1], + total_flight_path_length=total_flight_path_length, monitor=2) + + +def plan() -> Generator[Msg, None, None]: + fig, (ax1, ax2) = yield from call_qt_aware(plt.subplots, 2) + chained_fit = ChainedLiveFit(method=Linear.fit(), y=[dae.reducer.wavelength_bands[0].calculate_polarisation.name, + dae.reducer.wavelength_bands[1].calculate_polarisation.name], + x=bob.name, ax=[ax1, ax2]) + + # Subscribe chained_fit to RE and run do a scan for example + # chained_fit.get_livefits()[-1].result will give you the fitting results for the last wavelength band +``` + +- You are expected to pass in the list of signal names for each independent variable to `y` in order of how you want the subsequent fitting to go. +- You may also pass in a list of matplotlib axes, which will mean that LiveFitPlots are created per LiveFit, and it will plot the each respective fit to an axis. LiveFitPlots are not created if you do not pass `ax`. +- Similar to the `y` parameter, you may pass signal names which correspond to uncertainty values for each independent variable. + +```{hint} +The method for fitting is the same across all independent variables. +``` + +```{note} +Parameter uncertainties are not carried over between fits +``` + +```{important} +If a fit fails to converge, subsequent fits will use their default guess functions +``` \ No newline at end of file diff --git a/src/ibex_bluesky_core/callbacks/__init__.py b/src/ibex_bluesky_core/callbacks/__init__.py index 919d977f..b54924e6 100644 --- a/src/ibex_bluesky_core/callbacks/__init__.py +++ b/src/ibex_bluesky_core/callbacks/__init__.py @@ -20,7 +20,7 @@ from ibex_bluesky_core.callbacks._file_logger import ( HumanReadableFileCallback, ) -from ibex_bluesky_core.callbacks._fitting import LiveFit, LiveFitLogger +from ibex_bluesky_core.callbacks._fitting import ChainedLiveFit, LiveFit, LiveFitLogger from ibex_bluesky_core.callbacks._plotting import LivePColorMesh, LivePlot, PlotPNGSaver, show_plot from ibex_bluesky_core.callbacks._utils import get_default_output_path from ibex_bluesky_core.fitting import FitMethod @@ -32,6 +32,7 @@ __all__ = [ + "ChainedLiveFit", "DocLoggingCallback", "HumanReadableFileCallback", "ISISCallbacks", diff --git a/src/ibex_bluesky_core/callbacks/_fitting.py b/src/ibex_bluesky_core/callbacks/_fitting.py index 3bd218aa..b4f69d30 100644 --- a/src/ibex_bluesky_core/callbacks/_fitting.py +++ b/src/ibex_bluesky_core/callbacks/_fitting.py @@ -4,13 +4,18 @@ import logging import os import warnings +from itertools import zip_longest from pathlib import Path +import lmfit import numpy as np -from bluesky.callbacks import CallbackBase +from bluesky.callbacks import CallbackBase, LiveFitPlot from bluesky.callbacks import LiveFit as _DefaultLiveFit from bluesky.callbacks.core import make_class_safe -from event_model import Event, RunStart, RunStop +from event_model import Event, EventDescriptor, RunStart, RunStop +from lmfit import Parameter +from matplotlib.axes import Axes +from numpy import typing as npt from ibex_bluesky_core.callbacks._utils import ( DATA, @@ -24,7 +29,7 @@ logger = logging.getLogger(__name__) -__all__ = ["LiveFit", "LiveFitLogger"] +__all__ = ["ChainedLiveFit", "LiveFit", "LiveFitLogger"] @make_class_safe(logger=logger) # pyright: ignore (pyright doesn't understand this decorator) @@ -84,20 +89,17 @@ def update_weight(self, weight: float | None = 0.0) -> None: if self.yerr is not None: self.weight_data.append(weight) - def update_fit(self) -> None: - """Use the provided guess function with the most recent x and y values after every update. - - Args: - None - - Returns: - None - - """ + def can_fit(self) -> bool: + """Check if enough data points have been collected to fit.""" n = len(self.model.param_names) - if len(self.ydata) < n: + return len(self.ydata) >= n + + def update_fit(self) -> None: + """Use the guess function with the most recent x and y values after every update.""" + if not self.can_fit(): warnings.warn( - f"LiveFitPlot cannot update fit until there are at least {n} data points", + f"""LiveFitPlot cannot update fit until there are at least + {len(self.model.param_names)} data points""", stacklevel=1, ) else: @@ -174,7 +176,7 @@ def start(self, doc: RunStart) -> None: self.filename = self.output_dir / f"{rb_num}" / file def event(self, doc: Event) -> Event: - """Start collecting, y, x and yerr data. + """Start collecting, y, x, and yerr data. Args: doc: (Event): An event document. @@ -251,3 +253,145 @@ def write_fields_table_uncertainty(self) -> None: rows = zip(self.x_data, self.y_data, self.yerr_data, self.y_fit_data, strict=True) self.csvwriter.writerows(rows) + + +class ChainedLiveFit(CallbackBase): + """Processes multiple LiveFits, each fit's results inform the next, with optional plotting. + + This callback handles a sequence of LiveFit instances where the parameters from each + completed fit serve as the initial guess for the subsequent fit. Optional plotting + is built in using LivePlotFits. Note that you should not subscribe to the LiveFit/LiveFitPlot + callbacks directly, but rather subscribe just this callback. + """ + + def __init__( + self, + method: FitMethod, + y: list[str], + x: str, + *, + yerr: list[str] | None = None, + ax: list[Axes] | None = None, + ) -> None: + """Initialise ChainedLiveFit with multiple LiveFits. + + Args: + method: FitMethod instance for fitting + y: List of y-axis variable names + x: x-axis variable name + yerr: Optional list of error values corresponding to y variables + ax: A list of axes to plot fits on to. Creates LiveFitPlot instances. + + """ + super().__init__() + + if yerr and len(y) != len(yerr): + raise ValueError("yerr must be the same length as y") + + if ax and len(y) != len(ax): + raise ValueError("ax must be the same length as y") + + self._livefits = [ + LiveFit(method=method, y=y_name, x=x, yerr=yerr_name) + for y_name, yerr_name in zip_longest(y, yerr or []) + ] # if yerrs then create a LiveFit with a yerr else create a LiveFit without a yerr + + self._livefitplots = [ + LiveFitPlot(livefit=livefit, ax=axis) + for livefit, axis in zip(self._livefits, ax or [], strict=False) + ] # if ax then create a LiveFitPlot with ax else do not create any LiveFitPlots + + def _process_doc( + self, doc: RunStart | Event | RunStop | EventDescriptor, method_name: str + ) -> None: + """Process a document for either LivePlots or LiveFits. + + Args: + doc: document to process + method_name: Name of the method to call ('start', 'descriptor', 'event', or 'stop') + + """ + callbacks = self._livefitplots or self._livefits + for callback in callbacks: + assert hasattr(callback, method_name) + getattr(callback, method_name)(doc) + + def start(self, doc: RunStart) -> None: + """Process start document for all callbacks. + + Args: + doc: RunStart document + + """ + self._process_doc(doc, "start") + + def descriptor(self, doc: EventDescriptor) -> None: + """Process descriptor document for all callbacks. + + Args: + doc: EventDescriptor document. + + """ + self._process_doc(doc, "descriptor") + + def event(self, doc: Event) -> Event: + """Process event document for all callbacks. + + Args: + doc: Event document + + """ + init_guess = {} + + for livefit in self._livefits: + rem_guess = livefit.method.guess + try: + if init_guess: + # Use previous fit results as initial guess for next fit + def guess_func( + a: npt.NDArray[np.float64], b: npt.NDArray[np.float64] + ) -> dict[str, lmfit.Parameter]: + nonlocal init_guess + return { + name: Parameter(name, value.value) + for name, value in init_guess.items() # noqa: B023 + } # ruff doesn't understand nonlocal + + # Using value.value means that parameter uncertainty + # is not carried over between fits + livefit.method.guess = guess_func + + if self._livefitplots: + self._livefitplots[self._livefits.index(livefit)].event(doc) + else: + livefit.event(doc) + + finally: + livefit.method.guess = rem_guess + + if livefit.can_fit(): + if livefit.result is None: + raise RuntimeError("LiveFit.result was None. Could not update fit.") + + init_guess = livefit.result.params + + return doc + + def stop(self, doc: RunStop) -> None: + """Process stop document and update fitting parameters. + + Args: + doc: RunStop document + + """ + self._process_doc(doc, "stop") + + @property + def live_fits(self) -> list[LiveFit]: + """Return a list of the livefits.""" + return self._livefits + + @property + def live_fit_plots(self) -> list[LiveFitPlot]: + """Return a list of the livefitplots.""" + return self._livefitplots diff --git a/src/ibex_bluesky_core/devices/dae/__init__.py b/src/ibex_bluesky_core/devices/dae/__init__.py index 13d748e4..369ef57a 100644 --- a/src/ibex_bluesky_core/devices/dae/__init__.py +++ b/src/ibex_bluesky_core/devices/dae/__init__.py @@ -33,7 +33,9 @@ SinglePeriodSettings, ) from ibex_bluesky_core.devices.dae._settings import DaeSettings, DaeSettingsData, DaeTimingSource -from ibex_bluesky_core.devices.dae._spectra import DaeSpectra +from ibex_bluesky_core.devices.dae._spectra import ( + DaeSpectra, +) from ibex_bluesky_core.devices.dae._tcb_settings import ( DaeTCBSettings, DaeTCBSettingsData, diff --git a/src/ibex_bluesky_core/devices/polarisingdae/__init__.py b/src/ibex_bluesky_core/devices/polarisingdae/__init__.py new file mode 100644 index 00000000..a83bd287 --- /dev/null +++ b/src/ibex_bluesky_core/devices/polarisingdae/__init__.py @@ -0,0 +1,252 @@ +"""An interface to the DAE for bluesky, suited for polarisation.""" + +import logging +from typing import Generic, TypeAlias + +import scipp as sc +from bluesky.protocols import Movable, Triggerable +from ophyd_async.core import ( + AsyncStageable, + AsyncStatus, + Reference, +) +from typing_extensions import TypeVar + +from ibex_bluesky_core.devices.dae import Dae +from ibex_bluesky_core.devices.polarisingdae._reducers import ( + MultiWavelengthBandNormalizer, + PolarisationReducer, +) +from ibex_bluesky_core.devices.simpledae import ( + Controller, + GoodFramesWaiter, + PeriodGoodFramesWaiter, + PeriodPerPointController, + Reducer, + RunPerPointController, + Waiter, + wavelength_bounded_spectra, +) +from ibex_bluesky_core.utils import get_pv_prefix + +logger = logging.getLogger(__name__) + +__all__ = [ + "DualRunDae", + "MultiWavelengthBandNormalizer", + "PolarisationReducer", + "polarising_dae", +] + +TController_co = TypeVar("TController_co", bound="Controller", default=Controller, covariant=True) +TWaiter_co = TypeVar("TWaiter_co", bound="Waiter", default=Waiter, covariant=True) +TPReducer_co = TypeVar( + "TPReducer_co", + bound="Reducer", + default=Reducer, + covariant=True, +) +TMWBReducer_co = TypeVar( + "TMWBReducer_co", + bound="Reducer", + default=Reducer, + covariant=True, +) + + +class DualRunDae( + Dae, + Triggerable, + AsyncStageable, + Generic[TController_co, TWaiter_co, TPReducer_co, TMWBReducer_co], +): + """DAE with strategies for data collection, waiting, and reduction, suited for polarisation. + + This class is a more complex version of SimpleDae. It requires a flipper device to be provided + and will perform two runs, changing the flipper device at the start and inbetween runs. + """ + + def __init__( # noqa: PLR0913 + self, + *, + prefix: str, + name: str = "DAE", + controller: TController_co, + waiter: TWaiter_co, + reducer_final: TPReducer_co, + reducer_up: TMWBReducer_co, + reducer_down: TMWBReducer_co, + movable: Movable[float], + movable_states: list[float], + ) -> None: + """Initialise a DualRunDae. + + Args: + prefix: the PV prefix of the instrument being controlled. + name: A friendly name for this DAE object. + controller: A DAE control strategy, defines how the DAE begins and ends data acquisition + Pre-defined strategies in the ibex_bluesky_core.devices.controllers module + waiter: A waiting strategy, defines how the DAE waits for an acquisition to be complete + Pre-defined strategies in the ibex_bluesky_core.devices.waiters module + reducer_final: A data reduction strategy. It will be triggered once after the two runs. + reducer_up: A data reduction strategy. Triggers once after the first run completes. + reducer_down: A data reduction strategy. Triggers once after the second run completes. + movable: A device which will be changed at the start of the first run and between runs. + movable_states: A tuple of two floats, the states to set at the start and between runs. + + """ + self.movable: Reference[Movable[float]] = Reference(movable) + self.movable_states: list[float] = movable_states + + self._prefix = prefix + self.controller: TController_co = controller + self.waiter: TWaiter_co = waiter + self.reducer_up: TMWBReducer_co = reducer_up + self.reducer_down: TMWBReducer_co = reducer_down + self.reducer_final: TPReducer_co = reducer_final + + logger.info( + """created polarisingdae with prefix=%s, controller=%s, + waiter=%s, reducer=%s, reducer_up=%s, reducer_down=%s""", + prefix, + controller, + waiter, + reducer_final, + reducer_up, + reducer_down, + ) + + # controller, waiter and reducers may be Devices (but don't necessarily have to be), + # so can define their own signals. do __init__ after defining those, so that the signals + # are connected/named and usable. + super().__init__(prefix=prefix, name=name) + + # Ask each defined strategy what it's interesting signals are, and ensure those signals are + # published when the top-level SimpleDae object is read. + extra_readables = set() + for strategy in [ + self.controller, + self.waiter, + self.reducer_up, + self.reducer_down, + self.reducer_final, + ]: + extra_readables.update(strategy.additional_readable_signals(self)) + logger.info("extra readables: %s", list(extra_readables)) + self.add_readables(devices=list(extra_readables)) + + @AsyncStatus.wrap + async def stage(self) -> None: + """Pre-scan setup. Delegate to the controller.""" + await self.controller.setup(self) + + @AsyncStatus.wrap + async def trigger(self) -> None: + """Take a single measurement and prepare it for later reading. + + This waits for the acquisition and any defined reduction to be complete, such that + after this coroutine completes, all relevant data is available via read() + """ + self.movable().set(self.movable_states[0]) + + await self.controller.start_counting(self) + await self.waiter.wait(self) + await self.controller.stop_counting(self) + await self.reducer_up.reduce_data(self) + + self.movable().set(self.movable_states[1]) + + await self.controller.start_counting(self) + await self.waiter.wait(self) + await self.controller.stop_counting(self) + await self.reducer_down.reduce_data(self) + + await self.reducer_final.reduce_data(self) + + @AsyncStatus.wrap + async def unstage(self) -> None: + """Post-scan teardown, delegate to the controller.""" + await self.controller.teardown(self) + + +PolarisingDualRunDae: TypeAlias = DualRunDae[ + Controller, Waiter, PolarisationReducer, MultiWavelengthBandNormalizer +] + + +def polarising_dae( # noqa: PLR0913 + *, + det_pixels: list[int], + frames: int, + movable: Movable[float], + movable_states: list[float], + intervals: list[sc.Variable], + total_flight_path_length: sc.Variable, + periods: bool = True, + monitor: int = 1, + save_run: bool = False, +) -> PolarisingDualRunDae: + """Create a Polarising DAE which uses wavelength binning and calculates polarisation. + + This is a different version of monitor_normalising_dae, with a more complex set of strategies. + While already normalising using a monitor and waiting for frames, it requires a movable device + to be provided and will change the movable between two neutron states between runs. It uses + wavelength-bounded binning, and on completion of the two runs will calculate polarisation. + + Args: + det_pixels: list of detector pixel to use for scanning. + frames: number of frames to wait for. + movable: A device which can be used to change the neutron state between runs. + movable_states: A tuple of two floats, the neutron states to be set between runs. + intervals: list of wavelength intervals to use for binning. + total_flight_path_length: total flight path length of the neutron beam + from monitor to detector. + periods: whether or not to use software periods. + monitor: the monitor spectra number. + save_run: whether or not to save the run of the DAE. + + """ + prefix = get_pv_prefix() + + if periods: + controller = PeriodPerPointController(save_run=save_run) + waiter = PeriodGoodFramesWaiter(frames) + else: + controller = RunPerPointController(save_run=save_run) + waiter = GoodFramesWaiter(frames) + + sum_wavelength_bands = [ + wavelength_bounded_spectra(bounds=i, total_flight_path_length=total_flight_path_length) + for i in intervals + ] + + reducer_up = MultiWavelengthBandNormalizer( + prefix=prefix, + detector_spectra=det_pixels, + monitor_spectra=[monitor], + sum_wavelength_bands=sum_wavelength_bands, + ) + + reducer_down = MultiWavelengthBandNormalizer( + prefix=prefix, + detector_spectra=det_pixels, + monitor_spectra=[monitor], + sum_wavelength_bands=sum_wavelength_bands, + ) + + reducer_final = PolarisationReducer( + intervals=intervals, reducer_up=reducer_up, reducer_down=reducer_down + ) + + dae = DualRunDae( + prefix=prefix, + controller=controller, + waiter=waiter, + reducer_final=reducer_final, + reducer_up=reducer_up, + reducer_down=reducer_down, + movable=movable, + movable_states=movable_states, + ) + + return dae diff --git a/src/ibex_bluesky_core/devices/polarisingdae/_reducers.py b/src/ibex_bluesky_core/devices/polarisingdae/_reducers.py new file mode 100644 index 00000000..d3eaabc5 --- /dev/null +++ b/src/ibex_bluesky_core/devices/polarisingdae/_reducers.py @@ -0,0 +1,229 @@ +"""Data reduction strategies for polarising DAEs.""" + +import asyncio +import logging +import math +import typing +from collections.abc import Awaitable, Callable, Collection, Sequence + +import scipp as sc +from ophyd_async.core import Device, DeviceVector, Reference, StandardReadable + +from ibex_bluesky_core.devices.dae import Dae, DaeSpectra +from ibex_bluesky_core.devices.polarisingdae._spectra import ( + _PolarisedWavelengthBand, + _WavelengthBand, +) +from ibex_bluesky_core.devices.simpledae import INTENSITY_PRECISION, VARIANCE_ADDITION, Reducer +from ibex_bluesky_core.utils import calculate_polarisation + +logger = logging.getLogger(__name__) + + +class MultiWavelengthBandNormalizer(Reducer, StandardReadable): + """Sum a set of wavelength-bounded spectra, then normalise by monitor intensity.""" + + def __init__( + self, + prefix: str, + detector_spectra: Sequence[int], + monitor_spectra: Sequence[int], + sum_wavelength_bands: list[ + Callable[[Collection[DaeSpectra]], Awaitable[sc.Variable | sc.DataArray]] + ], + ) -> None: + """Init. + + Args: + prefix: the PV prefix of the instrument to get spectra from (e.g. IN:DEMO:) + detector_spectra: a sequence of spectra numbers (detectors) to sum. + monitor_spectra: a sequence of spectra numbers (monitors) to sum. + sum_wavelength_bands: takes a sequence of summing functions, each of which takes + spectra objects and returns a scipp scalar describing the detector intensity. + + """ + self.sum_wavelength_bands = sum_wavelength_bands + + dae_prefix = prefix + "DAE:" + + self.detectors = DeviceVector( + {i: DaeSpectra(dae_prefix=dae_prefix, spectra=i, period=0) for i in detector_spectra} + ) + self.monitors = DeviceVector( + {i: DaeSpectra(dae_prefix=dae_prefix, spectra=i, period=0) for i in monitor_spectra} + ) + + self._wavelength_bands = DeviceVector( + {i: _WavelengthBand() for i in range(len(self.sum_wavelength_bands))} + ) + + super().__init__(name="") + + async def reduce_data(self, dae: "Dae") -> None: + """Apply the normalisation.""" + logger.info("starting normalisation") + + for i in range(len(self.sum_wavelength_bands)): + sum_wavelength_band = self.sum_wavelength_bands[i] + wavelength_band = self._wavelength_bands[i] + detector_counts_sc, monitor_counts_sc = await asyncio.gather( + sum_wavelength_band(self.detectors.values()), + sum_wavelength_band(self.monitors.values()), + ) + + if monitor_counts_sc.value == 0.0: + raise ValueError( + f"""Cannot normalize; got zero monitor counts in wavelength band {i}. + Check beamline configuration.""" + ) + + # See doc\architectural_decisions\005-variance-addition.md + # for justification of this addition to variances. + detector_counts_sc.variance += VARIANCE_ADDITION + intensity_sc = detector_counts_sc / monitor_counts_sc + + intensity = float(intensity_sc.value) + det_counts = float(detector_counts_sc.value) + mon_counts = float(monitor_counts_sc.value) + + intensity_stddev = math.sqrt(intensity_sc.variance) + det_counts_stddev = math.sqrt(detector_counts_sc.variance) + mon_counts_stddev = math.sqrt(monitor_counts_sc.variance) + + wavelength_band.setter( + det_counts=det_counts, + det_counts_stddev=det_counts_stddev, + mon_counts=mon_counts, + mon_counts_stddev=mon_counts_stddev, + intensity=intensity, + intensity_stddev=intensity_stddev, + ) + + def additional_readable_signals(self, dae: Dae) -> list[Device]: + """Publish interesting signals derived or used by this reducer.""" + return list(self._wavelength_bands.values()) + + @property + def det_counts_names(self) -> list[str]: + return [band.det_counts.name for band in self._wavelength_bands.values()] + + @property + def det_counts_stddev_names(self) -> list[str]: + return [band.det_counts_stddev.name for band in self._wavelength_bands.values()] + + @property + def mon_counts_names(self) -> list[str]: + return [band.mon_counts.name for band in self._wavelength_bands.values()] + + @property + def mon_counts_stddev_names(self) -> list[str]: + return [band.mon_counts_stddev.name for band in self._wavelength_bands.values()] + + @property + def intensity_names(self) -> list[str]: + return [band.intensity.name for band in self._wavelength_bands.values()] + + @property + def intensity_stddev_names(self) -> list[str]: + return [band.intensity_stddev.name for band in self._wavelength_bands.values()] + + +class PolarisationReducer(Reducer, StandardReadable): + """Calculate polarisation from 'spin-up' and 'spin-down' states of a polarising DAE.""" + + def __init__( + self, + intervals: list[sc.Variable], + reducer_up: MultiWavelengthBandNormalizer, + reducer_down: MultiWavelengthBandNormalizer, + ) -> None: + """Init. + + Args: + intervals: a sequence of scipp describing the wavelength intervals over which + to calculate polarisation. + reducer_up: A data reduction strategy, defines the post-processing on raw DAE data. + Used to retrieve intensity values from the up-spin state. + reducer_down: A data reduction strategy, defines the post-processing on raw DAE data. + Used to retrieve intensity values from the down-spin state. + + """ + self.intervals = intervals + self.reducer_up = Reference(reducer_up) + self.reducer_down = Reference(reducer_down) + self._wavelength_bands = DeviceVector( + { + i: _PolarisedWavelengthBand(intensity_precision=INTENSITY_PRECISION) + for i in range(len(intervals)) + } + ) + super().__init__(name="") + + async def reduce_data(self, dae: Dae) -> None: + """Apply the polarisation.""" + logger.info("starting polarisation") + + if len(self.reducer_up().additional_readable_signals(dae)) != len( + self.reducer_down().additional_readable_signals(dae) + ): + raise ValueError("Mismatched number of wavelength bands") + + for i in range(len(self.intervals)): + wavelength_band = self._wavelength_bands[i] + + band_up = typing.cast( + _WavelengthBand, self.reducer_up().additional_readable_signals(dae)[i] + ) + intensity_up = await band_up.intensity.get_value() + + band_down = typing.cast( + _WavelengthBand, self.reducer_down().additional_readable_signals(dae)[i] + ) + intensity_down = await band_down.intensity.get_value() + + if intensity_up + intensity_down == 0.0: + raise ValueError("Cannot calculate polarisation; zero intensity sum detected") + + intensity_up_stddev = await band_up.intensity_stddev.get_value() + intensity_down_stddev = await band_down.intensity_stddev.get_value() + intensity_up_sc = sc.scalar( + value=intensity_up, variance=intensity_up_stddev, dtype=float + ) + intensity_down_sc = sc.scalar( + value=intensity_down, variance=intensity_down_stddev, dtype=float + ) + + polarisation_sc = calculate_polarisation(intensity_up_sc, intensity_down_sc) + polarisation_ratio_sc = intensity_up_sc / intensity_down_sc + + polarisation_val = float(polarisation_sc.value) + polarisation_ratio = float(polarisation_ratio_sc.value) + polarisation_stddev = float(polarisation_sc.variance) + polarisation_ratio_stddev = float(polarisation_ratio_sc.variance) + + wavelength_band.setter( + polarisation=polarisation_val, + polarisation_stddev=polarisation_stddev, + polarisation_ratio=polarisation_ratio, + polarisation_ratio_stddev=polarisation_ratio_stddev, + ) + + def additional_readable_signals(self, dae: Dae) -> list[Device]: + """Publish interesting signals derived or used by this reducer.""" + return list(self._wavelength_bands.values()) + + @property + def polarisation_names(self) -> list[str]: + return [band.polarisation.name for band in self._wavelength_bands.values()] + + @property + def polarisation_stddev_names(self) -> list[str]: + return [band.polarisation_stddev.name for band in self._wavelength_bands.values()] + + @property + def polarisation_ratio(self) -> list[str]: + return [band.polarisation_ratio.name for band in self._wavelength_bands.values()] + + @property + def polarisation_ratio_stddev(self) -> list[str]: + return [band.polarisation_ratio_stddev.name for band in self._wavelength_bands.values()] diff --git a/src/ibex_bluesky_core/devices/polarisingdae/_spectra.py b/src/ibex_bluesky_core/devices/polarisingdae/_spectra.py new file mode 100644 index 00000000..ad1cf55b --- /dev/null +++ b/src/ibex_bluesky_core/devices/polarisingdae/_spectra.py @@ -0,0 +1,95 @@ +from ophyd_async.core import StandardReadable, soft_signal_r_and_setter + +__all__ = ["_PolarisedWavelengthBand", "_WavelengthBand"] + + +class _WavelengthBand(StandardReadable): + """Subdevice for a single wavelength band. + + Represents a few measurements within a specific wavelength band. + Has a setter method to assign values to the published signals. + """ + + def __init__(self, *, name: str = "") -> None: + self.det_counts, self._det_counts_setter = soft_signal_r_and_setter(float, 0.0) + self.mon_counts, self._mon_counts_setter = soft_signal_r_and_setter(float, 0.0) + + self.intensity, self._intensity_setter = soft_signal_r_and_setter(float, 0.0) + self.det_counts_stddev, self._det_counts_stddev_setter = soft_signal_r_and_setter( + float, 0.0 + ) + self.mon_counts_stddev, self._mon_counts_stddev_setter = soft_signal_r_and_setter( + float, 0.0 + ) + self.intensity_stddev, self._intensity_stddev_setter = soft_signal_r_and_setter(float, 0.0) + + self.intensity.set_name("intensity") + self.intensity_stddev.set_name("intensity_stddev") + + super().__init__(name=name) + + def setter( + self, + *, + det_counts: float, + det_counts_stddev: float, + mon_counts: float, + mon_counts_stddev: float, + intensity: float, + intensity_stddev: float, + ) -> None: + self._intensity_setter(intensity) + self._det_counts_setter(det_counts) + self._mon_counts_setter(mon_counts) + + self._intensity_stddev_setter(intensity_stddev) + self._det_counts_stddev_setter(det_counts_stddev) + self._mon_counts_stddev_setter(mon_counts_stddev) + + +class _PolarisedWavelengthBand(StandardReadable): + """Subdevice that holds polarisation info for two wavelength bands. + + Represents the polarisation information calculated using measurements + taken from two :obj:`ibex_bluesky_core.devices.polarisingdae._spectra._WavelengthBand` + objects, one published from an "up state" + :obj:`ibex_bluesky_core.devices.polarisingdae.MultiWavelengthBandNormalizer`and the + other from a "down state" + :obj:`ibex_bluesky_core.devices.polarisingdae.MultiWavelengthBandNormalizer. + Has a setter method to assign values to the published signals. + """ + + def __init__(self, *, name: str = "", intensity_precision: int = 6) -> None: + with self.add_children_as_readables(): + self.polarisation, self._polarisation_setter = soft_signal_r_and_setter( + float, 0.0, precision=intensity_precision + ) + self.polarisation_stddev, self._polarisation_stddev_setter = soft_signal_r_and_setter( + float, 0.0, precision=intensity_precision + ) + self.polarisation_ratio, self._polarisation_ratio_setter = soft_signal_r_and_setter( + float, 0.0, precision=intensity_precision + ) + self.polarisation_ratio_stddev, self._polarisation_ratio_stddev_setter = ( + soft_signal_r_and_setter(float, 0.0, precision=intensity_precision) + ) + + self.polarisation.set_name("polarisation") + self.polarisation_stddev.set_name("polarisation_stddev") + self.polarisation_ratio.set_name("polarisation_ratio") + self.polarisation_ratio_stddev.set_name("polarisation_ratio_stddev") + + super().__init__(name=name) + + def setter( + self, + *, + polarisation: float, + polarisation_stddev: float, + polarisation_ratio: float, + polarisation_ratio_stddev: float, + ) -> None: + self._polarisation_setter(polarisation) + self._polarisation_stddev_setter(polarisation_stddev) + self._polarisation_ratio_setter(polarisation_ratio) + self._polarisation_ratio_stddev_setter(polarisation_ratio_stddev) diff --git a/src/ibex_bluesky_core/devices/simpledae/__init__.py b/src/ibex_bluesky_core/devices/simpledae/__init__.py index b9ba94db..a06cbfb8 100644 --- a/src/ibex_bluesky_core/devices/simpledae/__init__.py +++ b/src/ibex_bluesky_core/devices/simpledae/__init__.py @@ -16,13 +16,13 @@ RunPerPointController, ) from ibex_bluesky_core.devices.simpledae._reducers import ( + INTENSITY_PRECISION, VARIANCE_ADDITION, DSpacingMappingReducer, MonitorNormalizer, PeriodGoodFramesNormalizer, PeriodSpecIntegralsReducer, ScalarNormalizer, - polarization, sum_spectra, tof_bounded_spectra, wavelength_bounded_spectra, @@ -50,6 +50,7 @@ logger = logging.getLogger(__name__) __all__ = [ + "INTENSITY_PRECISION", "VARIANCE_ADDITION", "Controller", "DSpacingMappingReducer", @@ -72,15 +73,14 @@ "Waiter", "check_dae_strategies", "monitor_normalising_dae", - "polarization", "sum_spectra", "tof_bounded_spectra", "wavelength_bounded_spectra", ] -TController_co = TypeVar("TController_co", bound="Controller", default="Controller", covariant=True) -TWaiter_co = TypeVar("TWaiter_co", bound="Waiter", default="Waiter", covariant=True) -TReducer_co = TypeVar("TReducer_co", bound="Reducer", default="Reducer", covariant=True) +TController_co = TypeVar("TController_co", bound="Controller", default=Controller, covariant=True) +TWaiter_co = TypeVar("TWaiter_co", bound="Waiter", default=Waiter, covariant=True) +TReducer_co = TypeVar("TReducer_co", bound="Reducer", default=Reducer, covariant=True) class SimpleDae(Dae, Triggerable, AsyncStageable, Generic[TController_co, TWaiter_co, TReducer_co]): diff --git a/src/ibex_bluesky_core/devices/simpledae/_controllers.py b/src/ibex_bluesky_core/devices/simpledae/_controllers.py index 9dfd1d8e..e3ef6e91 100644 --- a/src/ibex_bluesky_core/devices/simpledae/_controllers.py +++ b/src/ibex_bluesky_core/devices/simpledae/_controllers.py @@ -1,7 +1,6 @@ """DAE control strategies.""" import logging -import typing from ophyd_async.core import ( Device, @@ -10,16 +9,13 @@ wait_for_value, ) -from ibex_bluesky_core.devices.dae import BeginRunExBits, RunstateEnum +from ibex_bluesky_core.devices.dae import BeginRunExBits, Dae, RunstateEnum from ibex_bluesky_core.devices.simpledae._strategies import Controller logger = logging.getLogger(__name__) -if typing.TYPE_CHECKING: - from ibex_bluesky_core.devices.simpledae import SimpleDae - -async def _end_or_abort_run(dae: "SimpleDae", save: bool) -> None: +async def _end_or_abort_run(dae: Dae, save: bool) -> None: if save: logger.info("ending run") await dae.controls.end_run.trigger(wait=True, timeout=None) @@ -49,7 +45,7 @@ def __init__(self, save_run: bool) -> None: self._save_run = save_run self._current_period = 0 - async def setup(self, dae: "SimpleDae") -> None: + async def setup(self, dae: Dae) -> None: """Pre-scan setup (begin a new run in paused mode).""" self._current_period = 0 logger.info("setting up new run") @@ -57,7 +53,7 @@ async def setup(self, dae: "SimpleDae") -> None: await wait_for_value(dae.run_state, RunstateEnum.PAUSED, timeout=10) logger.info("setup complete") - async def start_counting(self, dae: "SimpleDae") -> None: + async def start_counting(self, dae: Dae) -> None: """Start counting a scan point. Increments the period by 1, then unpauses the run. @@ -86,17 +82,17 @@ async def start_counting(self, dae: "SimpleDae") -> None: timeout=10, ) - async def stop_counting(self, dae: "SimpleDae") -> None: + async def stop_counting(self, dae: Dae) -> None: """Stop counting a scan point, by pausing the run.""" logger.info("stop counting") await dae.controls.pause_run.trigger(wait=True, timeout=None) await wait_for_value(dae.run_state, RunstateEnum.PAUSED, timeout=10) - async def teardown(self, dae: "SimpleDae") -> None: + async def teardown(self, dae: Dae) -> None: """Finish taking data, ending or aborting the run.""" await _end_or_abort_run(dae, self._save_run) - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """period_num is always an interesting signal if using this controller.""" return [dae.period_num] @@ -123,7 +119,7 @@ def __init__(self, save_run: bool) -> None: self.run_number, self._run_number_setter = soft_signal_r_and_setter(int, 0) super().__init__() - async def start_counting(self, dae: "SimpleDae") -> None: + async def start_counting(self, dae: Dae) -> None: """Start counting a scan point, by starting a DAE run.""" logger.info("start counting") await dae.controls.begin_run.trigger(wait=True, timeout=None) @@ -139,11 +135,11 @@ async def start_counting(self, dae: "SimpleDae") -> None: run_number = await dae.current_or_next_run_number.get_value() self._run_number_setter(run_number) - async def stop_counting(self, dae: "SimpleDae") -> None: + async def stop_counting(self, dae: Dae) -> None: """Stop counting a scan point, by ending or aborting the run.""" await _end_or_abort_run(dae, self._save_run) - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Run number is an interesting signal only if saving runs.""" if self._save_run: return [self.run_number] diff --git a/src/ibex_bluesky_core/devices/simpledae/_reducers.py b/src/ibex_bluesky_core/devices/simpledae/_reducers.py index 963a8c6c..82d74fbb 100644 --- a/src/ibex_bluesky_core/devices/simpledae/_reducers.py +++ b/src/ibex_bluesky_core/devices/simpledae/_reducers.py @@ -5,7 +5,6 @@ import math from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Collection, Sequence -from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt @@ -21,14 +20,11 @@ from scippneutron import conversion from scippneutron.conversion.tof import dspacing_from_tof -from ibex_bluesky_core.devices.dae import DaeSpectra +from ibex_bluesky_core.devices.dae import Dae, DaeSpectra from ibex_bluesky_core.devices.simpledae._strategies import Reducer logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from ibex_bluesky_core.devices.simpledae import SimpleDae - INTENSITY_PRECISION = 6 VARIANCE_ADDITION = 0.5 @@ -135,60 +131,6 @@ async def sum_spectra_with_wavelength( return sum_spectra_with_wavelength -def polarization( - a: sc.Variable | sc.DataArray, b: sc.Variable | sc.DataArray -) -> sc.Variable | sc.DataArray: - """Calculate polarization value and propagate uncertainties. - - This function computes the polarization given by the formula (a-b)/(a+b) - and propagates the uncertainties associated with a and b. - - Args: - a: scipp :external+scipp:py:obj:`Variable ` - or :external+scipp:py:obj:`DataArray ` - b: scipp :external+scipp:py:obj:`Variable ` - or :external+scipp:py:obj:`DataArray ` - - Returns: - polarization, ``(a - b) / (a + b)``, as a scipp - :external+scipp:py:obj:`Variable ` - or :external+scipp:py:obj:`DataArray ` - - On SANS instruments e.g. LARMOR, A and B correspond to intensity in different DAE - periods (before/after switching a flipper) and the output is interpreted as a neutron - polarization ratio. - - On reflectometry instruments e.g. POLREF, the situation is the same as on LARMOR. - - On muon instruments, A and B correspond to measuring from forward/backward detector - banks, and the output is interpreted as a muon asymmetry. - - """ - if a.unit != b.unit: - raise ValueError("The units of a and b are not equivalent.") - if a.sizes != b.sizes: - raise ValueError("Dimensions/shape of a and b must match.") - - # This line allows for dims, units, and dtype to be handled by scipp - polarization_value = (a - b) / (a + b) - - variances_a = a.variances - variances_b = b.variances - values_a = a.values - values_b = b.values - - # Calculate partial derivatives - partial_a = 2 * values_b / (values_a + values_b) ** 2 - partial_b = -2 * values_a / (values_a + values_b) ** 2 - - variance_return = (partial_a**2 * variances_a) + (partial_b**2 * variances_b) - - # Propagate uncertainties - polarization_value.variances = variance_return - - return polarization_value - - class ScalarNormalizer(Reducer, StandardReadable, ABC): """Sum a set of user-specified spectra, then normalize by a scalar signal.""" @@ -232,10 +174,10 @@ def __init__( super().__init__(name="") @abstractmethod - def denominator(self, dae: "SimpleDae") -> SignalR[int] | SignalR[float]: + def denominator(self, dae: Dae) -> SignalR[int] | SignalR[float]: """Get the normalization denominator, which is assumed to be a scalar signal.""" - async def reduce_data(self, dae: "SimpleDae") -> None: + async def reduce_data(self, dae: Dae) -> None: """Apply the normalization.""" logger.info("starting reduction") summed_counts, denominator = await asyncio.gather( @@ -259,7 +201,7 @@ async def reduce_data(self, dae: "SimpleDae") -> None: logger.info("reduction complete") - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Publish interesting signals derived or used by this reducer.""" return [ self.det_counts, @@ -273,7 +215,7 @@ def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: class PeriodGoodFramesNormalizer(ScalarNormalizer): """Sum a set of user-specified spectra, then normalize by period good frames.""" - def denominator(self, dae: "SimpleDae") -> SignalR[int]: + def denominator(self, dae: Dae) -> SignalR[int]: """Get normalization denominator (period good frames).""" return dae.period.good_frames @@ -335,7 +277,7 @@ def __init__( super().__init__(name="") - async def reduce_data(self, dae: "SimpleDae") -> None: + async def reduce_data(self, dae: Dae) -> None: """Apply the normalization.""" logger.info("starting reduction") detector_counts, monitor_counts = await asyncio.gather( @@ -364,7 +306,7 @@ async def reduce_data(self, dae: "SimpleDae") -> None: logger.info("reduction complete") - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Publish interesting signals derived or used by this reducer.""" return [ self.det_counts, @@ -425,7 +367,7 @@ def monitors(self) -> npt.NDArray[np.int64]: """Get the monitors used by this reducer.""" return self._monitors - async def reduce_data(self, dae: "SimpleDae") -> None: + async def reduce_data(self, dae: Dae) -> None: """Expose detector & monitor integrals. After this method returns, it is valid to read from det_integrals and @@ -453,7 +395,7 @@ async def reduce_data(self, dae: "SimpleDae") -> None: logger.info("reduction complete") - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Publish interesting signals derived or used by this reducer.""" return [ self.mon_integrals, @@ -531,7 +473,7 @@ def __init__( super().__init__(name="") - async def reduce_data(self, dae: "SimpleDae") -> None: + async def reduce_data(self, dae: Dae) -> None: """Expose calculated d-spacing. This will be in units of counts, which may be fractional due to rebinning. @@ -577,6 +519,6 @@ async def reduce_data(self, dae: "SimpleDae") -> None: self._dspacing_setter(summed_data.values) logger.info("reduction complete") - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Publish interesting signals derived or used by this reducer.""" return [self.dspacing] diff --git a/src/ibex_bluesky_core/devices/simpledae/_strategies.py b/src/ibex_bluesky_core/devices/simpledae/_strategies.py index 7021612f..b42865a2 100644 --- a/src/ibex_bluesky_core/devices/simpledae/_strategies.py +++ b/src/ibex_bluesky_core/devices/simpledae/_strategies.py @@ -1,13 +1,6 @@ -"""Base classes for DAE strategies.""" +from ophyd_async.core import Device -from typing import TYPE_CHECKING - -from ophyd_async.core import ( - Device, -) - -if TYPE_CHECKING: - from ibex_bluesky_core.devices.simpledae import SimpleDae +from ibex_bluesky_core.devices.dae import Dae class ProvidesExtraReadables: @@ -16,7 +9,7 @@ class ProvidesExtraReadables: Those signals will then be added to read() and describe() on the top-level SimpleDae object. """ - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Define signals that this strategy considers important. These will be added to the dae's default-read signals and made available by read() on the @@ -25,33 +18,39 @@ def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: return [] +class Waiter(ProvidesExtraReadables): + """Waiter specifies how the dae will wait for a scan point to complete counting.""" + + async def wait(self, dae: Dae) -> None: + """Wait for the acquisition to complete.""" + + class Controller(ProvidesExtraReadables): - """Controller specifies how DAE runs should be started & stopped.""" + """Controller specifies how DAE runs should be started & stopped. + + Controller specifies how DAE runs should be started & stopped. - async def start_counting(self, dae: "SimpleDae") -> None: + .. py:class:: Controller: + :canonical: ibex_bluesky_core.devices.dae.strategies.Controller: + """ + + async def start_counting(self, dae: Dae) -> None: """Start counting for a single scan point.""" - async def stop_counting(self, dae: "SimpleDae") -> None: + async def stop_counting(self, dae: Dae) -> None: """Stop counting for a single scan point.""" - async def setup(self, dae: "SimpleDae") -> None: + async def setup(self, dae: Dae) -> None: """Pre-scan setup.""" - async def teardown(self, dae: "SimpleDae") -> None: + async def teardown(self, dae: Dae) -> None: """Post-scan teardown.""" -class Waiter(ProvidesExtraReadables): - """Waiter specifies how the dae will wait for a scan point to complete counting.""" - - async def wait(self, dae: "SimpleDae") -> None: - """Wait for the acquisition to complete.""" - - class Reducer(ProvidesExtraReadables): """Reducer specifies any post-processing which needs to be done after a scan point completes.""" - async def reduce_data(self, dae: "SimpleDae") -> None: + async def reduce_data(self, dae: Dae) -> None: """Triggers a reduction of DAE data after a scan point has been measured. Data that should be published by this reducer should be added as soft signals, in diff --git a/src/ibex_bluesky_core/devices/simpledae/_waiters.py b/src/ibex_bluesky_core/devices/simpledae/_waiters.py index 63dbb085..f01c66a1 100644 --- a/src/ibex_bluesky_core/devices/simpledae/_waiters.py +++ b/src/ibex_bluesky_core/devices/simpledae/_waiters.py @@ -3,7 +3,7 @@ import asyncio import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import Generic, TypeVar from ophyd_async.core import ( Device, @@ -12,14 +12,11 @@ wait_for_value, ) +from ibex_bluesky_core.devices.dae import Dae from ibex_bluesky_core.devices.simpledae._strategies import Waiter logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from ibex_bluesky_core.devices.simpledae import SimpleDae - - T = TypeVar("T", int, float) @@ -35,7 +32,7 @@ def __init__(self, value: T) -> None: """ self.finish_wait_at = soft_signal_rw(float, value) - async def wait(self, dae: "SimpleDae") -> None: + async def wait(self, dae: Dae) -> None: """Wait for signal to reach the user-specified value.""" signal = self.get_signal(dae) logger.info("starting wait for signal %s", signal.source) @@ -43,19 +40,19 @@ async def wait(self, dae: "SimpleDae") -> None: await wait_for_value(signal, lambda v: v >= value, timeout=None) logger.info("completed wait for signal %s", signal.source) - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: """Publish the signal we're waiting on as an interesting signal.""" return [self.get_signal(dae)] @abstractmethod - def get_signal(self, dae: "SimpleDae") -> SignalR[T]: + def get_signal(self, dae: Dae) -> SignalR[T]: """Get the numeric signal to wait for.""" class PeriodGoodFramesWaiter(SimpleWaiter[int]): """Wait for period good frames to reach a user-specified value.""" - def get_signal(self, dae: "SimpleDae") -> SignalR[int]: + def get_signal(self, dae: Dae) -> SignalR[int]: """Wait for period good frames.""" return dae.period.good_frames @@ -63,7 +60,7 @@ def get_signal(self, dae: "SimpleDae") -> SignalR[int]: class GoodUahWaiter(SimpleWaiter[float]): """Wait for good microamp-hours to reach a user-specified value.""" - def get_signal(self, dae: "SimpleDae") -> SignalR[float]: + def get_signal(self, dae: Dae) -> SignalR[float]: """Wait for good uah.""" return dae.good_uah @@ -71,7 +68,7 @@ def get_signal(self, dae: "SimpleDae") -> SignalR[float]: class MEventsWaiter(SimpleWaiter[float]): """Wait for a user-specified number of millions of events.""" - def get_signal(self, dae: "SimpleDae") -> SignalR[float]: + def get_signal(self, dae: Dae) -> SignalR[float]: """Wait for mevents.""" return dae.m_events @@ -88,7 +85,7 @@ def __init__(self, *, seconds: float) -> None: """ self._secs = seconds - async def wait(self, dae: "SimpleDae") -> None: + async def wait(self, dae: Dae) -> None: """Wait for the specified time duration.""" logger.info("starting wait for %f seconds", self._secs) await asyncio.sleep(self._secs) diff --git a/src/ibex_bluesky_core/run_engine/__init__.py b/src/ibex_bluesky_core/run_engine/__init__.py index d6d5d6fe..b5a6d68d 100644 --- a/src/ibex_bluesky_core/run_engine/__init__.py +++ b/src/ibex_bluesky_core/run_engine/__init__.py @@ -60,6 +60,7 @@ def get_run_engine() -> RunEngine: - Run a plan:: from bluesky.plans import count # Or any other plan + det = ... # A "detector" object, for example a Block or Dae device. RE(count([det])) diff --git a/src/ibex_bluesky_core/utils.py b/src/ibex_bluesky_core/utils.py index 6c2e00dd..b6567cb4 100644 --- a/src/ibex_bluesky_core/utils.py +++ b/src/ibex_bluesky_core/utils.py @@ -6,9 +6,16 @@ from typing import Any, Protocol import matplotlib +import scipp as sc from bluesky.protocols import NamedMovable, Readable -__all__ = ["NamedReadableAndMovable", "centred_pixel", "get_pv_prefix", "is_matplotlib_backend_qt"] +__all__ = [ + "NamedReadableAndMovable", + "calculate_polarisation", + "centred_pixel", + "get_pv_prefix", + "is_matplotlib_backend_qt", +] def is_matplotlib_backend_qt() -> bool: @@ -43,3 +50,57 @@ def get_pv_prefix() -> str: class NamedReadableAndMovable(Readable[Any], NamedMovable[Any], Protocol): """Abstract class for type checking that an object is readable, named and movable.""" + + +def calculate_polarisation( + a: sc.Variable | sc.DataArray, b: sc.Variable | sc.DataArray +) -> sc.Variable | sc.DataArray: + """Calculate polarisation value and propagate uncertainties. + + This function computes the polarisation given by the formula (a-b)/(a+b) + and propagates the uncertainties associated with a and b. + + Args: + a: scipp :external+scipp:py:obj:`Variable ` + or :external+scipp:py:obj:`DataArray ` + b: scipp :external+scipp:py:obj:`Variable ` + or :external+scipp:py:obj:`DataArray ` + + Returns: + polarisation, ``(a - b) / (a + b)``, as a scipp + :external+scipp:py:obj:`Variable ` + or :external+scipp:py:obj:`DataArray ` + + On SANS instruments e.g. LARMOR, A and B correspond to intensity in different DAE + periods (before/after switching a flipper) and the output is interpreted as a neutron + polarisation ratio. + + On reflectometry instruments e.g. POLREF, the situation is the same as on LARMOR. + + On muon instruments, A and B correspond to measuring from forward/backward detector + banks, and the output is interpreted as a muon asymmetry. + + """ + if a.unit != b.unit: + raise ValueError("The units of a and b are not equivalent.") + if a.sizes != b.sizes: + raise ValueError("Dimensions/shape of a and b must match.") + + # This line allows for dims, units, and dtype to be handled by scipp + polarisation_value = (a - b) / (a + b) + + variances_a = a.variances + variances_b = b.variances + values_a = a.values + values_b = b.values + + # Calculate partial derivatives + partial_a = 2 * values_b / (values_a + values_b) ** 2 + partial_b = -2 * values_a / (values_a + values_b) ** 2 + + variance_return = (partial_a**2 * variances_a) + (partial_b**2 * variances_b) + + # Propagate uncertainties + polarisation_value.variances = variance_return + + return polarisation_value diff --git a/tests/callbacks/fitting/test_chained_fitting_callback.py b/tests/callbacks/fitting/test_chained_fitting_callback.py new file mode 100644 index 00000000..217f795f --- /dev/null +++ b/tests/callbacks/fitting/test_chained_fitting_callback.py @@ -0,0 +1,199 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from _pytest.fixtures import FixtureRequest +from bluesky.callbacks import LiveFitPlot +from event_model import Event +from lmfit import Parameter +from matplotlib.axes import Axes + +from ibex_bluesky_core.callbacks import ChainedLiveFit +from ibex_bluesky_core.fitting import FitMethod, Linear + +Y_VARS = ["y1", "y2"] +X_VAR = "x" +YERR_VARS = ["yerr1", "yerr2"] + + +@pytest.fixture +def method() -> FitMethod: + return Linear().fit() + + +@pytest.fixture +def mock_axes() -> list[Axes]: + return [MagicMock(spec=Axes) for _ in range(2)] + + +@pytest.fixture +def mock_doc() -> dict[str, dict[str, int]]: + return { + "data": { + "y": 1, + "x": 1, + } + } + + +def test_chained_livefit_initialization(method: FitMethod): + """Test that ChainedLiveFit properly initializes with minimum required parameters""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR) + assert len(clf._livefits) == len(Y_VARS) + assert not clf._livefitplots + + +def test_chained_livefit_with_yerr(method: FitMethod): + """Test that ChainedLiveFit properly handles yerr parameters""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, yerr=YERR_VARS) + assert len(clf._livefits) == len(Y_VARS) + + for livefit, yerr in zip(clf._livefits, YERR_VARS, strict=False): + assert livefit.yerr == yerr + + +def test_chained_livefit_with_plotting(method: FitMethod, mock_axes: list[Axes]): + """Test that ChainedLiveFit properly sets up plotting when axes are provided""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, ax=mock_axes) + + assert len(clf._livefitplots) == len(mock_axes) + assert all(isinstance(plot, LiveFitPlot) for plot in clf._livefitplots) + + +@pytest.mark.parametrize("doc_typ", ["start", "descriptor", "stop"]) +def test_document_processing(method: FitMethod, doc_typ: str, mock_doc: dict[str, dict[str, int]]): + """Test that calling start (etc..) on ChainedLiveFit correctly calls _process_doc""" + # Does not apply for event as this should not be called unconditionally + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR) + + with patch.object(clf, "_process_doc") as mock_process: + getattr(clf, doc_typ)(mock_doc) + mock_process.assert_called_once_with(mock_doc, doc_typ) + + +def test_livefit_document_processing(method: FitMethod, mock_doc: dict[str, dict[str, int]]): + """Test that documents are properly processed for LiveFit""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR) + + doc_typ = "event" + + for i in range(len(Y_VARS)): + with patch.object(clf._livefits[i], doc_typ) as mock_process: + clf._process_doc(mock_doc, doc_typ) # pyright: ignore + # Generic document type is not assignable + mock_process.assert_called_once_with(mock_doc) + + +def test_livefitplot_document_processing( + method: FitMethod, + mock_axes: list[Axes], + mock_doc: dict[str, dict[str, int]], +): + """Test that documents are properly processed for LiveFitPlots""" + # Test implementation needed for document processing + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, ax=mock_axes) + + doc_typ = "event" + + for i in range(len(Y_VARS)): + with patch.object(clf._livefitplots[i], doc_typ) as mock_process: + clf._process_doc(mock_doc, doc_typ) # pyright: ignore + # Generic document type is not assignable + mock_process.assert_called_once_with(mock_doc) + + +@pytest.mark.parametrize("axes", [None, "mock_axes"]) +def test_first_livefit_uses_normal_guess_function( + method: FitMethod, + mock_doc: dict[str, dict[str, int]], + axes: list[Axes] | None, + request: FixtureRequest, +): + """Test that if using the first LiveFit then it will fit using its own guess function""" + # If axes is a string, get the actual fixture value, otherwise use None + ax = request.getfixturevalue(axes) if isinstance(axes, str) else axes + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, ax=ax) + + with patch.object(clf._livefits[0].method, "guess") as mock_guess: + with patch.object(clf._livefits[0], "event") as mock_event: + clf.event(mock_doc) # pyright: ignore + # Generic document type is not assignable + mock_event.assert_called_once() + assert mock_guess == clf._livefits[0].method.guess + + +@pytest.mark.parametrize("axes", [None, "mock_axes"]) +def test_livefit_param_passing_between_fits( + method: FitMethod, + mock_doc: dict[str, dict[str, int]], + axes: str | None, + request: FixtureRequest, +): + """Test that parameters from first fit are passed correctly to second fit's guess function.""" + # Checks that this works for LiveFit and LiveFitPlot in either case + ax = request.getfixturevalue(axes) if isinstance(axes, str) else axes + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, ax=ax) + callbacks = clf._livefits if axes is None else clf._livefitplots + + # Mock first livefit's result with some parameters + mock_params = {"param1": Parameter("param1", 1.0), "param2": Parameter("param2", 2.0)} + mock_result = MagicMock() + mock_result.params = mock_params + + # Set up the first livefit to return a mocked result + clf._livefits[0].result = mock_result + clf._livefits[0].can_fit = MagicMock(return_value=True) + # Must mock out the first event for LiveFit or LiveFitPlot, otherwise ValueError + callbacks[0].event = MagicMock(spec=Event) + + # Mock second livefit's guess function to check that it gets the parameters correctly + def check_livefit_param(clf: ChainedLiveFit, mock_params: dict[str, Parameter]): + assert clf._livefits[1].method.guess(np.array(0), np.array(0)) == mock_params + + with patch.object( + callbacks[1], "event", side_effect=lambda _: check_livefit_param(clf, mock_params) + ): + clf.event(mock_doc) # pyright: ignore + # Generic document type is not assignable + + # Check that second livefit's guess function returns to original state + assert clf._livefits[1].method.guess == method.guess + + +def test_livefit_has_no_result_assert(method: FitMethod, mock_doc: dict[str, dict[str, int]]): + """Test that if LiveFit should be able to fit, but has no result, then throw an assertion""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR) + clf._livefits[0].can_fit = MagicMock(return_value=True) + + with patch.object(clf._livefits[0], "event"): + with pytest.raises(RuntimeError, match=r"LiveFit.result was None. Could not update fit."): + clf.event(mock_doc) # pyright: ignore + # Generic document type is not assignable + + +def test_get_livefits(method: FitMethod): + """Test that get_livefits returns the correct livefits""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR) + assert clf.live_fits == clf._livefits + + +def test_get_livefitplots(method: FitMethod, mock_axes: list[Axes]): + """Test that get_livefitplots returns the correct livefitplots""" + clf = ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, ax=mock_axes) + assert clf.live_fit_plots == clf._livefitplots + + +def test_yerr_length_mismatch_raises_error(method: FitMethod): + """Test that a ValueError is raised when yerr list length doesn't match y list length.""" + wrong_yerr_vars = ["yerr1"] # Only one error variable for two Y variables + + with pytest.raises(ValueError, match="yerr must be the same length as y"): + ChainedLiveFit(method=method, y=Y_VARS, x=X_VAR, yerr=wrong_yerr_vars) + + +def test_axes_length_mismatch_raises_error(method: FitMethod, mock_axes: list[Axes]): + """Test that a ValueError is raised when axes list length doesn't match y list length.""" + y_vars = ["y1", "y2", "y3"] + + with pytest.raises(ValueError, match="ax must be the same length as y"): + ChainedLiveFit(method=method, y=y_vars, x=X_VAR, ax=mock_axes) diff --git a/tests/devices/simpledae/conftest.py b/tests/devices/__init__.py similarity index 100% rename from tests/devices/simpledae/conftest.py rename to tests/devices/__init__.py diff --git a/tests/devices/polarisingdae/__init__.py b/tests/devices/polarisingdae/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/devices/polarisingdae/test_polarising_dae.py b/tests/devices/polarisingdae/test_polarising_dae.py new file mode 100644 index 00000000..e635afa4 --- /dev/null +++ b/tests/devices/polarisingdae/test_polarising_dae.py @@ -0,0 +1,174 @@ +from unittest.mock import MagicMock, call, patch + +import pytest +import scipp as sc +from ophyd_async.core import SignalRW, soft_signal_rw + +from ibex_bluesky_core.devices.polarisingdae import DualRunDae, polarising_dae +from ibex_bluesky_core.devices.simpledae import ( + Controller, + GoodFramesWaiter, + PeriodGoodFramesWaiter, + PeriodPerPointController, + Reducer, + RunPerPointController, + Waiter, +) + + +@pytest.fixture +def mock_controller() -> Controller: + return MagicMock(spec=Controller) + + +@pytest.fixture +def mock_waiter() -> Waiter: + return MagicMock(spec=Waiter) + + +@pytest.fixture +def mock_reducer() -> Reducer: + return MagicMock(spec=Reducer) + + +@pytest.fixture +def mock_reducer_up() -> Reducer: + return MagicMock(spec=Reducer) + + +@pytest.fixture +def mock_reducer_down() -> Reducer: + return MagicMock(spec=Reducer) + + +@pytest.fixture +def movable() -> SignalRW[float]: + return soft_signal_rw(float, 0.0) + + +@pytest.fixture +async def mock_dae( + mock_controller: Controller, + mock_waiter: Waiter, + mock_reducer: Reducer, + mock_reducer_up: Reducer, + mock_reducer_down: Reducer, + movable: SignalRW[float], +) -> DualRunDae: + mock_dae = DualRunDae( + prefix="unittest:mock:", + name="polarisingdae", + controller=mock_controller, + waiter=mock_waiter, + reducer_final=mock_reducer, + reducer_up=mock_reducer_up, + reducer_down=mock_reducer_down, + movable=movable, + movable_states=[0.0, 1.0], + ) + + await mock_dae.connect(mock=True) + return mock_dae + + +async def test_polarisingdae_calls_controller_twice_on_trigger( + mock_dae: DualRunDae, mock_controller: MagicMock +): + """Test that the DAE controller is called twice on trigger.""" + await mock_dae.trigger() + assert mock_controller.start_counting.call_count == 2 + mock_controller.start_counting.assert_has_calls([call(mock_dae), call(mock_dae)]) + + +async def test_polarisingdae_calls_waiter_twice_on_trigger( + mock_dae: DualRunDae, mock_waiter: MagicMock +): + """Test that the DAE waiter is called twice on trigger.""" + await mock_dae.trigger() + assert mock_waiter.wait.call_count == 2 + mock_waiter.wait.assert_has_calls([call(mock_dae), call(mock_dae)]) + + +async def test_polarisingdae_calls_reducer_on_trigger( + mock_dae: DualRunDae, + mock_reducer: MagicMock, + mock_reducer_up: MagicMock, + mock_reducer_down: MagicMock, +): + """Test that all reducers are called appropriately on trigger.""" + + await mock_dae.trigger() + mock_reducer.reduce_data.assert_called_once_with(mock_dae) + mock_reducer_up.reduce_data.assert_called_once_with(mock_dae) + mock_reducer_down.reduce_data.assert_called_once_with(mock_dae) + + +async def test_polarising_dae_sets_up_periods_correctly(movable: SignalRW[float]): + """Test that the DAE is correctly configured for period-per-point operation.""" + det_pixels = [1, 2, 3] + frames = 200 + monitor = 20 + intervals = [ + sc.array(dims=["tof"], values=[0, 9999999999.0], unit=sc.units.angstrom, dtype="float64") + ] + movable_states = [0.0, 1.0] + total_flight_path_length = sc.scalar(value=10, unit=sc.units.m) + save_run = False + + with patch("ibex_bluesky_core.devices.polarisingdae.get_pv_prefix"): + dae = polarising_dae( + det_pixels=det_pixels, + frames=frames, + periods=True, + monitor=monitor, + save_run=save_run, + intervals=intervals, + total_flight_path_length=total_flight_path_length, + movable=movable, + movable_states=movable_states, + ) + + assert isinstance(dae.waiter, PeriodGoodFramesWaiter) + value = await dae.waiter.finish_wait_at.get_value() + assert value == frames + assert isinstance(dae.controller, PeriodPerPointController) + + +async def test_polarising_dae_sets_up_single_period_correctly(movable: SignalRW[float]): + """Test that the DAE is correctly configured for run-per-point operation.""" + det_pixels = [1, 2, 3] + frames = 200 + monitor = 20 + intervals = [ + sc.array(dims=["tof"], values=[0, 9999999999.0], unit=sc.units.angstrom, dtype="float64") + ] + movable_states = [0.0, 1.0] + total_flight_path_length = sc.scalar(value=10, unit=sc.units.m) + save_run = False + + with patch("ibex_bluesky_core.devices.polarisingdae.get_pv_prefix"): + dae = polarising_dae( + det_pixels=det_pixels, + frames=frames, + periods=False, + monitor=monitor, + save_run=save_run, + intervals=intervals, + total_flight_path_length=total_flight_path_length, + movable=movable, + movable_states=movable_states, + ) + + assert isinstance(dae.waiter, GoodFramesWaiter) + value = await dae.waiter.finish_wait_at.get_value() + assert value == frames + assert isinstance(dae.controller, RunPerPointController) + + +async def test_simpledae_calls_controller_on_stage_and_unstage( + mock_dae: DualRunDae, mock_controller: MagicMock +): + await mock_dae.stage() + mock_controller.setup.assert_called_once_with(mock_dae) + await mock_dae.unstage() + mock_controller.teardown.assert_called_once_with(mock_dae) diff --git a/tests/devices/polarisingdae/test_polarising_reducers.py b/tests/devices/polarisingdae/test_polarising_reducers.py new file mode 100644 index 00000000..1fed2144 --- /dev/null +++ b/tests/devices/polarisingdae/test_polarising_reducers.py @@ -0,0 +1,550 @@ +import re +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import scipp as sc +from ophyd_async.core import DeviceVector, SignalRW, soft_signal_rw + +from ibex_bluesky_core.devices.polarisingdae import ( + DualRunDae, + MultiWavelengthBandNormalizer, + PolarisationReducer, +) +from ibex_bluesky_core.devices.polarisingdae._spectra import ( + _PolarisedWavelengthBand, + _WavelengthBand, +) +from ibex_bluesky_core.devices.simpledae import ( + VARIANCE_ADDITION, + Controller, + Reducer, + Waiter, + wavelength_bounded_spectra, +) +from ibex_bluesky_core.utils import calculate_polarisation + + +@pytest.fixture +def wavelength_bounds_single() -> sc.Variable: + """Single wavelength band spanning the full range.""" + return sc.array(dims=["tof"], values=[0, 9999999999.0], unit=sc.units.angstrom, dtype="float64") + + +@pytest.fixture +def wavelength_bounds_dual() -> list[sc.Variable]: + """Two wavelength bands: one for low and one for high wavelengths.""" + return [ + sc.array(dims=["tof"], values=[0.0, 0.0004], unit=sc.units.angstrom, dtype="float64"), + sc.array( + dims=["tof"], values=[0.0004, 9999999999.0], unit=sc.units.angstrom, dtype="float64" + ), + ] + + +@pytest.fixture +def flight_path() -> sc.Variable: + """Total flight path length for the instrument.""" + return sc.scalar(value=10, unit=sc.units.m, dtype="float64") + + +@pytest.fixture +async def normalizer_single( + wavelength_bounds_single: sc.Variable, flight_path: sc.Variable +) -> MultiWavelengthBandNormalizer: + """Create a normalizer with a single wavelength band.""" + reducer = MultiWavelengthBandNormalizer( + prefix="", + detector_spectra=[1], + monitor_spectra=[2], + sum_wavelength_bands=[ + wavelength_bounded_spectra( + bounds=wavelength_bounds_single, total_flight_path_length=flight_path + ) + ], + ) + await reducer.connect(mock=True) + return reducer + + +@pytest.fixture +async def normalizer_dual( + wavelength_bounds_dual: list[sc.Variable], flight_path: sc.Variable +) -> MultiWavelengthBandNormalizer: + """Create a normalizer with two wavelength bands.""" + reducer = MultiWavelengthBandNormalizer( + prefix="", + detector_spectra=[1], + monitor_spectra=[2], + sum_wavelength_bands=[ + wavelength_bounded_spectra( + bounds=wavelength_bounds_dual[0], total_flight_path_length=flight_path + ), + wavelength_bounded_spectra( + bounds=wavelength_bounds_dual[1], total_flight_path_length=flight_path + ), + ], + ) + await reducer.connect(mock=True) + return reducer + + +@pytest.fixture +async def normalizer_dual_alt( + wavelength_bounds_dual: list[sc.Variable], flight_path: sc.Variable +) -> MultiWavelengthBandNormalizer: + """Create another normalizer with two wavelength bands.""" + reducer = MultiWavelengthBandNormalizer( + prefix="", + detector_spectra=[1], + monitor_spectra=[2], + sum_wavelength_bands=[ + wavelength_bounded_spectra(bounds=band, total_flight_path_length=flight_path) + for band in wavelength_bounds_dual + ], + ) + await reducer.connect(mock=True) + return reducer + + +@pytest.fixture +def mock_controller() -> Controller: + return MagicMock(spec=Controller) + + +@pytest.fixture +def mock_waiter() -> Waiter: + return MagicMock(spec=Waiter) + + +@pytest.fixture +def mock_reducer() -> Reducer: + return MagicMock(spec=Reducer) + + +@pytest.fixture +def mock_reducer_up() -> MultiWavelengthBandNormalizer: + return MagicMock(spec=MultiWavelengthBandNormalizer) + + +@pytest.fixture +def mock_reducer_down() -> MultiWavelengthBandNormalizer: + return MagicMock(spec=MultiWavelengthBandNormalizer) + + +@pytest.fixture +def movable() -> SignalRW[float]: + return soft_signal_rw(float, 0.0) + + +@pytest.fixture +def polarising_reducer_single( + wavelength_bounds_single: sc.Variable, + normalizer_dual: MultiWavelengthBandNormalizer, + normalizer_dual_alt: MultiWavelengthBandNormalizer, +) -> PolarisationReducer: + """Create a polarising reducer with a single wavelength band.""" + return PolarisationReducer( + intervals=[wavelength_bounds_single], + reducer_up=normalizer_dual, + reducer_down=normalizer_dual_alt, + ) + + +@pytest.fixture +def polarising_reducer_dual( + wavelength_bounds_dual: list[sc.Variable], + normalizer_dual: MultiWavelengthBandNormalizer, + normalizer_dual_alt: MultiWavelengthBandNormalizer, +) -> PolarisationReducer: + """Create a polarising reducer with two wavelength bands.""" + return PolarisationReducer( + intervals=wavelength_bounds_dual, + reducer_up=normalizer_dual, + reducer_down=normalizer_dual_alt, + ) + + +@pytest.fixture +async def mock_dae( + mock_controller: Controller, + mock_waiter: Waiter, + mock_reducer: Reducer, + mock_reducer_up: Reducer, + mock_reducer_down: Reducer, + movable: SignalRW[float], +) -> DualRunDae: + mock_polarising_dae = DualRunDae( + prefix="unittest:mock:", + name="mock_dae", + controller=mock_controller, + waiter=mock_waiter, + reducer_final=mock_reducer, + reducer_up=mock_reducer_up, + reducer_down=mock_reducer_down, + movable=movable, + movable_states=[0.0, 1.0], + ) + + await mock_polarising_dae.connect(mock=True) + return mock_polarising_dae + + +@pytest.fixture +async def test_dae( + mock_controller: Controller, + mock_waiter: Waiter, + polarising_reducer_dual: PolarisationReducer, + normalizer_dual: MultiWavelengthBandNormalizer, + normalizer_dual_alt: MultiWavelengthBandNormalizer, + movable: SignalRW[float], +) -> DualRunDae: + """Create a test DAE instance with proper mocks and reducers.""" + # Add additional_readable_signals method to controller and waiter mocks + mock_controller.additional_readable_signals = MagicMock(return_value=[]) + mock_waiter.additional_readable_signals = MagicMock(return_value=[]) + + dae = DualRunDae( + prefix="unittest:mock:", + name="mock_dae", + controller=mock_controller, + waiter=mock_waiter, + reducer_final=polarising_reducer_dual, + reducer_up=normalizer_dual, + reducer_down=normalizer_dual_alt, + movable=movable, + movable_states=[0.0, 1.0], + ) + + await dae.connect(mock=True) + return dae + + +def test_wavelength_bounded_normalizer_publishes_wavelength_bands( + mock_dae: DualRunDae, + normalizer_single: MultiWavelengthBandNormalizer, +): + """Test that MultiWavelengthBandNormalizer publishes the correct signals.""" + readables = normalizer_single.additional_readable_signals(mock_dae) + + assert list(normalizer_single._wavelength_bands.values()) == readables + + +def test_polarising_reducer_publishes_wavelength_bands( + mock_dae: DualRunDae, + polarising_reducer_single: PolarisationReducer, +): + """Test that PolarisationReducer publishes the correct signals.""" + readables = polarising_reducer_single.additional_readable_signals(mock_dae) + + assert list(polarising_reducer_single._wavelength_bands.values()) == readables + + +async def test_wavelength_band_setter(): + """Test that WavelengthBand correctly sets counts/intensity info.""" + det_counts = 100 + det_counts_stddev = 0.1 + mon_counts = 200 + mon_counts_stddev = 0.2 + intensity = 0.5 + intensity_stddev = 3.7500000000000005e-06 + + wavelength_band = _WavelengthBand() + wavelength_band.setter( + det_counts=det_counts, + det_counts_stddev=det_counts_stddev, + mon_counts=mon_counts, + mon_counts_stddev=mon_counts_stddev, + intensity=intensity, + intensity_stddev=intensity_stddev, + ) + + assert await wavelength_band.det_counts.get_value() == det_counts + assert await wavelength_band.det_counts_stddev.get_value() == det_counts_stddev + assert await wavelength_band.mon_counts.get_value() == mon_counts + assert await wavelength_band.mon_counts_stddev.get_value() == mon_counts_stddev + assert await wavelength_band.intensity.get_value() == intensity + assert await wavelength_band.intensity_stddev.get_value() == intensity_stddev + + +async def test_polarised_wavelength_band_setter(): + """Test that PolarisedWavelengthBand correctly sets polarisation info.""" + polarisation = 1 + polarisation_stddev = 0.5 + polarisation_ratio = 0.1 + polarisation_ratio_stddev = 0.5 + + polarised_wavelength_band = _PolarisedWavelengthBand() + polarised_wavelength_band.setter( + polarisation=polarisation, + polarisation_stddev=polarisation_stddev, + polarisation_ratio=polarisation_ratio, + polarisation_ratio_stddev=polarisation_ratio_stddev, + ) + + assert await polarised_wavelength_band.polarisation.get_value() == polarisation + assert await polarised_wavelength_band.polarisation_stddev.get_value() == polarisation_stddev + assert await polarised_wavelength_band.polarisation_ratio.get_value() == polarisation_ratio + assert ( + await polarised_wavelength_band.polarisation_ratio_stddev.get_value() + == polarisation_ratio_stddev + ) + + +async def test_wavelength_bounded_normalizer( + mock_dae: DualRunDae, normalizer_dual: MultiWavelengthBandNormalizer +): + """Test wavelength bounded normaliser with mock spectrum data.""" + normalizer_dual.detectors[1].read_spectrum_dataarray = AsyncMock( + side_effect=lambda: sc.DataArray( + data=sc.Variable( + dims=["tof"], + values=[1000.0, 2000.0, 3000.0], + variances=[1000.0, 2000.0, 3000.0], + unit=sc.units.counts, + ), + coords={ + "tof": sc.array( + dims=["tof"], values=[0, 1, 2, 3], unit=sc.units.us, dtype="float64" + ) + }, + ) + ) + + normalizer_dual.monitors[2].read_spectrum_dataarray = AsyncMock( + side_effect=lambda: sc.DataArray( + data=sc.Variable( + dims=["tof"], + values=[4000.0, 5000.0, 6000.0], + variances=[4000.0, 5000.0, 6000.0], + unit=sc.units.counts, + ), + coords={ + "tof": sc.array( + dims=["tof"], values=[0, 1, 2, 3], unit=sc.units.us, dtype="float64" + ) + }, + ) + ) + + with patch.object(normalizer_dual._wavelength_bands[0], "setter") as low_band_setter: + with patch.object(normalizer_dual._wavelength_bands[1], "setter") as high_band_setter: + await normalizer_dual.reduce_data(dae=mock_dae) + + # Test low wavelength band + low_band_setter.assert_called_once_with( + det_counts=pytest.approx(1022.2273083661819), + det_counts_stddev=pytest.approx(31.980108010545898), + mon_counts=pytest.approx(4055.568270915455), + mon_counts_stddev=pytest.approx(63.68334374791775), + intensity=pytest.approx(0.2520552583708415), + intensity_stddev=pytest.approx(0.00882304684703932), + ) + + # Test high wavelength band + high_band_setter.assert_called_once_with( + det_counts=pytest.approx(4977.772691633818), + det_counts_stddev=pytest.approx(70.55687558015744), + mon_counts=pytest.approx(10944.431729084547), + mon_counts_stddev=pytest.approx(104.6156380713923), + intensity=pytest.approx(0.45482239871856617), + intensity_stddev=pytest.approx(0.0077757859250577885), + ) + + +async def test_mutli_wavelength_band_normalizer_zero_counts( + mock_dae: DualRunDae, normalizer_single: MultiWavelengthBandNormalizer +): + """Test that MultiWavelengthBandNormalizer handles zero counts correctly.""" + + mock_spectrum = sc.DataArray( + data=sc.Variable( + dims=["tof"], + values=[0.0, 0.0, 0.0], + variances=[0.0, 0.0, 0.0], + unit=sc.units.counts, + ), + coords={"tof": sc.array(dims=["tof"], values=[0, 1, 2, 3], unit=sc.units.us)}, + ) + + normalizer_single.detectors[1].read_spectrum_dataarray = AsyncMock( + side_effect=lambda: mock_spectrum + ) + normalizer_single.monitors[2].read_spectrum_dataarray = AsyncMock( + side_effect=lambda: sc.DataArray(data=mock_spectrum.data, coords=mock_spectrum.coords) + ) + + with pytest.raises( + ValueError, + match=re.escape("Cannot normalize; got zero monitor counts in wavelength band 0."), + ): + await normalizer_single.reduce_data(mock_dae) + + +def test_mutli_wavelength_band_normalizer_name_properties( + normalizer_dual: MultiWavelengthBandNormalizer, +): + """Test that the counts/intensity related name properties return correct values.""" + # Get all the wavelength bands + bands = list(normalizer_dual._wavelength_bands.values()) + + # Test detector counts names + assert normalizer_dual.det_counts_names == [band.det_counts.name for band in bands] + + # Test detector counts stddev names + assert normalizer_dual.det_counts_stddev_names == [ + band.det_counts_stddev.name for band in bands + ] + + # Test monitor counts names + assert normalizer_dual.mon_counts_names == [band.mon_counts.name for band in bands] + + # Test monitor counts stddev names + assert normalizer_dual.mon_counts_stddev_names == [ + band.mon_counts_stddev.name for band in bands + ] + + # Test intensity names + assert normalizer_dual.intensity_names == [band.intensity.name for band in bands] + + # Test intensity stddev names + assert normalizer_dual.intensity_stddev_names == [band.intensity_stddev.name for band in bands] + + +async def test_polarising_reducer( + test_dae: DualRunDae, + polarising_reducer_dual: PolarisationReducer, +): + """Test PolarisationReducer calculates polarisation from up/down data with wavelength bands.""" + # Test data + test_cases = [ + { + # First wavelength band + "up_intensity": 6000 / 15000, + "up_stddev": (6000 / 15000) + * ((6000 + VARIANCE_ADDITION) / 6000**2 + 15000 / 15000**2) ** 0.5, + "down_intensity": 7000 / 15000, + "down_stddev": (7000 / 15000) + * ((7000 + VARIANCE_ADDITION) / 7000**2 + 15000 / 15000**2) ** 0.5, + }, + { + # Second wavelength band + "up_intensity": 8000 / 15000, + "up_stddev": (8000 / 15000) + * ((8000 + VARIANCE_ADDITION) / 8000**2 + 15000 / 15000**2) ** 0.5, + "down_intensity": 7000 / 15000, + "down_stddev": (7000 / 15000) + * ((7000 + VARIANCE_ADDITION) / 7000**2 + 15000 / 15000**2) ** 0.5, + }, + ] + + # Configure mock intensity values for up/down states + for i, case in enumerate(test_cases): + # Set up mock values for up state + polarising_reducer_dual.reducer_up()._wavelength_bands[i].intensity.get_value = AsyncMock( + return_value=case["up_intensity"] + ) + polarising_reducer_dual.reducer_up()._wavelength_bands[ + i + ].intensity_stddev.get_value = AsyncMock(return_value=case["up_stddev"]) + + # Set up mock values for down state + polarising_reducer_dual.reducer_down()._wavelength_bands[i].intensity.get_value = AsyncMock( + return_value=case["down_intensity"] + ) + polarising_reducer_dual.reducer_down()._wavelength_bands[ + i + ].intensity_stddev.get_value = AsyncMock(return_value=case["down_stddev"]) + + # Test both wavelength bands + for i, case in enumerate(test_cases): + with patch.object(polarising_reducer_dual._wavelength_bands[i], "setter") as mock_setter: + await polarising_reducer_dual.reduce_data(test_dae) + + # Calculate expected values + intensity_up = sc.scalar( + value=case["up_intensity"], variance=case["up_stddev"], dtype=float + ) + intensity_down = sc.scalar( + value=case["down_intensity"], variance=case["down_stddev"], dtype=float + ) + + expected_polarisation = calculate_polarisation(intensity_up, intensity_down) + expected_ratio = intensity_up / intensity_down + + # Verify setter was called with correct values + mock_setter.assert_called_once_with( + polarisation=float(expected_polarisation.value), + polarisation_stddev=float(expected_polarisation.variance), + polarisation_ratio=float(expected_ratio.value), + polarisation_ratio_stddev=float(expected_ratio.variance), + ) + + +@pytest.mark.parametrize( + "invalid_intensity", + [ + (0.0, 1.0), # Zero up intensity + (1.0, 0.0), # Zero down intensity + ], +) +async def test_polarising_reducer_zero_intensity( + test_dae: DualRunDae, + polarising_reducer_dual: PolarisationReducer, + invalid_intensity: tuple[float, float], +): + """Test that PolarisationReducer handles zero intensities appropriately.""" + up_intensity, down_intensity = invalid_intensity + + polarising_reducer_dual.reducer_up()._wavelength_bands[0].intensity.get_value = AsyncMock( + return_value=up_intensity + ) + polarising_reducer_dual.reducer_down()._wavelength_bands[0].intensity.get_value = AsyncMock( + return_value=down_intensity + ) + + with pytest.raises( + ValueError, match="Cannot calculate polarisation; zero intensity sum detected" + ): + await polarising_reducer_dual.reduce_data(test_dae) + + +async def test_polarising_reducer_mismatched_bands( + test_dae: DualRunDae, + polarising_reducer_single: PolarisationReducer, +): + """Test that PolarisationReducer handles mismatched wavelength bands appropriately.""" + # Mock different number of wavelength bands for up and down states + + polarising_reducer_single.reducer_up()._wavelength_bands = DeviceVector( + children={0: _WavelengthBand(), 1: _WavelengthBand()} + ) + polarising_reducer_single.reducer_down()._wavelength_bands = DeviceVector( + children={0: _WavelengthBand()} + ) + + with pytest.raises(ValueError, match="Mismatched number of wavelength bands"): + await polarising_reducer_single.reduce_data(test_dae) + + +def test_polarising_reducer_name_properties(polarising_reducer_dual: PolarisationReducer): + """Test that the polarisation-related name properties return correct values.""" + # Get all the wavelength bands + bands = list(polarising_reducer_dual._wavelength_bands.values()) + + # Test polarisation names + assert polarising_reducer_dual.polarisation_names == [band.polarisation.name for band in bands] + + # Test polarisation standard deviation names + assert polarising_reducer_dual.polarisation_stddev_names == [ + band.polarisation_stddev.name for band in bands + ] + + # Test polarisation ratio names + assert polarising_reducer_dual.polarisation_ratio == [ + band.polarisation_ratio.name for band in bands + ] + + # Test polarisation ratio standard deviation names + assert polarising_reducer_dual.polarisation_ratio_stddev == [ + band.polarisation_ratio_stddev.name for band in bands + ] diff --git a/tests/devices/simpledae/__init__.py b/tests/devices/simpledae/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/devices/simpledae/test_reducers.py b/tests/devices/simpledae/test_reducers.py index f5aa9447..d8cfd177 100644 --- a/tests/devices/simpledae/test_reducers.py +++ b/tests/devices/simpledae/test_reducers.py @@ -7,10 +7,10 @@ import pytest import scipp as sc from ophyd_async.testing import get_mock_put, set_mock_value -from uncertainties import ufloat, unumpy from ibex_bluesky_core.devices.simpledae import ( VARIANCE_ADDITION, + DSpacingMappingReducer, MonitorNormalizer, PeriodGoodFramesNormalizer, PeriodSpecIntegralsReducer, @@ -19,7 +19,6 @@ tof_bounded_spectra, wavelength_bounded_spectra, ) -from ibex_bluesky_core.devices.simpledae._reducers import DSpacingMappingReducer, polarization @pytest.fixture @@ -642,7 +641,8 @@ async def test_monitor_normalizer_uncertainties( assert det_counts_stddev == math.sqrt(6000 + VARIANCE_ADDITION) assert mon_counts_stddev == math.sqrt(15000) assert intensity_stddev == pytest.approx( - (6000 / 15000) * math.sqrt((6000.5 / 6000**2) + (15000 / 15000**2)), 1e-8 + (6000 / 15000) * math.sqrt(((6000 + VARIANCE_ADDITION) / 6000**2) + (15000 / 15000**2)), + 1e-8, ) @@ -947,93 +947,6 @@ def test_wavelength_bounded_spectra_bounds_missing_or_too_many(data: list[float] ) -# Polarization -@pytest.mark.parametrize( - ("a", "b", "variance_a", "variance_b"), - [ - # Case 1: Symmetric case with equal uncertainties - (5.0, 3.0, 0.1, 0.1), - # Case 2: Asymmetric case with different uncertainties - (10.0, 6.0, 0.2, 0.3), - # Case 3: Case with larger values and different uncertainty magnitudes - (100.0, 60.0, 1.0, 2.0), - ], -) -def test_polarization_function_calculates_accurately(a, b, variance_a, variance_b): - # 'Uncertainties' library ufloat type; a nominal value and an error value - a_ufloat = ufloat(a, variance_a) - b_ufloat = ufloat(b, variance_b) - - # polarization value, i.e. (a - b) / (a + b) - polarization_ufloat = (a_ufloat.n - b_ufloat.n) / (a_ufloat.n + b_ufloat.n) - - # the partial derivatives of a and b, calculated with 'uncertainties' library's ufloat type - partial_a = (2 * b_ufloat.n) / ((a_ufloat.n + b_ufloat.n) ** 2) - partial_b = (-2 * a_ufloat.n) / ((a_ufloat.n + b_ufloat.n) ** 2) - - # variance calculated with 'uncertainties' library - variance = (partial_a**2 * a_ufloat.s) + (partial_b**2 * b_ufloat.s) - uncertainty = variance**0.5 # uncertainty is sqrt of variance - - # Two scipp scalars, to test our polarization function - var_a = sc.scalar(value=a, variance=variance_a, unit="", dtype="float64") - var_b = sc.scalar(value=b, variance=variance_b, unit="", dtype="float64") - result_value = polarization(var_a, var_b) - result_uncertainy = (result_value.variance) ** 0.5 # uncertainty is sqrt of variance - - assert result_value.value == pytest.approx(polarization_ufloat) - assert result_uncertainy == pytest.approx(uncertainty) - - -# test that arrays are supported -@pytest.mark.parametrize( - ("a", "b", "variances_a", "variances_b"), - [ - ([5.0, 10.0, 100.0], [3.0, 6.0, 60.0], [0.1, 0.2, 1.0], [0.1, 0.3, 2.0]), - ], -) -def test_polarization_2_arrays(a, b, variances_a, variances_b): - # 'Uncertainties' library ufloat type; a nominal value and an error value - - a_arr = unumpy.uarray(a, [v**0.5 for v in variances_a]) # convert variances to std dev - b_arr = unumpy.uarray(b, [v**0.5 for v in variances_b]) - - # polarization value, i.e. (a - b) / (a + b) - polarization_ufloat = (a_arr - b_arr) / (a_arr + b_arr) - - var_a = sc.array(dims="x", values=a, variances=variances_a, unit="", dtype="float64") - var_b = sc.array(dims="x", values=b, variances=variances_b, unit="", dtype="float64") - - result_value = polarization(var_a, var_b) - - result_uncertainties = (result_value.variances) ** 0.5 - - assert result_value.values == pytest.approx(unumpy.nominal_values(polarization_ufloat)) - assert result_uncertainties == pytest.approx(unumpy.std_devs(polarization_ufloat)) - - -# test that units don't match -def test_polarization_units_mismatch(): - var_a = sc.scalar(value=1, variance=0.1, unit="m", dtype="float64") - var_b = sc.scalar(value=1, variance=0.1, unit="u", dtype="float64") - - with pytest.raises( - expected_exception=ValueError, match=r"The units of a and b are not equivalent." - ): - polarization(var_a, var_b) - - -# test that arrays are of unmatching sizes -def test_polarization_arrays_of_different_sizes(): - var_a = sc.array(dims=["x"], values=[1, 2], variances=[0.1, 0.1], unit="m", dtype="float64") - var_b = sc.array(dims=["x"], values=[1], variances=[0.1], unit="m", dtype="float64") - - with pytest.raises( - expected_exception=ValueError, match=r"Dimensions/shape of a and b must match." - ): - polarization(var_a, var_b) - - @pytest.mark.parametrize( ("current_period", "mon_integrals", "det_integrals"), [ diff --git a/tests/devices/simpledae/test_simpledae.py b/tests/devices/simpledae/test_simpledae.py index df501988..107ffc60 100644 --- a/tests/devices/simpledae/test_simpledae.py +++ b/tests/devices/simpledae/test_simpledae.py @@ -4,7 +4,7 @@ from ophyd_async.core import Device, StandardReadable, soft_signal_rw from ophyd_async.testing import set_mock_value -from ibex_bluesky_core.devices.dae import DaeCheckingSignal +from ibex_bluesky_core.devices.dae import Dae, DaeCheckingSignal from ibex_bluesky_core.devices.simpledae import ( Controller, GoodFramesWaiter, @@ -85,16 +85,16 @@ def __init__(self): self.soft_signal = soft_signal_rw(float, 0.0) super().__init__(name="reducer") - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: # Signal explicitly published by this reducer rather than the DAE itself return [self.soft_signal] class TestController(Controller): - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: return [dae.good_uah] class TestWaiter(Waiter): - def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]: + def additional_readable_signals(self, dae: Dae) -> list[Device]: # Same signal as controller, should only be added once. return [dae.good_uah] diff --git a/tests/test_utils.py b/tests/test_utils.py index 4e43e4c2..0b274f67 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,16 @@ +# pyright: reportMissingParameterType=false from unittest.mock import patch import pytest +import scipp as sc +from uncertainties import ufloat, unumpy -from ibex_bluesky_core.utils import centred_pixel, get_pv_prefix, is_matplotlib_backend_qt +from ibex_bluesky_core.utils import ( + calculate_polarisation, + centred_pixel, + get_pv_prefix, + is_matplotlib_backend_qt, +) def test_get_pv_prefix(): @@ -26,3 +34,90 @@ def test_centred_pixel(): def test_is_matplotlib_backend_qt(mpl_backend: str): with patch("ibex_bluesky_core.utils.matplotlib.get_backend", return_value=mpl_backend): assert is_matplotlib_backend_qt() == ("qt" in mpl_backend) + + +# polarisation +@pytest.mark.parametrize( + ("a", "b", "variance_a", "variance_b"), + [ + # Case 1: Symmetric case with equal uncertainties + (5.0, 3.0, 0.1, 0.1), + # Case 2: Asymmetric case with different uncertainties + (10.0, 6.0, 0.2, 0.3), + # Case 3: Case with larger values and different uncertainty magnitudes + (100.0, 60.0, 1.0, 2.0), + ], +) +def test_polarisation_function_calculates_accurately(a, b, variance_a, variance_b): + # 'Uncertainties' library ufloat type; a nominal value and an error value + a_ufloat = ufloat(a, variance_a) + b_ufloat = ufloat(b, variance_b) + + # polarisation value, i.e. (a - b) / (a + b) + polarisation_ufloat = (a_ufloat.n - b_ufloat.n) / (a_ufloat.n + b_ufloat.n) + + # the partial derivatives of a and b, calculated with 'uncertainties' library's ufloat type + partial_a = (2 * b_ufloat.n) / ((a_ufloat.n + b_ufloat.n) ** 2) + partial_b = (-2 * a_ufloat.n) / ((a_ufloat.n + b_ufloat.n) ** 2) + + # variance calculated with 'uncertainties' library + variance = (partial_a**2 * a_ufloat.s) + (partial_b**2 * b_ufloat.s) + uncertainty = variance**0.5 # uncertainty is sqrt of variance + + # Two scipp scalars, to test our polarisation function + var_a = sc.scalar(value=a, variance=variance_a, unit="", dtype="float64") + var_b = sc.scalar(value=b, variance=variance_b, unit="", dtype="float64") + result_value = calculate_polarisation(var_a, var_b) + result_uncertainy = (result_value.variance) ** 0.5 # uncertainty is sqrt of variance + + assert result_value.value == pytest.approx(polarisation_ufloat) + assert result_uncertainy == pytest.approx(uncertainty) + + +# test that arrays are supported +@pytest.mark.parametrize( + ("a", "b", "variances_a", "variances_b"), + [ + ([5.0, 10.0, 100.0], [3.0, 6.0, 60.0], [0.1, 0.2, 1.0], [0.1, 0.3, 2.0]), + ], +) +def test_polarisation_2_arrays(a, b, variances_a, variances_b): + # 'Uncertainties' library ufloat type; a nominal value and an error value + + a_arr = unumpy.uarray(a, [v**0.5 for v in variances_a]) # convert variances to std dev + b_arr = unumpy.uarray(b, [v**0.5 for v in variances_b]) + + # polarisation value, i.e. (a - b) / (a + b) + polarisation_ufloat = (a_arr - b_arr) / (a_arr + b_arr) + + var_a = sc.array(dims="x", values=a, variances=variances_a, unit="", dtype="float64") + var_b = sc.array(dims="x", values=b, variances=variances_b, unit="", dtype="float64") + + result_value = calculate_polarisation(var_a, var_b) + + result_uncertainties = (result_value.variances) ** 0.5 + + assert result_value.values == pytest.approx(unumpy.nominal_values(polarisation_ufloat)) + assert result_uncertainties == pytest.approx(unumpy.std_devs(polarisation_ufloat)) + + +# test that units don't match +def test_polarisation_units_mismatch(): + var_a = sc.scalar(value=1, variance=0.1, unit="m", dtype="float64") + var_b = sc.scalar(value=1, variance=0.1, unit="u", dtype="float64") + + with pytest.raises( + expected_exception=ValueError, match=r"The units of a and b are not equivalent." + ): + calculate_polarisation(var_a, var_b) + + +# test that arrays are of unmatching sizes +def test_polarisation_arrays_of_different_sizes(): + var_a = sc.array(dims=["x"], values=[1, 2], variances=[0.1, 0.1], unit="m", dtype="float64") + var_b = sc.array(dims=["x"], values=[1], variances=[0.1], unit="m", dtype="float64") + + with pytest.raises( + expected_exception=ValueError, match=r"Dimensions/shape of a and b must match." + ): + calculate_polarisation(var_a, var_b)