Skip to content

Conversation

priyakasimbeg
Copy link
Contributor

@priyakasimbeg priyakasimbeg commented Mar 6, 2025

Purpose

The goal of this PR is to allow model parameter and optimizer state sharding, and also to migrate the JAX code from using jax.pmap to using jax.jit.

TODOs:

  • Migrate reference optimizers to use jax.jit
    • Nesterov
    • AdamW
    • Others
  • Migrate workloads to use jax.jit
    • (Test workload) MNIST
    • (Test workload) CIFAR
    • WMT
    • Criteo1TB
    • FastMRI
    • Librispeech
    • OGBG
    • ImageNet

Changelog

  • Added some sharding utilities to handle data distributed
  • Replaced pmap code for all workloads with jit
  • Modified reference submissions accordingly
  • Updated checkpoint and data_utils to support the new approach (mostly removing explicit jax_utils.replicate calls).
  • Upgraded Jax version to 0.7.0
  • Factored out the sharding into shape (bsz/n_devices, ...) out of shard_and_maybe_pad. PyTorch workloads now call a separate shard function if they used shard_and_maybe_pad before

@priyakasimbeg priyakasimbeg requested a review from a team as a code owner March 6, 2025 21:47
Copy link

github-actions bot commented Mar 6, 2025

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg priyakasimbeg changed the title Jit switch [WIP] Migrate JAX workloads from pmap to jit Mar 6, 2025
@priyakasimbeg priyakasimbeg changed the base branch from main to dev March 7, 2025 00:17
@priyakasimbeg priyakasimbeg changed the base branch from dev to main August 13, 2025 21:00
@priyakasimbeg priyakasimbeg changed the title [WIP] Migrate JAX workloads from pmap to jit Migrate JAX workloads from pmap to jit Aug 19, 2025
@priyakasimbeg priyakasimbeg changed the base branch from main to dev August 19, 2025 16:42
@rka97
Copy link
Contributor

rka97 commented Aug 20, 2025

We didn't migrate all the optimizers. Should we look into migrating the following?

  • adafactor
  • lamb
  • sam
  • shampoo

@priyakasimbeg
Copy link
Contributor Author

We didn't migrate all the optimizers. Should we look into migrating the following?

  • adafactor
  • lamb
  • sam
  • shampoo

The above algorithms are buggy and we have no plans to fix them, so instead of maintaining them we should delete them

@priyakasimbeg priyakasimbeg merged commit 47c8d2b into dev Aug 21, 2025
29 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Aug 21, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants