@@ -12,8 +12,10 @@ ARG TORCHVISION_VERSION
12
12
FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
13
13
FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
14
14
FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
15
- ENV CUDA_MAJOR_VERSION=11
16
- ENV CUDA_MINOR_VERSION=0
15
+ ARG CUDA_MAJOR_VERSION
16
+ ARG CUDA_MINOR_VERSION
17
+ ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
18
+ ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
17
19
# NVIDIA binaries from the host are mounted to /opt/bin.
18
20
ENV PATH=/opt/bin:${PATH}
19
21
# Add CUDA stubs to LD_LIBRARY_PATH to support building the GPU image on a CPU machine.
@@ -51,6 +53,13 @@ RUN pip uninstall -y horovod && \
51
53
/tmp/clean-layer.sh
52
54
{{ end }}
53
55
56
+ {{ if eq .Accelerator "gpu" }}
57
+ # b/230864778: Temporarily swap the NVIDIA GPG key. Remove once new base image with new GPG key is released.
58
+ RUN rm /etc/apt/sources.list.d/cuda.list && \
59
+ apt-key del 7fa2af80 && \
60
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
61
+ {{ end }}
62
+
54
63
# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
55
64
# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
56
65
RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \
@@ -72,8 +81,7 @@ RUN conda config --add channels nvidia && \
72
81
conda config --add channels rapidsai && \
73
82
# Base image channel order: conda-forge (highest priority), defaults.
74
83
# End state: rapidsai (highest priority), nvidia, conda-forge, defaults.
75
- # b/216162758 Pin mkl which last version breaks spacy.
76
- conda install mkl=2021.4.0 cartopy=0.19 imagemagick=7.1 pyproj==3.1.0 && \
84
+ conda install mkl cartopy=0.19 imagemagick=7.1 pyproj==3.1.0 && \
77
85
/tmp/clean-layer.sh
78
86
79
87
{{ if eq .Accelerator "gpu" }}
@@ -93,11 +101,12 @@ RUN conda install implicit && \
93
101
# Install PyTorch
94
102
{{ if eq .Accelerator "gpu" }}
95
103
COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
96
- RUN pip install /tmp/torch/*.whl && \
104
+ RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
105
+ pip install /tmp/torch/*.whl && \
97
106
rm -rf /tmp/torch && \
98
107
/tmp/clean-layer.sh
99
108
{{ else }}
100
- RUN pip install torch==$TORCH_VERSION+cpu torchvision==$TORCHVISION_VERSION+cpu torchaudio==$TORCHAUDIO_VERSION torchtext==$TORCHTEXT_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
109
+ RUN pip install torch==$TORCH_VERSION+cpu torchvision==$TORCHVISION_VERSION+cpu torchaudio==$TORCHAUDIO_VERSION+cpu torchtext==$TORCHTEXT_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
101
110
/tmp/clean-layer.sh
102
111
{{ end }}
103
112
@@ -155,7 +164,7 @@ RUN pip install pycuda && \
155
164
156
165
RUN pip install pysal && \
157
166
pip install seaborn python-dateutil dask python-igraph && \
158
- pip install pyyaml joblib husl geopy ml_metrics mne pyshp && \
167
+ pip install pyyaml joblib husl geopy mne pyshp && \
159
168
pip install pandas && \
160
169
pip install flax && \
161
170
# Install h2o from source.
@@ -212,6 +221,8 @@ RUN pip install ibis-framework && \
212
221
213
222
RUN pip install scipy && \
214
223
pip install scikit-learn && \
224
+ # Scikit-learn accelerated library for x86
225
+ pip install scikit-learn-intelex && \
215
226
# HDF5 support
216
227
pip install h5py && \
217
228
pip install biopython && \
@@ -277,8 +288,6 @@ RUN pip install mpld3 && \
277
288
pip install pyldavis==3.2.2 && \
278
289
pip install mlxtend && \
279
290
pip install altair && \
280
- # b/183944405 pystan 3.x is not compatible with fbprophet.
281
- pip install pystan==2.19.1.1 && \
282
291
pip install ImageHash && \
283
292
pip install ecos && \
284
293
pip install CVXcanon && \
@@ -301,7 +310,7 @@ RUN pip install mpld3 && \
301
310
pip install pyexcel-ods && \
302
311
pip install sklearn-pandas && \
303
312
pip install stemming && \
304
- pip install fbprophet && \
313
+ pip install prophet && \
305
314
pip install holoviews && \
306
315
pip install geoviews && \
307
316
pip install hypertools && \
@@ -314,9 +323,7 @@ RUN pip install mpld3 && \
314
323
pip install lightfm && \
315
324
pip install folium && \
316
325
pip install scikit-plot && \
317
- # dipy requires the optional fury dependency for visualizations.
318
- # b/217761018 pinned fury to fix test
319
- pip install fury==0.7.1 dipy && \
326
+ pip install fury dipy && \
320
327
pip install plotnine && \
321
328
pip install scikit-surprise && \
322
329
pip install pymongo && \
@@ -391,17 +398,19 @@ RUN pip install bleach && \
391
398
pip install ipywidgets && \
392
399
pip install isoweek && \
393
400
pip install jedi && \
394
- pip install Jinja2 && \
395
401
pip install jsonschema && \
396
402
pip install jupyter-client && \
397
403
pip install jupyter-console && \
398
404
pip install jupyter-core && \
405
+ pip install jupyterlab-lsp && \
399
406
pip install MarkupSafe && \
400
407
pip install mistune && \
401
- pip install nbconvert && \
408
+ # b/227194111 install latest version of nbconvert until the base image includes nbconvert >= 6.4.5
409
+ pip install --upgrade nbconvert Jinja2 && \
402
410
pip install nbformat && \
403
411
pip install notebook && \
404
412
pip install papermill && \
413
+ pip install python-lsp-server[all] && \
405
414
pip install olefile && \
406
415
# b/198300835 kornia 0.5.10 is not compatible with our version of numpy.
407
416
pip install kornia==0.5.8 && \
@@ -488,15 +497,16 @@ RUN pip install flashtext && \
488
497
pip install bqplot && \
489
498
pip install earthengine-api && \
490
499
pip install transformers && \
500
+ # b/232247930 >= 2.2.0 requires pyarrow >= 6.0.0 which conflicts with dependencies for rapidsai 0.21.*
501
+ pip install datasets==2.1.0 && \
491
502
pip install dlib && \
492
503
pip install kaggle-environments && \
493
504
pip install geopandas && \
494
505
pip install nnabla && \
495
506
pip install vowpalwabbit && \
496
507
pip install pydub && \
497
508
pip install pydegensac && \
498
- # b/215182966 torchmetrics 0.7.0 is causing an issue with pytorch-lightning.
499
- pip install torchmetrics==0.6.2 && \
509
+ pip install torchmetrics && \
500
510
pip install pytorch-lightning && \
501
511
pip install datatable && \
502
512
pip install sympy && \
0 commit comments