@@ -12,8 +12,10 @@ ARG TORCHVISION_VERSION
1212FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
1313FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
1414FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
15- ENV CUDA_MAJOR_VERSION=11
16- ENV CUDA_MINOR_VERSION=0
15+ ARG CUDA_MAJOR_VERSION
16+ ARG CUDA_MINOR_VERSION
17+ ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
18+ ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
1719# NVIDIA binaries from the host are mounted to /opt/bin.
1820ENV PATH=/opt/bin:${PATH}
1921# Add CUDA stubs to LD_LIBRARY_PATH to support building the GPU image on a CPU machine.
@@ -51,6 +53,13 @@ RUN pip uninstall -y horovod && \
5153 /tmp/clean-layer.sh
5254{{ end }}
5355
56+ {{ if eq .Accelerator "gpu" }}
57+ # b/230864778: Temporarily swap the NVIDIA GPG key. Remove once new base image with new GPG key is released.
58+ RUN rm /etc/apt/sources.list.d/cuda.list && \
59+ apt-key del 7fa2af80 && \
60+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
61+ {{ end }}
62+
5463# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
5564# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
5665RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \
@@ -72,8 +81,7 @@ RUN conda config --add channels nvidia && \
7281 conda config --add channels rapidsai && \
7382 # Base image channel order: conda-forge (highest priority), defaults.
7483 # End state: rapidsai (highest priority), nvidia, conda-forge, defaults.
75- # b/216162758 Pin mkl which last version breaks spacy.
76- conda install mkl=2021.4.0 cartopy=0.19 imagemagick=7.1 pyproj==3.1.0 && \
84+ conda install mkl cartopy=0.19 imagemagick=7.1 pyproj==3.1.0 && \
7785 /tmp/clean-layer.sh
7886
7987{{ if eq .Accelerator "gpu" }}
@@ -93,11 +101,12 @@ RUN conda install implicit && \
93101# Install PyTorch
94102{{ if eq .Accelerator "gpu" }}
95103COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
96- RUN pip install /tmp/torch/*.whl && \
104+ RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
105+ pip install /tmp/torch/*.whl && \
97106 rm -rf /tmp/torch && \
98107 /tmp/clean-layer.sh
99108{{ else }}
100- RUN pip install torch==$TORCH_VERSION+cpu torchvision==$TORCHVISION_VERSION+cpu torchaudio==$TORCHAUDIO_VERSION torchtext==$TORCHTEXT_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
109+ RUN pip install torch==$TORCH_VERSION+cpu torchvision==$TORCHVISION_VERSION+cpu torchaudio==$TORCHAUDIO_VERSION+cpu torchtext==$TORCHTEXT_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
101110 /tmp/clean-layer.sh
102111{{ end }}
103112
@@ -155,7 +164,7 @@ RUN pip install pycuda && \
155164
156165RUN pip install pysal && \
157166 pip install seaborn python-dateutil dask python-igraph && \
158- pip install pyyaml joblib husl geopy ml_metrics mne pyshp && \
167+ pip install pyyaml joblib husl geopy mne pyshp && \
159168 pip install pandas && \
160169 pip install flax && \
161170 # Install h2o from source.
@@ -212,6 +221,8 @@ RUN pip install ibis-framework && \
212221
213222RUN pip install scipy && \
214223 pip install scikit-learn && \
224+ # Scikit-learn accelerated library for x86
225+ pip install scikit-learn-intelex && \
215226 # HDF5 support
216227 pip install h5py && \
217228 pip install biopython && \
@@ -277,8 +288,6 @@ RUN pip install mpld3 && \
277288 pip install pyldavis==3.2.2 && \
278289 pip install mlxtend && \
279290 pip install altair && \
280- # b/183944405 pystan 3.x is not compatible with fbprophet.
281- pip install pystan==2.19.1.1 && \
282291 pip install ImageHash && \
283292 pip install ecos && \
284293 pip install CVXcanon && \
@@ -301,7 +310,7 @@ RUN pip install mpld3 && \
301310 pip install pyexcel-ods && \
302311 pip install sklearn-pandas && \
303312 pip install stemming && \
304- pip install fbprophet && \
313+ pip install prophet && \
305314 pip install holoviews && \
306315 pip install geoviews && \
307316 pip install hypertools && \
@@ -314,9 +323,7 @@ RUN pip install mpld3 && \
314323 pip install lightfm && \
315324 pip install folium && \
316325 pip install scikit-plot && \
317- # dipy requires the optional fury dependency for visualizations.
318- # b/217761018 pinned fury to fix test
319- pip install fury==0.7.1 dipy && \
326+ pip install fury dipy && \
320327 pip install plotnine && \
321328 pip install scikit-surprise && \
322329 pip install pymongo && \
@@ -391,17 +398,19 @@ RUN pip install bleach && \
391398 pip install ipywidgets && \
392399 pip install isoweek && \
393400 pip install jedi && \
394- pip install Jinja2 && \
395401 pip install jsonschema && \
396402 pip install jupyter-client && \
397403 pip install jupyter-console && \
398404 pip install jupyter-core && \
405+ pip install jupyterlab-lsp && \
399406 pip install MarkupSafe && \
400407 pip install mistune && \
401- pip install nbconvert && \
408+ # b/227194111 install latest version of nbconvert until the base image includes nbconvert >= 6.4.5
409+ pip install --upgrade nbconvert Jinja2 && \
402410 pip install nbformat && \
403411 pip install notebook && \
404412 pip install papermill && \
413+ pip install python-lsp-server[all] && \
405414 pip install olefile && \
406415 # b/198300835 kornia 0.5.10 is not compatible with our version of numpy.
407416 pip install kornia==0.5.8 && \
@@ -488,15 +497,16 @@ RUN pip install flashtext && \
488497 pip install bqplot && \
489498 pip install earthengine-api && \
490499 pip install transformers && \
500+ # b/232247930 >= 2.2.0 requires pyarrow >= 6.0.0 which conflicts with dependencies for rapidsai 0.21.*
501+ pip install datasets==2.1.0 && \
491502 pip install dlib && \
492503 pip install kaggle-environments && \
493504 pip install geopandas && \
494505 pip install nnabla && \
495506 pip install vowpalwabbit && \
496507 pip install pydub && \
497508 pip install pydegensac && \
498- # b/215182966 torchmetrics 0.7.0 is causing an issue with pytorch-lightning.
499- pip install torchmetrics==0.6.2 && \
509+ pip install torchmetrics && \
500510 pip install pytorch-lightning && \
501511 pip install datatable && \
502512 pip install sympy && \
0 commit comments