|
22 | 22 | from typing import ( |
23 | 23 | TYPE_CHECKING, |
24 | 24 | Any, |
| 25 | + Callable, |
25 | 26 | Dict, |
26 | 27 | List, |
27 | 28 | Optional, |
@@ -650,7 +651,6 @@ def __init__( |
650 | 651 | # The sequence of model-generated RNGs |
651 | 652 | self.rng_seq = [] |
652 | 653 | self._initial_values = {} |
653 | | - self._initial_point_cache = {} |
654 | 654 |
|
655 | 655 | if self.parent is not None: |
656 | 656 | self.named_vars = treedict(parent=self.parent.named_vars) |
@@ -935,42 +935,59 @@ def test_point(self) -> Dict[str, np.ndarray]: |
935 | 935 | @property |
936 | 936 | def initial_point(self) -> Dict[str, np.ndarray]: |
937 | 937 | """Maps free variable names to transformed, numeric initial values.""" |
938 | | - if set(self._initial_point_cache) != { |
939 | | - get_var_name(self.rvs_to_values[k]) for k in self.initial_values |
940 | | - }: |
941 | | - return self.recompute_initial_point() |
942 | | - return self._initial_point_cache |
| 938 | + return self.recompute_initial_point() |
943 | 939 |
|
944 | 940 | def recompute_initial_point(self) -> Dict[str, np.ndarray]: |
| 941 | + """Recomputes the initial point of the model. |
| 942 | +
|
| 943 | + Returns |
| 944 | + ------- |
| 945 | + ip : dict |
| 946 | + Maps names of transformed variables to numeric initial values in the transformed space. |
| 947 | + """ |
| 948 | + fn = self.make_initial_point_fn() |
| 949 | + return Point(fn(), model=self) |
| 950 | + |
| 951 | + def make_initial_point_fn( |
| 952 | + self, |
| 953 | + *, |
| 954 | + return_transformed: bool = True, |
| 955 | + ) -> Callable[[], Dict[TensorVariable, np.ndarray]]: |
945 | 956 | """Recomputes numeric initial values for all free model variables. |
946 | 957 |
|
| 958 | + Parameters |
| 959 | + ---------- |
| 960 | + return_transformed : bool |
| 961 | + Switches between returning the dictionary based on RV vars or RV value vars as keys. |
| 962 | +
|
947 | 963 | Returns |
948 | 964 | ------- |
949 | 965 | initial_point : dict |
950 | 966 | Maps transformed free variable names to transformed, numeric initial values. |
951 | 967 | """ |
952 | | - numeric_initvals = {} |
953 | | - # The entries in `initial_values` are already in topological order and can be evaluated one by one. |
954 | | - for rv_var, initval in self.initial_values.items(): |
955 | | - rv_value = self.rvs_to_values[rv_var] |
956 | | - transform = getattr(rv_value.tag, "transform", None) |
957 | | - if isinstance(initval, np.ndarray) and transform is None: |
958 | | - # Only untransformed, numeric initvals can be taken as they are. |
959 | | - numeric_initvals[rv_var] = initval |
960 | | - else: |
961 | | - # Evaluate initvals that are None, symbolic or need to be transformed. |
962 | | - # They can depend on other initvals from higher up in the graph, |
963 | | - # which are therefore fed to the evaluation as "givens". |
964 | | - test_value = getattr(rv_var.tag, "test_value", None) |
965 | | - numeric_initvals[rv_var] = self._eval_initval( |
966 | | - rv_var, initval, test_value, transform, given=numeric_initvals |
967 | | - ) |
968 | 968 |
|
969 | | - # Cache the evaluation results for next time. |
970 | | - self._initial_point_cache = Point( |
971 | | - [(self.rvs_to_values[k], v) for k, v in numeric_initvals.items()], model=self |
972 | | - ) |
973 | | - return self._initial_point_cache |
| 969 | + def fn(): |
| 970 | + numeric_initvals = {} |
| 971 | + # The entries in `initial_values` are already in topological order and can be evaluated one by one. |
| 972 | + for rv_var, initval in self.initial_values.items(): |
| 973 | + rv_value = self.rvs_to_values[rv_var] |
| 974 | + transform = getattr(rv_value.tag, "transform", None) |
| 975 | + if isinstance(initval, np.ndarray) and transform is None: |
| 976 | + # Only untransformed, numeric initvals can be taken as they are. |
| 977 | + numeric_initvals[rv_var] = initval |
| 978 | + else: |
| 979 | + # Evaluate initvals that are None, symbolic or need to be transformed. |
| 980 | + # They can depend on other initvals from higher up in the graph, |
| 981 | + # which are therefore fed to the evaluation as "givens". |
| 982 | + test_value = getattr(rv_var.tag, "test_value", None) |
| 983 | + numeric_initvals[rv_var] = self._eval_initval( |
| 984 | + rv_var, initval, test_value, transform, given=numeric_initvals |
| 985 | + ) |
| 986 | + if return_transformed: |
| 987 | + return {self.rvs_to_values[k]: v for k, v in numeric_initvals.items()} |
| 988 | + return numeric_initvals |
| 989 | + |
| 990 | + return fn |
974 | 991 |
|
975 | 992 | @property |
976 | 993 | def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]: |
|
0 commit comments