From cfbaf7a48e25b70c8164ddaf13f90fdbb5f3288c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 21 Aug 2025 17:09:06 +0000 Subject: [PATCH] remove jax.device_put from imagenet test pipeline because it results in an OOM for subsequent training steps --- algoperf/workloads/imagenet_resnet/imagenet_v2.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index 77f075545..75f844b86 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -6,10 +6,9 @@ import functools from typing import Dict, Iterator, Tuple -import jax import tensorflow_datasets as tfds -from algoperf import data_utils, jax_sharding_utils, spec +from algoperf import data_utils, spec from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -47,10 +46,4 @@ def _decode_example(example: Dict[str, float]) -> Dict[str, float]: if framework == 'pytorch': it = map(data_utils.shard, it) - elif framework == 'jax': - f = functools.partial( - jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding() - ) - it = map(f, it) - return it