@@ -714,54 +714,24 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
714714 raise ValueError (f"Can only compute the gradient of continuous types: { var } " )
715715
716716 if tempered :
717- with self :
718- # Convert random variables into their log-likelihood inputs and
719- # apply their transforms, if any
720- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
721-
722- free_RVs_logp = at .sum (
723- [at .sum (logpt (var , self .rvs_to_values .get (var , None ))) for var in self .free_RVs ]
724- + list (potentials )
725- )
726- observed_RVs_logp = at .sum (
727- [at .sum (logpt (obs , obs .tag .observations )) for obs in self .observed_RVs ]
728- )
729-
730- costs = [free_RVs_logp , observed_RVs_logp ]
717+ # TODO: Should this differ from self.datalogpt,
718+ # where the potential terms are added to the observations?
719+ costs = [self .varlogpt + self .potentiallogpt , self .observedlogpt ]
731720 else :
732721 costs = [self .logpt ]
733722
734723 input_vars = {i for i in graph_inputs (costs ) if not isinstance (i , Constant )}
735724 extra_vars = [self .rvs_to_values .get (var , var ) for var in self .free_RVs ]
725+ ip = self .recompute_initial_point (0 )
736726 extra_vars_and_values = {
737- var : self .initial_point [var .name ]
738- for var in extra_vars
739- if var in input_vars and var not in grad_vars
727+ var : ip [var .name ] for var in extra_vars if var in input_vars and var not in grad_vars
740728 }
741729 return ValueGradFunction (costs , grad_vars , extra_vars_and_values , ** kwargs )
742730
743731 @property
744732 def logpt (self ):
745733 """Aesara scalar of log-probability of the model"""
746-
747- rv_values = {}
748- for var in self .free_RVs :
749- rv_values [var ] = self .rvs_to_values .get (var , None )
750- rv_factors = logpt (self .free_RVs , rv_values )
751-
752- obs_values = {}
753- for obs in self .observed_RVs :
754- obs_values [obs ] = obs .tag .observations
755- obs_factors = logpt (self .observed_RVs , obs_values )
756-
757- # Convert random variables into their log-likelihood inputs and
758- # apply their transforms, if any
759- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
760- logp_var = at .sum ([at .sum (factor ) for factor in potentials ])
761- if rv_factors is not None :
762- logp_var += rv_factors
763- if obs_factors is not None :
764- logp_var += obs_factors
734+ logp_var = self .varlogpt + self .datalogpt
765735
766736 if self .name :
767737 logp_var .name = f"__logp_{ self .name } "
@@ -777,60 +747,65 @@ def logp_nojact(self):
777747 Note that if there is no transformed variable in the model, logp_nojact
778748 will be the same as logpt as there is no need for Jacobian correction.
779749 """
780- with self :
781- rv_values = {}
782- for var in self .free_RVs :
783- rv_values [var ] = getattr (var .tag , "value_var" , None )
784- rv_factors = logpt (self .free_RVs , rv_values , jacobian = False )
785-
786- obs_values = {}
787- for obs in self .observed_RVs :
788- obs_values [obs ] = obs .tag .observations
789- obs_factors = logpt (self .observed_RVs , obs_values , jacobian = False )
790-
791- # Convert random variables into their log-likelihood inputs and
792- # apply their transforms, if any
793- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
794- logp_var = at .sum ([at .sum (factor ) for factor in potentials ])
795-
796- if rv_factors is not None :
797- logp_var += rv_factors
798- if obs_factors is not None :
799- logp_var += obs_factors
800-
801- if self .name :
802- logp_var .name = f"__logp_nojac_{ self .name } "
803- else :
804- logp_var .name = "__logp_nojac"
805- return logp_var
750+ logp_var = self .varlogp_nojact + self .datalogpt
751+
752+ if self .name :
753+ logp_var .name = f"__logp_nojac_{ self .name } "
754+ else :
755+ logp_var .name = "__logp_nojac"
756+ return logp_var
757+
758+ @property
759+ def datalogpt (self ):
760+ """Aesara scalar of log-probability of the observed variables and
761+ potential terms"""
762+ return self .observedlogpt + self .potentiallogpt
806763
807764 @property
808765 def varlogpt (self ):
809766 """Aesara scalar of log-probability of the unobserved random variables
810767 (excluding deterministic)."""
811- with self :
812- rv_values = {}
813- for var in self .free_RVs :
814- rv_values [ var ] = getattr ( var . tag , "value_var" , None )
768+ rv_values = {}
769+ for var in self . free_RVs :
770+ rv_values [ var ] = self .rvs_to_values [ var ]
771+ if rv_values :
815772 return logpt (self .free_RVs , rv_values )
773+ else :
774+ return 0
816775
817776 @property
818- def datalogpt (self ):
819- with self :
820- obs_values = {}
821- for obs in self .observed_RVs :
822- obs_values [obs ] = obs .tag .observations
823- obs_factors = logpt (self .observed_RVs , obs_values )
824-
825- # Convert random variables into their log-likelihood inputs and
826- # apply their transforms, if any
827- potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
828- logp_var = at .sum ([at .sum (factor ) for factor in potentials ])
777+ def varlogp_nojact (self ):
778+ """Aesara scalar of log-probability of the unobserved random variables
779+ (excluding deterministic) without jacobian term."""
780+ rv_values = {}
781+ for var in self .free_RVs :
782+ rv_values [var ] = self .rvs_to_values [var ]
783+ if rv_values :
784+ return logpt (self .free_RVs , rv_values , jacobian = False )
785+ else :
786+ return 0
829787
830- if obs_factors is not None :
831- logp_var += obs_factors
788+ @property
789+ def observedlogpt (self ):
790+ """Aesara scalar of log-probability of the observed variables"""
791+ obs_values = {}
792+ for obs in self .observed_RVs :
793+ obs_values [obs ] = obs .tag .observations
794+ if obs_values :
795+ return logpt (self .observed_RVs , obs_values )
796+ else :
797+ return 0
832798
833- return logp_var
799+ @property
800+ def potentiallogpt (self ):
801+ """Aesara scalar of log-probability of the Potential terms"""
802+ # Convert random variables in Potential expression into their log-likelihood
803+ # inputs and apply their transforms, if any
804+ potentials , _ = rvs_to_value_vars (self .potentials , apply_transforms = True )
805+ if potentials :
806+ return at .sum ([at .sum (factor ) for factor in potentials ])
807+ else :
808+ return 0
834809
835810 @property
836811 def vars (self ):
0 commit comments