Skip to content

Commit 5978db8

Browse files
committed
Get tensorflow, jax, and pytorch working on TPU1VM
http://b/213335159
1 parent b055d91 commit 5978db8

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

tpu/Dockerfile

+10-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,13 @@ COPY --from=libtpu /libtpu.so /lib
1010

1111
COPY --from=tensorflow_whl /tmp/tensorflow_pkg/tensorflow*.whl /tmp/tensorflow_pkg/
1212
RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
13-
rm -rf /tmp/tensorflow_pkg
13+
rm -rf /tmp/tensorflow_pkg
14+
15+
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
16+
RUN pip uninstall -y torch
17+
RUN pip uninstall -y torch_xla
18+
RUN pip install torch==1.10
19+
RUN pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl
20+
21+
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
22+
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

tpu/tensorflow.Dockerfile

+17-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ RUN cd /usr/local/src && \
1717
pip install keras_applications --no-deps && \
1818
pip install keras_preprocessing --no-deps
1919

20-
# Create a TensorFlow wheel for CPU
20+
# Create a TensorFlow wheel for TPU
2121
RUN cd /usr/local/src/tensorflow && \
2222
cat /dev/null | ./configure && \
2323
bazel build \
@@ -32,7 +32,22 @@ RUN cd /usr/local/src/tensorflow && \
3232
RUN cd /usr/local/src/tensorflow && \
3333
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
3434

35-
# TODO(b/152075195): Will likely need to install custom build for TFA & tensorflow-gcs-config
35+
# Build TensorFlow addons library against TensorFlow CPU.
36+
#RUN cd /usr/local/src/ && \
37+
# git clone https://github.com/tensorflow/addons && \
38+
# cd addons && \
39+
# git checkout tags/v0.12.1 && \
40+
# python ./configure.py && \
41+
# bazel build --enable_runfiles build_pip_pkg && \
42+
# bazel-bin/build_pip_pkg /tmp/tfa_cpu && \
43+
# bazel clean
44+
45+
# Build tensorflow_gcs_config library against TensorFlow CPU.
46+
#ADD tensorflow-gcs-config /usr/local/src/tensorflow_gcs_config/
47+
#RUN cd /usr/local/src/tensorflow_gcs_config && \
48+
# apt-get install -y libcurl4-openssl-dev && \
49+
# python setup.py bdist_wheel -d /tmp/tensorflow_gcs_config && \
50+
# bazel clean
3651

3752
# Use multi-stage builds to minimize image output size.
3853
FROM alpine:latest

0 commit comments

Comments
 (0)