Skip to content

Conversation

@samanklesaria
Copy link
Collaborator

Closes #5046

@samanklesaria samanklesaria changed the title Add split and fork methods to RngStream Add split and key methods to RngStream Oct 27, 2025
@samanklesaria
Copy link
Collaborator Author

Unfortunately, the key attribute on streams already exists. I have renamed the existing attribute key_, and replaced its use with the hand sed command sed -I '' "s/\.key\([^_(]\)/.key_\1/".

@samanklesaria samanklesaria force-pushed the issues/5046 branch 3 times, most recently from 3223d3e to aa6ce6a Compare October 27, 2025 20:03
@samanklesaria samanklesaria changed the title Add split and key methods to RngStream Add split and key methods to RngStream and Rngs Oct 27, 2025
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae
Copy link
Collaborator

Thanks @samanklesaria !
I'll have to run all internal tests before merging.

As a follow up PR, we could consider:

  • replace all usage of jax.random.<function>(rngs(), ...) with rngs.<function>(...). I've done some of this.
  • when we do need a key, favor rngs.key() or rngs.some_stream.key().

return pad_shard_unpad_wrapper


class _DictOrList(dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a utility datatype used by build_tree_from_paths, which is an inverse of jax.tree.leaves_with_path. It's useful for defining the bridge code porting old rng state to the new names. See the changes to the tutorial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Give RngStream the same interface as Jax's stateful RNGs

2 participants