diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 6d7dd68d..8f29ca9f 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -2,7 +2,7 @@ import warnings from datetime import datetime -from itertools import chain +from itertools import chain, product from typing import Literal, TypedDict, get_args import cf_xarray # noqa: F401 @@ -17,7 +17,7 @@ from xcdat import bounds # noqa: F401 from xcdat._logger import _setup_custom_logger -from xcdat.axis import get_dim_coords +from xcdat.axis import get_dim_coords, get_dim_keys from xcdat.dataset import _get_data_var logger = _setup_custom_logger(__name__) @@ -2086,6 +2086,210 @@ def _calculate_departures( return ds_departs + def compute_monthly_average(self, data_var): + """ + Computes monthly averages for dataset + + This function ensures that the dataset's time bounds are + ordered correctly, computes the target monthly time bounds + and associated weights, and then the monthly average. + + Parameters + ---------- + data_var : str + The key of the data variable. + + Returns + ------- + xr.Dataset + Dataset with the computed monthly average. + + Notes + ----- + The monthly averages are computed from January - December, but + it is possible the source dataset starts after January or ends + before December. A potential enhancement would be to cater the + bounds to the source dataset. For example, if the source dataset + starts in March 2010, the resulting monthly dataset would begin + in March 2010. + """ + ds = self._dataset.copy() + # ensure source time bounds are ordered correctly + ds.temporal.ensure_bounds_order() + # get target time and bounds + target_time, target_bnds = ds.temporal.generate_monthly_bounds() + # get temporal weights + weights = ds.temporal.get_temporal_weights(target_bnds) + # compute average and return resulting dataset + return ds.temporal._experimental_averager(data_var, weights, target_bnds) + + + def _experimental_averager(self, data_var, weights, target_bnds): + """ + Calculates time period averages for a set of weights and bounds. + + Parameters + ---------- + data_var : str + The key of the data variable. + + weights : xr.DataArray + The weight of each source time slice that should be used to compute + a temporal average for each target time slice [target_time, source_time]. + + target_bnds : xr.DataArray + The time_bnds for the target time slices. + + Returns + ------- + xr.Dataset + The dataset with the computed temporal averages + """ + ds = self._dataset.copy() + # get time key + time_key = get_dim_keys(ds, 'T') + # convert to weighted array + da_weighted = ds[data_var].weighted(weights) + # compute weighted mean + with xr.set_options(keep_attrs=True): + da_mean = da_weighted.mean(dim=time_key) + # revert to original time coordinate name + da_mean = da_mean.rename({'target_time': time_key}) + # ensure order is the same as original dataset + da_mean = da_mean.transpose(*ds[data_var].dims) + # create output dataset + dsmean = ds.copy() + # The original time dimension is dropped from the dataset because + # it becomes obsolete after the data variable is averaged. When the + # averaged data variable is added to the dataset, the new time dimension + # and its associated coordinates are also added. + dsmean = dsmean.drop_dims(time_key) + # add weighted mean data array to output dataset + dsmean[data_var] = da_mean + # add the time bounds to the dataset + dsmean[time_key + '_bnds'] = target_bnds + return dsmean + + + def get_temporal_weights(self, target_bnds): + """Compute the temporal weights for a set of target time bounds. + + Parameters + ---------- + target_bnds : xr.DataArray + The bounds for target time averages + + Returns + ------- + xr.DataArray + The temporal weights that should be applied to the source data + to produce time averaged data corresponding to the target time + bounds + """ + ds = self._dataset.copy() + # Get time key and source time bounds + time_key = get_dim_keys(ds, 'T') + source_bnds = ds.cf.get_bounds(time_key).values + target_time = target_bnds['time'] + + # Preallocate weight matrix + weights = np.zeros((len(target_bnds), len(ds[time_key]))) + + # bounds adjustment + for i, tbnd in enumerate(target_bnds.values): + # Adjust source bounds to fit within target bounds + sbnds = source_bnds.copy() + sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment + sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment + + # Handle cases where bounds are outside the target range + sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound + sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound + + # Compute weights as the difference between bounds + w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]") + weights[i, :] = w + + # Convert weights to xarray DataArray + weights = xr.DataArray( + data=weights, + dims=['target_time', 'time'], + coords={'target_time': target_time.values, 'time': ds[time_key].values} + ) + return weights + + + def generate_monthly_bounds(self): + """Generates monthly time bounds and the corresponding time axis + for a dataset. + + This method will generate monthly time bounds, e.g., + [["2010-01-01 00:00:00", "2010-02-01 00:00:00"], + ["2010-02-01 00:00:00", "2010-03-01 00:00:00"], + ["2010-03-01 00:00:00", "2010-04-01 00:00:00"], + ...] + + and a time axis, e.g., + ["2010-01-16 12:00:00", + "2010-02-15 00:00:00", + "2010-03-16 12:00:00", + ...] + + for a dataset. The arrays will start with January 1 of the first + year in the original dataset going through December of the final year + in the original dataset. + + Returns + ------- + monthly_time : xr.DataArray + The centered time axis corresponding to the generated bounds. + + monthly_bnds : xr.DataArray + The generated monthly bounds. + """ + ds = self._dataset.copy() + # get all years in source dataset + time_key = get_dim_keys(ds, 'T') + years = list(set([t.year for t in ds[time_key].values])) + # get time type + ttype = type(ds[time_key].values[0]) + # create target time bounds and time axis + monthly_bnds = [] + monthly_time = [] + for year, month in product(years, range(1, 13)): + lower_bnd = ttype(year, month, 1) + upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1) + center_time = lower_bnd + (upper_bnd - lower_bnd)/2. + monthly_bnds.append([lower_bnd, upper_bnd]) + monthly_time.append(center_time) + # generate xarray dataarray objexts + monthly_time = xr.DataArray(data=monthly_time, + dims=[time_key], + coords={time_key: monthly_time}) + monthly_time.encoding = ds[time_key].encoding + target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'}) + monthly_bnds = xr.DataArray(data=monthly_bnds, + dims=[time_key, 'bnds'], + coords={time_key: monthly_time}) + monthly_bnds.encoding = ds[time_key].encoding + return monthly_time, monthly_bnds + + + def ensure_bounds_order(self): + """Ensures that time bounds are ordered [earlier, later] + + Raises + ------ + ValueError + If there are any bounds that are out of order. + """ + ds = self._dataset.copy() + time_bnds = ds.bounds.get_bounds("T") + for tbnd in time_bnds.values: + if tbnd[0] >= tbnd[1]: + raise ValueError('Time bounds are not ordered from low-to-high') + + def _infer_freq(time_coords: xr.DataArray) -> Frequency: """Infers the time frequency from the coordinates.