-
Notifications
You must be signed in to change notification settings - Fork 280
Optimize tree_sum compile time using tree_reduce_associative #1503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Optimize tree_sum compile time using tree_reduce_associative #1503
Conversation
a017f4e to
edf446a
Compare
| # Use tree_reduce_associative for better compile time performance when | ||
| # available (JAX >= 0.6.0). However, tree_reduce_associative doesn't | ||
| # support empty trees, so we need to check for that case. | ||
| if hasattr(jtu, 'tree_reduce_associative'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be
jax.tree.reduce_associative(operator.add, sums, initializer=0)
and the pythonic way would probably use AttributeError and try catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SobhanMP
Thanks for the review! However, jax.tree.reduce_associative doesn't support the initializer parameter - it only accepts (function, tree). That's why the original implementation failed with "Must specify identity for parallel reduction of empty sequence" when encountering empty trees. The current approach with hasattr is appropriate here because:
- We're checking for API availability across JAX versions (0.5.3 vs 0.6.0+)
- The empty tree check is necessary since tree_reduce_associative lacks initializer support
- Using hasattr for version compatibility is a common pattern in the JAX ecosystem
A try/except approach would be:
try:
leaves = jax.tree.leaves(sums)
if not leaves:
return 0
return jtu.tree_reduce_associative(operator.add, sums)
except AttributeError:
return jax.tree.reduce(operator.add, sums, initializer=0)
But this doesn't provide much benefit over hasattr for this use case, and the hasattr check makes the version compatibility intent clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Aaryan-549 jax.tree.reduce_associative has an identity argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right, thank you for the correction! I've updated the implementation to use identity=0 instead of manually checking for empty trees. Much cleaner now:
if hasattr(jtu, 'tree_reduce_associative'):
return jtu.tree_reduce_associative(operator.add, sums, identity=0)
else:
return jax.tree.reduce(operator.add, sums, initializer=0)
Changed tree_sum implementation to use jax.tree_util.tree_reduce_associative when available (JAX >= 0.6.0). Since addition is an associative operation, tree_reduce_associative can provide better compilation performance. Testing shows runtime is very close but compile time is significantly lower (18s vs 23s in reported cases). For compatibility with older JAX versions (< 0.6.0), the implementation falls back to jax.tree.reduce when tree_reduce_associative is not available.
edf446a to
c946fde
Compare
In Response to #1498
Changed tree_sum implementation to use jax.tree_util.tree_reduce_associative instead of jax.tree.reduce. Since addition is an associative operation, tree_reduce_associative can provide better compilation performance.
Testing shows runtime is very close but compile time is significantly lower (18s vs 23s in reported cases).