1
1
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
2
- ARG HIPBLASLT_BRANCH="4d40e36 "
2
+ ARG HIPBLASLT_BRANCH="db8e93b4 "
3
3
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
4
4
ARG LEGACY_HIPBLASLT_OPTION=
5
5
ARG RCCL_BRANCH="648a58d"
6
6
ARG RCCL_REPO="https://github.com/ROCm/rccl"
7
7
ARG TRITON_BRANCH="e5be006"
8
8
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
9
- ARG PYTORCH_BRANCH="3a585126 "
10
- ARG PYTORCH_VISION_BRANCH="v0.19.1 "
9
+ ARG PYTORCH_BRANCH="295f2ed4 "
10
+ ARG PYTORCH_VISION_BRANCH="v0.21.0 "
11
11
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
12
12
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
13
- ARG FA_BRANCH="b7d29fb "
14
- ARG FA_REPO="https://github.com/ROCm /flash-attention.git"
15
- ARG AITER_BRANCH="21d47a9 "
13
+ ARG FA_BRANCH="1a7f4dfa "
14
+ ARG FA_REPO="https://github.com/Dao-AILab /flash-attention.git"
15
+ ARG AITER_BRANCH="8970b25b "
16
16
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
17
17
18
18
FROM ${BASE_IMAGE} AS base
19
19
20
20
ENV PATH=/opt/rocm/llvm/bin:$PATH
21
21
ENV ROCM_PATH=/opt/rocm
22
22
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
23
- ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
23
+ ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
24
24
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
25
25
26
26
ARG PYTHON_VERSION=3.12
@@ -31,7 +31,7 @@ ENV DEBIAN_FRONTEND=noninteractive
31
31
32
32
# Install Python and other dependencies
33
33
RUN apt-get update -y \
34
- && apt-get install -y software-properties-common git curl sudo vim less \
34
+ && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \
35
35
&& add-apt-repository ppa:deadsnakes/ppa \
36
36
&& apt-get update -y \
37
37
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
@@ -42,7 +42,7 @@ RUN apt-get update -y \
42
42
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
43
43
&& python3 --version && python3 -m pip --version
44
44
45
- RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
45
+ RUN pip install -U packaging ' cmake<4' ninja wheel setuptools pybind11 Cython
46
46
47
47
FROM base AS build_hipblaslt
48
48
ARG HIPBLASLT_BRANCH
@@ -60,7 +60,8 @@ RUN cd hipBLAS-common \
60
60
RUN git clone https://github.com/ROCm/hipBLASLt
61
61
RUN cd hipBLASLt \
62
62
&& git checkout ${HIPBLASLT_BRANCH} \
63
- && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
63
+ && apt-get install -y llvm-dev \
64
+ && ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
64
65
&& cd build/release \
65
66
&& make package
66
67
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
@@ -110,11 +111,24 @@ RUN git clone ${FA_REPO}
110
111
RUN cd flash-attention \
111
112
&& git checkout ${FA_BRANCH} \
112
113
&& git submodule update --init \
113
- && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
114
+ && GPU_ARCHS=$(echo $ {PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
114
115
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
115
116
&& cp /app/vision/dist/*.whl /app/install \
116
117
&& cp /app/flash-attention/dist/*.whl /app/install
117
118
119
+ FROM base AS build_aiter
120
+ ARG AITER_BRANCH
121
+ ARG AITER_REPO
122
+ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
123
+ pip install /install/*.whl
124
+ RUN git clone --recursive ${AITER_REPO}
125
+ RUN cd aiter \
126
+ && git checkout ${AITER_BRANCH} \
127
+ && git submodule update --init --recursive \
128
+ && pip install -r requirements.txt
129
+ RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
130
+ RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
131
+
118
132
FROM base AS final
119
133
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
120
134
dpkg -i /install/*deb \
@@ -130,19 +144,12 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
130
144
pip install /install/*.whl
131
145
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
132
146
pip install /install/*.whl
133
-
134
- ARG AITER_REPO
135
- ARG AITER_BRANCH
136
- RUN git clone --recursive ${AITER_REPO}
137
- RUN cd aiter \
138
- && git checkout ${AITER_BRANCH} \
139
- && git submodule update --init --recursive \
140
- && pip install -r requirements.txt \
141
- && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
147
+ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
148
+ pip install /install/*.whl
142
149
143
150
ARG BASE_IMAGE
144
- ARG HIPBLASLT_BRANCH
145
151
ARG HIPBLAS_COMMON_BRANCH
152
+ ARG HIPBLASLT_BRANCH
146
153
ARG LEGACY_HIPBLASLT_OPTION
147
154
ARG RCCL_BRANCH
148
155
ARG RCCL_REPO
@@ -154,6 +161,8 @@ ARG PYTORCH_REPO
154
161
ARG PYTORCH_VISION_REPO
155
162
ARG FA_BRANCH
156
163
ARG FA_REPO
164
+ ARG AITER_BRANCH
165
+ ARG AITER_REPO
157
166
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
158
167
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
159
168
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
@@ -167,6 +176,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
167
176
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
168
177
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
169
178
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
170
- && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
171
179
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
172
180
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
0 commit comments