Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion openmc/statepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def add_volume_information(self, volume_calc):
def get_tally(self, scores=[], filters=[], nuclides=[],
name=None, id=None, estimator=None, exact_filters=False,
exact_nuclides=False, exact_scores=False,
multiply_density=None, derivative=None):
multiply_density=None, derivative=None, filter_type=None):
"""Finds and returns a Tally object with certain properties.

This routine searches the list of Tallies and returns the first Tally
Expand Down Expand Up @@ -575,6 +575,9 @@ def get_tally(self, scores=[], filters=[], nuclides=[],
to the same value as this parameter.
derivative : openmc.TallyDerivative, optional
TallyDerivative object to match.
filter_type : type, optional
If not None, the Tally must have at least one Filter that is an
instance of this type. For example `openmc.MeshFilter`.

Returns
-------
Expand Down Expand Up @@ -648,6 +651,10 @@ def get_tally(self, scores=[], filters=[], nuclides=[],
if not contains_filters:
continue

if filter_type is not None:
if not any(isinstance(f, filter_type) for f in test_tally.filters):
continue

# Determine if Tally has the queried Nuclide(s)
if nuclides:
if not all(nuclide in test_tally.nuclides for nuclide in nuclides):
Expand Down
65 changes: 65 additions & 0 deletions tests/unit_tests/test_statepoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import openmc


def test_get_tally_filter_type(run_in_tmpdir):
"""Test various ways of retrieving tallies from a StatePoint object."""

mat = openmc.Material()
mat.add_nuclide("H1", 1.0)
mat.set_density("g/cm3", 10.0)

sphere = openmc.Sphere(r=10.0, boundary_type="vacuum")
cell = openmc.Cell(fill=mat, region=-sphere)
geometry = openmc.Geometry([cell])

settings = openmc.Settings()
settings.particles = 10
settings.batches = 2
settings.run_mode = "fixed source"

reg_mesh = openmc.RegularMesh().from_domain(cell)
tally1 = openmc.Tally(tally_id=1)
mesh_filter = openmc.MeshFilter(reg_mesh)
tally1.filters = [mesh_filter]
tally1.scores = ["flux"]

tally2 = openmc.Tally(tally_id=2, name="heating tally")
cell_filter = openmc.CellFilter(cell)
tally2.filters = [cell_filter]
tally2.scores = ["heating"]

tallies = openmc.Tallies([tally1, tally2])
model = openmc.Model(
geometry=geometry, materials=[mat], settings=settings, tallies=tallies
)

sp_filename = model.run()

sp = openmc.StatePoint(sp_filename)

tally_found = sp.get_tally(filter_type=openmc.MeshFilter)
assert tally_found.id == 1

tally_found = sp.get_tally(filter_type=openmc.CellFilter)
assert tally_found.id == 2

tally_found = sp.get_tally(filters=[mesh_filter])
assert tally_found.id == 1

tally_found = sp.get_tally(filters=[cell_filter])
assert tally_found.id == 2

tally_found = sp.get_tally(scores=["heating"])
assert tally_found.id == 2

tally_found = sp.get_tally(name="heating tally")
assert tally_found.id == 2

tally_found = sp.get_tally(name=None)
assert tally_found.id == 1

tally_found = sp.get_tally(id=1)
assert tally_found.id == 1

tally_found = sp.get_tally(id=2)
assert tally_found.id == 2