Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fe77982
adding easyconfigs: jax-0.2.18-foss-2020b.eb
Aug 5, 2021
f3e3c7c
adding jax patch
Aug 5, 2021
4339d94
tmp dir prefix
Aug 5, 2021
4909247
removed unnecessary code
Aug 5, 2021
56b1d48
patch comment
Aug 5, 2021
40672c6
patch cleanup
Aug 5, 2021
04ddadd
cleanup
Aug 5, 2021
23c2f13
Update easybuild/easyconfigs/j/jax/jax-0.2.18-foss-2020b.eb
deniskristak Aug 5, 2021
af41b43
Update easybuild/easyconfigs/j/jax/jax-0.2.18-foss-2020b.eb
deniskristak Aug 5, 2021
09fe906
Update easybuild/easyconfigs/j/jax/jax-0.2.18-foss-2020b.eb
deniskristak Aug 5, 2021
bd3383d
Update easybuild/easyconfigs/j/jax/jax-0.2.18-foss-2020b.eb
deniskristak Aug 5, 2021
14f386c
latest progress on jax installation
Aug 11, 2021
885ebcb
removing absl-py from deps
Aug 12, 2021
544bda6
using buildcmd option
Aug 12, 2021
497398e
latest progress on jax
Aug 12, 2021
98dbee5
latest progress on jax
Aug 13, 2021
a4988de
latest progress on jax
Aug 13, 2021
a7cb72c
latest jax changes
Aug 13, 2021
edc2060
remove opt-einsum dependency for jax, since it's installed as an exte…
boegel Aug 13, 2021
ea71459
add checksums to jax easyconfig
boegel Aug 13, 2021
415ba09
latest jax changes
Aug 13, 2021
2286043
remove commented out line from jax easyconfig
boegel Aug 13, 2021
1210ab0
stick to specific TensorFlow commit used by jaxlib 0.1.70, enable bui…
boegel Aug 13, 2021
c1735ce
enable building with native optimizations for build host (triggers us…
boegel Aug 13, 2021
4b67573
Merge branch '20210805110301_new_pr_jax0218' of github.com:deniskrist…
boegel Aug 13, 2021
71475e9
run jax test suite
boegel Aug 13, 2021
54fb8ef
rename jax patch to prevent TensorFlow download during build
boegel Aug 13, 2021
94ccf12
add easyconfigs for pytest-xdist and pytest-benchmark build dependenc…
boegel Aug 13, 2021
e0f7c0f
update to jax 0.2.19
boegel Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.2.18-controlled_tf_tarball.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# using system environment variables
diff -ruN tensorflow-2.5.0_orig/third_party/mlir/tblgen.bzl tensorflow-2.5.0/third_party/mlir/tblgen.bzl
--- tensorflow-2.5.0_orig/third_party/mlir/tblgen.bzl 2021-05-12 15:26:41.000000000 +0200
+++ tensorflow-2.5.0/third_party/mlir/tblgen.bzl 2021-08-11 12:23:34.301593036 +0200
@@ -153,6 +153,7 @@
inputs = trans_srcs,
executable = ctx.executable.tblgen,
arguments = [args],
+ use_default_shell_env = True,
)

return [DefaultInfo()]
36 changes: 36 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.2.18-correct_libraries.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
diff -ruN jax-jaxlib-v0.1.70_orig/WORKSPACE jax-jaxlib-v0.1.70/WORKSPACE
--- jax-jaxlib-v0.1.70_orig/WORKSPACE 2021-08-12 13:12:54.025608000 +0200
+++ jax-jaxlib-v0.1.70/WORKSPACE 2021-08-12 13:32:51.581961000 +0200
@@ -5,21 +5,20 @@
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
-http_archive(
+# http_archive(
+# name = "org_tensorflow",
+# sha256 = "0f13410284b9186e436350e9617b3bed2d65f1dc1a220fd37ad9ef43c2035663",
+# strip_prefix = "tensorflow-4039feeb743bc42cd0a3d8146ce63fc05d23eb8d",
+# urls = [
+# "https://github.com/tensorflow/tensorflow/archive/4039feeb743bc42cd0a3d8146ce63fc05d23eb8d.tar.gz",
+# ],
+# )
+
+local_repository(
name = "org_tensorflow",
- sha256 = "0f13410284b9186e436350e9617b3bed2d65f1dc1a220fd37ad9ef43c2035663",
- strip_prefix = "tensorflow-4039feeb743bc42cd0a3d8146ce63fc05d23eb8d",
- urls = [
- "https://github.com/tensorflow/tensorflow/archive/4039feeb743bc42cd0a3d8146ce63fc05d23eb8d.tar.gz",
- ],
+ path = "pathToSed",
)

-# For development, one can use a local TF repository instead.
-# local_repository(
-# name = "org_tensorflow",
-# path = "tensorflow",
-# )
-
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
pocketfft()

87 changes: 87 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.2.18-foss-2020b.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
easyblock = 'PythonBundle'

name = 'jax'
version = '0.2.18'
local_jaxlib_ver = '0.1.70'
homepage = 'https://pypi.python.org/pypi/jax'
description = """Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and
more"""

toolchain = {'name': 'foss', 'version': '2020b'}

# downloading TF tarball for better control of what Bazel will use during installation of Jax
local_tf_tarball_version = '2.5.0'

builddependencies = [
('Bazel', '3.7.2'),
]

dependencies = [
('Python', '3.8.6'),
('SciPy-bundle', '2020.11'),
('Cython', '0.29.22'),
('opt-einsum', '3.3.0'),
('flatbuffers-python', '1.12'),
]

local_tf_builddir = "%%(builddir)s/tensorflow-%s" % local_tf_tarball_version

local_jax_preinstallopts = "sed -i -e 's$pathToSed$%s$g' WORKSPACE && " % local_tf_builddir

local_jax_build_cmd = "sed -i -e 's$pathToSed$%s$g' WORKSPACE && " % local_tf_builddir
local_jax_build_cmd += 'python build/build.py '
local_jax_build_cmd += '--bazel_startup_options="--output_user_root=%(builddir)s" '
local_jax_build_cmd += '--bazel_path="$EBROOTBAZEL/bin/bazel" '
local_jax_build_cmd += '--bazel_options=--subcommands '
local_jax_build_cmd += '--bazel_options=--jobs=1 --bazel_options=--action_env=PYTHONPATH '
local_jax_build_cmd += '--bazel_options=--action_env=EBPYTHONPREFIXES'


exts_list = [
('opt-einsum', '3.3.0', {
'source_tmpl': 'opt_einsum-%(version)s.tar.gz',
'modulename': 'opt_einsum',
}),
('absl-py', '0.13.0', {
'modulename': 'absl'
})]

default_component_specs = {
'sources': [SOURCE_TAR_GZ],
'start_dir': '%(name)s-%(version)s',
}

components = [
('jaxlib', local_jaxlib_ver, {
'easyblock': 'PythonPackage',
'sources': [
'jaxlib-v%s.zip' % local_jaxlib_ver,
{
'download_filename': 'v%s.tar.gz' % local_tf_tarball_version,
'filename': 'tensorflow_v%s.tar.gz' % local_tf_tarball_version,
}
],
'source_urls': [
'https://github.com/google/jax/archive/',
'https://github.com/tensorflow/tensorflow/archive/'
],
'patches': [
('jaxlib-%s-correct_libraries.patch' % local_jaxlib_ver, 1),
('jaxlib-%s-controlled_tf_tarball.patch' % local_jaxlib_ver, '../tensorflow-%s' % local_tf_tarball_version),
],
'start_dir': 'jax-jaxlib-v%s' % local_jaxlib_ver,
'preinstallopts': local_jax_preinstallopts,
'use_pip': True,
'sanity_pip_check': True,
'download_dep_fail': True,
'buildcmd': local_jax_build_cmd,
'install_src': 'dist/*.whl',
}),
('jax', '0.2.18', {
'easyblock': 'PythonPackage',
}),
]

moduleclass = 'tools'