diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index 77f075545..75f844b86 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -6,10 +6,9 @@ import functools from typing import Dict, Iterator, Tuple -import jax import tensorflow_datasets as tfds -from algoperf import data_utils, jax_sharding_utils, spec +from algoperf import data_utils, spec from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -47,10 +46,4 @@ def _decode_example(example: Dict[str, float]) -> Dict[str, float]: if framework == 'pytorch': it = map(data_utils.shard, it) - elif framework == 'jax': - f = functools.partial( - jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding() - ) - it = map(f, it) - return it