From bb1ba52695133529a2311736537f69ec60576cb0 Mon Sep 17 00:00:00 2001 From: Dhruvanshu Joshi Date: Fri, 10 Feb 2023 21:40:29 +0530 Subject: [PATCH 1/6] adding function for mutable coords --- pymc/model.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 62bbf66605..fb0ef86cdf 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -51,6 +51,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, is_minibatch +from pymc.distributions.logprob import _joint_logp from pymc.distributions.transforms import _default_transform from pymc.exceptions import ( BlockModelAccessError, @@ -60,7 +61,6 @@ ShapeWarning, ) from pymc.initial_point import make_initial_point_fn -from pymc.logprob.joint_logprob import joint_logp from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -586,6 +586,7 @@ def __init__( self._coords = {} self._dim_lengths = {} self.add_coords(coords) + self.add_mutable_coords(coords) from pymc.printing import str_for_model @@ -754,7 +755,7 @@ def logp( rv_logps: List[TensorVariable] = [] if rvs: - rv_logps = joint_logp( + rv_logps = _joint_logp( rvs=rvs, rvs_to_values=self.rvs_to_values, rvs_to_transforms=self.rvs_to_transforms, @@ -1080,6 +1081,20 @@ def add_coords( for name, values in coords.items(): self.add_coord(name, values, length=lengths.get(name, None)) + def add_mutable_coords( + self, + coords: Dict[str, Optional[Sequence]], + *, + lengths: Optional[Dict[str, Optional[Union[int, Variable]]]] = None, + ): + """Registers a mutable dimension coordinate with the model""" + if coords is None: + return + lengths = lengths or {} + + for name, values in coords.items(): + self.add_coord(name, values, mutable=True, length=lengths.get(name, None)) + def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] = None): """Update a mutable dimension. From 464dd45bc805584dbe79e87c06de24c79c914685 Mon Sep 17 00:00:00 2001 From: Dhruvanshu Joshi Date: Mon, 13 Feb 2023 01:55:32 +0530 Subject: [PATCH 2/6] Added coords_mutable --- pymc/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc/model.py b/pymc/model.py index 62bbf66605..25e5cf51f1 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -548,6 +548,7 @@ def __init__( self, name="", coords=None, + coords_mutable=None, check_bounds=True, *, pytensor_config=None, @@ -586,6 +587,8 @@ def __init__( self._coords = {} self._dim_lengths = {} self.add_coords(coords) + for name, values in coords_mutable.items(): + self.add_coord(name, values, mutable=True, length=lengths.get(name, None)) from pymc.printing import str_for_model From 8b848c88476d5eb158e506e77a3d13087b8b9858 Mon Sep 17 00:00:00 2001 From: Dhruvanshu Joshi Date: Mon, 13 Feb 2023 09:50:41 +0530 Subject: [PATCH 3/6] add coords_mutable --- pymc/model.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 2ad447e07e..25e5cf51f1 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -51,7 +51,6 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, is_minibatch -from pymc.distributions.logprob import _joint_logp from pymc.distributions.transforms import _default_transform from pymc.exceptions import ( BlockModelAccessError, @@ -61,6 +60,7 @@ ShapeWarning, ) from pymc.initial_point import make_initial_point_fn +from pymc.logprob.joint_logprob import joint_logp from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -587,7 +587,8 @@ def __init__( self._coords = {} self._dim_lengths = {} self.add_coords(coords) - self.add_mutable_coords(coords) + for name, values in coords_mutable.items(): + self.add_coord(name, values, mutable=True, length=lengths.get(name, None)) from pymc.printing import str_for_model @@ -756,7 +757,7 @@ def logp( rv_logps: List[TensorVariable] = [] if rvs: - rv_logps = _joint_logp( + rv_logps = joint_logp( rvs=rvs, rvs_to_values=self.rvs_to_values, rvs_to_transforms=self.rvs_to_transforms, @@ -1082,20 +1083,6 @@ def add_coords( for name, values in coords.items(): self.add_coord(name, values, length=lengths.get(name, None)) - def add_mutable_coords( - self, - coords: Dict[str, Optional[Sequence]], - *, - lengths: Optional[Dict[str, Optional[Union[int, Variable]]]] = None, - ): - """Registers a mutable dimension coordinate with the model""" - if coords is None: - return - lengths = lengths or {} - - for name, values in coords.items(): - self.add_coord(name, values, mutable=True, length=lengths.get(name, None)) - def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] = None): """Update a mutable dimension. From 0db7864d0d0befe6621c21662436a167e6327de4 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 13 Feb 2023 09:47:42 +0100 Subject: [PATCH 4/6] Make `coords_mutable` an explicit kwarg --- pymc/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model.py b/pymc/model.py index 25e5cf51f1..7ae7f8ab25 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -548,9 +548,9 @@ def __init__( self, name="", coords=None, - coords_mutable=None, check_bounds=True, *, + coords_mutable=None, pytensor_config=None, model=None, ): From a940f4674a7d8993f01a732a5deb29de7b28492a Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 13 Feb 2023 09:48:43 +0100 Subject: [PATCH 5/6] Remove reference to undefined variable --- pymc/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model.py b/pymc/model.py index 7ae7f8ab25..f6e8c740f3 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -588,7 +588,7 @@ def __init__( self._dim_lengths = {} self.add_coords(coords) for name, values in coords_mutable.items(): - self.add_coord(name, values, mutable=True, length=lengths.get(name, None)) + self.add_coord(name, values, mutable=True) from pymc.printing import str_for_model From 1927320bc70466c5a0d25ebdc227853ff207c545 Mon Sep 17 00:00:00 2001 From: Dhruvanshu Joshi Date: Wed, 15 Feb 2023 17:21:05 +0530 Subject: [PATCH 6/6] Adding an if block --- pymc/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index f6e8c740f3..104d810ff8 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -587,8 +587,9 @@ def __init__( self._coords = {} self._dim_lengths = {} self.add_coords(coords) - for name, values in coords_mutable.items(): - self.add_coord(name, values, mutable=True) + if coords_mutable is not None: + for name, values in coords_mutable.items(): + self.add_coord(name, values, mutable=True) from pymc.printing import str_for_model