Skip to content

takuseno/tensor-bridge

Repository files navigation

test

tensor-bridge

tensor-bridge is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy call with minimal overheads.

import torch
import jax
from tensor_bridge import copy_tensor


# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")

# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))

# Copy Jax tensor to PyTorch tensor
copy_tensor(torch_data, jax_data)

# And, other way around
copy_tensor(jax_data, torch_data)

⚠️ Currently, this repository is under active development. Especially, transfer between different layout of tensors is not implemented yet. I recommend to try copy_tensor_with_assertion before starting experiments. copy_tensor_with_assertion will raise an error if copy doesn't work. If copy_tensor_with_assertion raises an error, you need to force the tensor to be contiguous:

# PyTorch example

# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b)  # AssertionError !!

# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)

Since copy_tensor_with_assertion does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor in your experiments. Otherwise your training loop will be significantly slower.

Features

  • Fast inter-library tensor copy.
  • Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)

Supported deep learning libraries

  • PyTorch
  • Jax
  • nnabla

Installation

PyPI

If pip installation doesn't work, please try installation from source code.

Python 3.10.x

You can install a pre-built package.

pip install tensor-bridge

Other Python version

Your macine needs to install nvcc to compile a native code and Cython to compile .pyx files.

pip install Cython==0.29.36
pip install tensor-bridge

Pre-built packages for other Python versions are in progress.

From souce code

Your macine needs to install nvcc to compile a native code and Cython to compile .pyx files.

git clone [email protected]:takuseno/tensor-bridge
cd tensor-bridge
pip install Cython==0.29.36
pip install -e .

Unit test

Your machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.

./bin/build-docker
./bin/test

Benchmark

To benchmark round trip copies between Jax and PyTorch:

./bin/build-docker
./bin/benchmark

This is result with my local desktop with RTX4070.

Benchmarking copy_tensor...
Average compute time: 1.3043880462646485e-05 sec
Benchmarking copy via CPU...
Average compute time: 0.0016725873947143555 sec
Benchmarking dlpack...
Average compute time: 7.467031478881836e-05 sec

copy_tensor is surprisingly faster than DLPack. Looking at PyTorch's implementation, it seems that PyTorch does additional CUDA stream synchronization, which adds additional compute time.

About

Transfer tensors between PyTorch, Jax and more

Resources

License

Stars

Watchers

Forks

Packages

No packages published