Skip to content

Commit bb2ec46

Browse files
committed
Lazy rolling_window
1 parent 2b024aa commit bb2ec46

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

lib/iris/util.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This file is part of Iris and is released under the BSD license.
44
# See LICENSE in the root of the repository for full licensing details.
55
"""Miscellaneous utility functions."""
6+
from __future__ import annotations
67

78
from abc import ABCMeta, abstractmethod
89
from collections.abc import Hashable, Iterable
@@ -281,7 +282,12 @@ def guess_coord_axis(coord):
281282
return axis
282283

283284

284-
def rolling_window(a, window=1, step=1, axis=-1):
285+
def rolling_window(
286+
a: np.ndarray | da.Array,
287+
window: int = 1,
288+
step: int = 1,
289+
axis: int = -1,
290+
) -> np.ndarray | da.Array:
285291
"""Make an ndarray with a rolling window of the last dimension.
286292
287293
Parameters
@@ -322,34 +328,33 @@ def rolling_window(a, window=1, step=1, axis=-1):
322328
See more at :doc:`/userguide/real_and_lazy_data`.
323329
324330
"""
325-
# NOTE: The implementation of this function originates from
326-
# https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011
327331
if window < 1:
328332
raise ValueError("`window` must be at least 1.")
329333
if window > a.shape[axis]:
330334
raise ValueError("`window` is too long.")
331335
if step < 1:
332336
raise ValueError("`step` must be at least 1.")
333337
axis = axis % a.ndim
334-
num_windows = (a.shape[axis] - window + step) // step
335-
shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :]
336-
strides = (
337-
a.strides[:axis]
338-
+ (step * a.strides[axis], a.strides[axis])
339-
+ a.strides[axis + 1 :]
338+
array_module = da if isinstance(a, da.Array) else np
339+
steps = tuple(
340+
slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim)
340341
)
341-
rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
342-
if ma.isMaskedArray(a):
343-
mask = ma.getmaskarray(a)
344-
strides = (
345-
mask.strides[:axis]
346-
+ (step * mask.strides[axis], mask.strides[axis])
347-
+ mask.strides[axis + 1 :]
348-
)
349-
rw = ma.array(
350-
rw,
351-
mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides),
342+
343+
def _rolling_window(array):
344+
return array_module.moveaxis(
345+
array_module.lib.stride_tricks.sliding_window_view(
346+
array,
347+
window_shape=window,
348+
axis=axis,
349+
)[steps],
350+
-1,
351+
axis + 1,
352352
)
353+
354+
rw = _rolling_window(a)
355+
if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray):
356+
mask = _rolling_window(array_module.ma.getmaskarray(a))
357+
rw = array_module.ma.masked_array(rw, mask)
353358
return rw
354359

355360

0 commit comments

Comments
 (0)