Skip to content

Commit f398de9

Browse files
bouweandelapp-mo
authored andcommitted
Lazy rolling_window (SciTools#5775)
* Lazy rolling_window * Add test and whatsnew entry
1 parent d9b0680 commit f398de9

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

docs/src/whatsnew/latest.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ This document explains the changes made to Iris for this release
4848
🚀 Performance Enhancements
4949
===========================
5050

51-
#. N/A
51+
#. `@bouweandela`_ made :func:`iris.util.rolling_window` work with lazy arrays.
52+
(:pull:`5775`)
5253

5354

5455
🔥 Deprecations

lib/iris/tests/unit/util/test_rolling_window.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# importing anything else
99
import iris.tests as tests # isort:skip
1010

11+
import dask.array as da
1112
import numpy as np
1213
import numpy.ma as ma
1314

@@ -35,6 +36,12 @@ def test_2d(self):
3536
result = rolling_window(a, window=3, axis=1)
3637
self.assertArrayEqual(result, expected_result)
3738

39+
def test_3d_lazy(self):
40+
a = da.arange(2 * 3 * 4).reshape((2, 3, 4))
41+
expected_result = np.arange(2 * 3 * 4).reshape((1, 2, 3, 4))
42+
result = rolling_window(a, window=2, axis=0).compute()
43+
self.assertArrayEqual(result, expected_result)
44+
3845
def test_1d_masked(self):
3946
# 1-d masked array input
4047
a = ma.array([0, 1, 2, 3, 4], mask=[0, 0, 1, 0, 0], dtype=np.int32)

lib/iris/util.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This file is part of Iris and is released under the BSD license.
44
# See LICENSE in the root of the repository for full licensing details.
55
"""Miscellaneous utility functions."""
6+
from __future__ import annotations
67

78
from abc import ABCMeta, abstractmethod
89
from collections.abc import Hashable, Iterable
@@ -282,7 +283,12 @@ def guess_coord_axis(coord):
282283
return axis
283284

284285

285-
def rolling_window(a, window=1, step=1, axis=-1):
286+
def rolling_window(
287+
a: np.ndarray | da.Array,
288+
window: int = 1,
289+
step: int = 1,
290+
axis: int = -1,
291+
) -> np.ndarray | da.Array:
286292
"""Make an ndarray with a rolling window of the last dimension.
287293
288294
Parameters
@@ -323,34 +329,33 @@ def rolling_window(a, window=1, step=1, axis=-1):
323329
See more at :doc:`/userguide/real_and_lazy_data`.
324330
325331
"""
326-
# NOTE: The implementation of this function originates from
327-
# https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011
328332
if window < 1:
329333
raise ValueError("`window` must be at least 1.")
330334
if window > a.shape[axis]:
331335
raise ValueError("`window` is too long.")
332336
if step < 1:
333337
raise ValueError("`step` must be at least 1.")
334338
axis = axis % a.ndim
335-
num_windows = (a.shape[axis] - window + step) // step
336-
shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :]
337-
strides = (
338-
a.strides[:axis]
339-
+ (step * a.strides[axis], a.strides[axis])
340-
+ a.strides[axis + 1 :]
339+
array_module = da if isinstance(a, da.Array) else np
340+
steps = tuple(
341+
slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim)
341342
)
342-
rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
343-
if ma.isMaskedArray(a):
344-
mask = ma.getmaskarray(a)
345-
strides = (
346-
mask.strides[:axis]
347-
+ (step * mask.strides[axis], mask.strides[axis])
348-
+ mask.strides[axis + 1 :]
349-
)
350-
rw = ma.array(
351-
rw,
352-
mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides),
343+
344+
def _rolling_window(array):
345+
return array_module.moveaxis(
346+
array_module.lib.stride_tricks.sliding_window_view(
347+
array,
348+
window_shape=window,
349+
axis=axis,
350+
)[steps],
351+
-1,
352+
axis + 1,
353353
)
354+
355+
rw = _rolling_window(a)
356+
if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray):
357+
mask = _rolling_window(array_module.ma.getmaskarray(a))
358+
rw = array_module.ma.masked_array(rw, mask)
354359
return rw
355360

356361

0 commit comments

Comments
 (0)