|
54 | 54 | import numpy as np |
55 | 55 | import pytensor |
56 | 56 | import pytensor.tensor as at |
| 57 | +import xarray |
57 | 58 |
|
58 | 59 | from pytensor.graph.basic import Variable |
59 | 60 |
|
@@ -977,7 +978,7 @@ def symbolic_random(self): |
977 | 978 |
|
978 | 979 | @pytensor.config.change_flags(compute_test_value="off") |
979 | 980 | def set_size_and_deterministic( |
980 | | - self, node: Variable, s, d: bool, more_replacements: dict | None = None |
| 981 | + self, node: Variable, s, d: bool, more_replacements: dict = None |
981 | 982 | ) -> list[Variable]: |
982 | 983 | """*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or |
983 | 984 | :func:`symbolic_single_sample` new random generator can be allocated and applied to node |
@@ -1105,16 +1106,47 @@ def __str__(self): |
1105 | 1106 | return f"{self.__class__.__name__}[{shp}]" |
1106 | 1107 |
|
1107 | 1108 | @node_property |
1108 | | - def std(self): |
1109 | | - raise NotImplementedError |
| 1109 | + def std(self) -> at.TensorVariable: |
| 1110 | + """Standard deviation of the latent variables as an unstructured 1-dimensional tensor variable""" |
| 1111 | + raise NotImplementedError() |
1110 | 1112 |
|
1111 | 1113 | @node_property |
1112 | | - def cov(self): |
1113 | | - raise NotImplementedError |
| 1114 | + def cov(self) -> at.TensorVariable: |
| 1115 | + """Covariance between the latent variables as an unstructured 2-dimensional tensor variable""" |
| 1116 | + raise NotImplementedError() |
1114 | 1117 |
|
1115 | 1118 | @node_property |
1116 | | - def mean(self): |
1117 | | - raise NotImplementedError |
| 1119 | + def mean(self) -> at.TensorVariable: |
| 1120 | + """Mean of the latent variables as an unstructured 1-dimensional tensor variable""" |
| 1121 | + raise NotImplementedError() |
| 1122 | + |
| 1123 | + def var_to_data(self, shared: at.TensorVariable) -> xarray.Dataset: |
| 1124 | + """Takes a flat 1-dimensional tensor variable and maps it to an xarray data set based on the information in |
| 1125 | + `self.ordering`. |
| 1126 | + """ |
| 1127 | + # This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have |
| 1128 | + # `RaveledVars` and need to take the information from `self.ordering` instead |
| 1129 | + shared_nda = shared.eval() |
| 1130 | + result = dict() |
| 1131 | + for name, s, shape, dtype in self.ordering.values(): |
| 1132 | + dims = self.model.named_vars_to_dims.get(name, None) |
| 1133 | + if dims is not None: |
| 1134 | + coords = {d: np.array(self.model.coords[d]) for d in dims} |
| 1135 | + else: |
| 1136 | + coords = None |
| 1137 | + values = shared_nda[s].reshape(shape).astype(dtype) |
| 1138 | + result[name] = xarray.DataArray(values, coords=coords, dims=dims, name=name) |
| 1139 | + return xarray.Dataset(result) |
| 1140 | + |
| 1141 | + @property |
| 1142 | + def mean_data(self) -> xarray.Dataset: |
| 1143 | + """Mean of the latent variables as an xarray Dataset""" |
| 1144 | + return self.var_to_data(self.mean) |
| 1145 | + |
| 1146 | + @property |
| 1147 | + def std_data(self) -> xarray.Dataset: |
| 1148 | + """Standard deviation of the latent variables as an xarray Dataset""" |
| 1149 | + return self.var_to_data(self.std) |
1118 | 1150 |
|
1119 | 1151 |
|
1120 | 1152 | group_for_params = Group.group_for_params |
|
0 commit comments