4545from pytensor .tensor .random .op import RandomVariable
4646from pytensor .tensor .random .type import RandomType
4747from pytensor .tensor .random .var import RandomGeneratorSharedVariable
48+ from pytensor .tensor .rewriting .basic import topo_unconditional_constant_folding
4849from pytensor .tensor .rewriting .shape import ShapeFeature
4950from pytensor .tensor .sharedvar import SharedVariable , TensorSharedVariable
5051from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
@@ -1057,7 +1058,7 @@ def compile_pymc(
10571058
10581059def constant_fold (
10591060 xs : Sequence [TensorVariable ], raise_not_constant : bool = True
1060- ) -> tuple [np .ndarray , ...]:
1061+ ) -> tuple [np .ndarray | Variable , ...]:
10611062 """Use constant folding to get constant values of a graph.
10621063
10631064 Parameters
@@ -1072,8 +1073,12 @@ def constant_fold(
10721073 """
10731074 fg = FunctionGraph (outputs = xs , features = [ShapeFeature ()], copy_inputs = False , clone = True )
10741075
1075- # By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
1076- folded_xs = rewrite_graph (fg ).outputs
1076+ # The default rewrite_graph includes a constand_folding that is not always applied.
1077+ # We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
1078+ rewrite_graph (fg )
1079+ topo_unconditional_constant_folding .apply (fg )
1080+
1081+ folded_xs = fg .outputs
10771082
10781083 if raise_not_constant and not all (isinstance (folded_x , Constant ) for folded_x in folded_xs ):
10791084 raise NotConstantValueError
0 commit comments