-
Notifications
You must be signed in to change notification settings - Fork 16
More sophisticated bounds handling for temporal averaging #735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
|
|
||
| import warnings | ||
| from datetime import datetime | ||
| from itertools import chain | ||
| from itertools import chain, product | ||
| from typing import Literal, TypedDict, get_args | ||
|
|
||
| import cf_xarray # noqa: F401 | ||
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| from xcdat import bounds # noqa: F401 | ||
| from xcdat._logger import _setup_custom_logger | ||
| from xcdat.axis import get_dim_coords | ||
| from xcdat.axis import get_dim_coords, get_dim_keys | ||
| from xcdat.dataset import _get_data_var | ||
|
|
||
| logger = _setup_custom_logger(__name__) | ||
|
|
@@ -2086,6 +2086,210 @@ def _calculate_departures( | |
| return ds_departs | ||
|
|
||
|
|
||
| def compute_monthly_average(self, data_var): | ||
| """ | ||
| Computes monthly averages for dataset | ||
|
|
||
| This function ensures that the dataset's time bounds are | ||
| ordered correctly, computes the target monthly time bounds | ||
| and associated weights, and then the monthly average. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data_var : str | ||
| The key of the data variable. | ||
|
|
||
| Returns | ||
| ------- | ||
| xr.Dataset | ||
| Dataset with the computed monthly average. | ||
|
|
||
| Notes | ||
| ----- | ||
| The monthly averages are computed from January - December, but | ||
| it is possible the source dataset starts after January or ends | ||
| before December. A potential enhancement would be to cater the | ||
| bounds to the source dataset. For example, if the source dataset | ||
| starts in March 2010, the resulting monthly dataset would begin | ||
| in March 2010. | ||
| """ | ||
| ds = self._dataset.copy() | ||
| # ensure source time bounds are ordered correctly | ||
| ds.temporal.ensure_bounds_order() | ||
| # get target time and bounds | ||
| target_time, target_bnds = ds.temporal.generate_monthly_bounds() | ||
| # get temporal weights | ||
| weights = ds.temporal.get_temporal_weights(target_bnds) | ||
| # compute average and return resulting dataset | ||
| return ds.temporal._experimental_averager(data_var, weights, target_bnds) | ||
|
|
||
|
|
||
| def _experimental_averager(self, data_var, weights, target_bnds): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is intended to be a generic average averaging data variable information into targeted time periods (using the supplied weights). |
||
| """ | ||
| Calculates time period averages for a set of weights and bounds. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data_var : str | ||
| The key of the data variable. | ||
|
|
||
| weights : xr.DataArray | ||
| The weight of each source time slice that should be used to compute | ||
| a temporal average for each target time slice [target_time, source_time]. | ||
|
|
||
| target_bnds : xr.DataArray | ||
| The time_bnds for the target time slices. | ||
|
|
||
| Returns | ||
| ------- | ||
| xr.Dataset | ||
| The dataset with the computed temporal averages | ||
| """ | ||
| ds = self._dataset.copy() | ||
| # get time key | ||
| time_key = get_dim_keys(ds, 'T') | ||
| # convert to weighted array | ||
| da_weighted = ds[data_var].weighted(weights) | ||
| # compute weighted mean | ||
| with xr.set_options(keep_attrs=True): | ||
| da_mean = da_weighted.mean(dim=time_key) | ||
| # revert to original time coordinate name | ||
| da_mean = da_mean.rename({'target_time': time_key}) | ||
| # ensure order is the same as original dataset | ||
| da_mean = da_mean.transpose(*ds[data_var].dims) | ||
| # create output dataset | ||
| dsmean = ds.copy() | ||
| # The original time dimension is dropped from the dataset because | ||
| # it becomes obsolete after the data variable is averaged. When the | ||
| # averaged data variable is added to the dataset, the new time dimension | ||
| # and its associated coordinates are also added. | ||
| dsmean = dsmean.drop_dims(time_key) | ||
| # add weighted mean data array to output dataset | ||
| dsmean[data_var] = da_mean | ||
| # add the time bounds to the dataset | ||
| dsmean[time_key + '_bnds'] = target_bnds | ||
| return dsmean | ||
|
|
||
|
|
||
| def get_temporal_weights(self, target_bnds): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function basically gets the intersection between the dataset's own time bounds and the targeted time bounds (i.e., averaging periods). For a given time step, it assigns weight proportional to the duration in which a given timestep is within the a given averaging period.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this PR is ~10x slower than existing functionality. The slowdown is almost entirely in this function. If we could speed this step up, that would be great (but we likely can tolerate this slowdown, since the approach in this PR should be more robust/accurate). |
||
| """Compute the temporal weights for a set of target time bounds. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| target_bnds : xr.DataArray | ||
| The bounds for target time averages | ||
|
|
||
| Returns | ||
| ------- | ||
| xr.DataArray | ||
| The temporal weights that should be applied to the source data | ||
| to produce time averaged data corresponding to the target time | ||
| bounds | ||
| """ | ||
| ds = self._dataset.copy() | ||
| # Get time key and source time bounds | ||
| time_key = get_dim_keys(ds, 'T') | ||
| source_bnds = ds.cf.get_bounds(time_key).values | ||
| target_time = target_bnds['time'] | ||
|
|
||
| # Preallocate weight matrix | ||
| weights = np.zeros((len(target_bnds), len(ds[time_key]))) | ||
|
|
||
| # bounds adjustment | ||
| for i, tbnd in enumerate(target_bnds.values): | ||
| # Adjust source bounds to fit within target bounds | ||
| sbnds = source_bnds.copy() | ||
| sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment | ||
| sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment | ||
|
|
||
| # Handle cases where bounds are outside the target range | ||
| sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound | ||
| sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound | ||
|
|
||
| # Compute weights as the difference between bounds | ||
| w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]") | ||
| weights[i, :] = w | ||
|
|
||
| # Convert weights to xarray DataArray | ||
| weights = xr.DataArray( | ||
| data=weights, | ||
| dims=['target_time', 'time'], | ||
| coords={'target_time': target_time.values, 'time': ds[time_key].values} | ||
| ) | ||
| return weights | ||
|
|
||
|
|
||
| def generate_monthly_bounds(self): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prototype function for generating target bounds (i.e., what bins do you want to average your source data into). We could make other functions for other frequencies (e.g., daily, seasonal, yearly). |
||
| """Generates monthly time bounds and the corresponding time axis | ||
| for a dataset. | ||
|
|
||
| This method will generate monthly time bounds, e.g., | ||
| [["2010-01-01 00:00:00", "2010-02-01 00:00:00"], | ||
| ["2010-02-01 00:00:00", "2010-03-01 00:00:00"], | ||
| ["2010-03-01 00:00:00", "2010-04-01 00:00:00"], | ||
| ...] | ||
|
|
||
| and a time axis, e.g., | ||
| ["2010-01-16 12:00:00", | ||
| "2010-02-15 00:00:00", | ||
| "2010-03-16 12:00:00", | ||
| ...] | ||
|
|
||
| for a dataset. The arrays will start with January 1 of the first | ||
| year in the original dataset going through December of the final year | ||
| in the original dataset. | ||
|
|
||
| Returns | ||
| ------- | ||
| monthly_time : xr.DataArray | ||
| The centered time axis corresponding to the generated bounds. | ||
|
|
||
| monthly_bnds : xr.DataArray | ||
| The generated monthly bounds. | ||
| """ | ||
| ds = self._dataset.copy() | ||
| # get all years in source dataset | ||
| time_key = get_dim_keys(ds, 'T') | ||
| years = list(set([t.year for t in ds[time_key].values])) | ||
| # get time type | ||
| ttype = type(ds[time_key].values[0]) | ||
| # create target time bounds and time axis | ||
| monthly_bnds = [] | ||
| monthly_time = [] | ||
| for year, month in product(years, range(1, 13)): | ||
| lower_bnd = ttype(year, month, 1) | ||
| upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1) | ||
| center_time = lower_bnd + (upper_bnd - lower_bnd)/2. | ||
| monthly_bnds.append([lower_bnd, upper_bnd]) | ||
| monthly_time.append(center_time) | ||
| # generate xarray dataarray objexts | ||
| monthly_time = xr.DataArray(data=monthly_time, | ||
| dims=[time_key], | ||
| coords={time_key: monthly_time}) | ||
| monthly_time.encoding = ds[time_key].encoding | ||
| target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'}) | ||
| monthly_bnds = xr.DataArray(data=monthly_bnds, | ||
| dims=[time_key, 'bnds'], | ||
| coords={time_key: monthly_time}) | ||
| monthly_bnds.encoding = ds[time_key].encoding | ||
| return monthly_time, monthly_bnds | ||
|
|
||
|
|
||
| def ensure_bounds_order(self): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function just makes sure the bounds are ordered as expected. |
||
| """Ensures that time bounds are ordered [earlier, later] | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If there are any bounds that are out of order. | ||
| """ | ||
| ds = self._dataset.copy() | ||
| time_bnds = ds.bounds.get_bounds("T") | ||
| for tbnd in time_bnds.values: | ||
| if tbnd[0] >= tbnd[1]: | ||
| raise ValueError('Time bounds are not ordered from low-to-high') | ||
|
|
||
|
|
||
| def _infer_freq(time_coords: xr.DataArray) -> Frequency: | ||
| """Infers the time frequency from the coordinates. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function wraps several steps/functions (that are defined below) in order to compute monthly averages (e.g., from hourly/daily/pentad data to monthly means):
ensure_bounds_order: function ensures that dataset bounds are in order [earlier time, later time] (since PR logic depends on this)generate_monthly_bounds: function creates monthly boundsget_temporal_weights: function computes weights for averaging source dataset into targeted time periods_experimental_averager: function uses temporal weights to average data into targeted time periodsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could generalize this function (kind of like
.temporal.group_average()) by having it call different functions to generate target bounds (e.g.,generate_daily_bounds,generate_seasonal_bounds,generate_yearly_bounds). The other steps would work as-is.