Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 206 additions & 2 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from datetime import datetime
from itertools import chain
from itertools import chain, product
from typing import Literal, TypedDict, get_args

import cf_xarray # noqa: F401
Expand All @@ -17,7 +17,7 @@

from xcdat import bounds # noqa: F401
from xcdat._logger import _setup_custom_logger
from xcdat.axis import get_dim_coords
from xcdat.axis import get_dim_coords, get_dim_keys
from xcdat.dataset import _get_data_var

logger = _setup_custom_logger(__name__)
Expand Down Expand Up @@ -2086,6 +2086,210 @@ def _calculate_departures(
return ds_departs


def compute_monthly_average(self, data_var):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function wraps several steps/functions (that are defined below) in order to compute monthly averages (e.g., from hourly/daily/pentad data to monthly means):

  • ensure_bounds_order: function ensures that dataset bounds are in order [earlier time, later time] (since PR logic depends on this)
  • generate_monthly_bounds: function creates monthly bounds
  • get_temporal_weights: function computes weights for averaging source dataset into targeted time periods
  • _experimental_averager: function uses temporal weights to average data into targeted time periods

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could generalize this function (kind of like .temporal.group_average()) by having it call different functions to generate target bounds (e.g., generate_daily_bounds, generate_seasonal_bounds, generate_yearly_bounds). The other steps would work as-is.

"""
Computes monthly averages for dataset

This function ensures that the dataset's time bounds are
ordered correctly, computes the target monthly time bounds
and associated weights, and then the monthly average.

Parameters
----------
data_var : str
The key of the data variable.

Returns
-------
xr.Dataset
Dataset with the computed monthly average.

Notes
-----
The monthly averages are computed from January - December, but
it is possible the source dataset starts after January or ends
before December. A potential enhancement would be to cater the
bounds to the source dataset. For example, if the source dataset
starts in March 2010, the resulting monthly dataset would begin
in March 2010.
"""
ds = self._dataset.copy()
# ensure source time bounds are ordered correctly
ds.temporal.ensure_bounds_order()
# get target time and bounds
target_time, target_bnds = ds.temporal.generate_monthly_bounds()
# get temporal weights
weights = ds.temporal.get_temporal_weights(target_bnds)
# compute average and return resulting dataset
return ds.temporal._experimental_averager(data_var, weights, target_bnds)


def _experimental_averager(self, data_var, weights, target_bnds):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended to be a generic average averaging data variable information into targeted time periods (using the supplied weights).

"""
Calculates time period averages for a set of weights and bounds.

Parameters
----------
data_var : str
The key of the data variable.

weights : xr.DataArray
The weight of each source time slice that should be used to compute
a temporal average for each target time slice [target_time, source_time].

target_bnds : xr.DataArray
The time_bnds for the target time slices.

Returns
-------
xr.Dataset
The dataset with the computed temporal averages
"""
ds = self._dataset.copy()
# get time key
time_key = get_dim_keys(ds, 'T')
# convert to weighted array
da_weighted = ds[data_var].weighted(weights)
# compute weighted mean
with xr.set_options(keep_attrs=True):
da_mean = da_weighted.mean(dim=time_key)
# revert to original time coordinate name
da_mean = da_mean.rename({'target_time': time_key})
# ensure order is the same as original dataset
da_mean = da_mean.transpose(*ds[data_var].dims)
# create output dataset
dsmean = ds.copy()
# The original time dimension is dropped from the dataset because
# it becomes obsolete after the data variable is averaged. When the
# averaged data variable is added to the dataset, the new time dimension
# and its associated coordinates are also added.
dsmean = dsmean.drop_dims(time_key)
# add weighted mean data array to output dataset
dsmean[data_var] = da_mean
# add the time bounds to the dataset
dsmean[time_key + '_bnds'] = target_bnds
return dsmean


def get_temporal_weights(self, target_bnds):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function basically gets the intersection between the dataset's own time bounds and the targeted time bounds (i.e., averaging periods). For a given time step, it assigns weight proportional to the duration in which a given timestep is within the a given averaging period.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this PR is ~10x slower than existing functionality. The slowdown is almost entirely in this function. If we could speed this step up, that would be great (but we likely can tolerate this slowdown, since the approach in this PR should be more robust/accurate).

"""Compute the temporal weights for a set of target time bounds.

Parameters
----------
target_bnds : xr.DataArray
The bounds for target time averages

Returns
-------
xr.DataArray
The temporal weights that should be applied to the source data
to produce time averaged data corresponding to the target time
bounds
"""
ds = self._dataset.copy()
# Get time key and source time bounds
time_key = get_dim_keys(ds, 'T')
source_bnds = ds.cf.get_bounds(time_key).values
target_time = target_bnds['time']

# Preallocate weight matrix
weights = np.zeros((len(target_bnds), len(ds[time_key])))

# bounds adjustment
for i, tbnd in enumerate(target_bnds.values):
# Adjust source bounds to fit within target bounds
sbnds = source_bnds.copy()
sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment
sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment

# Handle cases where bounds are outside the target range
sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound
sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound

# Compute weights as the difference between bounds
w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]")
weights[i, :] = w

# Convert weights to xarray DataArray
weights = xr.DataArray(
data=weights,
dims=['target_time', 'time'],
coords={'target_time': target_time.values, 'time': ds[time_key].values}
)
return weights


def generate_monthly_bounds(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prototype function for generating target bounds (i.e., what bins do you want to average your source data into). We could make other functions for other frequencies (e.g., daily, seasonal, yearly).

"""Generates monthly time bounds and the corresponding time axis
for a dataset.

This method will generate monthly time bounds, e.g.,
[["2010-01-01 00:00:00", "2010-02-01 00:00:00"],
["2010-02-01 00:00:00", "2010-03-01 00:00:00"],
["2010-03-01 00:00:00", "2010-04-01 00:00:00"],
...]

and a time axis, e.g.,
["2010-01-16 12:00:00",
"2010-02-15 00:00:00",
"2010-03-16 12:00:00",
...]

for a dataset. The arrays will start with January 1 of the first
year in the original dataset going through December of the final year
in the original dataset.

Returns
-------
monthly_time : xr.DataArray
The centered time axis corresponding to the generated bounds.

monthly_bnds : xr.DataArray
The generated monthly bounds.
"""
ds = self._dataset.copy()
# get all years in source dataset
time_key = get_dim_keys(ds, 'T')
years = list(set([t.year for t in ds[time_key].values]))
# get time type
ttype = type(ds[time_key].values[0])
# create target time bounds and time axis
monthly_bnds = []
monthly_time = []
for year, month in product(years, range(1, 13)):
lower_bnd = ttype(year, month, 1)
upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1)
center_time = lower_bnd + (upper_bnd - lower_bnd)/2.
monthly_bnds.append([lower_bnd, upper_bnd])
monthly_time.append(center_time)
# generate xarray dataarray objexts
monthly_time = xr.DataArray(data=monthly_time,
dims=[time_key],
coords={time_key: monthly_time})
monthly_time.encoding = ds[time_key].encoding
target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'})
monthly_bnds = xr.DataArray(data=monthly_bnds,
dims=[time_key, 'bnds'],
coords={time_key: monthly_time})
monthly_bnds.encoding = ds[time_key].encoding
return monthly_time, monthly_bnds


def ensure_bounds_order(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function just makes sure the bounds are ordered as expected.

"""Ensures that time bounds are ordered [earlier, later]

Raises
------
ValueError
If there are any bounds that are out of order.
"""
ds = self._dataset.copy()
time_bnds = ds.bounds.get_bounds("T")
for tbnd in time_bnds.values:
if tbnd[0] >= tbnd[1]:
raise ValueError('Time bounds are not ordered from low-to-high')


def _infer_freq(time_coords: xr.DataArray) -> Frequency:
"""Infers the time frequency from the coordinates.

Expand Down
Loading