|
3 | 3 | # This file is part of Iris and is released under the BSD license.
|
4 | 4 | # See LICENSE in the root of the repository for full licensing details.
|
5 | 5 | """Miscellaneous utility functions."""
|
| 6 | +from __future__ import annotations |
6 | 7 |
|
7 | 8 | from abc import ABCMeta, abstractmethod
|
8 | 9 | from collections.abc import Hashable, Iterable
|
@@ -282,7 +283,12 @@ def guess_coord_axis(coord):
|
282 | 283 | return axis
|
283 | 284 |
|
284 | 285 |
|
285 |
| -def rolling_window(a, window=1, step=1, axis=-1): |
| 286 | +def rolling_window( |
| 287 | + a: np.ndarray | da.Array, |
| 288 | + window: int = 1, |
| 289 | + step: int = 1, |
| 290 | + axis: int = -1, |
| 291 | +) -> np.ndarray | da.Array: |
286 | 292 | """Make an ndarray with a rolling window of the last dimension.
|
287 | 293 |
|
288 | 294 | Parameters
|
@@ -323,34 +329,33 @@ def rolling_window(a, window=1, step=1, axis=-1):
|
323 | 329 | See more at :doc:`/userguide/real_and_lazy_data`.
|
324 | 330 |
|
325 | 331 | """
|
326 |
| - # NOTE: The implementation of this function originates from |
327 |
| - # https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011 |
328 | 332 | if window < 1:
|
329 | 333 | raise ValueError("`window` must be at least 1.")
|
330 | 334 | if window > a.shape[axis]:
|
331 | 335 | raise ValueError("`window` is too long.")
|
332 | 336 | if step < 1:
|
333 | 337 | raise ValueError("`step` must be at least 1.")
|
334 | 338 | axis = axis % a.ndim
|
335 |
| - num_windows = (a.shape[axis] - window + step) // step |
336 |
| - shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :] |
337 |
| - strides = ( |
338 |
| - a.strides[:axis] |
339 |
| - + (step * a.strides[axis], a.strides[axis]) |
340 |
| - + a.strides[axis + 1 :] |
| 339 | + array_module = da if isinstance(a, da.Array) else np |
| 340 | + steps = tuple( |
| 341 | + slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim) |
341 | 342 | )
|
342 |
| - rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) |
343 |
| - if ma.isMaskedArray(a): |
344 |
| - mask = ma.getmaskarray(a) |
345 |
| - strides = ( |
346 |
| - mask.strides[:axis] |
347 |
| - + (step * mask.strides[axis], mask.strides[axis]) |
348 |
| - + mask.strides[axis + 1 :] |
349 |
| - ) |
350 |
| - rw = ma.array( |
351 |
| - rw, |
352 |
| - mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides), |
| 343 | + |
| 344 | + def _rolling_window(array): |
| 345 | + return array_module.moveaxis( |
| 346 | + array_module.lib.stride_tricks.sliding_window_view( |
| 347 | + array, |
| 348 | + window_shape=window, |
| 349 | + axis=axis, |
| 350 | + )[steps], |
| 351 | + -1, |
| 352 | + axis + 1, |
353 | 353 | )
|
| 354 | + |
| 355 | + rw = _rolling_window(a) |
| 356 | + if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray): |
| 357 | + mask = _rolling_window(array_module.ma.getmaskarray(a)) |
| 358 | + rw = array_module.ma.masked_array(rw, mask) |
354 | 359 | return rw
|
355 | 360 |
|
356 | 361 |
|
|
0 commit comments