Skip to content

get_vars_in_point_list should ignore variables that are not in the model #6199

@lucianopaz

Description

@lucianopaz

Description of your problem

This thread pointed out a bug in get_vars_in_point_list.

with pm.Model() as model:
    a = pm.Normal(“a”,0,1)
    b = pm.Normal(“b”,0,1)

vars_in_trace = pm.sampling.get_vars_in_point_list(
    [{“a”:0, “c”:0}],
    model,
)

This should return a list with a, but raises a KeyError instead because ”c” is not in model.

The solution is simple, just include the tensors that are in the model and in the trace, ignoring the variables in the trace that yield KeyError

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions