|
3 | 3 | # This file is part of Iris and is released under the BSD license.
|
4 | 4 | # See LICENSE in the root of the repository for full licensing details.
|
5 | 5 | """Miscellaneous utility functions."""
|
| 6 | +from __future__ import annotations |
6 | 7 |
|
7 | 8 | from abc import ABCMeta, abstractmethod
|
8 | 9 | from collections.abc import Hashable, Iterable
|
@@ -281,7 +282,12 @@ def guess_coord_axis(coord):
|
281 | 282 | return axis
|
282 | 283 |
|
283 | 284 |
|
284 |
| -def rolling_window(a, window=1, step=1, axis=-1): |
| 285 | +def rolling_window( |
| 286 | + a: np.ndarray | da.Array, |
| 287 | + window: int = 1, |
| 288 | + step: int = 1, |
| 289 | + axis: int = -1, |
| 290 | +) -> np.ndarray | da.Array: |
285 | 291 | """Make an ndarray with a rolling window of the last dimension.
|
286 | 292 |
|
287 | 293 | Parameters
|
@@ -322,34 +328,33 @@ def rolling_window(a, window=1, step=1, axis=-1):
|
322 | 328 | See more at :doc:`/userguide/real_and_lazy_data`.
|
323 | 329 |
|
324 | 330 | """
|
325 |
| - # NOTE: The implementation of this function originates from |
326 |
| - # https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011 |
327 | 331 | if window < 1:
|
328 | 332 | raise ValueError("`window` must be at least 1.")
|
329 | 333 | if window > a.shape[axis]:
|
330 | 334 | raise ValueError("`window` is too long.")
|
331 | 335 | if step < 1:
|
332 | 336 | raise ValueError("`step` must be at least 1.")
|
333 | 337 | axis = axis % a.ndim
|
334 |
| - num_windows = (a.shape[axis] - window + step) // step |
335 |
| - shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :] |
336 |
| - strides = ( |
337 |
| - a.strides[:axis] |
338 |
| - + (step * a.strides[axis], a.strides[axis]) |
339 |
| - + a.strides[axis + 1 :] |
| 338 | + array_module = da if isinstance(a, da.Array) else np |
| 339 | + steps = tuple( |
| 340 | + slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim) |
340 | 341 | )
|
341 |
| - rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) |
342 |
| - if ma.isMaskedArray(a): |
343 |
| - mask = ma.getmaskarray(a) |
344 |
| - strides = ( |
345 |
| - mask.strides[:axis] |
346 |
| - + (step * mask.strides[axis], mask.strides[axis]) |
347 |
| - + mask.strides[axis + 1 :] |
348 |
| - ) |
349 |
| - rw = ma.array( |
350 |
| - rw, |
351 |
| - mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides), |
| 342 | + |
| 343 | + def _rolling_window(array): |
| 344 | + return array_module.moveaxis( |
| 345 | + array_module.lib.stride_tricks.sliding_window_view( |
| 346 | + array, |
| 347 | + window_shape=window, |
| 348 | + axis=axis, |
| 349 | + )[steps], |
| 350 | + -1, |
| 351 | + axis + 1, |
352 | 352 | )
|
| 353 | + |
| 354 | + rw = _rolling_window(a) |
| 355 | + if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray): |
| 356 | + mask = _rolling_window(array_module.ma.getmaskarray(a)) |
| 357 | + rw = array_module.ma.masked_array(rw, mask) |
353 | 358 | return rw
|
354 | 359 |
|
355 | 360 |
|
|
0 commit comments