diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index d2b67e24e..95bd8c4c7 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -677,6 +677,16 @@ def reshape_t(x, shape): return x[0] +class PointFunc: + """Wraps so a function so it takes a dict of arguments instead of arguments.""" + + def __init__(self, f): + self.f = f + + def __call__(self, state): + return self.f(**state) + + class CallableTensor: """Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input. diff --git a/pymc/model.py b/pymc/model.py index 019015bcc..e2a27bccb 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -48,6 +48,7 @@ from aesara.tensor.var import TensorConstant, TensorVariable from pymc.aesaraf import ( + PointFunc, compile_pymc, convert_observed_data, gradient, @@ -640,7 +641,7 @@ def compile_logp( vars: Optional[Union[Variable, Sequence[Variable]]] = None, jacobian: bool = True, sum: bool = True, - ): + ) -> PointFunc: """Compiled log probability density function. Parameters @@ -660,7 +661,7 @@ def compile_dlogp( self, vars: Optional[Union[Variable, Sequence[Variable]]] = None, jacobian: bool = True, - ): + ) -> PointFunc: """Compiled log probability density gradient function. Parameters @@ -677,7 +678,7 @@ def compile_d2logp( self, vars: Optional[Union[Variable, Sequence[Variable]]] = None, jacobian: bool = True, - ): + ) -> PointFunc: """Compiled log probability density hessian function. Parameters @@ -1597,15 +1598,18 @@ def compile_fn( mode=None, point_fn: bool = True, **kwargs, - ) -> Union["PointFunc", Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]: + ) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]: """Compiles an Aesara function Parameters ---------- - outs: Aesara variable or iterable of Aesara variables - inputs: Aesara input variables, defaults to aesaraf.inputvars(outs). - mode: Aesara compilation mode, default=None - point_fn: + outs + Aesara variable or iterable of Aesara variables. + inputs + Aesara input variables, defaults to aesaraf.inputvars(outs). + mode + Aesara compilation mode, default=None. + point_fn : bool Whether to wrap the compiled function in a PointFunc, which takes a Point dictionary with model variable names and values as input. @@ -1871,22 +1875,30 @@ def set_data(new_data, model=None, *, coords=None): model.set_data(variable_name, new_value, coords=coords) -def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): +def compile_fn( + outs, mode=None, point_fn: bool = True, model: Optional[Model] = None, **kwargs +) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]: """Compiles an Aesara function which returns ``outs`` and takes values of model vars as a dict as an argument. + Parameters ---------- - outs: Aesara variable or iterable of Aesara variables - mode: Aesara compilation mode - point_fn: + outs + Aesara variable or iterable of Aesara variables. + mode + Aesara compilation mode, default=None. + point_fn : bool Whether to wrap the compiled function in a PointFunc, which takes a Point dictionary with model variable names and values as input. + model : Model, optional + Current model on stack. + Returns ------- Compiled Aesara function as point function. """ model = modelcontext(model) - return model.compile_fn(outs, mode, point_fn=point_fn, **kwargs) + return model.compile_fn(outs, mode=mode, point_fn=point_fn, **kwargs) def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]: @@ -1913,16 +1925,6 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]: } -class PointFunc: - """Wraps so a function so it takes a dict of arguments instead of arguments.""" - - def __init__(self, f): - self.f = f - - def __call__(self, state): - return self.f(**state) - - def Deterministic(name, var, model=None, dims=None, auto=False): """Create a named deterministic variable diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 34dedaa1c..45c7c4207 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -996,3 +996,21 @@ def test_deterministic(): def test_empty_model_representation(): assert pm.Model().str_repr() == "" + + +def test_compile_fn(): + with pm.Model() as m: + x = pm.Normal("x", 0, 1, size=2) + y = pm.LogNormal("y", 0, 1, size=2) + + test_vals = np.array([0.0, -1.0]) + state = {"x": test_vals, "y": test_vals} + + with m: + func = pm.compile_fn(x + y, inputs=[x, y]) + result_compute = func(state) + + func = m.compile_fn(x + y, inputs=[x, y]) + result_expect = func(state) + + np.testing.assert_allclose(result_compute, result_expect)