Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,16 @@ def reshape_t(x, shape):
return x[0]


class PointFunc:
"""Wraps so a function so it takes a dict of arguments instead of arguments."""

def __init__(self, f):
self.f = f

def __call__(self, state):
return self.f(**state)


class CallableTensor:
"""Turns a symbolic variable with one input into a function that returns symbolic arguments
with the one variable replaced with the input.
Expand Down
48 changes: 25 additions & 23 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from aesara.tensor.var import TensorConstant, TensorVariable

from pymc.aesaraf import (
PointFunc,
compile_pymc,
convert_observed_data,
gradient,
Expand Down Expand Up @@ -640,7 +641,7 @@ def compile_logp(
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
jacobian: bool = True,
sum: bool = True,
):
) -> PointFunc:
"""Compiled log probability density function.

Parameters
Expand All @@ -660,7 +661,7 @@ def compile_dlogp(
self,
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
jacobian: bool = True,
):
) -> PointFunc:
"""Compiled log probability density gradient function.

Parameters
Expand All @@ -677,7 +678,7 @@ def compile_d2logp(
self,
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
jacobian: bool = True,
):
) -> PointFunc:
"""Compiled log probability density hessian function.

Parameters
Expand Down Expand Up @@ -1597,15 +1598,18 @@ def compile_fn(
mode=None,
point_fn: bool = True,
**kwargs,
) -> Union["PointFunc", Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]:
) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]:
"""Compiles an Aesara function

Parameters
----------
outs: Aesara variable or iterable of Aesara variables
inputs: Aesara input variables, defaults to aesaraf.inputvars(outs).
mode: Aesara compilation mode, default=None
point_fn:
outs
Aesara variable or iterable of Aesara variables.
inputs
Aesara input variables, defaults to aesaraf.inputvars(outs).
mode
Aesara compilation mode, default=None.
point_fn : bool
Whether to wrap the compiled function in a PointFunc, which takes a Point
dictionary with model variable names and values as input.

Expand Down Expand Up @@ -1871,22 +1875,30 @@ def set_data(new_data, model=None, *, coords=None):
model.set_data(variable_name, new_value, coords=coords)


def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs):
def compile_fn(
outs, mode=None, point_fn: bool = True, model: Optional[Model] = None, **kwargs
) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]:
"""Compiles an Aesara function which returns ``outs`` and takes values of model
vars as a dict as an argument.

Parameters
----------
outs: Aesara variable or iterable of Aesara variables
mode: Aesara compilation mode
point_fn:
outs
Aesara variable or iterable of Aesara variables.
mode
Aesara compilation mode, default=None.
point_fn : bool
Whether to wrap the compiled function in a PointFunc, which takes a Point
dictionary with model variable names and values as input.
model : Model, optional
Current model on stack.

Returns
-------
Compiled Aesara function as point function.
"""
model = modelcontext(model)
return model.compile_fn(outs, mode, point_fn=point_fn, **kwargs)
return model.compile_fn(outs, mode=mode, point_fn=point_fn, **kwargs)


def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
Expand All @@ -1913,16 +1925,6 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
}


class PointFunc:
"""Wraps so a function so it takes a dict of arguments instead of arguments."""

def __init__(self, f):
self.f = f

def __call__(self, state):
return self.f(**state)


def Deterministic(name, var, model=None, dims=None, auto=False):
"""Create a named deterministic variable

Expand Down
18 changes: 18 additions & 0 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,3 +996,21 @@ def test_deterministic():

def test_empty_model_representation():
assert pm.Model().str_repr() == ""


def test_compile_fn():
with pm.Model() as m:
x = pm.Normal("x", 0, 1, size=2)
y = pm.LogNormal("y", 0, 1, size=2)

test_vals = np.array([0.0, -1.0])
state = {"x": test_vals, "y": test_vals}

with m:
func = pm.compile_fn(x + y, inputs=[x, y])
result_compute = func(state)

func = m.compile_fn(x + y, inputs=[x, y])
result_expect = func(state)

np.testing.assert_allclose(result_compute, result_expect)