diff --git a/.ci/aarch64_linux/build_aarch64_wheel.py b/.ci/aarch64_linux/build_aarch64_wheel.py index 025d0a20579d..7a4715d33006 100755 --- a/.ci/aarch64_linux/build_aarch64_wheel.py +++ b/.ci/aarch64_linux/build_aarch64_wheel.py @@ -438,9 +438,7 @@ def build_torchvision( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -495,9 +493,7 @@ def build_torchdata( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -553,9 +549,7 @@ def build_torchtext( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -613,9 +607,7 @@ def build_torchaudio( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index a286d8da39ac..aabfbd5a4772 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -144,16 +144,6 @@ case "$tag" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9) - CUDA_VERSION=12.6.3 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 @@ -164,39 +154,6 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks) - CUDA_VERSION=12.6 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - INDUCTOR_BENCHMARKS=yes - ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks) - CUDA_VERSION=12.6 - ANACONDA_PYTHON_VERSION=3.12 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - INDUCTOR_BENCHMARKS=yes - ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks) - CUDA_VERSION=12.6 - ANACONDA_PYTHON_VERSION=3.13 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - INDUCTOR_BENCHMARKS=yes - ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 @@ -219,19 +176,7 @@ case "$tag" in VISION=yes TRITON=yes ;; - pytorch-linux-jammy-py3.11-clang12) - ANACONDA_PYTHON_VERSION=3.11 - CLANG_VERSION=12 - VISION=yes - TRITON=yes - ;; - pytorch-linux-jammy-py3.9-gcc9) - ANACONDA_PYTHON_VERSION=3.9 - GCC_VERSION=9 - VISION=yes - TRITON=yes - ;; - pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) + pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) if [[ $tag =~ "jammy" ]]; then ANACONDA_PYTHON_VERSION=3.10 else @@ -245,7 +190,9 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} - INDUCTOR_BENCHMARKS=yes + if [[ $tag =~ "benchmarks" ]]; then + INDUCTOR_BENCHMARKS=yes + fi ;; pytorch-linux-noble-rocm-alpha-py3) ANACONDA_PYTHON_VERSION=3.12 @@ -257,7 +204,6 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} - INDUCTOR_BENCHMARKS=yes PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950" ;; pytorch-linux-jammy-xpu-2025.0-py3) diff --git a/.ci/docker/ci_commit_pins/huggingface.txt b/.ci/docker/ci_commit_pins/huggingface.txt index f00d6ca4f9ca..4fc4729a25da 100644 --- a/.ci/docker/ci_commit_pins/huggingface.txt +++ b/.ci/docker/ci_commit_pins/huggingface.txt @@ -1 +1 @@ -243e186efbf7fb93328dd6b34927a4e8c8f24395 +v4.54.0 diff --git a/.github/ci_commit_pins/torchbench.txt b/.ci/docker/ci_commit_pins/torchbench.txt similarity index 100% rename from .github/ci_commit_pins/torchbench.txt rename to .ci/docker/ci_commit_pins/torchbench.txt diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 6dc1c44507eb..60c896b80c8f 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -11ec6354315768a85da41032535e3b7b99c5f706 +f7888497a1eb9e98d4c07537f0d0bcfe180d1363 diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index d7fc6ea264dd..c160e5704ba3 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -66,8 +66,9 @@ function do_cpython_build { ln -s pip3 ${prefix}/bin/pip fi # install setuptools since python 3.12 is required to use distutils - ${prefix}/bin/pip install wheel==0.45.1 setuptools==80.9.0 - local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))") + # packaging is needed to create symlink since wheel no longer provides needed information + ${prefix}/bin/pip install packaging==25.0 wheel==0.45.1 setuptools==80.9.0 + local abi_tag=$(${prefix}/bin/python -c "from packaging.tags import interpreter_name, interpreter_version; import sysconfig ; from sysconfig import get_config_var; print('{0}{1}-{0}{1}{2}'.format(interpreter_name(), interpreter_version(), 't' if sysconfig.get_config_var('Py_GIL_DISABLED') else ''))") ln -sf ${prefix} /opt/python/${abi_tag} } diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index c8a780f65c8e..ebebd195d6b7 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -68,8 +68,8 @@ function install_nvshmem { # download, unpack, install wget -q "${url}" tar xf "${filename}.tar.gz" - cp -a "libnvshmem/include/"* /usr/local/include/ - cp -a "libnvshmem/lib/"* /usr/local/lib/ + cp -a "libnvshmem/include/"* /usr/local/cuda/include/ + cp -a "libnvshmem/lib/"* /usr/local/cuda/lib64/ # cleanup cd .. diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 7312dce170db..21fced2e851d 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -15,11 +15,37 @@ function install_timm() { commit=$(get_pinned_commit timm) pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}" - # Clean up - conda_run pip uninstall -y torch torchvision triton +} + +function install_torchbench() { + local commit + commit=$(get_pinned_commit torchbench) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + python install.py --continue_on_fail + + # soxr comes from https://github.com/huggingface/transformers/pull/39429 + pip install transformers==4.54.0 soxr==0.5.0 + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd + + chown -R jenkins torchbench + chown -R jenkins /opt/conda } # Pango is needed for weasyprint which is needed for doctr conda_install pango + +# Stable packages are ok here, just to satisfy TorchBench check +pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 + +install_torchbench install_huggingface install_timm + +# Clean up +conda_run pip uninstall -y torch torchvision torchaudio triton torchao diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index f49bdf3f2b46..a965f0f743d4 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -103,5 +103,5 @@ fi # It depends on torch and triton. We don't want to install # triton and torch from production on Docker CI images if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then - pip_install helion==0.0.10 --no-deps + pip_install helion --no-deps fi diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index ecbbb8ccccf8..7f21d2e42c72 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -34,18 +34,27 @@ function install_ubuntu() { # The xpu-smi packages apt-get install -y flex bison xpu-smi - # Compute and Media Runtimes - apt-get install -y \ - intel-opencl-icd intel-level-zero-gpu level-zero \ - intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ - libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ - libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ - mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo - if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then - apt-get install -y intel-ocloc + + if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then + # Compute and Media Runtimes + apt-get install -y \ + intel-opencl-icd intel-level-zero-gpu level-zero \ + intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ + libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo + # Development Packages + apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev + else # rolling driver + apt-get install -y \ + intel-opencl-icd libze-intel-gpu1 libze1 \ + intel-media-va-driver-non-free libmfx-gen1 libvpl2 \ + libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc + apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev fi - # Development Packages - apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev + # Install Intel Support Packages apt-get install -y ${XPU_PACKAGES} @@ -130,11 +139,11 @@ function install_sles() { } -# Default use GPU driver LTS releases -XPU_DRIVER_VERSION="/lts/2350" -if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then - # Use GPU driver rolling releases - XPU_DRIVER_VERSION="" +# Default use GPU driver rolling releases +XPU_DRIVER_VERSION="" +if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then + # Use GPU driver LTS releases + XPU_DRIVER_VERSION="/lts/2350" fi # Default use IntelĀ® oneAPI Deep Learning Essentials 2025.0 diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index d25f79766baf..d4bdd9b2a9cb 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -63,11 +63,12 @@ lark==0.12.0 #Pinned versions: 0.12.0 #test that import: -librosa>=0.6.2 ; python_version < "3.11" -librosa==0.10.2 ; python_version == "3.12" +librosa>=0.6.2 ; python_version < "3.11" and platform_machine != "s390x" +librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x" #Description: A python package for music and audio analysis #Pinned versions: >=0.6.2 #test that import: test_spectral_ops.py +#librosa depends on numba; disable it for s390x while numba is disabled too #mkl #this breaks linux-bionic-rocm4.5-py3.7 #Description: Intel oneAPI Math Kernel Library @@ -110,14 +111,15 @@ ninja==1.11.1.3 #Pinned versions: 1.11.1.3 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py -numba==0.49.0 ; python_version < "3.9" -numba==0.55.2 ; python_version == "3.9" -numba==0.55.2 ; python_version == "3.10" -numba==0.60.0 ; python_version == "3.12" +numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x" +numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x" +numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x" +numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x" #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 #test that import: test_numba_integration.py #For numba issue see https://github.com/pytorch/pytorch/issues/51511 +#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073 #numpy #Description: Provides N-dimensional arrays and linear algebra @@ -307,7 +309,7 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: -z3-solver==4.15.1.0 +z3-solver==4.15.1.0 ; platform_machine != "s390x" #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: @@ -361,7 +363,6 @@ pwlf==2.2.1 #Pinned versions: 2.2.1 #test that import: test_sac_estimator.py - # To build PyTorch itself pyyaml pyzstd diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 4997e15d4687..3de4d8e0e44e 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -1,7 +1,7 @@ sphinx==5.3.0 #Description: This is used to generate PyTorch docs #Pinned versions: 5.3.0 --e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2 # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering # but it doesn't seem to work and hangs around idly. The initial thought that it is probably @@ -50,7 +50,7 @@ IPython==8.12.0 #Pinned versions: 8.12.0 myst-nb==0.17.2 -#Description: This is used to generate PyTorch functorch and torch.compile docs +#Description: This is used to generate PyTorch functorch and torch.compile docs. #Pinned versions: 0.17.2 # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 2528da07c69e..8f2cc6eef958 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt # (optional) Install non-default Ninja version ARG NINJA_VERSION diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 27c466dd8d41..077910cef9f3 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt ARG TRITON ARG TRITON_CPU diff --git a/.ci/manywheel/build.sh b/.ci/manywheel/build.sh index 4c4d51134715..6b2a60bc5ca2 100755 --- a/.ci/manywheel/build.sh +++ b/.ci/manywheel/build.sh @@ -5,10 +5,6 @@ set -ex SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" case "${GPU_ARCH_TYPE:-BLANK}" in - BLANK) - # Legacy behavior for CircleCI - bash "${SCRIPTPATH}/build_cuda.sh" - ;; cuda) bash "${SCRIPTPATH}/build_cuda.sh" ;; diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 49549c9f2994..4c268befb30e 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -138,28 +138,11 @@ fi echo "Calling setup.py bdist at $(date)" -if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ - BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 \ +time CMAKE_ARGS=${CMAKE_ARGS[@]} \ + EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR - echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" - time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ - BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \ - BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ - USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ - CMAKE_FRESH=1 python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR - echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" -else - time CMAKE_ARGS=${CMAKE_ARGS[@]} \ - EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ - BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ - USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ - python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR -fi echo "Finished setup.py bdist at $(date)" # Build libtorch packages @@ -272,10 +255,6 @@ ls /tmp/$WHEELHOUSE_DIR mkdir -p "/$WHEELHOUSE_DIR" mv /tmp/$WHEELHOUSE_DIR/torch*linux*.whl /$WHEELHOUSE_DIR/ -if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - mv /tmp/$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/ || true -fi - if [[ -n "$BUILD_PYTHONLESS" ]]; then mkdir -p /$LIBTORCH_HOUSE_DIR mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR @@ -452,16 +431,8 @@ if [[ -z "$BUILD_PYTHONLESS" ]]; then pushd $PYTORCH_ROOT/test # Install the wheel for this Python version - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - pip uninstall -y "$TORCH_NO_PYTHON_PACKAGE_NAME" || true - fi - pip uninstall -y "$TORCH_PACKAGE_NAME" - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - pip install "$TORCH_NO_PYTHON_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v - fi - pip install "$TORCH_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v # Print info on the libraries installed in this wheel diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh index 690600efdb37..ffc15bcdc5fa 100755 --- a/.ci/manywheel/build_rocm.sh +++ b/.ci/manywheel/build_rocm.sh @@ -194,7 +194,7 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library ROCBLAS_LIB_DST=lib/rocblas/library ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) -ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES) +ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES) # hipblaslt library files HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index a7ce0fef736c..65f97389324a 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -50,9 +50,6 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi -# Enable LLVM dependency for TensorExpr testing -export USE_LLVM=/opt/llvm -export LLVM_DIR=/opt/llvm/lib/cmake/llvm if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with @@ -176,7 +173,7 @@ fi # We only build FlashAttention files for CUDA 8.0+, and they require large amounts of # memory to build and will OOM -if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]]; then +if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' ' '\n' | sed 's/$/>= 8.0/' | bc | grep -q 1; then export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j 2" fi @@ -192,7 +189,6 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then export USE_ASAN=1 export REL_WITH_DEB_INFO=1 export UBSAN_FLAGS="-fno-sanitize-recover=all" - unset USE_LLVM fi if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then @@ -265,22 +261,13 @@ else WERROR=1 python setup.py clean - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - python3 tools/packaging/split_wheel.py bdist_wheel - else - WERROR=1 python setup.py bdist_wheel - fi + WERROR=1 python setup.py bdist_wheel else python setup.py clean if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then source .ci/pytorch/install_cache_xla.sh fi - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - echo "USE_SPLIT_BUILD cannot be used with xla or rocm" - exit 1 - else - python setup.py bdist_wheel - fi + python setup.py bdist_wheel fi pip_install_whl "$(echo dist/*.whl)" diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index e9c7741947cf..06decc2ea64b 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -229,7 +229,6 @@ function install_torchrec_and_fbgemm() { pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm - pushd /tmp local wheel_dir=dist/fbgemm_gpu local found_whl=0 @@ -245,7 +244,7 @@ function install_torchrec_and_fbgemm() { if [ "${found_whl}" == "0" ]; then git clone --recursive https://github.com/pytorch/fbgemm pushd fbgemm/fbgemm_gpu - git checkout "${fbgemm_commit}" + git checkout "${fbgemm_commit}" --recurse-submodules python setup.py bdist_wheel \ --build-variant=rocm \ -DHIP_ROOT_DIR="${ROCM_PATH}" \ @@ -264,7 +263,6 @@ function install_torchrec_and_fbgemm() { done rm -rf fbgemm - popd else pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu @@ -283,30 +281,6 @@ function clone_pytorch_xla() { fi } -function checkout_install_torchbench() { - local commit - commit=$(get_pinned_commit torchbench) - git clone https://github.com/pytorch/benchmark torchbench - pushd torchbench - git checkout "$commit" - - if [ "$1" ]; then - python install.py --continue_on_fail models "$@" - else - # Occasionally the installation may fail on one model but it is ok to continue - # to install and test other models - python install.py --continue_on_fail - fi - - # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 - # is regressing speedup metric. This needs to be investigated further - pip install transformers==4.38.1 - - echo "Print all dependencies after TorchBench is installed" - python -mpip freeze - popd -} - function install_torchao() { local commit commit=$(get_pinned_commit torchao) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 83f8e4e04331..c9d926a5df37 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -157,6 +157,32 @@ test_jit_hooks() { assert_git_not_dirty } +# Shellcheck doesn't like it when you pass no arguments to a function +# that can take args. See https://www.shellcheck.net/wiki/SC2120 +# shellcheck disable=SC2120 +checkout_install_torchbench() { + local commit + commit=$(cat .ci/docker/ci_commit_pins/torchbench.txt) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + if [ "$1" ]; then + python install.py --continue_on_fail models "$@" + else + # Occasionally the installation may fail on one model but it is ok to continue + # to install and test other models + python install.py --continue_on_fail + fi + + # soxr comes from https://github.com/huggingface/transformers/pull/39429 + pip install transformers==4.54.0 soxr==0.5.0 + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd +} + torchbench_setup_macos() { git clone --recursive https://github.com/pytorch/vision torchvision git clone --recursive https://github.com/pytorch/audio torchaudio @@ -179,8 +205,6 @@ torchbench_setup_macos() { USE_OPENMP=0 python setup.py develop popd - # Shellcheck doesn't like it when you pass no arguments to a function that can take args. See https://www.shellcheck.net/wiki/SC2120 - # shellcheck disable=SC2119,SC2120 checkout_install_torchbench } diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index fb4e0759d508..daa258d283fa 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -627,6 +627,8 @@ test_perf_for_dashboard() { device=cuda_a10g elif [[ "${TEST_CONFIG}" == *h100* ]]; then device=cuda_h100 + elif [[ "${TEST_CONFIG}" == *b200* ]]; then + device=cuda_b200 elif [[ "${TEST_CONFIG}" == *rocm* ]]; then device=rocm fi @@ -801,6 +803,16 @@ test_dynamo_benchmark() { if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@" elif [[ "${TEST_CONFIG}" == *perf* ]]; then + # TODO (huydhn): Just smoke test some sample models + if [[ "${TEST_CONFIG}" == *b200* ]]; then + if [[ "${suite}" == "huggingface" ]]; then + export TORCHBENCH_ONLY_MODELS="DistillGPT2" + elif [[ "${suite}" == "timm_models" ]]; then + export TORCHBENCH_ONLY_MODELS="inception_v3" + elif [[ "${suite}" == "torchbench" ]]; then + export TORCHBENCH_ONLY_MODELS="hf_Bert" + fi + fi test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@" else if [[ "${TEST_CONFIG}" == *cpu* ]]; then @@ -1039,20 +1051,10 @@ test_libtorch_api() { mkdir -p $TEST_REPORTS_DIR OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml - "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml else # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest" - # On s390x, pytorch is built without llvm. - # Even if it would be built with llvm, llvm currently doesn't support used features on s390x and - # test fails with errors like: - # JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer - # unknown file: Failure - # C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) } - if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then - python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr - fi fi # quantization is not fully supported on s390x yet @@ -1672,43 +1674,34 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then elif [[ "${TEST_CONFIG}" == cachebench ]]; then install_torchaudio install_torchvision - checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco - PYTHONPATH=$(pwd)/torchbench test_cachebench + PYTHONPATH=/torchbench test_cachebench elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then install_torchaudio install_torchvision - checkout_install_torchbench nanogpt - PYTHONPATH=$(pwd)/torchbench test_verify_cachebench + PYTHONPATH=/torchbench test_verify_cachebench elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio install_torchvision - install_torchao id=$((SHARD_NUMBER-1)) # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then - checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer - PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf + PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then - checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \ - llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \ - functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0 - PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf + PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then - checkout_install_torchbench - TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest + TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest else - checkout_install_torchbench # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in if [[ "${TEST_CONFIG}" != *cpu* ]]; then install_torchrec_and_fbgemm fi - PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" + PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then install_torchvision - PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" + PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" if [[ "$SHARD_NUMBER" -eq "1" ]]; then test_inductor_aoti fi diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 7ceb425ce2d1..19d715b9d0b6 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -61,9 +61,10 @@ if "%USE_XPU%"=="1" ( call "C:\Program Files (x86)\Intel\oneAPI\compiler\latest\env\vars.bat" call "C:\Program Files (x86)\Intel\oneAPI\ocloc\latest\env\vars.bat" if errorlevel 1 exit /b 1 - :: Reduce build time. Only have MTL self-hosted runner now - SET TORCH_XPU_ARCH_LIST=xe-lpg - SET USE_KINETO=0 + :: Reduce build time + SET TORCH_XPU_ARCH_LIST=bmg + :: Re-setup python env for build + call pip install -r requirements.txt ) @echo on diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 878d6595c84c..b90e6f38e911 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -192,9 +192,6 @@ retry brew install libomp # For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule export USE_DISTRIBUTED=1 -if [[ -n "$CROSS_COMPILE_ARM64" ]]; then - export CMAKE_OSX_ARCHITECTURES=arm64 -fi export USE_MKLDNN=OFF export USE_QNNPACK=OFF export BUILD_TEST=OFF @@ -202,16 +199,7 @@ export BUILD_TEST=OFF pushd "$pytorch_rootdir" echo "Calling setup.py bdist_wheel at $(date)" -if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel -d "$whl_tmp_dir" - echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" - BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 CMAKE_FRESH=1 python setup.py bdist_wheel -d "$whl_tmp_dir" - echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" -else - python setup.py bdist_wheel -d "$whl_tmp_dir" -fi +python setup.py bdist_wheel -d "$whl_tmp_dir" echo "Finished setup.py bdist_wheel at $(date)" diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 11678cabb2c3..c24a50b8b17e 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -65,16 +65,8 @@ fi if [[ "$PACKAGE_TYPE" != libtorch ]]; then if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - pkg_no_python="$(ls -1 /final_pkgs/torch_no_python* | sort |tail -1)" - pkg_torch="$(ls -1 /final_pkgs/torch-* | sort |tail -1)" - # todo: after folder is populated use the pypi_pkg channel instead - pip install "\$pkg_no_python" "\$pkg_torch" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}_pypi_pkg" - retry pip install -q numpy protobuf typing-extensions - else - pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" - retry pip install -q numpy protobuf typing-extensions - fi + pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" + retry pip install -q numpy protobuf typing-extensions else pip install "\$pkg" retry pip install -q numpy protobuf typing-extensions diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 7f89c5c2dd8e..87fea14b8d28 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -134,7 +134,6 @@ export DESIRED_PYTHON="${DESIRED_PYTHON:-}" export DESIRED_CUDA="$DESIRED_CUDA" export LIBTORCH_VARIANT="${LIBTORCH_VARIANT:-}" export BUILD_PYTHONLESS="${BUILD_PYTHONLESS:-}" -export USE_SPLIT_BUILD="${USE_SPLIT_BUILD:-}" if [[ "${OSTYPE}" == "msys" ]]; then export LIBTORCH_CONFIG="${LIBTORCH_CONFIG:-}" if [[ "${LIBTORCH_CONFIG:-}" == 'debug' ]]; then diff --git a/.circleci/scripts/binary_upload.sh b/.circleci/scripts/binary_upload.sh index cf87748d538c..6c4aa8bee1df 100755 --- a/.circleci/scripts/binary_upload.sh +++ b/.circleci/scripts/binary_upload.sh @@ -23,10 +23,6 @@ if [[ "${DRY_RUN}" = "disabled" ]]; then AWS_S3_CP="aws s3 cp" fi -if [[ "${USE_SPLIT_BUILD:-false}" == "true" ]]; then - UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_pypi_pkg" -fi - # this is special build with all dependencies packaged if [[ ${BUILD_NAME} == *-full* ]]; then UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_full" diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 647671e8c83d..85c7999c1857 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -54,6 +54,7 @@ self-hosted-runner: - linux.rocm.gpu.2 - linux.rocm.gpu.4 # gfx942 runners + - linux.rocm.gpu.gfx942.1 - linux.rocm.gpu.gfx942.2 - linux.rocm.gpu.gfx942.4 - rocm-docker diff --git a/.github/actions/setup-rocm/action.yml b/.github/actions/setup-rocm/action.yml index d3644c52fbcd..a58db801b1cf 100644 --- a/.github/actions/setup-rocm/action.yml +++ b/.github/actions/setup-rocm/action.yml @@ -59,11 +59,6 @@ runs: echo "$msg" exit 1 fi - if [[ $ngpu -eq 1 ]]; then - echo "Error: only 1 GPU detected, at least 2 GPUs are needed for distributed jobs" - echo "$msg" - exit 1 - fi - name: Runner diskspace health check uses: pytorch/pytorch/.github/actions/diskspace-cleanup@main diff --git a/.github/actions/test-pytorch-binary/action.yml b/.github/actions/test-pytorch-binary/action.yml index 63acd791b85c..d4b8be8b609a 100644 --- a/.github/actions/test-pytorch-binary/action.yml +++ b/.github/actions/test-pytorch-binary/action.yml @@ -24,7 +24,6 @@ runs: -e PYTORCH_FINAL_PACKAGE_DIR \ -e PYTORCH_ROOT \ -e SKIP_ALL_TESTS \ - -e USE_SPLIT_BUILD \ --tty \ --detach \ -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index bdcf79a8164c..9f7623cf35ca 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -bf305f538005f2e900f8850ed57146024a8bc559 +bdb88e1d66f272cad72156c90ac8428ca61a601c diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index 71ac82425e57..b86f3276765d 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -ca9e2be3ed6320b51f52f536595cd24e254f8bb2 +458e74eb907f96069e6d8a4f3c9f457001fef2ea diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 031b844d4a55..cf8eb1a1efce 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -29ae4c76c026185f417a25e841d2cd5e65f087a3 +095faec1e7b6cc47220181e74ae9cde2605f9b00 diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 5fc358c820f2..354381755ce5 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -488,6 +488,10 @@ - torch/_dynamo/** - torch/csrc/dynamo/** - test/dynamo/** + - test/dynamo_expected_failures/** + - test/dynamo_skips/** + - test/inductor_expected_failures/** + - test/inductor_skips/** approved_by: - guilhermeleobas mandatory_checks_name: diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 deleted file mode 100644 index b6e9a6ce9f3e..000000000000 --- a/.github/requirements/conda-env-macOS-ARM64 +++ /dev/null @@ -1,5 +0,0 @@ -# Not pinning certifi so that we can always get the latest certificates -certifi -pip=23.2.1 -pkg-config=0.29.2 -wheel=0.37.1 diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 8252a8458067..ce4a44953413 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -193,7 +193,7 @@ def arch_type(arch_version: str) -> str: "cpu": "libtorch-cxx11-builder:cpu", } -FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"] +FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"] def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str: @@ -273,7 +273,6 @@ def generate_wheels_matrix( os: str, arches: Optional[list[str]] = None, python_versions: Optional[list[str]] = None, - use_split_build: bool = False, ) -> list[dict[str, str]]: package_type = "wheel" if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": @@ -315,15 +314,11 @@ def generate_wheels_matrix( # TODO: Enable python 3.13t on cpu-s390x if gpu_arch_type == "cpu-s390x" and python_version == "3.13t": continue - - if use_split_build and ( - arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux" + # TODO: Enable python 3.14 on non linux OSes + if os != "linux" and ( + python_version == "3.14" or python_version == "3.14t" ): - raise RuntimeError( - "Split build is only supported on linux with cuda 12* and cpu.\n" - f"Currently attempting to build on arch version {arch_version} and os {os}.\n" - "Please modify the matrix generation to exclude this combination." - ) + continue # cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install @@ -339,7 +334,6 @@ def generate_wheels_matrix( "gpu_arch_type": gpu_arch_type, "gpu_arch_version": gpu_arch_version, "desired_cuda": desired_cuda, - "use_split_build": "True" if use_split_build else "False", "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split( ":" )[0], @@ -372,7 +366,6 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), - "use_split_build": "True" if use_split_build else "False", "container_image": WHEEL_CONTAINER_IMAGES[ arch_version ].split(":")[0], @@ -395,7 +388,6 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), - "use_split_build": "True" if use_split_build else "False", "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split( ":" )[0], diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 4df6150f9765..67906d4ad88d 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -59,9 +59,7 @@ class BinaryBuildWorkflow: is_scheduled: str = "" branches: str = "nightly" # Mainly for macos - cross_compile_arm64: bool = False macos_runner: str = "macos-14-xlarge" - use_split_build: bool = False # Mainly used for libtorch builds build_variant: str = "" @@ -72,9 +70,6 @@ def __post_init__(self) -> None: for item in [self.os, "binary", self.package_type, self.build_variant] if item != "" ) - if self.use_split_build: - # added to distinguish concurrency groups - self.build_environment += "-split" def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: output_file_path = ( @@ -117,21 +112,6 @@ class OperatingSystem: isolated_workflow=True, ), ), - # See https://github.com/pytorch/pytorch/issues/138750 - # BinaryBuildWorkflow( - # os=OperatingSystem.LINUX, - # package_type="manywheel", - # build_configs=generate_binary_build_matrix.generate_wheels_matrix( - # OperatingSystem.LINUX, - # use_split_build=True, - # arches=["11.8", "12.1", "12.4", "cpu"], - # ), - # ciflow_config=CIFlowConfig( - # labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, - # isolated_workflow=True, - # ), - # use_split_build=True, - # ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", @@ -175,27 +155,11 @@ class OperatingSystem: package_type="manywheel", build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.LINUX, - arches=["12.6", "12.8", "12.9"], - python_versions=["3.9"], + arches=["12.8"], + python_versions=["3.12"], ), branches="main", ), - # See https://github.com/pytorch/pytorch/issues/138750 - # BinaryBuildWorkflow( - # os=OperatingSystem.LINUX, - # package_type="manywheel", - # build_configs=generate_binary_build_matrix.generate_wheels_matrix( - # OperatingSystem.LINUX, - # arches=["11.8", "12.1", "12.4"], - # python_versions=["3.9"], - # use_split_build=True, - # ), - # ciflow_config=CIFlowConfig( - # labels={LABEL_CIFLOW_PERIODIC}, - # ), - # branches="main", - # use_split_build=True, - # ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", @@ -338,7 +302,6 @@ class OperatingSystem: generate_binary_build_matrix.RELEASE, libtorch_variants=["shared-with-deps"], ), - cross_compile_arm64=False, macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH}, @@ -351,7 +314,6 @@ class OperatingSystem: build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.MACOS_ARM64 ), - cross_compile_arm64=False, macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index 1481459d40c4..baf560234549 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -262,7 +262,12 @@ def is_exception_branch(branch: str) -> bool: """ Branches that get opted out of experiments by default, until they're explicitly enabled. """ - return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} + return branch.split("/", maxsplit=1)[0] in { + "main", + "nightly", + "release", + "landchecks", + } def load_yaml(yaml_text: str) -> Any: diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 29b92ad461ef..1a5780b01519 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -47,9 +47,6 @@ env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} SKIP_ALL_TESTS: 0 -{%- if cross_compile_arm64 %} - CROSS_COMPILE_ARM64: 1 -{% endif %} !{{ common.concurrency(build_environment) }} jobs: diff --git a/.github/templates/upload.yml.j2 b/.github/templates/upload.yml.j2 index f159d623f1bf..763784f5f3e1 100644 --- a/.github/templates/upload.yml.j2 +++ b/.github/templates/upload.yml.j2 @@ -25,11 +25,6 @@ DOCKER_IMAGE: !{{ config["container_image"] }} DOCKER_IMAGE_TAG_PREFIX: !{{ config["container_image_tag_prefix"] }} {%- endif %} -{%- if config["package_type"] == "manywheel" %} - {%- if config.use_split_build is defined %} - use_split_build: !{{ config["use_split_build"] }} - {%- endif %} -{%- endif %} {%- if config["package_type"] == "libtorch" %} {%- if config["libtorch_config"] %} LIBTORCH_CONFIG: !{{ config["libtorch_config"] }} diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index f11ee4a6621e..bfa035bc753b 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -26,13 +26,6 @@ on: default: 240 type: number description: timeout for the job - use_split_build: - description: | - [Experimental] Build a libtorch only wheel and build pytorch such that - are built from the libtorch wheel. - required: false - type: boolean - default: false ALPINE_IMAGE: required: false type: string @@ -117,7 +110,6 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - USE_SPLIT_BUILD: ${{ inputs.use_split_build }} steps: - name: Make the env permanent during this workflow (but not the secrets) shell: bash @@ -142,7 +134,6 @@ jobs: echo "PR_NUMBER=${{ env.PR_NUMBER }}" echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" echo "SHA1=${{ env.SHA1 }}" - echo "USE_SPLIT_BUILD=${{ env.use_split_build }}" } >> "${GITHUB_ENV} }}" - name: List the env @@ -261,7 +252,6 @@ jobs: -e PYTORCH_ROOT \ -e SKIP_ALL_TESTS \ -e PYTORCH_EXTRA_INSTALL_REQUIREMENTS \ - -e USE_SPLIT_BUILD \ --tty \ --detach \ -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index 434167d0f0c6..476dd182db0f 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -64,13 +64,6 @@ on: required: true type: string description: Hardware to run this job on. Valid values are linux.4xlarge, linux.4xlarge.nvidia.gpu, linux.arm64.2xlarge, and linux.rocm.gpu - use_split_build: - description: | - [Experimental] Build a libtorch only wheel and build pytorch such that - are built from the libtorch wheel. - required: false - type: boolean - default: false secrets: github-token: required: true @@ -104,7 +97,6 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - USE_SPLIT_BUILD: ${{ inputs.use_split_build }} steps: - name: Make the env permanent during this workflow (but not the secrets) shell: bash @@ -129,7 +121,6 @@ jobs: echo "PR_NUMBER=${{ env.PR_NUMBER }}" echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" echo "SHA1=${{ env.SHA1 }}" - echo "USE_SPLIT_BUILD=${{ env.USE_SPLIT_BUILD }}" } >> "${GITHUB_ENV} }}" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml index 6750102b5a29..636b76d42931 100644 --- a/.github/workflows/_binary-upload.yml +++ b/.github/workflows/_binary-upload.yml @@ -51,13 +51,6 @@ on: required: false type: string description: Desired python version - use_split_build: - description: | - [Experimental] Build a libtorch only wheel and build pytorch such that - are built from the libtorch wheel. - required: false - type: boolean - default: false secrets: github-token: required: true @@ -86,7 +79,6 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - USE_SPLIT_BUILD: ${{ inputs.use_split_build }} steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 5173425009f6..4d46de4b8657 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -306,7 +306,6 @@ jobs: -e OUR_GITHUB_JOB_ID \ -e HUGGING_FACE_HUB_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ - -e USE_SPLIT_BUILD \ -e BUILD_ADDITIONAL_PACKAGES \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 1848586d3cef..07be3720b2bf 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -96,7 +96,7 @@ jobs: steps: - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@main - if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} + if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }} with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -109,7 +109,7 @@ jobs: no-sudo: true - name: Setup Python - if: matrix.runner == 'B200' + if: contains(matrix.runner, 'b200') uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: '3.12' @@ -117,7 +117,7 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200' + if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200') - name: configure aws credentials if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} @@ -128,7 +128,7 @@ jobs: aws-region: us-east-1 - name: Login to Amazon ECR - if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }} + if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }} id: login-ecr continue-on-error: true uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 @@ -166,17 +166,17 @@ jobs: uses: pytorch/test-infra/.github/actions/setup-nvidia@main with: driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }} - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }} + if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }} - name: Setup GPU_FLAG for docker run id: setup-gpu-flag run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}" - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }} + if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }} - name: Setup SCCACHE_SERVER_PORT environment for docker run when on container id: setup-sscache-port-flag run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}" - if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }} + if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }} - name: Lock NVIDIA A100 40GB Frequency run: | @@ -277,8 +277,8 @@ jobs: NO_TD: ${{ steps.keep-going.outputs.ci-no-td }} TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs - SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }} - SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }} + SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }} + SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }} SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} @@ -403,7 +403,7 @@ jobs: job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }} - name: Authenticate with AWS - if: ${{ matrix.runner == 'B200' }} + if: ${{ contains(matrix.runner, 'b200') }} uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 2d660d98905e..f73972942b5f 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -88,6 +88,16 @@ jobs: - name: Setup ROCm uses: ./.github/actions/setup-rocm + - name: Runner check GPU count (distributed jobs) + if: ${{ contains(matrix.config, 'distributed') }} + shell: bash + run: | + ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') + if [[ $ngpu -lt 4 ]]; then + echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs" + exit 1 + fi + - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml index 44430522b79d..a3a87708e966 100644 --- a/.github/workflows/check-labels.yml +++ b/.github/workflows/check-labels.yml @@ -34,7 +34,8 @@ jobs: contents: read pull-requests: write name: Check labels - if: github.repository_owner == 'pytorch' + # Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved + if: github.repository_owner == 'pytorch' && false runs-on: linux.24_04.4x steps: - name: Checkout PyTorch diff --git a/.github/workflows/check_mergeability_ghstack.yml b/.github/workflows/check_mergeability_ghstack.yml index 569a174665ba..689ee250c809 100644 --- a/.github/workflows/check_mergeability_ghstack.yml +++ b/.github/workflows/check_mergeability_ghstack.yml @@ -7,7 +7,8 @@ on: jobs: ghstack-mergeability-check: - if: github.repository_owner == 'pytorch' + # Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved + if: github.repository_owner == 'pytorch' && false runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 255e36ebfffa..c83609facbd9 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -51,21 +51,17 @@ jobs: docker-image-name: [ pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, - pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.9-clang12, - pytorch-linux-jammy-py3.11-clang12, - pytorch-linux-jammy-py3.12-clang12, pytorch-linux-jammy-py3.13-clang12, pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, pytorch-linux-noble-rocm-alpha-py3, + pytorch-linux-jammy-rocm-n-py3-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12, pytorch-linux-jammy-py3.9-gcc11, pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, @@ -76,7 +72,8 @@ jobs: pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-linter, pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter, - pytorch-linux-jammy-py3-clang12-executorch, + # Executorch pin needs update + # pytorch-linux-jammy-py3-clang12-executorch, pytorch-linux-jammy-py3.12-triton-cpu ] include: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 8cde3006e381..757eadc0cc04 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -60,7 +60,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -84,7 +83,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -108,7 +106,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 secrets: @@ -129,7 +126,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -156,7 +152,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda-aarch64-12_9 secrets: @@ -176,7 +171,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -200,7 +194,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -224,7 +217,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 secrets: @@ -245,7 +237,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -272,7 +263,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda-aarch64-12_9 secrets: @@ -292,7 +282,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -316,7 +305,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -340,7 +328,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 secrets: @@ -361,7 +348,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -388,7 +374,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda-aarch64-12_9 secrets: @@ -408,7 +393,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -432,7 +416,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -456,7 +439,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 secrets: @@ -477,7 +459,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -504,7 +485,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda-aarch64-12_9 secrets: @@ -524,7 +504,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -548,7 +527,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -572,7 +550,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-aarch64 secrets: @@ -593,7 +570,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -620,7 +596,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda-aarch64-12_9 secrets: @@ -640,7 +615,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -664,7 +638,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -688,7 +661,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu-aarch64 secrets: @@ -709,7 +681,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -736,7 +707,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda-aarch64-12_9 secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index d1e89bb6e2d8..6387d75a73b5 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -42,54 +42,7 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - manywheel-py3_9-cuda12_6-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_6 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_6-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_6-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_6 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_8-build: + manywheel-py3_12-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -103,18 +56,17 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False - DESIRED_PYTHON: "3.9" + DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_8-test: # Testing + manywheel-py3_12-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda12_8-build + - manywheel-py3_12-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -127,56 +79,8 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_9-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: 12.9 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_9 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_9-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_9-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: 12.9 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_9 + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index e309812a86ef..e68d26c669ad 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -60,7 +60,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu @@ -82,7 +81,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel @@ -105,7 +103,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu secrets: @@ -126,7 +123,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_6 @@ -150,7 +146,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel @@ -174,7 +169,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_6 secrets: @@ -195,7 +189,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_8 @@ -219,7 +212,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel @@ -243,7 +235,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_8 secrets: @@ -264,7 +255,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_9 @@ -288,7 +278,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel @@ -312,7 +301,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_9 secrets: @@ -333,7 +321,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_3 @@ -358,7 +345,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -426,7 +412,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_3 secrets: @@ -447,7 +432,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_4 @@ -472,7 +456,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -540,7 +523,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_4 secrets: @@ -560,7 +542,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-xpu @@ -585,7 +566,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.9" permissions: id-token: write @@ -653,7 +633,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-xpu secrets: @@ -673,7 +652,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu @@ -695,7 +673,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel @@ -718,7 +695,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu secrets: @@ -739,7 +715,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_6 @@ -763,7 +738,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel @@ -787,7 +761,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_6 secrets: @@ -808,7 +781,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_8 @@ -832,7 +804,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel @@ -856,7 +827,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_8 secrets: @@ -877,7 +847,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_9 @@ -901,7 +870,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel @@ -925,7 +893,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_9 secrets: @@ -946,7 +913,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_3 @@ -971,7 +937,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1039,7 +1004,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_3 secrets: @@ -1060,7 +1024,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_4 @@ -1085,7 +1048,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1153,7 +1115,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_4 secrets: @@ -1173,7 +1134,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu @@ -1198,7 +1158,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.10" permissions: id-token: write @@ -1266,7 +1225,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-xpu secrets: @@ -1286,7 +1244,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu @@ -1308,7 +1265,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel @@ -1331,7 +1287,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu secrets: @@ -1352,7 +1307,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_6 @@ -1376,7 +1330,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel @@ -1400,7 +1353,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_6 secrets: @@ -1421,7 +1373,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8 @@ -1445,7 +1396,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel @@ -1469,7 +1419,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8 secrets: @@ -1490,7 +1439,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8-full @@ -1513,7 +1461,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8-full build_environment: linux-binary-manywheel @@ -1537,7 +1484,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8-full secrets: @@ -1558,7 +1504,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_9 @@ -1582,7 +1527,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel @@ -1606,7 +1550,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_9 secrets: @@ -1627,7 +1570,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_3 @@ -1652,7 +1594,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -1720,7 +1661,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_3 secrets: @@ -1741,7 +1681,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_4 @@ -1766,7 +1705,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -1834,7 +1772,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_4 secrets: @@ -1854,7 +1791,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu @@ -1879,7 +1815,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.11" permissions: id-token: write @@ -1947,7 +1882,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-xpu secrets: @@ -1967,7 +1901,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu @@ -1989,7 +1922,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel @@ -2012,7 +1944,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu secrets: @@ -2033,7 +1964,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_6 @@ -2057,7 +1987,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel @@ -2081,7 +2010,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_6 secrets: @@ -2102,7 +2030,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_8 @@ -2126,7 +2053,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel @@ -2150,7 +2076,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_8 secrets: @@ -2171,7 +2096,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_9 @@ -2195,7 +2119,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel @@ -2219,7 +2142,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_9 secrets: @@ -2240,7 +2162,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_3 @@ -2265,7 +2186,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -2333,7 +2253,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_3 secrets: @@ -2354,7 +2273,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_4 @@ -2379,7 +2297,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -2447,7 +2364,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_4 secrets: @@ -2467,7 +2383,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu @@ -2492,7 +2407,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.12" permissions: id-token: write @@ -2560,7 +2474,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-xpu secrets: @@ -2580,7 +2493,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu @@ -2602,7 +2514,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel @@ -2625,7 +2536,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu secrets: @@ -2646,7 +2556,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_6 @@ -2670,7 +2579,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel @@ -2694,7 +2602,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_6 secrets: @@ -2715,7 +2622,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_8 @@ -2739,7 +2645,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel @@ -2763,7 +2668,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_8 secrets: @@ -2784,7 +2688,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_9 @@ -2808,7 +2711,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel @@ -2832,7 +2734,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_9 secrets: @@ -2853,7 +2754,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-rocm6_3 @@ -2878,7 +2778,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13" steps: - name: Setup ROCm @@ -2946,7 +2845,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-rocm6_3 secrets: @@ -2967,7 +2865,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-rocm6_4 @@ -2992,7 +2889,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13" steps: - name: Setup ROCm @@ -3060,7 +2956,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-rocm6_4 secrets: @@ -3080,7 +2975,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu @@ -3105,7 +2999,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13" permissions: id-token: write @@ -3173,7 +3066,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-xpu secrets: @@ -3193,7 +3085,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cpu @@ -3215,7 +3106,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu build_environment: linux-binary-manywheel @@ -3238,7 +3128,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu secrets: @@ -3259,7 +3148,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_6 @@ -3283,7 +3171,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel @@ -3307,7 +3194,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_6 secrets: @@ -3328,7 +3214,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_8 @@ -3352,7 +3237,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel @@ -3376,7 +3260,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_8 secrets: @@ -3397,7 +3280,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_9 @@ -3421,7 +3303,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel @@ -3445,7 +3326,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_9 secrets: @@ -3466,7 +3346,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-rocm6_3 @@ -3491,7 +3370,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13t" steps: - name: Setup ROCm @@ -3559,7 +3437,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-rocm6_3 secrets: @@ -3580,7 +3457,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-rocm6_4 @@ -3605,7 +3481,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13t" steps: - name: Setup ROCm @@ -3673,7 +3548,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-rocm6_4 secrets: @@ -3693,7 +3567,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-xpu @@ -3718,7 +3591,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13t" permissions: id-token: write @@ -3786,9 +3658,1192 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-xpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-cpu-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-cuda12_6-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-cuda12_6 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda12_6 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_6-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-cuda12_6-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda12_6 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-cuda12_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-cuda12_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda12_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-cuda12_8-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda12_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_9-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-cuda12_9-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda12_9 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-rocm6_3-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-rocm6_3 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-rocm6_3-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-rocm6_3-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + DESIRED_PYTHON: "3.14" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_14-rocm6_3 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.3 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm + manywheel-py3_14-rocm6_3-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-rocm6_3-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-rocm6_3 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-rocm6_4 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DESIRED_PYTHON: "3.14" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_14-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm + manywheel-py3_14-rocm6_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-rocm6_4-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-rocm6_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + DESIRED_PYTHON: "3.14" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14-xpu + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-xpu-build + - get-label-type + runs-on: linux.idc.xpu + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + DESIRED_PYTHON: "3.14" + permissions: + id-token: write + contents: read + steps: + - name: Setup XPU + uses: ./.github/actions/setup-xpu + - name: configure aws credentials + id: aws_creds + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v2 + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_14-xpu + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: xpu + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown XPU + uses: ./.github/actions/teardown-xpu + manywheel-py3_14-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14-xpu-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-cpu-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-cuda12_6-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-cuda12_6 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda12_6 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_6-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-cuda12_6-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda12_6 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-cuda12_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-cuda12_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda12_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-cuda12_8-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda12_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_9-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-cuda12_9-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda12_9 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-rocm6_3-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-rocm6_3 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-rocm6_3-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-rocm6_3-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + DESIRED_PYTHON: "3.14t" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_14t-rocm6_3 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.3 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm + manywheel-py3_14t-rocm6_3-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-rocm6_3-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-rocm6_3 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-rocm6_4 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DESIRED_PYTHON: "3.14t" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_14t-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm + manywheel-py3_14t-rocm6_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-rocm6_4-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-rocm6_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_14t-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + DESIRED_PYTHON: "3.14t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_14t-xpu + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-xpu-build + - get-label-type + runs-on: linux.idc.xpu + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + DESIRED_PYTHON: "3.14t" + permissions: + id-token: write + contents: read + steps: + - name: Setup XPU + uses: ./.github/actions/setup-xpu + - name: configure aws credentials + id: aws_creds + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v2 + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_14t-xpu + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: xpu + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown XPU + uses: ./.github/actions/teardown-xpu + manywheel-py3_14t-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_14t-xpu-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml index b6b63c4e38d5..a3e5937fdcc4 100644 --- a/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml @@ -58,7 +58,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_4 @@ -83,7 +82,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 66c0813afe90..9570f8d97a2d 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -60,7 +60,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.9" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -84,7 +83,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -107,7 +105,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x secrets: @@ -127,7 +124,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.10" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -151,7 +147,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -174,7 +169,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x secrets: @@ -194,7 +188,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.11" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -218,7 +211,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -241,7 +233,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x secrets: @@ -261,7 +252,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.12" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -285,7 +275,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -308,7 +297,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x secrets: @@ -328,7 +316,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.13" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -352,7 +339,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -375,7 +361,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x secrets: diff --git a/.github/workflows/inductor-perf-test-b200.yml b/.github/workflows/inductor-perf-test-b200.yml new file mode 100644 index 000000000000..7b59e92386a3 --- /dev/null +++ b/.github/workflows/inductor-perf-test-b200.yml @@ -0,0 +1,154 @@ +name: inductor-perf-b200 + +on: + schedule: + - cron: 0 7 * * 1-6 + - cron: 0 7 * * 0 + # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it + # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs + workflow_dispatch: + inputs: + training: + description: Run training (on by default)? + required: false + type: boolean + default: true + inference: + description: Run inference (on by default)? + required: false + type: boolean + default: true + default: + description: Run inductor_default? + required: false + type: boolean + default: false + dynamic: + description: Run inductor_dynamic_shapes? + required: false + type: boolean + default: false + cppwrapper: + description: Run inductor_cpp_wrapper? + required: false + type: boolean + default: false + cudagraphs: + description: Run inductor_cudagraphs? + required: false + type: boolean + default: true + freezing_cudagraphs: + description: Run inductor_cudagraphs with freezing for inference? + required: false + type: boolean + default: false + aotinductor: + description: Run aot_inductor for inference? + required: false + type: boolean + default: false + maxautotune: + description: Run inductor_max_autotune? + required: false + type: boolean + default: false + benchmark_configs: + description: The list of configs used the benchmark + required: false + type: string + default: inductor_huggingface_perf_cuda_b200,inductor_timm_perf_cuda_b200,inductor_torchbench_perf_cuda_b200 + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + opt_out_experiments: lf + + build: + name: cuda12.8-py3.10-gcc9-sm100 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + # Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100 + # or newer GPUs, so it doesn't benefit much from existing compiler cache + # from trunk. Also use a memory-intensive runner here because memory is + # usually the bottleneck + runner: linux.12xlarge.memory + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '10.0' + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" }, + { config: "inductor_timm_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" }, + { config: "inductor_torchbench_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" }, + ]} + selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio fbgemm torchao" + secrets: inherit + + test-periodically: + name: cuda12.8-py3.10-gcc9-sm100 + uses: ./.github/workflows/_linux-test.yml + needs: build + if: github.event.schedule == '0 7 * * 1-6' + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true + docker-image: ${{ needs.build.outputs.docker-image }} + test-matrix: ${{ needs.build.outputs.test-matrix }} + aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + timeout-minutes: 720 + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit + + test-weekly: + name: cuda12.8-py3.10-gcc9-sm100 + uses: ./.github/workflows/_linux-test.yml + needs: build + if: github.event.schedule == '0 7 * * 0' + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true + docker-image: ${{ needs.build.outputs.docker-image }} + test-matrix: ${{ needs.build.outputs.test-matrix }} + timeout-minutes: 1440 + aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit + + test: + name: cuda12.8-py3.10-gcc9-sm100 + uses: ./.github/workflows/_linux-test.yml + needs: build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} + docker-image: ${{ needs.build.outputs.docker-image }} + test-matrix: ${{ needs.build.outputs.test-matrix }} + aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + timeout-minutes: 720 + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index 377f6d04bc8c..f329fe74e6b6 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -85,26 +85,26 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-jammy-rocm-py3_10 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ - { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index d3f1ff1f1dae..436cf95c156d 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -77,25 +77,25 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-jammy-rocm-py3_10 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks sync-tag: rocm-build test-matrix: | { include: [ - { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" }, - { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, - { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index f4c81ce7d7b8..732ec7eb85f3 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -47,8 +47,8 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7bb1ff9296ab..2acc987e523c 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -75,10 +75,11 @@ jobs: repo-owner: pytorch branch: main pin-folder: .github/ci_commit_pins - - repo-name: executorch - repo-owner: pytorch - branch: main - pin-folder: .ci/docker/ci_commit_pins + # executorch jobs are disabled since it needs some manual work for the hash update + # - repo-name: executorch + # repo-owner: pytorch + # branch: main + # pin-folder: .ci/docker/ci_commit_pins - repo-name: triton repo-owner: triton-lang branch: main diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 976fb241c99f..7d43c68c61b0 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -51,37 +51,6 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-jammy-cuda12_4-py3_10-gcc11-sm89-build: - name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11 - cuda-arch-list: 8.9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_4-py3_10-gcc11-sm89-test: - name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_4-py3_10-gcc11-sm89-build - - target-determination - with: - build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cuda12_4-py3_10-gcc11-build: name: linux-jammy-cuda12.4-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index be0bdc527cc1..3fe8ac15a305 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -254,67 +254,6 @@ jobs: timeout-minutes: 600 secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed: - name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-arch-list: '7.5' - test-matrix: | - { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-test-distributed: - name: linux-jammy-cuda12.8-py3.10-gcc11-test - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.test-matrix }} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-build: - name: linux-jammy-cuda12.8-py3.10-gcc11 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-test: - name: linux-jammy-cuda12.8-py3.10-gcc11 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-build - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml @@ -329,30 +268,6 @@ jobs: ]} secrets: inherit - linux-jammy-py3_9-clang9-xla-build: - name: linux-jammy-py3_9-clang9-xla - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.9-clang9-xla - docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite - test-matrix: | - { include: [ - { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - ]} - secrets: inherit - - linux-jammy-py3_9-clang9-xla-test: - name: linux-jammy-py3_9-clang9-xla - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-py3_9-clang9-xla-build - with: - build-environment: linux-jammy-py3.9-clang9-xla - docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cpu-py3_10-gcc11-bazel-test: name: linux-jammy-cpu-py3.10-gcc11-bazel-test uses: ./.github/workflows/_bazel-build-test.yml @@ -402,38 +317,8 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc11-sm89-build: - name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-arch-list: 8.9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-sm89-test: - name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-sm89-build - - target-determination - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-py3-clang12-executorch-build: + if: false # Docker build needs pin update name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build.yml needs: get-label-type @@ -458,31 +343,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: - name: cuda12.8-py3.10-gcc9-sm75 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks - cuda-arch-list: '7.5' - test-matrix: | - { include: [ - { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: - name: cuda12.8-py3.10-gcc9-sm75 - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-xpu-2025_1-py3_9-build: name: linux-jammy-xpu-2025.1-py3.9 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index c51d89e5c955..7e3ba43bf984 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -48,12 +48,12 @@ jobs: sync-tag: rocm-build test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index c656c16e97c2..08fcd3340262 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -10,6 +10,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +permissions: + id-token: write + contents: read + jobs: get-default-label-prefix: if: github.repository_owner == 'pytorch' diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 3879b62cc020..19b0e88b5921 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -63,6 +63,43 @@ jobs: ]} secrets: inherit + linux-jammy-cuda12_8-py3_10-gcc11-build: + name: linux-jammy-cuda12.8-py3.10-gcc11 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '7.5 8.9' + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-test: + name: linux-jammy-cuda12.8-py3.10-gcc11 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-build + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} + secrets: inherit + + # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build: name: linux-jammy-cuda12.8-py3.10-gcc11-no-ops @@ -205,7 +242,7 @@ jobs: with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.9-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks test-matrix: | { include: [ { config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, diff --git a/.github/workflows/unstable.yml b/.github/workflows/unstable.yml index 08ae920e7cb0..7f0fe6058bd0 100644 --- a/.github/workflows/unstable.yml +++ b/.github/workflows/unstable.yml @@ -12,7 +12,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: # There must be at least one job here to satisfy GitHub action workflow syntax @@ -51,3 +53,27 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + + linux-jammy-py3_9-clang9-xla-build: + name: linux-jammy-py3_9-clang9-xla + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-clang9-xla + docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite + test-matrix: | + { include: [ + { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + ]} + secrets: inherit + + linux-jammy-py3_9-clang9-xla-test: + name: linux-jammy-py3_9-clang9-xla + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-py3_9-clang9-xla-build + with: + build-environment: linux-jammy-py3.9-clang9-xla + docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} + secrets: inherit diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index 59d3590665fd..3d445756f7a2 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -23,7 +23,7 @@ jobs: with: repository: pytorch/pytorch stable-branch: viable/strict - requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]' + requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\", \"linux-aarch64\"]' secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }} clickhouse-url: ${{ secrets.CLICKHOUSE_URL }} clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }} diff --git a/.gitignore b/.gitignore index b4e78e642b24..ed7208e55aa0 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,9 @@ merge_record.json torchgen/packaged/* !torchgen/packaged/README.md +# This file is injected by ROCm build scripts to bootstrap in torch/__init__.py. +torch/_rocm_init.py + # IPython notebook checkpoints .ipynb_checkpoints diff --git a/.lintrunner.toml b/.lintrunner.toml index 9c46c91b5e35..3e28de5d16b9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1452,8 +1452,6 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - '--no-black-binary', - 'black==23.12.1', 'usort==1.0.8.post1', 'isort==6.0.1', 'ruff==0.12.2', # sync with RUFF diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 2c67fb1981b7..000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -repos: - - repo: local - hooks: - - id: lintrunner - name: Run Lintrunner in an isolated venv before every push. The first run may be slow... - entry: python scripts/run_lintrunner.py # wrapper below - language: python # pre‑commit manages venv for the wrapper - additional_dependencies: [] # wrapper handles lintrunner install - always_run: true - stages: [pre-push] # fire only on pre‑push - pass_filenames: false # Lintrunner gets no per‑file args - verbose: true # stream output as it is produced...allegedly anyways diff --git a/AGENTS.md b/AGENTS.md index daf0f491702b..3d5436a02a85 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1 +1,17 @@ - This is the only AGENTS.md, there are no recursive AGENTS.md +- When you are working on a bug, first create a standalone file that + reproduces the bug and verify it fails in the expected way. Use this to + test if your changes work. Once the change is passing, find an appropriate + test file to add the test to and make sure to follow local conventions on + the test file. +- If you are running the real test suite, DO NOT run the entire test suite. + Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir' +- Do NOT run setup.py, you do not have a working build environment +- Do NOT run pre-commit, it is not setup +- To run lint, run 'lintrunner -a' (which will autoapply changes) +- Do NOT attempt to install dependencies, you do not have Internet access +- When you are ready to make a PR, do exactly these steps: + - git stash -u + - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch + - git stash pop + - Resolve conflicts if necessary diff --git a/CMakeLists.txt b/CMakeLists.txt index 119d845f7391..cc9476bb001a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,7 +239,9 @@ option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) -cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF) +cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX OR WIN32" OFF) +cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF) +option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF) cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF @@ -251,7 +253,6 @@ cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) -option(USE_FAKELOWP "Use FakeLowp operators" OFF) option(USE_GFLAGS "Use GFLAGS" OFF) option(USE_GLOG "Use GLOG" OFF) option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -260,11 +261,13 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) +option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option(USE_NCCL "Use NCCL" ON - "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) + "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_XCCL "Use XCCL" ON "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) +cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" OFF) @@ -322,7 +325,6 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN}) cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN" OFF) option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF) -option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option( USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) @@ -564,7 +566,7 @@ if(MSVC) set(CMAKE_NINJA_CMCLDEPS_RC OFF) if(MSVC_Z7_OVERRIDE) # CMake set debug flags to use /Z7 - set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT Embedded) + set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT "$<$:Embedded>") endif() foreach( flag_var @@ -834,10 +836,11 @@ include(ExternalProject) # ---[ Dependencies ---[ FBGEMM doesn't work on x86 32bit and # CMAKE_SYSTEM_PROCESSOR thinks its 64bit -if(USE_FBGEMM - AND((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VOID_P EQUAL - 4) - OR CMAKE_SYSTEM_PROCESSOR STREQUAL "x86")) +if(USE_FBGEMM AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + message(WARNING + "x64 operating system is required for FBGEMM. " + "Not compiling with FBGEMM. " + "Turn this warning off by USE_FBGEMM=OFF.") set(USE_FBGEMM OFF) endif() diff --git a/CODEOWNERS b/CODEOWNERS index d18517c9ef80..1d91adacb062 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -14,7 +14,6 @@ /torch/csrc/autograd/ @albanD @soulitzer /torch/autograd/ @albanD @soulitzer /tools/autograd/ @albanD @soulitzer -/torch/header_only_apis.txt @janeyx99 /torch/nn/ @albanD @jbschlosser @mikaylagawarecki /torch/optim/ @albanD @janeyx99 /test/test_public_bindings.py @albanD @@ -165,6 +164,7 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd # torch.export /torch/export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi /torch/_export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi +/torch/_export/serde/schema.py @SherlockNoMad @zhxchen17 # Dynamic Shapes /torch/fx/experimental/symbolic_shapes.py @bobrenjc93 @laithsakka @@ -196,3 +196,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed /torch/utils/_cxx_pytree.py @XuehaiPan /torch/utils/pytree/ @XuehaiPan /torch/_dynamo/polyfills/pytree.py @XuehaiPan + +# Relating to libtorch ABI +/torch/csrc/stable/ @janeyx99 @mikaylagawarecki +/torch/headeronly/ @janeyx99 +/torch/header_only_apis.txt @janeyx99 diff --git a/README.md b/README.md index 62e3b9ea4937..03f76893e3e8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png) +![PyTorch Logo](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/pytorch-logo-dark.png) -------------------------------------------------------------------------------- @@ -72,7 +72,7 @@ Elaborating Further: If you use NumPy, then you have used Tensors (a.k.a. ndarray). -![Tensor illustration](./docs/source/_static/img/tensor_illustration.png) +![Tensor illustration](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/tensor_illustration.png) PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount. @@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research. -![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif) +![Dynamic graph](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/dynamic_graph.gif) ### Python First @@ -243,7 +243,7 @@ git submodule update --init --recursive ```bash conda install cmake ninja -# Run this command from the PyTorch directory after cloning the source code using the ā€œGet the PyTorch Sourceā€œ section below +# Run this command from the PyTorch directory after cloning the source code using the ā€œGet the PyTorch Sourceā€œ section above pip install -r requirements.txt ``` @@ -276,7 +276,7 @@ conda install pkg-config libuv pip install mkl-static mkl-include # Add these packages if torch.distributed is needed. # Distributed package support on Windows is a prototype feature and is subject to changes. -conda install -c conda-forge libuv=1.39 +conda install -c conda-forge libuv ``` #### Install PyTorch @@ -560,7 +560,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. -PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. +PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), [Alban Desmaison](https://github.com/albanD), [Piotr Bialecki](https://github.com/ptrblck) and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. diff --git a/android/README.md b/android/README.md index 6b8000c13fcc..f0c74750522d 100644 --- a/android/README.md +++ b/android/README.md @@ -2,7 +2,7 @@ ## Demo applications and tutorials -Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). +Please refer to [meta-pytorch/executorch-examples](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions. diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 2938e690d491..5f4997357f82 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -119,6 +119,8 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp") file(GLOB_RECURSE native_mps_mm "native/mps/*.mm") file(GLOB_RECURSE native_mps_metal "native/mps/*.metal") file(GLOB_RECURSE native_mps_h "native/mps/*.h") +file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm") +file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal") file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_quantized_cpp @@ -178,26 +180,27 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) - if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) - set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) - if(USE_CK_FLASH_ATTENTION STREQUAL "1") - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) - if(NUM_ARCHS GREATER 1) - message(WARNING "Building CK for multiple archs can increase build time considerably! - Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") - endif() - endif() - message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") - message(STATUS "Generating CK kernel instances...") - add_subdirectory(native/transformers/hip/flash_attn/ck) - file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) - # FAv3 Generation - add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) - file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) + if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1") + message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead") + caffe2_update_option(USE_ROCM_CK_SDPA ON) + endif() + if(USE_ROCM_CK_SDPA) + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") endif() + endif() + message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled") + message(STATUS "Generating CK kernel instances...") + add_subdirectory(native/transformers/hip/flash_attn/ck) + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + # FAv3 Generation + add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) + file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") @@ -416,39 +419,42 @@ if(USE_CUDA) endif() if(USE_ROCM) - # NOTE: The PyTorch build does not actually add_subdirectory - # third_party/composable_kernel or use it as a CMake library. What is used - # is header only, so this should be ok, except that the CMake build generates - # a ck/config.h. We just do that part here. Without this, the ck.h from the - # ROCM SDK may get accidentally used instead. - function(_pytorch_rocm_generate_ck_conf) - set(CK_ENABLE_INT8 "ON") - set(CK_ENABLE_FP16 "ON") - set(CK_ENABLE_FP32 "ON") - set(CK_ENABLE_FP64 "ON") - set(CK_ENABLE_BF16 "ON") - set(CK_ENABLE_FP8 "ON") - set(CK_ENABLE_BF8 "ON") - set(CK_USE_XDL "ON") - set(CK_USE_WMMA "ON") - configure_file( - "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" - "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" - ) - endfunction() - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) - _pytorch_rocm_generate_ck_conf() + if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM) + # NOTE: The PyTorch build does not actually add_subdirectory + # third_party/composable_kernel or use it as a CMake library. What is used + # is header only, so this should be ok, except that the CMake build generates + # a ck/config.h. We just do that part here. Without this, the ck.h from the + # ROCM SDK may get accidentally used instead. + function(_pytorch_rocm_generate_ck_conf) + set(CK_ENABLE_INT8 "ON") + set(CK_ENABLE_FP16 "ON") + set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_FP64 "ON") + set(CK_ENABLE_BF16 "ON") + set(CK_ENABLE_FP8 "ON") + set(CK_ENABLE_BF8 "ON") + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") + configure_file( + "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" + ) + endfunction() + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) + _pytorch_rocm_generate_ck_conf() + endif() # Next two lines are needed because TunableOp uses third-party/fmt list(APPEND ATen_HIP_INCLUDE $) list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only) -if(USE_FLASH_ATTENTION) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) -endif() + if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) + endif() list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} @@ -458,12 +464,17 @@ endif() ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) - if(WIN32) # Windows doesn't support Composable Kernels + if(NOT USE_ROCM_CK_GEMM) file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" ${native_hip_bgemm} ${native_hip_ck}) endif() + if(WIN32) # Windows doesn't support Composable Kernels and Triton + exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" + ${native_transformers_hip_hip} ${native_transformers_hip_cpp}) + endif() + # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) list(APPEND all_hip_cpp ${native_nested_hip_cpp} @@ -698,29 +709,25 @@ endif() if(USE_MPS) include(../../../cmake/Metal.cmake) - set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h}) + set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm}) if(CAN_COMPILE_METAL) - foreach(SHADER ${native_mps_metal}) + foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) cmake_path(GET SHADER STEM TGT_STEM) - string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air") - string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air") + string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air") list(APPEND AIR_BASIC ${TGT_BASIC}) - list(APPEND AIR_BFLOAT ${TGT_BFLOAT}) - metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0") - metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1") + metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1") endforeach() air_to_metallib(kernels_basic.metallib ${AIR_BASIC}) - air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT}) add_custom_command( COMMAND echo "// $$(date)" > metallib_dummy.cpp - DEPENDS kernels_basic.metallib kernels_bfloat.metallib + DEPENDS kernels_basic.metallib OUTPUT metallib_dummy.cpp COMMENT "Updating metallibs timestamp") - add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp) + add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp) else() file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps") - foreach(SHADER ${native_mps_metal}) + foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) cmake_path(GET SHADER STEM TGT_STEM) string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h") metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME}) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 2b89a46ed9af..30c2235131fb 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -480,6 +480,9 @@ at::BlasBackend Context::blasPreferredBackend() { // call site for blasPreferredBackend(), we set it to an actual value. if (blas_preferred_backend == at::BlasBackend::Default) { blas_preferred_backend = at::BlasBackend::Cublas; + // This logic sits in the getter because it needs to validate + // values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT + // which initialize the backend without calling the setter #ifdef USE_ROCM // AMD Instinct targets prefer hipblaslt static const bool hipblaslt_preferred = []() { @@ -509,6 +512,10 @@ at::BlasBackend Context::blasPreferredBackend() { // hipblaslt support for all archs is not as complete as hipblas if (blas_preferred_backend == at::BlasBackend::Cublaslt) { static const bool hipblaslt_unsupported = []() { + if(!hasCuBLASLt()) + { + return true; + } static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 @@ -534,6 +541,24 @@ at::BlasBackend Context::blasPreferredBackend() { return blas_preferred_backend; } +bool Context::ckSupported() { +#ifdef USE_ROCM + static const std::vector supported_archs = { + "gfx90a", "gfx942", "gfx950" + }; + for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) { + if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) { + TORCH_WARN_ONCE( + "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); + return false; + } + } + return true; +#else + return false; +#endif +} + void Context::setBlasPreferredBackend(at::BlasBackend b) { #ifdef _MSC_VER TORCH_WARN_ONCE( @@ -543,8 +568,14 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #else TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(), "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt."); - TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(), - "Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm."); +#ifdef USE_ROCM + static const bool ckSupportedFlag = ckSupported(); + static const bool hasCKGEMMFlag = hasCKGEMM(); + TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag), + "Cannot set preferred blas backend to CK since following conditions are not true: ", + "architecture supported for CK: ", ckSupportedFlag, + ", PyTorch built with CK GEMM support: ", hasCKGEMMFlag); +#endif if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) { TORCH_WARN_ONCE( "torch.backends.cuda.preferred_blas_library is an experimental feature. " @@ -556,35 +587,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #endif } -at::ROCmFABackend Context::getROCmFAPreferredBackend() const { +at::ROCmFABackend Context::getROCmFAPreferredBackend() { +#ifdef USE_ROCM + // Set potential "Default" value so we don't have to interpret at call sites. + // We use aotriton backend as the default, for now. + if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) { + rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton; + } else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) { + // This logic sits in the getter because it needs to validate + // values set via env vars such as TORCH_ROCM_FA_PREFER_CK + // which initialize the backend without calling the setter + // Perform validity checking + static const bool hasCKSDPAFlag = hasCKSDPA(); + static const bool ckSupportedFlag = ckSupported(); + if(!(hasCKSDPAFlag && ckSupportedFlag)){ + TORCH_WARN_ONCE( + "Cannot set preferred SDPA backend to CK since following conditions are not true: ", + "architecture supported for CK: ", ckSupportedFlag, + ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag); + rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton; + } + } +#endif + return rocm_fa_preferred_backend; } void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { - - // TODO: add plumbing for hasCK for validity checking - TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(), - "Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm."); #ifdef USE_ROCM - if(b == at::ROCmFABackend::Ck) { - static const bool ck_unsupported = []() { - static const std::vector archs = { - "gfx90a", "gfx942" - }; - for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(archs, index)) { - TORCH_WARN_ONCE( - "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); - return true; - } - } - return false; - }(); - if(!ck_unsupported) rocm_fa_preferred_backend = b; - } - else { - rocm_fa_preferred_backend = b; - } + static const bool hasCKSDPAFlag = hasCKSDPA(); + static const bool ckSupportedFlag = ckSupported(); + TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag), + "Cannot set preferred SDPA backend to CK since following conditions are not true: ", + "architecture supported for CK: ", ckSupportedFlag, + ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag); #endif rocm_fa_preferred_backend = b; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 945076f3f012..2cc12a38a0b6 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -132,6 +132,7 @@ class TORCH_API Context { static bool hasKleidiAI(); static bool hasLAPACK(); static bool hasMKLDNN(); + static bool ckSupported(); static bool hasMAGMA() { return detail::getCUDAHooks().hasMAGMA(); } @@ -162,6 +163,12 @@ class TORCH_API Context { static bool hasROCM() { return detail::getCUDAHooks().hasROCM(); } + static bool hasCKSDPA() { + return detail::getCUDAHooks().hasCKSDPA(); + } + static bool hasCKGEMM() { + return detail::getCUDAHooks().hasCKGEMM(); + } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } @@ -252,7 +259,7 @@ class TORCH_API Context { at::BlasBackend blasPreferredBackend(); void setBlasPreferredBackend(at::BlasBackend); - at::ROCmFABackend getROCmFAPreferredBackend() const; + at::ROCmFABackend getROCmFAPreferredBackend(); void setROCmFAPreferredBackend(at::ROCmFABackend); // Note [Enabling Deterministic Operations] diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f37e492c861f..f23b35047fcc 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +TORCH_API inline void emptyCache() { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->emptyCache(); +} + +TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); +} + +TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); +} + +TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetPeakStats(device_index); +} + } // namespace at::accelerator namespace at { diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 5634733325a2..0e535ab20cd2 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -31,7 +31,9 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { return at::globalContext().getPinnedMemoryAllocator(opt_device_type); } else { TORCH_CHECK( - false, "Need to provide pin_memory allocator to use pin memory.") + false, + "pin_memory=True requires a CUDA or other accelerator backend; " + "no pinned memory allocator is available on this system.") } } diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index afd0a6b67674..2bf57a7ca5cb 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -239,6 +239,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { KERNEL_MPS(scaled_dot_product_attention, lower_precision_fp) // fp32 + KERNEL_MPS(conv_transpose3d, input, fp32) KERNEL_MPS(acos, fp32) KERNEL_MPS(asin, fp32) KERNEL_MPS(cosh, fp32) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index c6087f0a68ec..72589436606e 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -97,6 +97,8 @@ c10::TypePtr IValue::TagType::get(const IValue& v) { return ComplexType::get(); case Tag::Int: return IntType::get(); + case Tag::UInt: + return IntType::get(); case Tag::SymInt: return c10::SymIntType::get(); case Tag::SymFloat: @@ -320,6 +322,8 @@ IValue IValue::equals(const IValue& rhs) const { return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble(); case Tag::Int: return rhs.isInt() && lhs.toInt() == rhs.toInt(); + case Tag::UInt: + return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt(); case Tag::SymInt: return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt(); case Tag::SymFloat: @@ -379,6 +383,8 @@ size_t IValue::hash(const IValue& v) { case Tag::Int: return c10::get_hash(v.payload.u.as_int); // NB: these are technically strict aliasing violations + case Tag::UInt: + return c10::get_hash(v.payload.u.as_int); case Tag::SymInt: return c10::get_hash(v.payload.u.as_int); case Tag::SymFloat: @@ -806,6 +812,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { return printComplex(out, v); } case IValue::Tag::Int: return out << v.toInt(); + case IValue::Tag::UInt: + return out << v.toUInt(); case IValue::Tag::SymInt: return out << v.toSymInt(); case IValue::Tag::SymFloat: diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 175860dc99a7..ab2039e05820 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -160,6 +161,7 @@ struct Capsule { _(Double) \ _(ComplexDouble) \ _(Int) \ + _(UInt) \ _(SymInt) \ _(SymFloat) \ _(SymBool) \ @@ -653,6 +655,29 @@ struct TORCH_API IValue final { } } + // Unsigned + IValue(uint64_t u) : tag( u <= std::numeric_limits::max() ? Tag::Int : Tag::UInt) { + payload.u.as_uint = u; + } + + + // See Note [Meaning of HAS_u] + // IValue type model closely follows that of c10::Scalar + // Where all integers are upcast to 64-bit representation, and `as_int` is used as default + // representation unless value could not be represented as signed int + bool isUnsigned() const { + return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0); + } + + uint64_t toUInt() const { + if (isUnsigned()) { + return payload.u.as_uint; + } else { + TORCH_INTERNAL_ASSERT(0, "expected unsigned int"); + } + } + + // Bool IValue(bool b) : tag(Tag::Bool) { #if defined(__clang__) && defined(__x86_64__) @@ -893,8 +918,14 @@ struct TORCH_API IValue final { } else { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( s.isIntegral(false), "Unknown type in Scalar"); - tag = Tag::Int; - payload.u.as_int = s.toLong(); + if (s.isUnsigned()) { + const auto val = s.toUInt64(); + payload.u.as_uint = val; + tag = val <= std::numeric_limits::max() ? Tag::Int : Tag::UInt; + } else { + payload.u.as_int = s.toLong(); + tag = Tag::Int; + } } } @@ -918,6 +949,8 @@ struct TORCH_API IValue final { return toSymFloat(); else if (isSymBool()) return toSymBool(); + else if (isUnsigned()) + return toUInt(); TORCH_CHECK(false, "IValue is not a Scalar"); } @@ -1247,6 +1280,8 @@ struct TORCH_API IValue final { return true; case Tag::Int: return false; + case Tag::UInt: + return false; case Tag::SymInt: return true; case Tag::SymFloat: @@ -1343,6 +1378,8 @@ struct TORCH_API IValue final { union TriviallyCopyablePayload { TriviallyCopyablePayload() : as_int(0) {} int64_t as_int; + // See Note [Meaning of HAS_u] + uint64_t as_uint; double as_double; bool as_bool; // Invariant: never nullptr; null state is represented as diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index cf403365b2df..0dbae4aeed5b 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -832,7 +832,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::bgemm_internal_ck(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } @@ -1273,7 +1273,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); } @@ -1289,7 +1289,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100 gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); @@ -1341,7 +1341,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1357,7 +1357,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); } diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 7fba7c4c7424..2800e505a9b7 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index c8cae16b624f..4f2aa31dd1c3 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 247fdb2537cb..3dedf3fd64c7 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -207,6 +207,27 @@ bool CUDAHooks::hasCuBLASLt() const { #endif } + +bool CUDAHooks::hasCKSDPA() const { +#if !defined(USE_ROCM) + return false; +#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA) + return true; +#else + return false; +#endif +} + +bool CUDAHooks::hasCKGEMM() const { +#if !defined(USE_ROCM) + return false; +#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) + return true; +#else + return false; +#endif +} + bool CUDAHooks::hasROCM() const { // Currently, this is same as `compiledWithMIOpen`. // But in future if there are ROCm builds without MIOpen, diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index b0dac7a71e80..2780369a37b7 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -31,6 +31,8 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCuSOLVER() const override; bool hasCuBLASLt() const override; bool hasROCM() const override; + bool hasCKSDPA() const override; + bool hasCKGEMM() const override; const at::cuda::NVRTC& nvrtc() const override; DeviceIndex current_device() const override; bool isBuilt() const override {return true;} diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index f99e03d156c9..00573e3cf701 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -118,6 +118,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { return false; } + virtual bool hasCKSDPA() const { + return false; + } + + virtual bool hasCKGEMM() const { + return false; + } + virtual const at::cuda::NVRTC& nvrtc() const { TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/MTIAHooksInterface.cpp b/aten/src/ATen/detail/MTIAHooksInterface.cpp index b6e260e59ec4..d2e331abb0c0 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.cpp +++ b/aten/src/ATen/detail/MTIAHooksInterface.cpp @@ -21,6 +21,10 @@ bool isMTIAHooksBuilt() { } // namespace detail +bool MTIAHooksInterface::isAvailable() const { + return detail::isMTIAHooksBuilt() && detail::getMTIAHooks().deviceCount() > 0; +} + C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs) } // namespace at diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index fb8ed6fb2322..b415862f29e7 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -149,6 +149,8 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return; } + + virtual bool isAvailable() const override; }; struct TORCH_API MTIAHooksArgs {}; diff --git a/aten/src/ATen/mps/EmptyTensor.cpp b/aten/src/ATen/mps/EmptyTensor.cpp index 7b04d65ebdd0..d858df073397 100644 --- a/aten/src/ATen/mps/EmptyTensor.cpp +++ b/aten/src/ATen/mps/EmptyTensor.cpp @@ -43,7 +43,6 @@ TensorBase empty_mps( int64_t nelements = c10::multiply_integers(size); auto dtype = dtype_or_default(dtype_opt); TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED); - TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer"); auto dtype_meta = scalarTypeToTypeMeta(dtype); diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index a70ce2510820..87c820430c98 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -18,11 +18,7 @@ namespace at::mps { // Helper enum to check if a MPSGraph op is supported in a given macOS version enum class MacOSVersion : uint32_t { - MACOS_VER_13_1_PLUS = 0, - MACOS_VER_13_2_PLUS, - MACOS_VER_13_3_PLUS, - MACOS_VER_14_0_PLUS, - MACOS_VER_14_4_PLUS, + MACOS_VER_14_4_PLUS = 0, MACOS_VER_15_0_PLUS, MACOS_VER_15_1_PLUS, MACOS_VER_15_2_PLUS, diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 55af5f83b388..72a066c69450 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -32,11 +32,11 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de MPSDevice::MPSDevice() : _mtl_device(nil) { // Check that MacOS 13.0+ version of MPS framework is available - // Create the MPSGraph and check method introduced in 13.0 + // Create the MPSGraph and check method introduced in 14.0 // which is used by MPS backend. id mpsCD = NSClassFromString(@"MPSGraph"); - if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) { + if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) { return; } @@ -66,24 +66,12 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}]; } }; - static bool _macos_13_1_plus = is_os_version_at_least(13, 1); - static bool _macos_13_2_plus = is_os_version_at_least(13, 2); - static bool _macos_13_3_plus = is_os_version_at_least(13, 3); - static bool _macos_14_0_plus = is_os_version_at_least(14, 0); static bool _macos_14_4_plus = is_os_version_at_least(14, 4); static bool _macos_15_0_plus = is_os_version_at_least(15, 0); static bool _macos_15_1_plus = is_os_version_at_least(15, 1); static bool _macos_15_2_plus = is_os_version_at_least(15, 2); switch (version) { - case MacOSVersion::MACOS_VER_13_1_PLUS: - return _macos_13_1_plus; - case MacOSVersion::MACOS_VER_13_2_PLUS: - return _macos_13_2_plus; - case MacOSVersion::MACOS_VER_13_3_PLUS: - return _macos_13_3_plus; - case MacOSVersion::MACOS_VER_14_0_PLUS: - return _macos_14_0_plus; case MacOSVersion::MACOS_VER_14_4_PLUS: return _macos_14_4_plus; case MacOSVersion::MACOS_VER_15_0_PLUS: diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index f6133e887722..a2ec221c1bfe 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -34,7 +34,7 @@ case 14: switch (minor) { case 0: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); + return true; case 4: return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); default: @@ -42,19 +42,7 @@ return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); } case 13: - switch (minor) { - case 0: - return true; - case 1: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS); - case 2: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); - case 3: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - default: - TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+"); - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - } + return true; default: TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false"); return false; diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 79dbe7353e15..b16c1ef04fa0 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -51,7 +51,7 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int * // brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C #if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5) #define ONEDNN_UKERNEL_1 -#elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6) +#elif ((IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR >= 6) || (IDEEP_VERSION_MAJOR > 3)) #define ONEDNN_UKERNEL_2 #endif #if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))) diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 95d11903dc77..8b75f12ebaf2 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -206,6 +206,16 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex float +#define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float +#define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float +#define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32 +#define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32 +#define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32 + TORCH_API void brgemm( int64_t M, int64_t N, diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 8739f45d8ad1..13bef0a00b9c 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin } } +template<> +void _assert_match>( + const c10::Device& original, + const std::optional& compared, + const std::string& name) { + if (compared) { + const c10::Device& expected = compared.value(); + if (original.type() != expected.type()) { + std::stringstream msg; + msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original; + throw std::runtime_error(msg.str()); + } + + // If the expected device doesn't have an index (e.g., just "cuda"), + // or if both devices have the same index, consider them equal + if (expected.has_index() && original.has_index() && expected.index() != original.index()) { + std::stringstream msg; + msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original; + throw std::runtime_error(msg.str()); + } + } +} + void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { _assert_match(tensor.sym_sizes(), sizes, "sizes"); _assert_match(tensor.sym_strides(), strides, "strides"); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 7932e32b428b..5bcb4fe55fd2 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -463,7 +464,7 @@ struct ConvParams { return true; } // native kernel doesn't support 64-bit non-splittable case - if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { + if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) { static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1; if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index 2a6570bd00e3..7e2cba98ff1d 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -367,27 +367,27 @@ void int8pack_mm_kernel_( auto* C_data = C.data_ptr(); const auto* S_data = scales.const_data_ptr(); - int M = A.size(0); - int N = B.size(0); - int K = A.size(1); - int lda = A.stride(0); - constexpr int BLOCK_M = 4; - constexpr int BLOCK_N = 4; - - const int MB = (M + BLOCK_M - 1) / BLOCK_M; - const int NB = (N + BLOCK_N - 1) / BLOCK_N; - - at::parallel_for(0, MB * NB, 0, [&](int begin, int end) { - int mb{0}, nb{0}; + int64_t M = A.size(0); + int64_t N = B.size(0); + int64_t K = A.size(1); + int64_t lda = A.stride(0); + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 4; + + const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M; + const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N; + + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; data_index_init(begin, mb, MB, nb, NB); for (const auto i : c10::irange(begin, end)) { (void)i; - int mb_start = mb * BLOCK_M; - int mb_size = std::min(BLOCK_M, M - mb_start); - int nb_start = nb * BLOCK_N; - int nb_size = std::min(BLOCK_N, N - nb_start); + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); const auto* A_ptr = A_data + mb_start * lda; const auto* B_ptr = B_data + nb_start * K; diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu index 72dc0b6c3463..47c705a667b5 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu @@ -526,7 +526,7 @@ namespace { // we are dealing with packed tensor here. max index is the same as numel. - // TODO: to really support input tensor large enought to go beyond int32, + // TODO: to really support input tensor large enough to go beyond int32, // we will need to restrict out shared memory usage and adjust the launch // config; AT_ASSERT(input_.numel() < std::numeric_limits::max()); @@ -681,7 +681,7 @@ namespace { const dim3 grid(grid_x, grid_y, grid_z); // we are dealing with packed tensor here. max index is the same as numel. - // TODO: to really support input tensor large enought to go beyond int32, + // TODO: to really support input tensor large enough to go beyond int32, // we will need to restrict out shared memory usage and adjust the launch // config; AT_ASSERT(input.numel() < std::numeric_limits::max()); diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 5317ab75ba08..cf8905268bf7 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1634,6 +1634,9 @@ bool use_fast_accum) { TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; + if (!a_is_2d || !b_is_2d) { + TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); + } TORCH_CHECK( mat_a.size(-1) % 16 == 0, "Expected trailing dimension of mat_a to be divisible by 16 ", @@ -1716,6 +1719,9 @@ std::optional out_dtype) { TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; + if (!a_is_2d || !b_is_2d) { + TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); + } // check that the strides are valid, the fn will throw an error if not check_valid_strides_and_return_transposed(mat_a); diff --git a/aten/src/ATen/native/cuda/CuFFTPlanCache.h b/aten/src/ATen/native/cuda/CuFFTPlanCache.h index 06276c72c53a..333c21e94f18 100644 --- a/aten/src/ATen/native/cuda/CuFFTPlanCache.h +++ b/aten/src/ATen/native/cuda/CuFFTPlanCache.h @@ -223,7 +223,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo class CuFFTConfig { public: - // Only move semantics is enought for this class. Although we already use + // Only move semantics is enough for this class. Although we already use // unique_ptr for the plan, still remove copy constructor and assignment op so // we don't accidentally copy and take perf hit. CuFFTConfig(const CuFFTConfig&) = delete; diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index 68acf79f6894..a917b0d6163f 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -241,6 +241,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( Strides tensor_StrideA = make_strides(mat_a.strides()); Strides tensor_StrideB = make_strides(mat_b.strides()); Strides tensor_StrideOutput = make_strides(out.strides()); + Strides tensor_ShapeA = make_strides(mat_a.sizes()); + Strides tensor_ShapeB = make_strides(mat_b.sizes()); at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>( reinterpret_cast(mat_a.data_ptr()), @@ -264,6 +266,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( tensor_StrideA, tensor_StrideB, tensor_StrideOutput, + tensor_ShapeA, + tensor_ShapeB, 0, 0, a_row_major, diff --git a/aten/src/ATen/native/cuda/GroupMMCommon.cuh b/aten/src/ATen/native/cuda/GroupMMCommon.cuh index 943d352a7b42..ed8176b53f84 100644 --- a/aten/src/ATen/native/cuda/GroupMMCommon.cuh +++ b/aten/src/ATen/native/cuda/GroupMMCommon.cuh @@ -38,18 +38,20 @@ __global__ void prepare_grouped_gemm_data( Strides tensor_StrideA, Strides tensor_StrideB, Strides tensor_StrideOutput, + Strides tensor_ShapeA, + Strides tensor_ShapeB, int64_t a_scale_stride, int64_t b_scale_stride, bool a_row_major = true, bool b_row_major = false) { int32_t tid = threadIdx.x; int32_t delta = 0; + int32_t offset = 0; if (offs != nullptr) { int32_t start = tid == 0 ? 0 : offs[tid - 1]; - delta = offs[tid] - start; - if (K < 0) { - CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); - } + offset = offs[tid]; + delta = offset - start; + CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n"); // TMA transfers require global memory tensor addresses to be // aligned to 16 bytes. @@ -84,6 +86,7 @@ __global__ void prepare_grouped_gemm_data( int64_t lda, ldb, ldoutput; if (M < 0) { // A and output is 2d + CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n"); M = delta; lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; @@ -96,6 +99,7 @@ __global__ void prepare_grouped_gemm_data( output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; B_ptrs[tid] = B + tid * tensor_StrideB[0]; } else if (N < 0) { + CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n"); N = delta; lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed @@ -108,6 +112,7 @@ __global__ void prepare_grouped_gemm_data( inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; } } else if (K < 0) { + CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n"); // A, B is 2d, output is 3d K = delta; lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 1696ee64eac6..5bdb3f6cc67d 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -282,6 +282,14 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd( } // not coalsced, so now let try to capture lane-matches... + + if (numel > 16 /*<-hueristic threshold*/ * 64 ) { + // well shucks, unlikely to capture same-dest atomics in a wave. + // fall back to direct fastAtomic... + fastAtomicAdd(self_ptr, index, numel, value, true); + return; + } + // __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd // __match_any_sync() -- returns bit mask of the threads that have same dest addr auto mask = __match_any_sync(__activemask(), (int64_t)dst); diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index b5908cc0abcf..c6d3c25200d5 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -644,7 +644,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta)) // As above, there may be better configurations to use. - constexpr int max_threads = std::is_same_v ? 1024 : 896; // we need 72 or so 32 bit registers for double + constexpr int max_threads_ = std::is_same_v ? 1024 : 896; // we need 72 or so 32 bit registers for double + int max_threads = max_threads_; + // Blackwell launch bounds + if (at::cuda::getCurrentDeviceProperties()->major >= 10) { + max_threads = 512; + } int threads_target = max_threads; while (threads_target / 2 >= 2*max_target_length+1) { threads_target /= 2; diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 15a572804af5..521b46748090 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -209,6 +209,10 @@ struct ReduceConfig { int values_per_thread() const { return div_up(num_inputs, step_input); } + + int mock_values_per_thread(int parallelism) { + return div_up(num_inputs, step_input * parallelism); + } }; std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); @@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ else if (config.ctas_per_output < 16) config.ctas_per_output = 1; bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); - if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) { config.ctas_per_output = 4; + int vpt = config.values_per_thread(); + // Capping the number of values per thread to 2048 for now + // based on known use cases. + while (vpt >= 2048) { + config.ctas_per_output *= 2; + // Computes the new values per thread without side effects + vpt = config.mock_values_per_thread(config.ctas_per_output); + } + } #endif if (config.ctas_per_output > 1) { config.input_mult[2] = config.split_input(config.ctas_per_output); diff --git a/aten/src/ATen/native/cuda/ScaledGroupMM.cu b/aten/src/ATen/native/cuda/ScaledGroupMM.cu index 8afc8970607a..9a06c5907feb 100644 --- a/aten/src/ATen/native/cuda/ScaledGroupMM.cu +++ b/aten/src/ATen/native/cuda/ScaledGroupMM.cu @@ -298,6 +298,9 @@ void f8f8bf16_grouped_gemm_impl_sm90( Strides tensor_StrideA = make_strides(mat_a.strides()); Strides tensor_StrideB = make_strides(mat_b.strides()); Strides tensor_StrideOutput = make_strides(out.strides()); + Strides tensor_ShapeA = make_strides(mat_a.sizes()); + Strides tensor_ShapeB = make_strides(mat_b.sizes()); + // scale stride will be used inside the kernel only if needed, // so for 1d scales the "1" assigned here won't be used int64_t a_scale_stride = scale_a.stride(0); @@ -325,6 +328,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( tensor_StrideA, tensor_StrideB, tensor_StrideOutput, + tensor_ShapeA, + tensor_ShapeB, a_scale_stride, b_scale_stride); diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index 272eb9b9c564..5444bb57eba7 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -1304,7 +1304,7 @@ at::Tensor _convert_weight_to_int4pack_cuda( constexpr int32_t kKTileSize = 16; // GPT-FAST assumes nTileSize of 8 for quantized weight tensor. - // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 + // See https://github.com/meta-pytorch/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 // Torch dynamo also requires the torch ops has the same output shape for each device. // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263 constexpr int32_t kNTileSizeTensor = 8; diff --git a/aten/src/ATen/native/cuda/int8mm.cu b/aten/src/ATen/native/cuda/int8mm.cu new file mode 100644 index 000000000000..60f64cd9fc20 --- /dev/null +++ b/aten/src/ATen/native/cuda/int8mm.cu @@ -0,0 +1,74 @@ +#include +#include +#include +#include + +namespace at::native { + +__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) { + // one thread per output element: [B, N] + int b = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B || n >= N) return; + + float acc = 0.0f; + for (int k = 0; k < K; ++k) { + acc += x[b * K + k] * static_cast(w[n * K + k]); + } + + out[b * N + n] = acc * scale[n]; +} + +void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) { + const int B = x.size(0); + const int K = x.size(1); + const int N = w_int8.size(0); + + const dim3 block(16, 16); + const dim3 grid((N + block.x - 1) / block.x, (B + block.y - 1) / block.y); + + auto stream = at::cuda::getCurrentCUDAStream(); + + weight_int8pack_mm_kernel<<>>( + x.data_ptr(), + w_int8.data_ptr(), + scale.data_ptr(), + out.data_ptr(), + B, K, N); +} + + +// Main GPU entry point +at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) { + // --- Check inputs --- + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor"); + TORCH_CHECK(scale.is_cuda(), "scale must be a CUDA tensor"); + + TORCH_CHECK(x.dim() == 2, "x must be 2D"); + TORCH_CHECK(w_int8.dim() == 2, "w must be 2D"); + TORCH_CHECK(scale.dim() == 1, "scale must be 1D"); + + TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)"); + TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)"); + + // --- Determine shapes --- + auto B = x.size(0); // batch size + auto N = w_int8.size(0); // output dim + + // Ensure inputs are in the correct types for the kernel + auto x_f32 = x.to(at::kFloat); + auto w_int8_contiguous = w_int8.contiguous(); + auto scale_f32 = scale.to(at::kFloat); + + // --- Allocate output --- + auto out = at::empty({B, N}, x.options().dtype(at::kFloat)); + + // --- Launch kernel --- + launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out); + + return out; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index c9e2fb361297..371b77722cd5 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -28,6 +28,22 @@ std::tuple cudnn_batch_norm( TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support"); } +std::tuple cudnn_batch_norm_out( + const Tensor& input, + const Tensor& weight, + const std::optional& bias, + const std::optional& running_mean, + const std::optional& running_var, + bool training, + double exponential_average_factor, + double epsilon, + Tensor& out, + Tensor& save_mean, + Tensor& save_var, + Tensor& reserve) { + AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support"); +} + std::tuple cudnn_batch_norm_backward( const Tensor& input, const Tensor& grad_output, @@ -120,7 +136,12 @@ size_t _get_cudnn_batch_norm_reserve_space_size( return reserve_size; } -std::tuple cudnn_batch_norm( +// Param `reserve` is a placeholder, just passing an empty tensor. +// usage: +// auto reserve = torch::empty({0}, torch::device(torch::kCUDA)); +// at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var, +// reserve); +std::tuple cudnn_batch_norm_out( const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, @@ -128,7 +149,11 @@ std::tuple cudnn_batch_norm( const std::optional& running_var_t_opt, bool training, double exponential_average_factor, - double epsilon) { + double epsilon, + Tensor& output_t, + Tensor& save_mean, + Tensor& save_var, + Tensor& reserve) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); @@ -168,9 +193,6 @@ std::tuple cudnn_batch_norm( cudnnBatchNormMode_t mode = getCudnnBatchNormMode( training, input->suggest_memory_format(), input->dim()); - auto output_t = - at::empty_like(*input, input->options(), input->suggest_memory_format()); - TensorArg output{output_t, "output", 0}; auto handle = getCudnnHandle(); @@ -182,15 +204,8 @@ std::tuple cudnn_batch_norm( Constant one(dataType, 1); Constant zero(dataType, 0); - Tensor save_mean, save_var; - - Tensor reserve; if (training) { - int64_t num_features = input_t.size(1); - save_mean = at::empty({num_features}, weight_t.options()); - save_var = at::empty({num_features}, weight_t.options()); - auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( @@ -238,9 +253,6 @@ std::tuple cudnn_batch_norm( reserve_size)); } else { reserve = at::empty({0}, input->options().dtype(kByte)); - // This keeps a consistent output with native_batch_norm - save_mean = at::empty({0}, weight_t.options()); - save_var = at::empty({0}, weight_t.options()); AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference( handle, mode, @@ -261,10 +273,48 @@ std::tuple cudnn_batch_norm( // save_mean and save_var can be undefined // If this causes problems, we can initialize them to empty tensors // of the correct type - return std::tuple{ + return std::tuple{ output_t, save_mean, save_var, reserve}; } +std::tuple cudnn_batch_norm( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_t_opt, + const std::optional& running_mean_t_opt, + const std::optional& running_var_t_opt, + bool training, + double exponential_average_factor, + double epsilon) { + auto output_t = at::empty_like( + input_t, input_t.options(), input_t.suggest_memory_format()); + Tensor save_mean, save_var, reserve; + + if (training) { + int64_t num_features = input_t.size(1); + save_mean = at::empty({num_features}, weight_t.options()); + save_var = at::empty({num_features}, weight_t.options()); + } else { + // This keeps a consistent output with native_batch_norm + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); + } + + return cudnn_batch_norm_out( + input_t, + weight_t, + bias_t_opt, + running_mean_t_opt, + running_var_t_opt, + training, + exponential_average_factor, + epsilon, + output_t, + save_mean, + save_var, + reserve); +} + // NB: CuDNN only implements the backward algorithm for batchnorm // in training mode (evaluation mode batchnorm has a different algorithm), // which is why this doesn't accept a 'training' parameter. diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 48119a6a3b4c..a482c9041c90 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -2,9 +2,13 @@ #include #include -#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ - (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) +#if AT_CUDNN_ENABLED() +#include +#endif +#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) || \ + (defined(CUDNN_FRONTEND_VERSION) && CUDNN_FRONTEND_VERSION < 10100) namespace at { namespace native { @@ -84,6 +88,37 @@ void run_cudnn_SDP_bprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset) { + TORCH_CHECK( + false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); +} + } // namespace native } // namespace at @@ -95,7 +130,6 @@ void run_cudnn_SDP_bprop( #include #include -#include #include #include @@ -111,40 +145,6 @@ namespace native { #include namespace fe = cudnn_frontend; -using graph_and_tensors = std::tuple< - std::shared_ptr, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::optional>, // Bias - std::shared_ptr, // Attn_scale, - // TODO(eqy): additional options - // std::shared_ptr, // SEQ_LEN_Q, - // std::shared_ptr, // SEQ_LEN_KV, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - // std::shared_ptr, // Dropout_mask, - // std::shared_ptr, // Dropout_scale - std::shared_ptr, // O - std::shared_ptr // Stats - >; - -using graph_and_tensors_backward = std::tuple< - std::shared_ptr, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::optional>, // Bias, - std::shared_ptr, // Attn_scale, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - std::shared_ptr, // O, - std::shared_ptr, // dO, - std::shared_ptr, // stats, - std::shared_ptr, // dQ, - std::shared_ptr, // dK,, - std::shared_ptr // dV, - >; #define MAX_MHA_DIM 4 @@ -298,11 +298,45 @@ struct MHAGraphCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html -thread_local MHAGraphCache mhagraphcache; -thread_local MHAGraphCache - mhagraphbackwardcache; +// We also leak the caches to workaround potential teardown race issues. + +auto& getMHAGraphCache_() { + thread_local auto& instance = + *new MHAGraphCache, MHACacheKeyWrapper>; + return instance; +} + +auto& getMHAGraphBackwardCache_() { + thread_local auto& instance = + *new MHAGraphCache, MHACacheKeyWrapper>; + return instance; +} namespace { + +enum UIDS { + Q, + K, + V, + O, + BIAS, + SCALE, + SEED, + OFFSET, + LSE, + DO, + DQ, + DK, + DV, + SEQ_LEN_Q, + SEQ_LEN_KV, + RAG_Q_OFF, + RAG_K_OFF, + RAG_V_OFF, + RAG_O_OFF, + RAG_LSE_OFF +}; + // analogous to the same function in Descriptors.h for cuDNN Convolutions... auto fixSizeOneDimStrideSDPA( const IntArrayRef sizes, @@ -320,9 +354,10 @@ auto fixSizeOneDimStrideSDPA( } return strides; } + } // namespace -auto build_graph_and_tensors( +auto build_graph( int64_t b, int64_t h, int64_t s_q, @@ -355,46 +390,55 @@ auto build_graph_and_tensors( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) - .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset); - auto Q = mha_graph->tensor( + .set_attn_scale(attn_scale); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + scaled_dot_product_flash_attention_options.set_dropout( + dropout_probability, seed, offset); + } + auto Q_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(Q) .set_name("Q") .set_dim(q.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); - auto K = mha_graph->tensor( + auto K_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(K) .set_name("K") .set_dim(k.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); - auto V = mha_graph->tensor( + auto V_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(V) .set_name("V") .set_dim(v.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); @@ -402,17 +446,20 @@ auto build_graph_and_tensors( if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); - O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); + auto [O_, Stats] = + mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); + O_->set_uid(O); + O_->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { + Stats->set_uid(LSE); Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } @@ -423,20 +470,10 @@ auto build_graph_and_tensors( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats)); + return mha_graph; } -auto build_graph_and_tensors_nestedtensor( +auto build_graph_nestedtensor( int64_t b, int64_t h_q, int64_t h_k, @@ -473,28 +510,22 @@ auto build_graph_and_tensors_nestedtensor( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV = + auto SEQ_LEN_Q_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_Q) + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_KV) .set_name("Seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -506,41 +537,66 @@ auto build_graph_and_tensors_nestedtensor( .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset) - .set_seq_len_q(SEQ_LEN_Q) - .set_seq_len_kv(SEQ_LEN_KV) + .set_seq_len_q(SEQ_LEN_Q_) + .set_seq_len_kv(SEQ_LEN_KV_) .set_padding_mask(true); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + scaled_dot_product_flash_attention_options.set_dropout( + dropout_probability, seed, offset); + } // We hardcode BSHD to cuDNN even though the underlying layout is THD auto q_strides = q.strides(); auto k_strides = k.strides(); auto v_strides = v.strides(); + // NB: cuDNN API shape is transposed constexpr int strideidx0 = 1; constexpr int strideidx1 = 0; constexpr int strideidx2 = 2; - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); std::optional> bias; if (attn_bias.has_value()) { TORCH_CHECK( @@ -548,44 +604,48 @@ auto build_graph_and_tensors_nestedtensor( "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("cum_seq_stats") - // .set_dim({b + 1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - auto RAG_STATS_OFF = nullptr; - Q->set_ragged_offset(RAG_Q_OFF); - K->set_ragged_offset(RAG_K_OFF); - V->set_ragged_offset(RAG_V_OFF); - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + auto RAG_Q_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_Q_OFF) + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_K_OFF) + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_V_OFF) + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_O_OFF) + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + Q_->set_ragged_offset(RAG_Q_OFF_); + K_->set_ragged_offset(RAG_K_OFF_); + V_->set_ragged_offset(RAG_V_OFF_); + auto [O_, Stats] = + mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); auto o_strides = o.strides(); - O->set_output(true) + O_->set_output(true) + .set_uid(O) .set_dim({b, h_q, s_q, d_v}) .set_stride( {INT_MAX, @@ -593,16 +653,20 @@ auto build_graph_and_tensors_nestedtensor( o_strides[strideidx1], o_strides[strideidx2]}); - O->set_ragged_offset(RAG_O_OFF); + O_->set_ragged_offset(RAG_O_OFF_); if (Stats) { - TORCH_CHECK( - false, - "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); - // TODO(eqy): fix when stats (backward) support is added + auto RAG_STATS_OFF = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_LSE_OFF) + .set_name("cum_seq_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); Stats->set_output(true) + .set_uid(LSE) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h_q, s_q, 1}) - .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); + .set_stride({h_q * s_q, 1, h_q, 1}); Stats->set_ragged_offset(RAG_STATS_OFF); } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -611,27 +675,10 @@ auto build_graph_and_tensors_nestedtensor( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats), - std::move(RAG_Q_OFF), - std::move(RAG_K_OFF), - std::move(RAG_V_OFF), - std::move(RAG_O_OFF), - std::move(RAG_STATS_OFF), - std::move(SEQ_LEN_Q), - std::move(SEQ_LEN_KV)); + return mha_graph; } -auto build_graph_and_tensors_backward( +auto build_graph_backward( int64_t b, int64_t h, int64_t s_q, @@ -667,6 +714,7 @@ auto build_graph_and_tensors_backward( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -676,87 +724,327 @@ auto build_graph_and_tensors_backward( .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(q.sizes().vec()) - .set_stride(q.strides().vec())); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(k.sizes().vec()) - .set_stride(k.strides().vec())); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(v.sizes().vec()) - .set_stride(v.strides().vec())); + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); sdpa_backward_options.set_bias(bias.value()); } - auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - - auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutoffset.dtype() == kInt + dropoutseed.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, seed, offset); + } - auto O = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim(o.sizes().vec()) - .set_stride(o.strides().vec())); - auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(O) + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(LSE) .set_name("Stats") .set_dim(softmaxstats.sizes().vec()) .set_stride(softmaxstats.strides().vec()) .set_data_type(fe::DataType_t::FLOAT)); - auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() + auto Do = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(DO) .set_name("DO") .set_dim(dO.sizes().vec()) .set_stride(dO.strides().vec())); + auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( + Q_, K_, V_, O_, Do, Stats, sdpa_backward_options); + Dq->set_uid(DQ); + Dq->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + Dk->set_uid(DK); + Dk->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + Dv->set_uid(DV); + Dv->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); + AT_CUDNN_FRONTEND_CHECK( + mha_graph->create_execution_plans({fe::HeurMode_t::A})); + AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + return mha_graph; +} + +auto build_graph_backward_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset, + cudnnHandle_t& handle) { + auto dtype = fe::DataType_t::HALF; + if (q.scalar_type() == kBFloat16) { + dtype = fe::DataType_t::BFLOAT16; + } + auto mha_graph = std::make_shared(); + // We're baking in float accumulation and scale types + // in theory the graph may support other types, but they + // have not been tested + mha_graph->set_io_data_type(dtype) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto attn_scale = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) + .set_name("Attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto SEQ_LEN_Q_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_Q) + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_KV) + .set_name("Seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("CUDNN_SDPA_NESTEDTENSOR_BACKWARD") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_seq_len_q(SEQ_LEN_Q_) + .set_seq_len_kv(SEQ_LEN_KV_) + .set_padding_mask(true); if (dropout_probability != 0.0f) { - sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, seed, offset); } - auto [DQ, DK, DV] = - mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); - DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); - DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); - DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); + auto q_strides = q.strides(); + auto k_strides = k.strides(); + auto v_strides = v.strides(); + // NB: cuDNN API shape is transposed + constexpr int strideidx0 = 1; + constexpr int strideidx1 = 0; + constexpr int strideidx2 = 2; + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); + auto o_strides = o.strides(); + auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(O) + .set_name("O") + .set_dim({b, h_q, s_q, d_v}) + .set_stride( + {INT_MAX, + o_strides[strideidx0], + o_strides[strideidx1], + o_strides[strideidx2]})); + + std::optional> bias; + if (attn_bias.has_value()) { + TORCH_CHECK( + false, + "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + sdpa_backward_options.set_bias(bias.value()); + } + auto RAG_Q_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_Q_OFF) + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_K_OFF) + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_V_OFF) + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_O_OFF) + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_STATS_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_LSE_OFF) + .set_name("cum_seq_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + O_->set_ragged_offset(RAG_O_OFF_); + Q_->set_ragged_offset(RAG_Q_OFF_); + K_->set_ragged_offset(RAG_K_OFF_); + V_->set_ragged_offset(RAG_V_OFF_); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(LSE) + .set_name("stats") + .set_dim({b, h_q, s_q, 1}) + .set_stride({s_q * h_q, 1, h_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + STATS->set_ragged_offset(RAG_STATS_OFF_); + auto do_strides = dO.strides(); + auto DO_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_ragged_offset(RAG_O_OFF_) + .set_uid(DO) + .set_name("DO") + .set_dim({b, h_q, s_q, d_v}) + .set_stride( + {INT_MAX, + do_strides[strideidx0], + do_strides[strideidx1], + do_strides[strideidx2]})); + auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( + Q_, K_, V_, O_, DO_, STATS, sdpa_backward_options); + Dq->set_output(true) + .set_uid(DQ) + .set_ragged_offset(RAG_Q_OFF_) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]}); + Dk->set_output(true) + .set_uid(DK) + .set_ragged_offset(RAG_K_OFF_) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]}); + Dv->set_output(true) + .set_uid(DV) + .set_ragged_offset(RAG_V_OFF_) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]}); + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(Seed), - std::move(Offset), - std::move(O), - std::move(DO), - std::move(STATS), - std::move(DQ), - std::move(DK), - std::move(DV)); + return mha_graph; } void run_cudnn_SDP_fprop( @@ -817,12 +1105,12 @@ void run_cudnn_SDP_fprop( dropout_probability, is_causal, return_softmaxstats); - auto graph_and_tensors_ptr = mhagraphcache.find(key); - graph_and_tensors graph_and_tensors_values; - if (graph_and_tensors_ptr) { - graph_and_tensors_values = *graph_and_tensors_ptr; + auto graph_ptr = getMHAGraphCache_().find(key); + std::shared_ptr mha_graph; + if (graph_ptr) { + mha_graph = *graph_ptr; } else { - graph_and_tensors_values = build_graph_and_tensors( + mha_graph = build_graph( b, h, s_q, @@ -843,29 +1131,28 @@ void run_cudnn_SDP_fprop( _dropoutoffset, handle); } - auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = - graph_and_tensors_values; - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, _dropoutseed.data_ptr()}, - {offset, _dropoutoffset.data_ptr()}, - {O, o.data_ptr()}}; + std::unordered_map variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {SCALE, &scaling_factor}, + {O, o.data_ptr()}}; if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); + variant_pack[LSE] = softmaxstats.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[bias.value()] = attn_bias.value().data_ptr(); + variant_pack[BIAS] = attn_bias.value().data_ptr(); + } + if (dropout_probability != 0.0f) { + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - mhagraphcache.update(key, graph_and_tensors_values); + getMHAGraphCache_().update(key, mha_graph); } void run_cudnn_SDP_fprop_nestedtensor( @@ -904,72 +1191,55 @@ void run_cudnn_SDP_fprop_nestedtensor( if (return_softmaxstats && !softmaxstats.defined()) { softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - seed, - offset, - O, - Stats, - RAG_Q_OFF, - RAG_K_OFF, - RAG_V_OFF, - RAG_O_OFF, - RAG_STATS_OFF, - SEQ_LEN_Q, - SEQ_LEN_KV] = - build_graph_and_tensors_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - return_softmaxstats, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - softmaxstats, - o, - dropoutseed, - dropoutoffset, - handle); + auto mha_graph = build_graph_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + return_softmaxstats, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + softmaxstats, + o, + dropoutseed, + dropoutoffset, + handle); auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); auto rag_stats_off = cum_seqlen_q.mul(h_q); - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, dropoutseed.data_ptr()}, - {offset, dropoutoffset.data_ptr()}, - {O, o.data_ptr()}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + std::unordered_map variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {SCALE, &scaling_factor}, + {O, o.data_ptr()}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); - variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); + variant_pack[LSE] = softmaxstats.data_ptr(); + variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr(); + } + if (dropout_probability != 0.0f) { + variant_pack[SEED] = dropoutseed.data_ptr(); + variant_pack[OFFSET] = dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { TORCH_CHECK("bias not supported with nestedtensor"); @@ -1053,12 +1323,12 @@ void run_cudnn_SDP_bprop( dropout_probability, is_causal, true); - auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); - graph_and_tensors_backward graph_and_tensors_backward_values; - if (graph_and_tensors_backward_ptr) { - graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr; + auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key); + std::shared_ptr mha_graph; + if (graph_backward_ptr) { + mha_graph = *graph_backward_ptr; } else { - graph_and_tensors_backward_values = build_graph_and_tensors_backward( + mha_graph = build_graph_backward( b, h, s_q, @@ -1082,49 +1352,153 @@ void run_cudnn_SDP_bprop( _dropoutoffset, handle); } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - Seed, - Offset, - O, - Do, - Stats, - Dq, - Dk, - Dv] = graph_and_tensors_backward_values; - std::unordered_map, void*> - variant_pack = {// inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {Do, dO_.data_ptr()}, - {Stats, softmaxstats.data_ptr()}, - // outputs - {Dq, dQ.data_ptr()}, - {Dk, dK.data_ptr()}, - {Dv, dV.data_ptr()}, - // pass by value - {attn_scale, &scaling_factor}}; + std::unordered_map variant_pack = { + // inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {DO, dO_.data_ptr()}, + {LSE, softmaxstats.data_ptr()}, + // outputs + {DQ, dQ.data_ptr()}, + {DK, dK.data_ptr()}, + {DV, dV.data_ptr()}, + {SCALE, &scaling_factor}}; if (dropout_probability != 0.0f) { - variant_pack[Seed] = _dropoutseed.data_ptr(); - variant_pack[Offset] = _dropoutoffset.data_ptr(); + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[bias.value()] = attn_bias.value().data_ptr(); + variant_pack[BIAS] = attn_bias.value().data_ptr(); + } + auto workspace_size = mha_graph->get_workspace_size(); + auto workspace_ptr = + c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); + TORCH_CHECK(!workspace_size || workspace_ptr.get()); + TORCH_CHECK( + mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); + getMHAGraphBackwardCache_().update(key, mha_graph); +} + +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset) { + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || + !softmaxstats.numel()) { + return; } + + Tensor dO_ = dO; + const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1]; + if (innermost_dO_stride != 1) { + permute_to_matching_layout(o, dO_); + } + + auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); + auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); + auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); + auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); + auto rag_stats_off = cum_seqlen_q.mul(h_q); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + auto _dropoutseed = dropoutseed; + auto _dropoutoffset = dropoutoffset; + // cuDNN dropout bug requires these to be in int64 + if (dprops->major == 10 && dprops->minor == 0) { + _dropoutseed = dropoutseed.to(kLong); + _dropoutoffset = dropoutoffset.to(kLong); + } + + cudnnHandle_t handle = getCudnnHandle(); + + auto mha_graph = build_graph_backward_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + o, + dO_, + softmaxstats, + dQ, + dK, + dV, + dropoutseed, + dropoutoffset, + handle); + + std::unordered_map variant_pack = { + // inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {DO, dO_.data_ptr()}, + {LSE, softmaxstats.data_ptr()}, + // outputs + {DQ, dQ.data_ptr()}, + {DK, dK.data_ptr()}, + {DV, dV.data_ptr()}, + {SCALE, &scaling_factor}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {RAG_LSE_OFF, rag_stats_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + if (dropout_probability != 0.0f) { + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + } + TORCH_CHECK( + !attn_bias.has_value(), + "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - mhagraphbackwardcache.update(key, graph_and_tensors_backward_values); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 045e8cf6dee9..620abc1aa0a8 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -70,4 +70,31 @@ void run_cudnn_SDP_bprop( const Tensor& dropoutseed, const Tensor& dropoutoffset); +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset); + } // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm.h b/aten/src/ATen/native/hip/ck_gemm.h index 176cbabd5e01..0d42cad56fcd 100644 --- a/aten/src/ATen/native/hip/ck_gemm.h +++ b/aten/src/ATen/native/hip/ck_gemm.h @@ -10,6 +10,7 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); } +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); template <> @@ -18,7 +19,7 @@ template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); - +#endif } // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip index 79cb14be4103..7561cede386f 100644 --- a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip +++ b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip @@ -1,6 +1,7 @@ #undef __HIP_NO_HALF_CONVERSIONS__ - #include + +#if defined(USE_ROCM_CK_GEMM) #include #include @@ -781,3 +782,4 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } } // namespace at::native +#endif // USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/hip/ck_gemm_float.hip b/aten/src/ATen/native/hip/ck_gemm_float.hip index b8301a47981c..c4fea6088d3f 100644 --- a/aten/src/ATen/native/hip/ck_gemm_float.hip +++ b/aten/src/ATen/native/hip/ck_gemm_float.hip @@ -1,6 +1,7 @@ #undef __HIP_NO_HALF_CONVERSIONS__ #include +#if defined(USE_ROCM_CK_GEMM) #include #include @@ -484,3 +485,4 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)) { } } // namespace at::native +#endif // USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip index 552f0de84541..ebe044c38972 100644 --- a/aten/src/ATen/native/hip/ck_gemm_half.hip +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -1,6 +1,7 @@ #undef __HIP_NO_HALF_CONVERSIONS__ #include +#if defined(USE_ROCM_CK_GEMM) #include #include @@ -606,3 +607,4 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { } } // namespace at::native +#endif // USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index 5a6e59fad786..44c06a74a222 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -1,7 +1,6 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include -#include #include #include @@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ } } -template -bool use_mkldnn_typed_matmul( +bool use_mkldnn_bf16_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { - bool dtype_check = false; - if constexpr (std::is_same_v) { #if defined(__aarch64__) - if (mkldnn_bf16_device_check_arm()) { - // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. - // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 - // inputs, allow it for float as well - dtype_check = use_mkldnn_bf16_matmul() && - ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)); - } -#else - dtype_check = dtype_check && use_mkldnn_bf16_matmul() && - (mat1.scalar_type() == kBFloat16); + if (mkldnn_bf16_device_check_arm()) { + // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. + // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 + // inputs, allow it for float as well + return ( + use_mkldnn_bf16_matmul() && + (mat1.scalar_type() == mat2.scalar_type()) && + (!result.defined() || (mat1.scalar_type() == result.scalar_type())) && + ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) && + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); + } else #endif - } else if constexpr (std::is_same_v) { - dtype_check = dtype_check && use_mkldnn_fp16_matmul() && - (mat1.scalar_type() == kHalf); - } else if constexpr (std::is_same_v) { - dtype_check = dtype_check && - (use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) && - (mat1.scalar_type() == kFloat); + { + return ( + use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 && + mat2.scalar_type() == kBFloat16 && + (!result.defined() || result.scalar_type() == kBFloat16) && + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); } - if (!dtype_check) { - return false; - } - bool size_check = - mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2); - dtype_check = (mat1.scalar_type() == mat2.scalar_type()) && - (!result.defined() || result.scalar_type() == mat1.scalar_type()); - return dtype_check && size_check; +} + +bool use_mkldnn_fp16_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result) { + return ( + use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf && + mat2.scalar_type() == kHalf && + (!result.defined() || result.scalar_type() == kHalf) && + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); +} + +bool use_mkldnn_bf32_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result) { + return ( + use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat && + mat2.scalar_type() == kFloat && + (!result.defined() || result.scalar_type() == kFloat) && + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); +} + +bool use_mkldnn_tf32_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result) { + return ( + use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat && + mat2.scalar_type() == kFloat && + (!result.defined() || result.scalar_type() == kFloat) && + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); } bool use_mkldnn_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { - auto mat1_type = mat1.scalar_type(); - if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) { - return false; - } - AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] { - return use_mkldnn_typed_matmul(mat1, mat2, result); - }); - return false; + return ( + use_mkldnn_bf16_matmul(mat1, mat2, result) || + use_mkldnn_fp16_matmul(mat1, mat2, result) || + use_mkldnn_bf32_matmul(mat1, mat2, result) || + use_mkldnn_tf32_matmul(mat1, mat2, result)); } static void _mkldnn_matmul_i8i8i32_with_primitive( diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index a1314a141473..6a66abc7b062 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -469,4 +469,94 @@ Tensor _weight_int4pack_mm_xpu( return C; } + +Tensor& _int_mm_out_xpu( + const Tensor& self, + const Tensor& mat2, + Tensor& result) { + TORCH_CHECK( + self.dim() == 2, + "Expected self to be of dimension 2 but got ", + self.dim()); + TORCH_CHECK( + mat2.dim() == 2, + "Expected mat2 to be of dimension 2 but got ", + mat2.dim()); + TORCH_CHECK( + self.size(1) == mat2.size(0), + "self.size(1) needs to match mat2.size(0) but got ", + self.size(1), + " and ", + mat2.size(0)); + + TORCH_CHECK( + self.dtype() == at::kChar, + "Expected self dtype to be of type int8 but got ", + self.dtype()); + TORCH_CHECK( + mat2.dtype() == at::kChar, + "Expected mat2 dtype to be of type int8 but got ", + mat2.dtype()); + TORCH_CHECK( + result.dtype() == at::kInt, + "Expected result dtype to be of type kInt but got ", + result.dtype()); + TORCH_CHECK( + result.size(0) == self.size(0), + "Expected result.size(0) to be ", + self.size(0), + " but got ", + result.size(0)); + TORCH_CHECK( + result.size(1) == mat2.size(1), + "Expected result.size(1) to be ", + mat2.size(1), + " but got ", + result.size(1)); + + TORCH_CHECK( + result.dim() == 2, + "Expected result to be of dimension 2 but got ", + result.dim()); + + TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous."); + + if (result.numel() == 0 || self.size(1) == 0) { + return result.zero_(); + } + + Tensor bias = at::Tensor(); + Tensor mat2_scales = at::ones({1}, mat2.options().dtype(at::kFloat)); + Tensor mat2_zero_points = at::Tensor(); + auto post_op_args = torch::List>(); + + at::native::onednn::quantized_matmul( + self.contiguous(), + 1.0, + 0, + mat2.contiguous(), + mat2_scales, + mat2_zero_points, + bias, + result, + 1.0, + 0, + result.scalar_type(), + /*other*/ std::nullopt, + /*other scale*/ 1.0, + /*other zp*/ 0, + /*binary post op*/ "none", + /*binary alpha*/ 1.0, + /*post_op_name*/ "none", + post_op_args, + /*post_op_algorithm*/ "none", + /*m2_trans*/ true); + return result; +} + +Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) { + Tensor result = + at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt)); + return _int_mm_out_xpu(self, mat2, result); +} } // namespace at::native diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index e6f87f5499a4..f9cd28ca06fa 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -88,14 +88,8 @@ std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const Tensor& src, Tensor& dst); Tensor& scatterViewTensor(const Tensor& src, Tensor& output); -MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64 = false); -MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64 = false); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); @@ -435,14 +429,6 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::functionexecuteMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE); } -static inline void checkSupportsComplex() { - TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer."); -} - MPSDataType getMPSDataType(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Float: @@ -100,7 +96,6 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { case ScalarType::Half: return MPSDataTypeFloat16; case ScalarType::BFloat16: - checkSupportsBFloat16(); return MPSDataTypeBFloat16; case ScalarType::Int: return MPSDataTypeInt32; @@ -119,10 +114,8 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " "Please use float32 instead.") case ScalarType::ComplexHalf: - checkSupportsComplex(); return MPSDataTypeComplexFloat16; case ScalarType::ComplexFloat: - checkSupportsComplex(); return MPSDataTypeComplexFloat32; // Unsigned types case ScalarType::UInt64: @@ -140,16 +133,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast to these // types. -MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64) { +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) { MPSDataType dataType = getMPSDataType(input.scalar_type()); - bool condition = - (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); - if (includesInt64) { - condition = condition && (dataType != MPSDataTypeInt64); - } + bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && + (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64); if (condition) { dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; @@ -160,16 +147,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast from these // types. -MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64) { +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) { MPSDataType dataType = getMPSDataType(input.scalar_type()); - bool condition = - (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); - if (includesInt64) { - condition = condition && (dataType != MPSDataTypeInt64); - } + bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && + (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64); if (condition) { inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; } @@ -186,7 +167,6 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Half: return MPSDataTypeFloat16; case ScalarType::BFloat16: - checkSupportsBFloat16(); return MPSDataTypeBFloat16; case ScalarType::Int: return MPSDataTypeInt32; @@ -201,13 +181,11 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Bool: return MPSDataTypeBool; case ScalarType::ComplexHalf: - checkSupportsComplex(); return MPSDataTypeComplexFloat16; // This is an intentional fallthrough supporting ComplexDouble for Scalar // types as they are casted to Complex64 currently. case ScalarType::ComplexDouble: case ScalarType::ComplexFloat: - checkSupportsComplex(); return MPSDataTypeComplexFloat32; // Unsigned types case ScalarType::UInt64: @@ -267,7 +245,6 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Half: return "half"; case ScalarType::BFloat16: - checkSupportsBFloat16(); return "bfloat"; case ScalarType::Int: return "int"; @@ -879,9 +856,7 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} MTLCompileOptions* options = compile_options; if (!options) { options = [[MTLCompileOptions new] autorelease]; - // Need 3.0 for atomic oprations, 3.1 introduces bfloat support - [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 - : MTLLanguageVersion3_0]; + [options setLanguageVersion:MTLLanguageVersion3_1]; if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe; options.mathFloatingPointFunctions = @@ -953,8 +928,7 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} if (C10_UNLIKELY(!library)) { auto device = MPSDevice::getInstance()->device(); NSError* error = nil; - auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic"; - library = [device newLibraryWithData:getSectionData(section_name) error:&error]; + library = [device newLibraryWithData:getSectionData("metal_basic") error:&error]; TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); } return library; diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal index f7335d150d40..ae1fda66c3b3 100644 --- a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -33,21 +33,15 @@ struct shrink_backward_functor { REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float); REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat); -#endif REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float); REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat); -#endif REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float); REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat); -#endif struct hardsigmoid_functor { template @@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor { REGISTER_UNARY_OP(hardsigmoid, float, float); REGISTER_UNARY_OP(hardsigmoid, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat); -#endif REGISTER_BINARY_OP(hardsigmoid_backward, float, float); REGISTER_BINARY_OP(hardsigmoid_backward, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat); -#endif struct hardswish_functor { template @@ -103,15 +93,11 @@ struct hardswish_backward_functor { REGISTER_UNARY_OP(hardswish, float, float); REGISTER_UNARY_OP(hardswish, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_UNARY_OP(hardswish, bfloat, bfloat); -#endif REGISTER_BINARY_OP(hardswish_backward, float, float); REGISTER_BINARY_OP(hardswish_backward, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); -#endif struct leaky_relu_functor { template @@ -135,12 +121,8 @@ struct leaky_relu_backward_functor { REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float); REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat); -#endif REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float); REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half); -#if __METAL_VERSION__ >= 310 REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/Amp.metal b/aten/src/ATen/native/mps/kernels/Amp.metal index abe852798f44..653c2057d498 100644 --- a/aten/src/ATen/native/mps/kernels/Amp.metal +++ b/aten/src/ATen/native/mps/kernels/Amp.metal @@ -113,18 +113,12 @@ kernel void ampUpdateScale( INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat); -#endif INSTANTIATE_AMP_UPDATE_SCALE(float); INSTANTIATE_AMP_UPDATE_SCALE(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_AMP_UPDATE_SCALE(bfloat); -#endif INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/Attention.metal b/aten/src/ATen/native/mps/kernels/Attention.metal index c18ef3711ea3..6bb2cbfb3d71 100644 --- a/aten/src/ATen/native/mps/kernels/Attention.metal +++ b/aten/src/ATen/native/mps/kernels/Attention.metal @@ -590,9 +590,7 @@ kernel void attention( INSTANTIATE_SDPA_VECTOR_HEADS(float); INSTANTIATE_SDPA_VECTOR_HEADS(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); -#endif #define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \ template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \ @@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); INSTANTIATE_ATTN_SHAPES_HELPER(float); INSTANTIATE_ATTN_SHAPES_HELPER(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_ATTN_SHAPES_HELPER(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index a178ba257123..f6f4935608e4 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor { }; struct nextafter_functor { -#if __METAL_VERSION__ < 310 - template - struct bit_type {}; - template <> - struct bit_type { - using type = int; - }; - template <> - struct bit_type { - using type = short; - }; -#endif template inline T operator()(const T a, const T b) { -#if __METAL_VERSION__ >= 310 return static_cast(::metal::nextafter(a, b)); -#else - using U = typename bit_type::type; - if (a == b) { - return a; - } - if (::metal::isunordered(a, b)) { - return NAN; - } - if (a == 0) { - constexpr auto eps = as_type(static_cast(1)); - return b > 0 ? eps : -eps; - } - auto bits = as_type(a); - (a > 0) ^ (a > b) ? bits++ : bits--; - return as_type(bits); -#endif } }; @@ -344,13 +315,6 @@ struct fmod_functor { } }; -// Some helper defines -#if __METAL_VERSION__ >= 310 -#define _METAL_310_PLUS(x) x -#else -#define _METAL_310_PLUS(x) -#endif - #define REGISTER_INTEGER_BINARY_OP(NAME) \ REGISTER_BINARY_OP(NAME, long, long); \ REGISTER_BINARY_OP(NAME, int, int); \ @@ -370,12 +334,12 @@ struct fmod_functor { #define REGISTER_FLOAT_BINARY_OP(NAME) \ REGISTER_BINARY_OP(NAME, float, float); \ REGISTER_BINARY_OP(NAME, half, half); \ - _METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat)) + REGISTER_BINARY_OP(NAME, bfloat, bfloat) #define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \ REGISTER_OPMATH_BINARY_OP(NAME, float, float); \ REGISTER_OPMATH_BINARY_OP(NAME, half, half); \ - _METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)) + REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat) REGISTER_FLOAT_BINARY_OP(copysign); REGISTER_INT2FLOAT_BINARY_OP(copysign); @@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar); REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char); REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool); -#if __METAL_VERSION__ >= 310 REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat); -#endif // Complex binary functions REGISTER_BINARY_OP(polar, float, float2); diff --git a/aten/src/ATen/native/mps/kernels/Bucketization.metal b/aten/src/ATen/native/mps/kernels/Bucketization.metal index f4054aebc64a..a84698d77f57 100644 --- a/aten/src/ATen/native/mps/kernels/Bucketization.metal +++ b/aten/src/ATen/native/mps/kernels/Bucketization.metal @@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int); REGISTER_SEARCHSORTED_OP(float, long); REGISTER_SEARCHSORTED_OP(half, int); REGISTER_SEARCHSORTED_OP(half, long); -#if __METAL_VERSION__ >= 310 REGISTER_SEARCHSORTED_OP(bfloat, int); REGISTER_SEARCHSORTED_OP(bfloat, long); -#endif REGISTER_SEARCHSORTED_OP(char, int); REGISTER_SEARCHSORTED_OP(char, long); REGISTER_SEARCHSORTED_OP(uchar, int); diff --git a/aten/src/ATen/native/mps/kernels/Col2Im.metal b/aten/src/ATen/native/mps/kernels/Col2Im.metal index f784fc3053a6..61f596a9250f 100644 --- a/aten/src/ATen/native/mps/kernels/Col2Im.metal +++ b/aten/src/ATen/native/mps/kernels/Col2Im.metal @@ -96,6 +96,4 @@ kernel void col2im_kernel( INSTANTIATE_COL2IM(bool); INSTANTIATE_COL2IM(float); INSTANTIATE_COL2IM(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_COL2IM(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/CrossKernel.metal b/aten/src/ATen/native/mps/kernels/CrossKernel.metal index 7ee93250a5e1..bceae51c02db 100644 --- a/aten/src/ATen/native/mps/kernels/CrossKernel.metal +++ b/aten/src/ATen/native/mps/kernels/CrossKernel.metal @@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short); REGISTER_CROSS_FUNC(char); REGISTER_CROSS_FUNC(uchar); REGISTER_CROSS_FUNC(bool); -#if __METAL_VERSION__ >= 310 REGISTER_CROSS_FUNC(bfloat); -#endif template kernel void cross( @@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short); REGISTER_CROSS_OP(char); REGISTER_CROSS_OP(uchar); REGISTER_CROSS_OP(bool); -#if __METAL_VERSION__ >= 310 REGISTER_CROSS_OP(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal index fe5605226748..f46b10f99bf4 100644 --- a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal +++ b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal @@ -1,11 +1,9 @@ #include using metal::max; -#if __METAL_VERSION__ >= 310 bfloat max(bfloat a, bfloat b) { return a > b ? a : b; } -#endif #define kmaxThreadGroups 32 #define kmaxTensors 32 @@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float); REGISTER_ADAM_OPS_QUART(float, half); REGISTER_ADAM_OPS_QUART(half, float); REGISTER_ADAM_OPS_QUART(half, half); -#if __METAL_VERSION__ >= 310 REGISTER_ADAM_OPS_QUART(float, bfloat); REGISTER_ADAM_OPS_QUART(bfloat, bfloat); REGISTER_ADAM_OPS_QUART(bfloat, float); -#endif template inline void sgd_momentum_math( @@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float); REGISTER_FUSED_SGD_OP(half); REGISTER_FUSED_SGD_MOMENTUM_OP(float); REGISTER_FUSED_SGD_MOMENTUM_OP(half); -#if __METAL_VERSION__ >= 310 REGISTER_FUSED_SGD_OP(bfloat); REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/Gamma.metal b/aten/src/ATen/native/mps/kernels/Gamma.metal index 549576b62d28..1c150a726edb 100644 --- a/aten/src/ATen/native/mps/kernels/Gamma.metal +++ b/aten/src/ATen/native/mps/kernels/Gamma.metal @@ -106,9 +106,7 @@ kernel void polygamma( constant int64_t& order [[buffer(2)]], \ uint id [[thread_position_in_grid]]); -#if __METAL_VERSION__ >= 310 INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat); -#endif INSTANTIATE_GAMMA_KERNELS(half, half); INSTANTIATE_GAMMA_KERNELS(float, float); INSTANTIATE_GAMMA_KERNELS(bool, float); diff --git a/aten/src/ATen/native/mps/kernels/Im2Col.metal b/aten/src/ATen/native/mps/kernels/Im2Col.metal index 566bfd12a234..191462bbd3d0 100644 --- a/aten/src/ATen/native/mps/kernels/Im2Col.metal +++ b/aten/src/ATen/native/mps/kernels/Im2Col.metal @@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float); INSTANTIATE_IM2COL(float2); INSTANTIATE_IM2COL(half); INSTANTIATE_IM2COL(half2); -#if __METAL_VERSION__ >= 310 INSTANTIATE_IM2COL(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/Indexing.metal b/aten/src/ATen/native/mps/kernels/Indexing.metal index e61f33dc5444..048b2e5ae7c9 100644 --- a/aten/src/ATen/native/mps/kernels/Indexing.metal +++ b/aten/src/ATen/native/mps/kernels/Indexing.metal @@ -5,29 +5,6 @@ using namespace metal; using namespace c10::metal; -namespace c10 { -namespace metal { -// There are no atomic 64-bit add in Metal yet, but this implements a consistent -// add I.e. if multiple threads are modify the same 64-bit value, results stored -// at the address will eventually be equal to its original value plus sum of all -// operands -template <> -struct AtomicType { - using type = ::metal::atomic; - static inline void atomic_add(device type* data, long offset, long value) { - const auto value_bits = as_type(value); - const uint low = static_cast(value_bits); - uint high = static_cast(value_bits >> 32); - auto ptr = data + (offset << 1); - auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed); - high += (old_low + low < old_low) ? 1 : 0; - atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed); - } -}; - -} // namespace metal -} // namespace c10 - struct IndexAB { constant int64_t* indexArray; }; @@ -234,15 +211,15 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial); REGISTER_INDEX_OP(put_accumulate, float, float); REGISTER_INDEX_OP(put_accumulate, half, half); +REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); REGISTER_INDEX_OP(put_accumulate, long, long); REGISTER_INDEX_OP(put_accumulate, int, int); REGISTER_INDEX_OP(put_accumulate, short, short); REGISTER_INDEX_OP(put_accumulate, char, char); REGISTER_INDEX_OP(put_accumulate, uchar, uchar); REGISTER_INDEX_OP(put_accumulate, bool, bool); -#if __METAL_VERSION__ >= 310 -REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); -#endif +REGISTER_INDEX_OP(put_accumulate, float2, float2); +REGISTER_INDEX_OP(put_accumulate, half2, half2); template kernel void kernel_index_offsets( @@ -477,10 +454,8 @@ INSTANTIATE_INDEX_COPY(char, long); INSTANTIATE_INDEX_COPY(uchar, int); INSTANTIATE_INDEX_COPY(uchar, long); -#if __METAL_VERSION__ >= 310 INSTANTIATE_INDEX_COPY(bfloat, int); INSTANTIATE_INDEX_COPY(bfloat, long); -#endif INSTANTIATE_INDEX_COPY(float2, int); INSTANTIATE_INDEX_COPY(float2, long); INSTANTIATE_INDEX_COPY(half2, int); diff --git a/aten/src/ATen/native/mps/kernels/LayerNorm.metal b/aten/src/ATen/native/mps/kernels/LayerNorm.metal index 1ca9f916c2c0..7b4a789ed292 100644 --- a/aten/src/ATen/native/mps/kernels/LayerNorm.metal +++ b/aten/src/ATen/native/mps/kernels/LayerNorm.metal @@ -288,7 +288,6 @@ kernel void layer_norm_looped( #define instantiate_layer_norm(DTYPE) \ instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE) -instantiate_layer_norm(float) instantiate_layer_norm(half) -#if __METAL_VERSION__ >= 310 - instantiate_layer_norm(bfloat) -#endif +instantiate_layer_norm(float); +instantiate_layer_norm(half); +instantiate_layer_norm(bfloat); diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 15d46d8c8d8e..4ba2bca720db 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -68,6 +68,37 @@ kernel void matmul( } } +template +kernel void addmm( + constant T* mat1Data [[buffer(0)]], + constant T* mat2Data [[buffer(1)]], + device T* outputData [[buffer(2)]], + constant T* biasData [[buffer(3)]], + constant array, 2>& alpha_beta [[buffer(4)]], + constant array& strides [[buffer(5)]], + constant uint3& sizes [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 thread_id [[thread_position_in_grid]]) { + threadgroup T A_tile[TILE_DIM][TILE_DIM]; + threadgroup T B_tile[TILE_DIM][TILE_DIM]; + + auto sum = matmul_inner( + mat1Data, + mat2Data, + reinterpret_cast&>(strides), + sizes, + A_tile, + B_tile, + tid, + thread_id); + if (thread_id.y < sizes.x && thread_id.x < sizes.z) { + auto bias = + biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y]; + outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] = + static_cast(alpha_beta[0] * sum + alpha_beta[1] * bias); + } +} + template kernel void naive_bmm( constant T* mat1Data [[buffer(0)]], @@ -613,17 +644,15 @@ kernel void applyPivots( } } -#define INSTANTIATE_NAIVE_MM(DTYPE) \ - template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ - constant DTYPE * mat1Data [[buffer(0)]], \ - constant DTYPE * mat2Data [[buffer(1)]], \ - device DTYPE * outputData [[buffer(2)]], \ - constant array & strides [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ - uint2 tid [[thread_position_in_threadgroup]], \ - uint2 group_id [[threadgroup_position_in_grid]]) - -#define INSTANTIATE_NAIVE_BMM(DTYPE) \ +#define INSTANTIATE_MM_OPS(DTYPE) \ + template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant array & strides [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]); \ template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm( \ constant DTYPE * mat1Data [[buffer(0)]], \ constant DTYPE * mat2Data [[buffer(1)]], \ @@ -631,22 +660,26 @@ kernel void applyPivots( constant array & strides [[buffer(3)]], \ constant uint4 & sizes [[buffer(4)]], \ uint3 tid [[thread_position_in_threadgroup]], \ - uint3 group_id [[threadgroup_position_in_grid]]) + uint3 group_id [[threadgroup_position_in_grid]]); \ + template [[host_name("addmm_" #DTYPE)]] kernel void addmm( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant DTYPE * biasData [[buffer(3)]], \ + constant array, 2> & \ + alpha_beta [[buffer(4)]], \ + constant array & strides [[buffer(5)]], \ + constant uint3 & sizes [[buffer(6)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]) -INSTANTIATE_NAIVE_MM(float); -INSTANTIATE_NAIVE_MM(half); -#if __METAL_VERSION__ >= 310 -INSTANTIATE_NAIVE_MM(bfloat); -#endif +INSTANTIATE_MM_OPS(float); +INSTANTIATE_MM_OPS(half); +INSTANTIATE_MM_OPS(bfloat); // Integral MM -INSTANTIATE_NAIVE_MM(short); -INSTANTIATE_NAIVE_MM(int); -INSTANTIATE_NAIVE_MM(long); -INSTANTIATE_NAIVE_MM(char); -INSTANTIATE_NAIVE_MM(uchar); -INSTANTIATE_NAIVE_BMM(short); -INSTANTIATE_NAIVE_BMM(int); -INSTANTIATE_NAIVE_BMM(long); -INSTANTIATE_NAIVE_BMM(char); -INSTANTIATE_NAIVE_BMM(uchar); +INSTANTIATE_MM_OPS(long); +INSTANTIATE_MM_OPS(int); +INSTANTIATE_MM_OPS(short); +INSTANTIATE_MM_OPS(char); +INSTANTIATE_MM_OPS(uchar); diff --git a/aten/src/ATen/native/mps/kernels/Pooling.h b/aten/src/ATen/native/mps/kernels/Pooling.h index 6a7b1d3a116a..303388110a1b 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.h +++ b/aten/src/ATen/native/mps/kernels/Pooling.h @@ -48,3 +48,14 @@ struct PoolingBackwardParams { ::c10::metal::array grad_output_strides; ::c10::metal::array indices_strides; }; + +template +struct MaxUnpoolingParams { + int32_t dims; + int32_t pooling_dims; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array indices_strides; +}; diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal index 34abc3b78e78..45a8d680afcd 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.metal +++ b/aten/src/ATen/native/mps/kernels/Pooling.metal @@ -88,6 +88,53 @@ void max_pool_3d_input_iter( } } +template +void max_pool_2d_input_iter( + constant T* input, + device T* output, + device int64_t* indices, + constant int32_t* input_sizes, + constant int32_t* input_strides, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + constant int32_t* dilation) { + auto bounds0 = get_input_iter_bounds<0>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + auto bounds1 = get_input_iter_bounds<1>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + + auto d0 = dilation[0]; + auto d1 = dilation[1]; + + T max_value = input + [input_strides[0] * bounds0.start + input_strides[1] * bounds1.start]; + auto max_index = bounds0.start * input_sizes[1] + bounds1.start; + + for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) { + auto offset0 = input_strides[0] * i0; + + for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) { + auto offset1 = input_strides[1] * i1; + + auto input_value = input[offset0 + offset1]; + bool is_greater = input_value > max_value; + + max_value = is_greater ? input_value : max_value; + + if (return_indices) { + auto input_index = i0 * input_sizes[1] + i1; + max_index = is_greater ? input_index : max_index; + } + } + } + *output = max_value; + if (return_indices) { + *indices = max_index; + } +} + struct PoolOffsets { int32_t output; int32_t indices; @@ -168,6 +215,16 @@ PoolOffsets find_pool_offsets( leading_dims, return_indices, tid); + case 3: + return find_pool_offsets_dim_specific<3>( + output_sizes, + output_strides, + indices_strides, + input_strides, + pooling_dim_indices, + leading_dims, + return_indices, + tid); } return PoolOffsets(); } @@ -202,7 +259,7 @@ kernel void max_pool( PoolOffsets offsets = find_pool_offsets( output_sizes, output_strides, - indices_strides, + return_indices ? indices_strides : nullptr, input_strides, pooling_dim_indices, dims, @@ -214,18 +271,47 @@ kernel void max_pool( indices += offsets.indices; input += offsets.input_leading; - max_pool_3d_input_iter( - input, - output, - indices, - input_sizes + leading_dims, - input_strides + leading_dims, - pooling_dim_indices, - kernel_size, - stride, - padding, - dilation, - return_indices); + switch (pooling_dims) { + case 2: + if (return_indices) { + return max_pool_2d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + pooling_dim_indices, + kernel_size, + stride, + padding, + dilation); + } else { + return max_pool_2d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + pooling_dim_indices, + kernel_size, + stride, + padding, + dilation); + } + case 3: + return max_pool_3d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + pooling_dim_indices, + kernel_size, + stride, + padding, + dilation, + return_indices); + } } // Finds the element in the grad input which corresponds to the index into the @@ -292,6 +378,68 @@ kernel void max_pool_backward( pooling_dims); } +template +void max_unpool_impl( + device T* output, + T input_element, + int32_t input_index, + constant int32_t* output_sizes, + constant int32_t* output_strides, + int32_t pooling_dims) { + int32_t size_prod = 1; + int32_t pool_offset = 0; + + for (auto dim = pooling_dims - 1; dim >= 0; dim--) { + auto next_size_prod = output_sizes[dim] * size_prod; + pool_offset += + output_strides[dim] * ((input_index % next_size_prod) / size_prod); + size_prod *= output_sizes[dim]; + } + + output[pool_offset] = input_element; +} + +// Kernel computes one element of the grad input per kernel call. +template +kernel void max_unpool( + device T* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant int64_t* indices [[buffer(2)]], + constant MaxUnpoolingParams<5>& params [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + auto pooling_dims = params.pooling_dims; + auto dims = params.dims; + auto input_sizes = params.input_sizes.data(); + auto input_strides = params.input_strides.data(); + auto output_sizes = params.output_sizes.data(); + auto output_strides = params.output_strides.data(); + auto indices_strides = params.indices_strides.data(); + + auto leading_dims = dims - pooling_dims; + + // NOTE: Since we're doing unpooling, the variable names "input" and "output" + // are reversed compared to the pooling operations. So in `find_pool_offsets`, + // we need to map "input" -> "output" and "output" -> "input". + PoolOffsets offsets = find_pool_offsets( + /*output_sizes=*/input_sizes, + /*output_strides=*/input_strides, + indices_strides, + /*input_strides=*/output_strides, + /*pooling_dim_indices=*/nullptr, + dims, + leading_dims, + /*return_indices=*/true, + tid); + + max_unpool_impl( + output + offsets.input_leading, + input[offsets.output], + indices[offsets.indices], + output_sizes + leading_dims, + output_strides + leading_dims, + pooling_dims); +} + template struct AvgPoolIterBounds { T start; @@ -358,7 +506,6 @@ void avg_pool_3d_input_iter( auto divisor = has_divisor_override ? divisor_override : (bounds0.count) * (bounds1.count) * (bounds2.count); - auto size12 = input_sizes[1] * input_sizes[2]; for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) { auto offset0 = input_strides[0] * i0; @@ -376,6 +523,64 @@ void avg_pool_3d_input_iter( *output = value_sum / static_cast(divisor); } +template +void avg_pool_backward_3d_input_iter( + device AtomicType_t* grad_input, + constant T* grad_output, + constant int32_t* grad_input_sizes, + constant int32_t* grad_input_strides, + int32_t grad_input_leading_offset, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + bool count_include_pad, + bool has_divisor_override, + int32_t divisor_override) { + auto bounds0 = get_avg_pool_input_iter_bounds<0>( + grad_input_sizes, + pooling_dim_indices, + kernel_size, + stride, + padding, + count_include_pad); + auto bounds1 = get_avg_pool_input_iter_bounds<1>( + grad_input_sizes, + pooling_dim_indices, + kernel_size, + stride, + padding, + count_include_pad); + auto bounds2 = get_avg_pool_input_iter_bounds<2>( + grad_input_sizes, + pooling_dim_indices, + kernel_size, + stride, + padding, + count_include_pad); + + auto divisor = has_divisor_override + ? divisor_override + : (bounds0.count) * (bounds1.count) * (bounds2.count); + auto grad_val = *grad_output / static_cast(divisor); + + for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) { + auto offset0 = grad_input_strides[0] * i0; + + for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) { + auto offset1 = grad_input_strides[1] * i1; + + for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) { + auto offset2 = grad_input_strides[2] * i2; + auto pool_offset = offset0 + offset1 + offset2; + + AtomicType::atomic_add( + grad_input, grad_input_leading_offset + pool_offset, grad_val); + } + } + } +} + // Kernel computes one element of the output per kernel call. template kernel void avg_pool( @@ -428,31 +633,97 @@ kernel void avg_pool( params.divisor_override); } -#define REGISTER_POOL_OP(DTYPE) \ - template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * output [[buffer(1)]], \ - device int64_t* indices [[buffer(2)]], \ - constant PoolingParams<5>& params [[buffer(3)]], \ - uint tid [[thread_position_in_grid]]); \ - \ - template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * output [[buffer(1)]], \ - constant AvgPoolingParams<5> & params [[buffer(2)]], \ +template +kernel void avg_pool_backward( + device AtomicType_t* grad_input [[buffer(0)]], + constant T* grad_output [[buffer(1)]], + constant AvgPoolingParams<5>& params [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + auto pooling_dims = params.pooling_dims; + auto dims = params.dims; + auto grad_input_sizes = params.input_sizes.data(); + auto grad_input_strides = params.input_strides.data(); + auto grad_output_sizes = params.output_sizes.data(); + auto grad_output_strides = params.output_strides.data(); + auto kernel_size = params.kernel_size.data(); + auto stride = params.stride.data(); + auto padding = params.padding.data(); + auto leading_dims = dims - pooling_dims; + + // This buffer keeps track of the pooling dimension indices of this thread's + // element of the output. We need to fill it with the proper values below. + int32_t pooling_dim_indices[3]; + + PoolOffsets offsets = find_pool_offsets( + grad_output_sizes, + grad_output_strides, + /*indices_strides=*/nullptr, + grad_input_strides, + pooling_dim_indices, + dims, + leading_dims, + /*return_indices=*/false, + tid); + + grad_output += offsets.output; + grad_input_sizes += leading_dims; + grad_input_strides += leading_dims; + + avg_pool_backward_3d_input_iter( + grad_input, + grad_output, + grad_input_sizes, + grad_input_strides, + offsets.input_leading, + pooling_dim_indices, + kernel_size, + stride, + padding, + params.count_include_pad, + params.has_divisor_override, + params.divisor_override); +} + +#define REGISTER_POOL_OP(DTYPE) \ + template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant PoolingParams<5>& params [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); \ + \ + template [[host_name("max_unpool_" #DTYPE)]] kernel void max_unpool( \ + device DTYPE * output [[buffer(0)]], \ + constant DTYPE * input [[buffer(1)]], \ + constant int64_t* indices [[buffer(2)]], \ + constant MaxUnpoolingParams<5>& params [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); \ + \ + template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant AvgPoolingParams<5> & params [[buffer(2)]], \ uint tid [[thread_position_in_grid]]); -#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \ +#define REGISTER_POOL_BACKWARD_OP(DTYPE) \ template [[host_name("max_pool_backward_" #DTYPE)]] \ kernel void max_pool_backward( \ device AtomicType_t * grad_input [[buffer(0)]], \ constant DTYPE * grad_output_ [[buffer(1)]], \ constant int64_t* grad_indices_ [[buffer(2)]], \ constant PoolingBackwardParams<5>& params [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); \ + \ + template [[host_name("avg_pool_backward_" #DTYPE)]] \ + kernel void avg_pool_backward( \ + device AtomicType_t * grad_input [[buffer(0)]], \ + constant DTYPE * grad_output [[buffer(1)]], \ + constant AvgPoolingParams<5> & params [[buffer(2)]], \ uint tid [[thread_position_in_grid]]); REGISTER_POOL_OP(float); REGISTER_POOL_OP(half); +REGISTER_POOL_OP(bfloat); REGISTER_POOL_OP(int); REGISTER_POOL_OP(long); REGISTER_POOL_OP(short); @@ -460,10 +731,6 @@ REGISTER_POOL_OP(char); REGISTER_POOL_OP(uchar); REGISTER_POOL_OP(bool); -REGISTER_MAX_POOL_BACKWARD_OP(float); -REGISTER_MAX_POOL_BACKWARD_OP(half); - -#if __METAL_VERSION__ >= 310 -REGISTER_POOL_OP(bfloat); -REGISTER_MAX_POOL_BACKWARD_OP(bfloat); -#endif +REGISTER_POOL_BACKWARD_OP(float); +REGISTER_POOL_BACKWARD_OP(half); +REGISTER_POOL_BACKWARD_OP(bfloat); diff --git a/aten/src/ATen/native/mps/kernels/Quantized.metal b/aten/src/ATen/native/mps/kernels/Quantized.metal index 4d57027a576c..b84c033a07f4 100644 --- a/aten/src/ATen/native/mps/kernels/Quantized.metal +++ b/aten/src/ATen/native/mps/kernels/Quantized.metal @@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128); INSTANTIATE_INT4MV(half, 128); INSTANTIATE_INT4MV(float, 256); INSTANTIATE_INT4MV(half, 256); -#if __METAL_VERSION__ >= 310 INSTANTIATE_INT4MV(bfloat, 32); INSTANTIATE_INT4MV(bfloat, 64); INSTANTIATE_INT4MV(bfloat, 128); INSTANTIATE_INT4MV(bfloat, 256); -#endif // ------------------------------ int8 MM For M >= 12 ------------------------------------ /** @@ -234,12 +232,10 @@ template <> struct BlockType { using simdgroup_type8x8 = simdgroup_half8x8; using type4 = half4; }; -#if __METAL_VERSION__ >= 310 template <> struct BlockType { using simdgroup_type8x8 = simdgroup_bfloat8x8; using type4 = bfloat4; }; -#endif template float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) { @@ -490,9 +486,7 @@ kernel void kernel_mul_mm( \ INSTANTIATE_MM(float, char, get_scale_zero_q8); INSTANTIATE_MM(half, char, get_scale_zero_q8); -#if __METAL_VERSION__ >= 310 INSTANTIATE_MM(bfloat, char, get_scale_zero_q8); -#endif // ------------------------------ int8 MM For M < 12 ------------------------------------ /* Matrix vector multiplication, used for small M size for matrix multiplication as well. @@ -646,6 +640,4 @@ kernel void kernel_mul_mv( INSTANTIATE_MV(float); INSTANTIATE_MV(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_MV(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/RMSNorm.metal b/aten/src/ATen/native/mps/kernels/RMSNorm.metal index f66dcb035dfa..d6c69217e65f 100644 --- a/aten/src/ATen/native/mps/kernels/RMSNorm.metal +++ b/aten/src/ATen/native/mps/kernels/RMSNorm.metal @@ -192,6 +192,4 @@ template instantiate_rms(float) instantiate_rms(half) -#if __METAL_VERSION__ >= 310 instantiate_rms(bfloat) -#endif // clang-format on diff --git a/aten/src/ATen/native/mps/kernels/RenormKernel.metal b/aten/src/ATen/native/mps/kernels/RenormKernel.metal index eda61867e8c7..0bfd60b04c16 100644 --- a/aten/src/ATen/native/mps/kernels/RenormKernel.metal +++ b/aten/src/ATen/native/mps/kernels/RenormKernel.metal @@ -23,6 +23,4 @@ kernel void renorm( REGISTER_RENORM_OP(float); REGISTER_RENORM_OP(half); -#if __METAL_VERSION__ >= 310 REGISTER_RENORM_OP(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/ScanKernel.metal b/aten/src/ATen/native/mps/kernels/ScanKernel.metal index e6d739cac13c..de493af7aaa0 100644 --- a/aten/src/ATen/native/mps/kernels/ScanKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ScanKernel.metal @@ -25,379 +25,6 @@ struct LogAddExp { }; }; -#if __METAL_VERSION__ < 310 -template > -struct CumMinOp { - static acc_t apply(acc_t a, acc_t b) { - return metal::min(a, b); - } - static acc_t identity() { - return static_cast( - metal::is_floating_point_v ? metal::numeric_limits::infinity() - : metal::numeric_limits::max()); - } -}; - -template > -struct CumMaxOp { - static acc_t apply(acc_t a, acc_t b) { - return metal::max(a, b); - } - static acc_t identity() { - return static_cast( - metal::is_floating_point_v ? -metal::numeric_limits::infinity() - : metal::numeric_limits::lowest()); - } -}; - -template > -struct LogCumSumExpOp { - static acc_t apply(acc_t x, acc_t y) { - return LogAddExp{}(x, y); - } - static acc_t identity() { - return -metal::numeric_limits::infinity(); - } -}; - -// Inclusive scan along innermost dimension for contiguous tensors -template > -kernel void scan_contiguous_innermost_dim( - constant T* input [[buffer(0)]], - device T* output [[buffer(1)]], - constant uint& num_rows [[buffer(2)]], - constant uint& row_size [[buffer(3)]], - uint row [[thread_position_in_grid]]) { - if (row >= num_rows) - return; - - const uint offset = row * row_size; - - acc_t accumulator = Op::identity(); - - for (uint col = 0; col < row_size; col++) { - T val = input[offset + col]; - acc_t accum_val = static_cast(val); - accumulator = Op::apply(accumulator, accum_val); - output[offset + col] = static_cast(accumulator); - } -} - -// Inclusive scan along outer dimension for contiguous tensors -template > -kernel void scan_contiguous_outer_dim( - constant T* input [[buffer(0)]], - device T* output [[buffer(1)]], - constant uint& num_orows [[buffer(2)]], - constant uint& num_irows [[buffer(3)]], - constant uint& row_size [[buffer(4)]], - uint thread_index [[thread_position_in_grid]]) { - const uint orow = thread_index / num_irows; - const uint irow = thread_index % num_irows; - - if (orow >= num_orows) - return; - - acc_t accumulator = Op::identity(); - - const uint idx_base = orow * row_size * num_irows + irow; - for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) { - T val = input[idx]; - acc_t accum_val = static_cast(val); - accumulator = Op::apply(accumulator, accum_val); - output[idx] = static_cast(accumulator); - } -} - -// Inclusive scan with indices along innermost dimension for contiguous tensors -template > -kernel void scan_with_indices_contiguous_innermost_dim( - constant T* input [[buffer(0)]], - device T* values [[buffer(1)]], - device int64_t* indices [[buffer(2)]], - constant uint& num_rows [[buffer(3)]], - constant uint& row_size [[buffer(4)]], - uint row [[thread_position_in_grid]]) { - if (row >= num_rows) - return; - - const uint offset = row * row_size; - - acc_t accumulator = Op::identity(); - int64_t best_idx = 0; - - for (uint col = 0; col < row_size; col++) { - T val = input[offset + col]; - acc_t accum_val = static_cast(val); - if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) { - accumulator = accum_val; - best_idx = col; - } - values[offset + col] = static_cast(accumulator); - indices[offset + col] = best_idx; - } -} - -// Inclusive scan with indices along outer dimension for contiguous tensors -template > -kernel void scan_with_indices_contiguous_outer_dim( - constant T* input [[buffer(0)]], - device T* values [[buffer(1)]], - device int64_t* indices [[buffer(2)]], - constant uint& num_orows [[buffer(3)]], - constant uint& num_irows [[buffer(4)]], - constant uint& row_size [[buffer(5)]], - uint thread_index [[thread_position_in_grid]]) { - const uint orow = thread_index / num_irows; - const uint irow = thread_index % num_irows; - - if (orow >= num_orows) - return; - - acc_t accumulator = Op::identity(); - int64_t best_idx = 0; - - const uint idx_base = orow * row_size * num_irows + irow; - for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) { - T val = input[idx]; - acc_t accum_val = static_cast(val); - if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) { - accumulator = accum_val; - best_idx = col; - } - values[idx] = static_cast(accumulator); - indices[idx] = best_idx; - } -} - -// Shared utility functions for strided kernels -inline long calculate_non_scan_elements( - constant long* sizes, - uint ndim, - uint scan_dim) { - long total = 1; - for (uint i = 0; i < ndim; ++i) { - if (i != scan_dim) { - total *= sizes[i]; - } - } - return total; -} - -inline void thread_index_to_coordinates( - uint index, - int pos[c10::metal::max_ndim], - constant long* sizes, - uint ndim, - uint scan_dim) { - long remaining_index = index; - for (uint i = 0; i < ndim; ++i) { - if (i != scan_dim) { - pos[i] = remaining_index % sizes[i]; - remaining_index /= sizes[i]; - } else { - pos[i] = 0; - } - } -} - -inline long calculate_base_offset( - int pos[c10::metal::max_ndim], - constant long* strides, - uint ndim, - uint scan_dim) { - long offset = 0; - for (uint i = 0; i < ndim; ++i) { - if (i != scan_dim) { - offset += pos[i] * strides[i]; - } - } - return offset; -} - -// Generic strided scan kernel -template > -kernel void scan_strided( - constant T* input [[buffer(0)]], - device T* output [[buffer(1)]], - constant long* sizes [[buffer(2)]], - constant long* input_strides [[buffer(3)]], - constant long* output_strides [[buffer(4)]], - constant uint& ndim [[buffer(5)]], - constant uint& scan_dim [[buffer(6)]], - uint thread_index [[thread_position_in_grid]]) { - const long total_non_scan_elements = - calculate_non_scan_elements(sizes, ndim, scan_dim); - if (thread_index >= total_non_scan_elements) { - return; - } - - int pos[c10::metal::max_ndim]; - thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim); - - const long input_base_offset = - calculate_base_offset(pos, input_strides, ndim, scan_dim); - const long output_base_offset = - calculate_base_offset(pos, output_strides, ndim, scan_dim); - - acc_t accumulator = Op::identity(); - const long scan_size = sizes[scan_dim]; - const long input_scan_stride = input_strides[scan_dim]; - const long output_scan_stride = output_strides[scan_dim]; - - for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) { - const long input_offset = input_base_offset + scan_idx * input_scan_stride; - const long output_offset = - output_base_offset + scan_idx * output_scan_stride; - - T val = input[input_offset]; - acc_t accum_val = static_cast(val); - accumulator = Op::apply(accumulator, accum_val); - output[output_offset] = static_cast(accumulator); - } -} - -// Generic strided scan with indices kernel -template > -kernel void scan_with_indices_strided( - constant T* input [[buffer(0)]], - device T* values [[buffer(1)]], - device int64_t* indices [[buffer(2)]], - constant long* sizes [[buffer(3)]], - constant long* input_strides [[buffer(4)]], - constant long* values_strides [[buffer(5)]], - constant long* indices_strides [[buffer(6)]], - constant uint& ndim [[buffer(7)]], - constant uint& scan_dim [[buffer(8)]], - uint thread_index [[thread_position_in_grid]]) { - const long total_non_scan_elements = - calculate_non_scan_elements(sizes, ndim, scan_dim); - if (thread_index >= total_non_scan_elements) { - return; - } - - int pos[c10::metal::max_ndim]; - thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim); - - const long input_base_offset = - calculate_base_offset(pos, input_strides, ndim, scan_dim); - const long values_base_offset = - calculate_base_offset(pos, values_strides, ndim, scan_dim); - const long indices_base_offset = - calculate_base_offset(pos, indices_strides, ndim, scan_dim); - - acc_t accumulator = Op::identity(); - int64_t best_idx = 0; - const long scan_size = sizes[scan_dim]; - const long input_scan_stride = input_strides[scan_dim]; - const long values_scan_stride = values_strides[scan_dim]; - const long indices_scan_stride = indices_strides[scan_dim]; - - for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) { - const long input_offset = input_base_offset + scan_idx * input_scan_stride; - const long values_offset = - values_base_offset + scan_idx * values_scan_stride; - const long indices_offset = - indices_base_offset + scan_idx * indices_scan_stride; - - T val = input[input_offset]; - acc_t accum_val = static_cast(val); - if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) { - accumulator = accum_val; - best_idx = scan_idx; - } - values[values_offset] = static_cast(accumulator); - indices[indices_offset] = best_idx; - } -} - -#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \ - template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \ - scan_contiguous_innermost_dim>( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * output [[buffer(1)]], \ - constant uint & num_rows [[buffer(2)]], \ - constant uint & row_size [[buffer(3)]], \ - uint row [[thread_position_in_grid]]); \ - \ - template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \ - scan_contiguous_outer_dim>( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * output [[buffer(1)]], \ - constant uint & num_orows [[buffer(2)]], \ - constant uint & num_irows [[buffer(3)]], \ - constant uint & row_size [[buffer(4)]], \ - uint thread_index [[thread_position_in_grid]]); \ - \ - template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \ - scan_strided>( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * output [[buffer(1)]], \ - constant long* sizes [[buffer(2)]], \ - constant long* input_strides [[buffer(3)]], \ - constant long* output_strides [[buffer(4)]], \ - constant uint& ndim [[buffer(5)]], \ - constant uint& scan_dim [[buffer(6)]], \ - uint thread_index [[thread_position_in_grid]]); - -#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \ - template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \ - scan_with_indices_contiguous_innermost_dim>( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * values [[buffer(1)]], \ - device int64_t* indices [[buffer(2)]], \ - constant uint& num_rows [[buffer(3)]], \ - constant uint& row_size [[buffer(4)]], \ - uint row [[thread_position_in_grid]]); \ - \ - template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \ - scan_with_indices_contiguous_outer_dim>( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * values [[buffer(1)]], \ - device int64_t* indices [[buffer(2)]], \ - constant uint& num_orows [[buffer(3)]], \ - constant uint& num_irows [[buffer(4)]], \ - constant uint& row_size [[buffer(5)]], \ - uint thread_index [[thread_position_in_grid]]); \ - \ - template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \ - scan_with_indices_strided>( \ - constant DTYPE * input [[buffer(0)]], \ - device DTYPE * values [[buffer(1)]], \ - device int64_t* indices [[buffer(2)]], \ - constant long* sizes [[buffer(3)]], \ - constant long* input_strides [[buffer(4)]], \ - constant long* values_strides [[buffer(5)]], \ - constant long* indices_strides [[buffer(6)]], \ - constant uint& ndim [[buffer(7)]], \ - constant uint& scan_dim [[buffer(8)]], \ - uint thread_index [[thread_position_in_grid]]); - -// Simple scan operations -REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float); -REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half); - -// Scan operations with indices -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar); -REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool); - -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar); -REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool); - -#else // __METAL_VERSION__ >= 310 - C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size; // The reminder of this file contains cummin and cummax implementations adapted @@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4); REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4); REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4); REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4); - -#endif diff --git a/aten/src/ATen/native/mps/kernels/SpecialOps.metal b/aten/src/ATen/native/mps/kernels/SpecialOps.metal index 4b339c9a92b6..1e37573a36e8 100644 --- a/aten/src/ATen/native/mps/kernels/SpecialOps.metal +++ b/aten/src/ATen/native/mps/kernels/SpecialOps.metal @@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float); REGISTER_SPECIAL(int, float); REGISTER_SPECIAL(long, float); REGISTER_SPECIAL(half, half); -#if __METAL_VERSION__ >= 310 REGISTER_SPECIAL(bfloat, bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/TriangularOps.metal b/aten/src/ATen/native/mps/kernels/TriangularOps.metal index 27ad50602848..ad1a0f93a217 100644 --- a/aten/src/ATen/native/mps/kernels/TriangularOps.metal +++ b/aten/src/ATen/native/mps/kernels/TriangularOps.metal @@ -100,9 +100,7 @@ kernel void triul( INSTANTIATE_TRIUL_KERNELS(float, int); INSTANTIATE_TRIUL_KERNELS(half, int); -#if __METAL_VERSION__ >= 310 INSTANTIATE_TRIUL_KERNELS(bfloat, int); -#endif INSTANTIATE_TRIUL_KERNELS(float2, int); INSTANTIATE_TRIUL_KERNELS(half2, int); diff --git a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal index 37a61397467f..23c4810a2496 100644 --- a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal @@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half); REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0) -#if __METAL_VERSION__ >= 310 INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat); REGISTER_UNARY_OP(neg, bfloat, bfloat); REGISTER_UNARY_OP(abs, bfloat, bfloat); -#endif INSTANTIATE_UNARY_KERNELS2(half, half); INSTANTIATE_UNARY_KERNELS2(float, float); INSTANTIATE_UNARY_KERNELS2(float, bool); @@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float); REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float); REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half); -#if __METAL_VERSION__ >= 310 REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal b/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal index 3685278e6765..8369258a30a6 100644 --- a/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal +++ b/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal @@ -70,6 +70,4 @@ kernel void unfold_backward( INSTANTIATE_UNFOLD_BACKWARD(float); INSTANTIATE_UNFOLD_BACKWARD(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_UNFOLD_BACKWARD(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index d778e50ede65..393c9e1b4d42 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar); INSTANTIATE_UPSAMPLE_3D(uchar); INSTANTIATE_UPSAMPLE_ALL(float); INSTANTIATE_UPSAMPLE_ALL(half); -#if __METAL_VERSION__ >= 310 INSTANTIATE_UPSAMPLE_ALL(bfloat); -#endif diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 806eeb82e1d1..b2a1b2757b13 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -53,6 +53,7 @@ void binary_op_kernel(const std::string func_name, .add_input(input) .add_input(other) .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(true) .build(); lib.exec_binary_kernel(iter, func_name, alpha); diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index a9589ecc490e..06b6edcff940 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -48,28 +48,11 @@ #define BinaryOpFn(graph, primary, secondary) \ MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary) -static inline Tensor legacy_complex_as_view(const Tensor& t) { - // Convert non-complex types (and cdouble CPU scalars) to cfloat - if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) { - return at::view_as_real(t.to(kMPS, kComplexFloat)); - } - return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS)); -} - static void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock) { - TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) && - (self.scalar_type() == ScalarType::Long || - (other.scalar_type() == ScalarType::Long && - (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))), - "MPS: ", - op_name, - " op with int64 input is supported natively starting from macOS 13.2"); - TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(), - "Complex types are supported starting from MacOS 14.0+"); MPSStream* mpsStream = getCurrentMPSStream(); const bool is_self_scalar = self.dim() == 0; diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index f167067216d4..101ef5feb224 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -51,9 +51,6 @@ inline void dot_check(const Tensor& self, const Tensor& other) { } // namespace mps Tensor dot_mps(const Tensor& self, const Tensor& other) { - TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long, - "MPS: dot op doesn't support int64 input on MacOS13") - using namespace mps; using CachedGraph = MPSBinaryCachedGraph; diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 97d562730dd8..d572d52d103a 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -124,7 +124,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, IntArrayRef dilation, int64_t groups, std::optional input_shape) { - const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); Tensor input_t = input_t_; bool is3DConv = input_t.dim() == 5; @@ -132,9 +131,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, input_t = input_t.contiguous(); } - TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer), - "Conv3D is only supported on MPS for MacOS_13_2 or newer"); - TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); using namespace at::native::mps; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 4f879c3b63b0..0c121cee8fb6 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -60,7 +60,6 @@ static void copy_cast_mps(at::Tensor& dst, outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"]; } if (needs_conj) { - TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+"); outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil]; } @@ -275,24 +274,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // for GPU to GPU copies we only encode to stream's command buffer (no flushing) stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id); } else { - // Simulate cast to Complex on older MacOS by initializing real and imag parts - if (dst_.is_complex() && !supportsComplex()) { - if (!src.is_complex()) { - at::real(dst_).copy_(src); - at::imag(dst_).fill_(0); - } else if (src.is_conj() || dst_.is_conj()) { - // One cannot take view of conjugated tensor, but for some reason real and imag views are fine - // Use this to implement a conjugation - at::real(dst_).copy_(at::real(src)); - if (src.is_conj() != dst_.is_conj()) { - at::imag(dst_).copy_(at::neg(at::imag(src))); - } else { - at::imag(dst_).copy_(at::imag(src)); - } - } else { - at::view_as_real(dst_).copy_(at::view_as_real(src)); - } - } else if (dst_byte_offset) { + if (dst_byte_offset) { auto maybeCastedSource = at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource); diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index d072e5a40ac9..4d3f99ea9e02 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -87,7 +87,6 @@ case kFloat: return MPSDataTypeFloat32; case kBFloat16: { - checkSupportsBFloat16(); return MPSDataTypeBFloat16; } default: diff --git a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm index a9ac70110617..7e9867c9b948 100644 --- a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm +++ b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm @@ -88,7 +88,6 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, // TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { - TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(onesided); @autoreleasepool { @@ -129,7 +128,6 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t normalization, int64_t last_dim_size, Tensor& out) { - TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(last_dim_size); @autoreleasepool { @@ -155,7 +153,6 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, } Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) { - TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(forward); @autoreleasepool { diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 1e701d314354..8f51474e7a2c 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -127,15 +127,6 @@ Tensor grid_sampler_2d_mps(const Tensor& input, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) { - TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ", - "Falling back on CPU. This may have performance implications."); - - return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners) - .clone() - .to("mps"); - } - auto in_size = input.sizes(); auto grid_size = grid.sizes(); auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index f00d155559da..a73866dc4357 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -108,26 +108,12 @@ static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, - const std::string& op, - bool accumulate) { - using namespace mps; - + const std::string& op) { const auto num_indices = index_size.size(); TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels"); AT_ASSERT(num_indices == index_stride.size()); AT_ASSERT(static_cast(num_indices) == iter.ntensors() - 2); - const Tensor& inputTensor = iter.tensor(1); - const auto scalar_type = inputTensor.scalar_type(); - - if (accumulate) { - // No atomic support for the complex dtypes - TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type)); - } else { - TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) || - scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf, - getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out")); - } } static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) { @@ -158,7 +144,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter, IntArrayRef index_stride, const std::string& kernel_name, const bool serial = false) { - validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); + validateInputData(iter, index_size, index_stride, "index.Tensor_out"); if (iter.numel() == 0) return; if (!iter.can_use_32bit_indexing()) { @@ -200,7 +186,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter, } static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) { - validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); + validateInputData(iter, index_size, index_stride, "index.Tensor_out"); dispatch_index_kernel( iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0)))); } @@ -210,7 +196,7 @@ static void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_stride, bool accumulate) { @autoreleasepool { - validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate); + validateInputData(iter, index_size, index_stride, "index_put_impl"); if (accumulate) { dispatch_index_kernel(iter, index_size, @@ -353,14 +339,7 @@ static Tensor nonzero_fallback(const Tensor& self) { } Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) { - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ", - "Falling back on CPU. This may have performance implications."); - Tensor out_fallback = nonzero_fallback(self); - at::native::resize_output(out_, out_fallback.sizes()); - out_.copy_(out_fallback); - return out_; - } else if (self.is_complex()) { + if (self.is_complex()) { TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ", "Falling back on CPU. This may have performance implications."); Tensor out_fallback = nonzero_fallback(self); @@ -445,11 +424,7 @@ static Tensor nonzero_fallback(const Tensor& self) { } Tensor nonzero_mps(const Tensor& self) { - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ", - "Falling back on CPU. This may have performance implications."); - return nonzero_fallback(self); - } else if (self.is_complex()) { + if (self.is_complex()) { TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ", "Falling back on CPU. This may have performance implications."); return nonzero_fallback(self); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 3cdf0021e987..7a3dde679c05 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -112,6 +112,61 @@ return output; } +Tensor& do_metal_addmm(const Tensor& self, + const Tensor& other, + Tensor& output, + const Scalar& alpha, + const Scalar& beta, + const Tensor& bias) { + if (beta.toDouble() == 0 && alpha.toDouble() == 1) { + return do_metal_mm(self, other, output); + } + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output)); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other}); + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:matmulPSO]; + std::array sizes = {static_cast(self.size(0)), + static_cast(self.size(1)), + static_cast(output.size(1))}; + std::array strides = {self.stride(0), + self.stride(1), + other.stride(0), + other.stride(1), + output.stride(0), + output.stride(1), + bias.stride(0), + bias.stride(1)}; + union { + std::array i64; + std::array i32; + std::array f32; + } alpha_beta; + if (output.scalar_type() == kLong) { + alpha_beta.i64 = {alpha.toLong(), beta.toLong()}; + } else if (c10::isIntegralType(output.scalar_type(), true)) { + alpha_beta.i32 = {alpha.toInt(), beta.toInt()}; + } else { + TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type())); + alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()}; + } + constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs + uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM; + uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM; + + MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1); + mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + getMPSProfiler().endProfileKernel(matmulPSO); + } + }); + return output; +} + std::tuple do_mm(MPSGraph* graph, const Tensor& self, const Tensor& other) { @@ -644,7 +699,6 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const TORCH_CHECK(output.is_mps()); TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input"); TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}}; checkAllSameGPU(__func__, args); @@ -671,6 +725,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const return output; } + if (use_metal_mm(self, other, output)) { + return do_metal_addmm(self, other, output, alpha, beta, *bias_); + } + bool is_beta_non_zero = beta.toDouble() != 0.0; struct CachedGraph : public mps::MPSCachedGraph { diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 7982645e8fe5..6ae3122cf3d1 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -21,6 +22,8 @@ #include #include #include +#include +#include #endif namespace at::native { @@ -294,13 +297,13 @@ static PoolSizes process_pool_sizes(const Tensor& input, pooling_dims, " ints"); - TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3, + TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == pooling_dims, op_name, ": stride must either be omitted, a single int, or a tuple of ", pooling_dims, " ints"); - TORCH_CHECK(padding.size() == 1 || padding.size() == 3, + TORCH_CHECK(padding.size() == 1 || padding.size() == pooling_dims, op_name, ": padding must either be a single int, or a tuple of ", pooling_dims, @@ -330,6 +333,22 @@ static PoolSizes process_pool_sizes(const Tensor& input, ": pad should be at most half of effective kernel size"); } + if (pooling_dims == 2) { + const auto memory_format = input.suggest_memory_format(); + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + if (memory_format == at::MemoryFormat::ChannelsLast) { + // Expect tensor in NHWC format and allow 0-dim only for N. + TORCH_CHECK((dims == 4 && valid_dims && input.size(3) != 0), + "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: ", + input.sizes()); + } else { + TORCH_CHECK((dims == 3 && input.size(0) != 0 && valid_dims) || (dims == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:", + input.sizes()); + } + } + for (const auto dim : c10::irange(static_cast(leading_dims == 2), dims)) { TORCH_CHECK(input.size(dim) > 0, op_name, ": Expected input's non-batch dimensions to have positive length"); } @@ -492,6 +511,60 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input, }); } +static void max_unpool_out_mps_template(const Tensor& input, + const Tensor& indices, + IntArrayRef output_size_, + IntArrayRef stride, + IntArrayRef padding, + Tensor& output, + const int32_t pooling_dims, + const std::string& op_name) { + auto dims = input.dim(); + auto leading_dims = input.dim() - pooling_dims; + + const auto memory_format = input.suggest_memory_format(); + std::vector output_size(dims); + for (int dim : c10::irange(leading_dims)) { + output_size[dim] = input.sizes()[dim]; + } + for (int dim : c10::irange(pooling_dims)) { + output_size[leading_dims + dim] = output_size_[dim]; + } + + output.resize_(output_size, memory_format); + output.fill_(0); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + const auto numThreads = input.numel(); + MaxUnpoolingParams<5> params; + + params.dims = dims; + params.pooling_dims = pooling_dims; + + for (const auto dim : c10::irange(dims)) { + params.output_sizes[dim] = safe_downcast(output.size(dim)); + params.output_strides[dim] = safe_downcast(output.stride(dim)); + params.input_sizes[dim] = safe_downcast(input.size(dim)); + params.input_strides[dim] = safe_downcast(input.stride(dim)); + params.indices_strides[dim] = safe_downcast(indices.stride(dim)); + } + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto PSO = lib.getPipelineStateForFunc("max_unpool_" + scalarToMetalTypeString(input)); + + getMPSProfiler().beginProfileKernel(PSO, op_name, {input}); + [computeEncoder setComputePipelineState:PSO]; + mtl_setArgs(computeEncoder, output, input, indices, params); + + mtl_dispatch1DJob(computeEncoder, PSO, numThreads); + getMPSProfiler().endProfileKernel(PSO); + } + }); +} + static void avg_pool2d_template(const Tensor& input, const Tensor& output, const std::optional& grad_output_opt, @@ -669,8 +742,76 @@ static void avg_pool_out_mps_template(const Tensor& output, }); } +static void avg_pool_backward_out_mps_template(const Tensor& grad_input, + const Tensor& input, + const Tensor& grad_output, + IntArrayRef _kernel_size, + IntArrayRef _stride, + IntArrayRef _padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override, + const int32_t pooling_dims, + const std::string& op_name) { + auto [dims, _, kernel_size, stride, padding, __] = + process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name); + + const auto memory_format = input.suggest_memory_format(); + grad_input.resize_(input.sizes(), memory_format); + grad_input.fill_(0); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + const auto numThreads = grad_output.numel(); + + AvgPoolingParams<5> params; + + params.dims = dims; + params.pooling_dims = pooling_dims; + params.count_include_pad = count_include_pad; + params.has_divisor_override = divisor_override.has_value(); + if (divisor_override.has_value()) { + params.divisor_override = safe_downcast(divisor_override.value()); + } + + for (const auto dim : c10::irange(dims)) { + params.output_sizes[dim] = safe_downcast(grad_output.size(dim)); + params.output_strides[dim] = safe_downcast(grad_output.stride(dim)); + params.input_sizes[dim] = safe_downcast(grad_input.size(dim)); + params.input_strides[dim] = safe_downcast(grad_input.stride(dim)); + } + + memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t)); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input)); + + getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output}); + [computeEncoder setComputePipelineState:PSO]; + mtl_setArgs(computeEncoder, grad_input, grad_output, params); + + mtl_dispatch1DJob(computeEncoder, PSO, numThreads); + getMPSProfiler().endProfileKernel(PSO); + } + }); +} + } // namespace mps +// TODO: The MPS graph impl can sometimes give significantly better performance +// than the Metal impl for cases where the stride is 1 in all dimensions. There +// may be a code path in the graph kernel that specifically optimizes for that +// case. We should look into implementing a specialized case in Metal so we can +// avoid using the graph impl. +static bool use_graph_for_max_pool2d(IntArrayRef kernel_size, IntArrayRef stride_) { + IntArrayRef stride = stride_.empty() ? kernel_size : stride_; + return (stride[0] == 1) && (stride.size() == 1 || stride[1] == 1); +} + Tensor mps_max_pool2d(const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -678,24 +819,37 @@ Tensor mps_max_pool2d(const Tensor& input, IntArrayRef dilation, bool ceil_mode) { Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); - mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { - MPSGraph* mpsGraph = cachedGraph.graph(); - return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; - }; - mps::pool2d_template(input, - output, - std::nullopt, - std::nullopt, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - false, - std::nullopt, - pooling_op_block, - "max_pool2d"); - + bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); + if (use_graph) { + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; + }; + mps::pool2d_template(input, + output, + std::nullopt, + std::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + std::nullopt, + pooling_op_block, + "max_pool2d"); + } else { + mps::max_pool_with_indices_out_mps_template(output, + std::nullopt, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + /*pooling_dims=*/2, + "max_pool2d"); + } return output; } @@ -740,32 +894,45 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& output, const Tensor& indices) { - auto indices_memory_format = indices.suggest_memory_format(); - - mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { - MPSGraph* mpsGraph = cachedGraph.graph(); - NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor - descriptor:desc - name:nil]; - cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); - return poolOutputs[0]; - }; - mps::pool2d_template(input, - output, - indices, - std::nullopt, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - false, - std::nullopt, - pooling_op_block, - "max_pool2d_indices"); + bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); + if (use_graph) { + auto indices_memory_format = indices.suggest_memory_format(); + + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + NSArray* poolOutputs = + [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; + cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); + return poolOutputs[0]; + }; + mps::pool2d_template(input, + output, + indices, + std::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + std::nullopt, + pooling_op_block, + "max_pool2d_indices"); + if (indices_memory_format == MemoryFormat::ChannelsLast) { + const_cast(indices) = indices.to(MemoryFormat::ChannelsLast); + } - if (indices_memory_format == MemoryFormat::ChannelsLast) { - const_cast(indices) = indices.to(MemoryFormat::ChannelsLast); + } else { + mps::max_pool_with_indices_out_mps_template(output, + indices, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + /*pooling_dims=*/2, + "max_pool2d"); } } @@ -896,6 +1063,68 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output, return grad_input; } +Tensor& max_unpooling2d_forward_out_mps(const Tensor& self, + const Tensor& indices, + IntArrayRef output_size, + Tensor& output) { + mps::max_unpool_out_mps_template(self, + indices, + output_size, + /*stride=*/{}, + /*padding=*/{}, + output, + /*pooling_dims=*/2, + "max_unpool2d"); + return output; +} + +Tensor max_unpooling2d_forward_mps(const Tensor& self, const Tensor& indices, IntArrayRef output_size) { + auto output = at::empty({0}, self.options()); + mps::max_unpool_out_mps_template(self, + indices, + output_size, + /*stride=*/{}, + /*padding=*/{}, + output, + /*pooling_dims=*/2, + "max_unpool2d"); + return output; +} + +Tensor& max_unpooling3d_forward_out_mps(const Tensor& self, + const Tensor& indices, + IntArrayRef output_size, + IntArrayRef stride, + IntArrayRef padding, + Tensor& output) { + mps::max_unpool_out_mps_template(self, + indices, + output_size, + stride, + padding, + output, + /*pooling_dims=*/3, + "max_unpool3d"); + return output; +} + +Tensor max_unpooling3d_forward_mps(const Tensor& self, + const Tensor& indices, + IntArrayRef output_size, + IntArrayRef stride, + IntArrayRef padding) { + auto output = at::empty({0}, self.options()); + mps::max_unpool_out_mps_template(self, + indices, + output_size, + stride, + padding, + output, + /*pooling_dims=*/3, + "max_unpool3d"); + return output; +} + TORCH_IMPL_FUNC(avg_pool2d_out_mps) (const Tensor& input, int64_t kH, @@ -965,4 +1194,26 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output, "avg_pool3d"); } +TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override, + const Tensor& grad_input) { + mps::avg_pool_backward_out_mps_template(grad_input, + input, + grad_output, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + /*pooling_dims=*/3, + "avg_pool3d_backward"); +} + } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 21020bad467d..4b209403f853 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -152,8 +152,6 @@ static void reduction_out_mps(const Tensor& input_t, const Tensor& output_t, MPSReductionType reduction_type, const std::string& func_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name); // NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor bool canSqueezeLastDim = true; IntArrayRef input_shape = input_t.sizes(); @@ -236,12 +234,10 @@ static void reduction_out_mps(const Tensor& input_t, MPSGraphTensor* castInputTensor = inputTensor; MPSDataType inputCastType = MPSDataTypeInvalid; if (dtype.has_value() && - (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || - (dtype.value() == kLong && macOS13_3_plus))) { + (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) { inputCastType = getMPSDataType(dtype.value()); } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && - inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && - (inputScalarType != kLong || !macOS13_3_plus)) { + inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) { inputCastType = getMPSDataType(kFloat); } @@ -615,9 +611,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, } static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median"); - IntArrayRef input_shape = input_t.sizes(); int64_t num_in_elements = c10::multiply_integers(input_shape); @@ -634,8 +627,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { auto medianCachedGraph = LookUpOrCreateCachedGraph(medianKey, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil]; @@ -693,9 +685,6 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { } static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max"); - using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); @@ -713,8 +702,7 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* castOutputTensor = nil; - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); NSArray* axes = getTensorAxes(input_t); if (reduction_type == MPSReductionType::MAX) { @@ -749,9 +737,6 @@ static void min_max_out_mps(const Tensor& input_t, const Tensor& indices_t, MPSReductionType reduction_type, const std::string& func_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out"); - if (output_t.numel() == 0) { return; } @@ -789,8 +774,7 @@ static void min_max_out_mps(const Tensor& input_t, auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); if (reduction_type == MPSReductionType::MAX) { outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; @@ -896,9 +880,6 @@ static void argmax_argmin_out_mps(const Tensor& input_t, const std::string& func_name) { using CachedGraph = MPSUnaryCachedGraph; - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out"); - int64_t dim_ = -1; if (dim.has_value()) { @@ -953,7 +934,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t, MPSGraphTensor* castInputTensor = inputTensor; if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && - (inputScalarType != kLong || !macOS13_3_plus)) { + inputScalarType != kLong) { castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat); } if (reduction_type == MPSReductionType::MAX) { @@ -1282,9 +1263,6 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name); - int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, op_name.c_str()); @@ -1303,7 +1281,7 @@ static void all_any_common_impl_mps(const Tensor& input_t, auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); // reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 MPSGraphTensor* outputTensor = nil; @@ -1369,14 +1347,11 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out"); - @autoreleasepool { std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); // reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 if (input_t.dim() > 4) { @@ -1420,14 +1395,11 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out"); - @autoreleasepool { std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); // reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 if (input_t.ndimension() > 4) { @@ -1512,9 +1484,6 @@ static void median_out_mps_common(const Tensor& input_t, Tensor& indices, const std::string& func_name, bool nanmedian) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out"); - int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "max()"); @@ -1585,8 +1554,7 @@ static void median_out_mps_common(const Tensor& input_t, getTensorsStringKey(indices); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); MPSGraphTensor* effectiveLengthTensor = nil; if (nanmedian) { diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index 10668309a8c2..40afa15b4f70 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -129,16 +129,8 @@ void computeRepeatIndices(const index_t* repeat_ptr, }); } -Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional output_size) { +Tensor repeat_interleave_mps(const Tensor& repeat, std::optional output_size) { Tensor output; - Tensor repeat = repeat_; - if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { - // #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output, - // which currently doesn't support int64_t as input. Casting internally the indices to int32_t. - TORCH_WARN_ONCE( - "MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3"); - repeat = repeat.to(kInt); - } AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() { output = repeat_interleave_common>(repeat, output_size); }); diff --git a/aten/src/ATen/native/mps/operations/ScanKernel.mm b/aten/src/ATen/native/mps/operations/ScanKernel.mm index 9e3269d97014..80495ba9d501 100644 --- a/aten/src/ATen/native/mps/operations/ScanKernel.mm +++ b/aten/src/ATen/native/mps/operations/ScanKernel.mm @@ -23,125 +23,6 @@ #include #endif -// Generic scan implementation that handles both simple scans and scans with indices -static void scan_mps_impl(const Tensor& self, - const std::vector& outputs, - int64_t dim, - const std::string& op_name) { - if (outputs[0].numel() == 0) { - return; - } - - const int64_t ndim = self.dim(); - const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim); - - // Calculate dimensions for scan operation - int64_t row_size = self.size(wrapped_dim); - auto sizes = self.sizes(); - - bool is_innermost = (wrapped_dim == ndim - 1); - - // Check if all tensors are contiguous - bool is_contiguous = self.is_contiguous(); - for (const auto& output : outputs) { - is_contiguous = is_contiguous && output.is_contiguous(); - } - - uint32_t num_rows, num_orows, num_irows, num_threads; - - if (is_innermost) { - // Treat all outer dimensions as a single dimension - num_rows = self.numel() / row_size; - num_threads = num_rows; - } else { - // Treat all outer dimensions (i.e. dim_ < dim) as one - num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim); - // Treat all inner dimensions (i.e. dim > dimension) as one - num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end()); - num_threads = num_orows * num_irows; - } - - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync_with_rethrow(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - - // Choose kernel based on contiguity and dimension - std::string kernel_name; - if (is_contiguous) { - kernel_name = - op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self); - } else { - kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self); - } - - id scanPSO = lib.getPipelineStateForFunc(kernel_name); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() { - std::vector all_tensors = {self}; - all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end()); - return all_tensors; - }()); - - [computeEncoder setComputePipelineState:scanPSO]; - - // Set input tensor - mtl_setBuffer(computeEncoder, self, 0); - - // Set output tensors - for (size_t i = 0; i < outputs.size(); ++i) { - mtl_setBuffer(computeEncoder, outputs[i], i + 1); - } - - if (is_contiguous) { - // Contiguous kernels - if (is_innermost) { - if (outputs.size() == 1) { - // Simple scan - mtl_setArgs<2>(computeEncoder, num_rows, static_cast(row_size)); - } else { - // Scan with indices - mtl_setArgs<3>(computeEncoder, num_rows, static_cast(row_size)); - } - } else { - if (outputs.size() == 1) { - // Simple scan - mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast(row_size)); - } else { - // Scan with indices - mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast(row_size)); - } - } - } else { - // Strided kernels - pass full tensor information - if (outputs.size() == 1) { - // Simple scan - mtl_setArgs<2>(computeEncoder, - self.sizes(), - self.strides(), - outputs[0].strides(), - static_cast(self.ndimension()), - static_cast(wrapped_dim)); - } else { - // Scan with indices - mtl_setArgs<3>(computeEncoder, - self.sizes(), - self.strides(), - outputs[0].strides(), - outputs[1].strides(), - static_cast(self.ndimension()), - static_cast(wrapped_dim)); - } - } - - mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads); - - getMPSProfiler().endProfileKernel(scanPSO); - } - }); -} - // Utility function to get 2D grid dimensions for dispatch static std::pair get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) { size_t grid_x = 1; @@ -375,19 +256,11 @@ static void scan_with_indices_mps_impl(const Tensor& self, } // namespace mps void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax"); - } else { - mps::scan_mps_impl(self, {values, indices}, dim, "cummax"); - } + mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax"); } void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin"); - } else { - mps::scan_mps_impl(self, {values, indices}, dim, "cummin"); - } + mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin"); } Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) { @@ -402,11 +275,7 @@ void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int6 return result; } - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp"); - } else { - mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp"); - } + mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp"); return result; } diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index c73b7c33098f..cfec1e443e25 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -26,9 +26,6 @@ const Tensor& indices) { using namespace mps; - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out"); - if (self.numel() == 0) { return; } @@ -55,8 +52,7 @@ auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self); MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor axis:(NSInteger)dim descending:(BOOL)descending diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 6e030c99d035..16e0608012f3 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -297,9 +297,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, const auto common_type = at::result_type(elements, test_elements); TORCH_CHECK(elements.is_mps() && test_elements.is_mps()); - TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type), - "isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ", - common_type); @autoreleasepool { std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert); diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index edf45a5ff80d..8fbefcb6ab8a 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -208,28 +208,12 @@ static void unary_op(const Tensor& self, } Tensor& angle_out_mps(const Tensor& self, Tensor& output) { - if (mps::supportsComplex()) { - mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; - auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; - return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; - }); - return output; - } else { - TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13") - mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - // On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are - // not available, and NaN is not propagated correctly: - auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType]; - auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil]; - auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil]; - return [mpsGraph selectWithPredicateTensor:nanMask - truePredicateTensor:inputTensor - falsePredicateTensor:result - name:nil]; - }); - return output; - } + mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; + auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; + return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; + }); + return output; } Tensor angle_mps(const Tensor& self) { @@ -362,7 +346,6 @@ static void cumulative_op_impl(const Tensor& self, const Tensor& result, MPSCumulativeOpType cumulativeOpType, const std::string& op_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); auto nDims = self.dim(); auto wrapped_dim = maybe_wrap_dim(dim, nDims); TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()), @@ -381,11 +364,6 @@ static void cumulative_op_impl(const Tensor& self, bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int && input.scalar_type() != ScalarType::Long); - TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long, - "MPS does not support ", - op_name, - " op with int64 input. Support has been added in macOS 13.3"); - mps::unary_op( input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { if (castInputData) { @@ -440,17 +418,9 @@ static void cumulative_op_impl(const Tensor& self, Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) { TORCH_CHECK(self.is_complex()); - if (!mps::supportsComplex()) { - if (!result.is_same_size(self)) { - result.resize_(self.sizes()); - } - at::real(result).copy_(at::real(self)); - at::imag(result).copy_(at::neg(at::imag(self))); - } else { - mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - return [mpsGraph conjugateWithTensor:inputTensor name:nil]; - }); - } + mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return [mpsGraph conjugateWithTensor:inputTensor name:nil]; + }); return result; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b1bb48647743..1bb8fe52512c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -719,6 +719,7 @@ dispatch: CPU, CUDA: all_out MPS: all_out_mps + MTIA: all_out_mtia - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -808,6 +809,7 @@ CPU, Meta: arange_out CUDA: arange_cuda_out MPS: arange_mps_out + MTIA: arange_mtia_out cpp_no_default_args: ['step'] # This function is a temporary hack to allow tracing of arange like constructs with dynamic @@ -1889,7 +1891,10 @@ - func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm - autogen: cudnn_batch_norm.out + +- func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + dispatch: + CUDA: cudnn_batch_norm_out # NB: You can only use this if you used cudnn_batch_norm training=True - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) @@ -4182,11 +4187,13 @@ dispatch: CPU: _int_mm_cpu CUDA: _int_mm_cuda + XPU: _int_mm_xpu - func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _int_mm_out_cpu CUDA: _int_mm_out_cuda + XPU: _int_mm_out_xpu - func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor dispatch: @@ -4223,6 +4230,7 @@ - func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor dispatch: CPU: _weight_int8pack_mm_cpu + CUDA: _weight_int8pack_mm_cuda MPS: _weight_int8pack_mm_mps - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor @@ -7124,18 +7132,21 @@ dispatch: CPU: _scaled_mm_cpu CUDA: _scaled_mm_cuda + tags: needs_exact_strides - func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU: _scaled_mm_out_cpu CUDA: _scaled_mm_out_cuda + tags: needs_exact_strides - func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor variants: function dispatch: CUDA: _scaled_grouped_mm_cuda + tags: needs_exact_strides - func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor variants: function @@ -7412,6 +7423,7 @@ dispatch: SparseCPU: _coalesce_sparse_cpu SparseCUDA: _coalesce_sparse_cuda + SparseMPS: _coalesce_sparse_mps autogen: _coalesce.out - func: is_coalesced(Tensor self) -> bool @@ -7450,7 +7462,7 @@ - func: indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: indices_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: indices_sparse CompositeExplicitAutograd: indices_default device_check: NoCheck device_guard: False @@ -7458,7 +7470,7 @@ - func: values(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: values_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: values_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: values_nested CompositeExplicitAutograd: values_default @@ -10487,6 +10499,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ + MTIA: foreach_tensor_add_scalar_kernel_mtia_ autogen: _foreach_add.Scalar_out - func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] @@ -10495,6 +10508,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda + MTIA: foreach_tensor_add_list_kernel_mtia - func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10502,6 +10516,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ + MTIA: foreach_tensor_add_list_kernel_mtia_ autogen: _foreach_add.List_out - func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10532,6 +10547,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_ CUDA: foreach_tensor_add_tensor_kernel_cuda_ + MTIA: foreach_tensor_add_tensor_kernel_mtia_ autogen: _foreach_add.Tensor_out - func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10592,6 +10608,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ + MTIA: foreach_tensor_mul_scalar_kernel_mtia_ autogen: _foreach_mul.Scalar_out - func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10600,6 +10617,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda + MTIA: foreach_tensor_mul_list_kernel_mtia - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10607,6 +10625,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ + MTIA: foreach_tensor_mul_list_kernel_mtia_ autogen: _foreach_mul.List_out - func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10630,6 +10649,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow CUDA: foreach_tensor_mul_tensor_kernel_cuda + MTIA: foreach_tensor_mul_tensor_kernel_mtia - func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10637,6 +10657,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_ CUDA: foreach_tensor_mul_tensor_kernel_cuda_ + MTIA: foreach_tensor_mul_tensor_kernel_mtia_ autogen: _foreach_mul.Tensor_out - func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10933,6 +10954,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda + MTIA: foreach_tensor_addcmul_scalar_mtia - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10954,6 +10976,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ + MTIA: foreach_tensor_addcmul_scalar_mtia_ autogen: _foreach_addcmul.Scalar_out - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () @@ -10978,6 +11001,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_abs_slow CUDA: foreach_tensor_abs_cuda + MTIA: foreach_tensor_abs_mtia - func: _foreach_abs_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10985,6 +11009,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ + MTIA: foreach_tensor_abs_mtia_ autogen: _foreach_abs.out - func: _foreach_acos(Tensor[] self) -> Tensor[] @@ -11319,6 +11344,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_norm_slow CUDA: foreach_tensor_norm_cuda + MTIA: foreach_tensor_norm_mtia autogen: _foreach_norm.Scalar_out - func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] @@ -11491,6 +11517,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ + MTIA: foreach_tensor_sqrt_mtia_ autogen: _foreach_sqrt.out - func: _foreach_tan(Tensor[] self) -> Tensor[] @@ -11552,6 +11579,7 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_ CUDA: foreach_tensor_copy_list_kernel_cuda_ + MTIA: foreach_tensor_copy_list_kernel_mtia_ autogen: _foreach_copy.out - func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out @@ -11559,6 +11587,7 @@ variants: function dispatch: CompositeExplicitAutograd: _foreach_copy + MTIA: foreach_tensor_copy_list_kernel_mtia - func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor dispatch: @@ -12351,6 +12380,7 @@ dispatch: CPU: avg_pool3d_backward_out_cpu CUDA: avg_pool3d_backward_out_cuda + MPS: avg_pool3d_backward_out_mps MkldnnCPU: mkldnn_avg_pool3d_backward_out - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor @@ -12476,24 +12506,28 @@ dispatch: CPU: max_unpooling2d_forward_out_cpu CUDA: max_unpooling2d_forward_out_cuda + MPS: max_unpooling2d_forward_out_mps - func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor python_module: nn dispatch: CPU: max_unpooling2d_forward_cpu CUDA: max_unpooling2d_forward_cuda + MPS: max_unpooling2d_forward_mps - func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: max_unpooling3d_forward_out_cpu CUDA: max_unpooling3d_forward_out_cuda + MPS: max_unpooling3d_forward_out_mps - func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor python_module: nn dispatch: CPU: max_unpooling3d_forward_cpu CUDA: max_unpooling3d_forward_cuda + MPS: max_unpooling3d_forward_mps - func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -14979,6 +15013,7 @@ - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda + NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda tags: nondeterministic_seeded - func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) @@ -15011,6 +15046,11 @@ CUDA: _cudnn_attention_forward tags: nondeterministic_seeded +- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _cudnn_attention_backward + tags: nondeterministic_seeded + - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 5b7476453407..96c6ab8310f8 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -349,6 +349,63 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } +std::tuple _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + std::optional scale) { + if (!grad_out.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + auto [ + grad_out_buffer_reshaped, + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + output_buffer_reshaped] = + preprocessing::sdpa_nested_preprocessing_backward( + grad_out, + query, + key, + value, + out, + cum_seq_q, + cum_seq_k, + max_q, + max_k); + + auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped, + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + output_buffer_reshaped, + logsumexp, + philox_seed, + philox_offset, + attn_bias, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + scale); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); +} + + std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 1e91fecd4500..807a9b25d377 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half, "'embedding_bag_byte_prepack' only support float32 or float16."); - const auto weight_sizes = weight.sizes(); - const auto cols_dim = weight_sizes.size() - 1; - const int32_t embedding_cols = static_cast(weight_sizes[cols_dim]); + const auto weight_sizes = weight.sym_sizes(); + const auto cols_dim = weight.ndimension() - 1; + const auto embedding_cols = weight_sizes[cols_dim]; // Add 8 bytes per column to store FP32 scale and zero_point per row. - const int32_t output_columns = static_cast(embedding_cols + 2 * sizeof(float)); + const auto output_columns = embedding_cols + 2 * sizeof(float); // Adjust output dimensions to account for FP32 scale and zero_points. - std::vector output_shape = weight_sizes.vec(); + auto output_shape = weight_sizes.vec(); output_shape.at(cols_dim) = output_columns; at::SymDimVector output_shape_vec(output_shape); diff --git a/aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm b/aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm new file mode 100644 index 000000000000..7ccdf4077542 --- /dev/null +++ b/aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm @@ -0,0 +1,220 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native { + +using namespace mps; +using namespace at::sparse; + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + + +static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) { + + TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D"); + TORCH_CHECK(static_cast(indices.size(0)) == size.size(), + "flatten_indices: indices.size(0) must equal size.size()"); + + int64_t sparse_dim = indices.size(0); + int64_t nnz = indices.size(1); + + if (nnz == 0) { + return at::empty({0}, indices.options().dtype(kLong)); + } + + std::vector strides(sparse_dim); + strides[sparse_dim - 1] = 1; + for (int64_t i = sparse_dim - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * size[i + 1]; + } + + Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong)); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + return flat_indices; +} + +static Tensor compute_output_positions(const Tensor& is_unique) { + + int64_t nnz = is_unique.size(0); + if (nnz == 0) { + return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt)); + } + + Tensor positions = at::empty({nnz}, TensorOptions().device(kMPS).dtype(kInt)); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("compute_output_positions_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, is_unique, positions); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + return positions; +} + +static Tensor compute_output_positions_parallel(const Tensor& is_unique) { + + int64_t nnz = is_unique.size(0); + if (nnz == 0) { + return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt)); + } + + // for small arrays, use simple kernel + // speed of the naive kernel drops off after 4096 nnz elements + if (nnz <= 4096) { + return compute_output_positions(is_unique); + } + auto stream = getCurrentMPSStream(); + Tensor positions = is_unique.to(kInt); + // Kogge-Stone parallel prefix sum + Tensor positions_cloned = positions.clone(); + + for (int64_t stride = 1; stride < nnz; stride *= 2) { + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("kogge_stone_step"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, positions, positions_cloned, stride); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + std::swap(positions, positions_cloned); + } + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("shift_right_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, positions, positions_cloned); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + return positions_cloned; +} + +static std::pair mark_unique_and_count(const Tensor& flat_indices) { + + int64_t nnz = flat_indices.size(0); + if (nnz == 0) { + return {at::empty({0}, flat_indices.options().dtype(kBool)), 0}; + } + + Tensor is_unique = at::empty({nnz}, flat_indices.options().dtype(kBool)); + Tensor count_result = at::zeros({1}, flat_indices.options().dtype(kInt)); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("mark_unique_positions_and_count_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, flat_indices, is_unique, count_result); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + int32_t num_unique = count_result.item(); + + return {is_unique, num_unique}; +} + +SparseTensor _coalesce_sparse_mps(const SparseTensor& self) { + int64_t nnz = self._nnz(); + TORCH_INTERNAL_ASSERT(!self.is_coalesced()); + if (nnz < 2) { + SparseTensor dst = self.clone(); + dst._coalesced_(true); + return dst; + } + + Tensor indices = self._indices(); + Tensor values = self._values(); + + Tensor flat_indices = flatten_indices(indices, self.sizes()); + Tensor sorted_order = flat_indices.argsort(); + Tensor flat_indices_sorted = flat_indices.index({sorted_order}); + values = values.index({sorted_order}); + indices = indices.index_select(1, sorted_order); + + auto unique_info = mark_unique_and_count(flat_indices_sorted); + Tensor is_unique = unique_info.first; + int32_t newNnz = unique_info.second; + + Tensor output_positions = compute_output_positions_parallel(is_unique); + + Tensor out_indices = at::empty({indices.size(0), newNnz}, indices.options()); + auto outValuesSize = values.sizes().vec(); + outValuesSize[0] = newNnz; + Tensor out_values = at::zeros(outValuesSize, values.options()); + + Tensor is_unique_local = is_unique; + int64_t sparse_dim = indices.size(0); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("coalesce_with_positions_kernel_" + scalarToMetalTypeString(values)); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + const uint32_t numThreads = static_cast(nnz); + const uint32_t valueSize = static_cast(values.numel() / nnz); + mtl_setArgs(encoder, + flat_indices_sorted, + indices, + values, + is_unique_local, + output_positions, + out_indices, + out_values, + numThreads, + valueSize, + sparse_dim, + newNnz); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + SparseTensor result = _sparse_coo_tensor_unsafe_symint(out_indices, out_values, self.sym_sizes())._coalesced_(true); + return result; +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal b/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal new file mode 100644 index 000000000000..8b85950e393a --- /dev/null +++ b/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal @@ -0,0 +1,128 @@ +#include +#include +using namespace metal; + +kernel void flatten_indices_kernel( + device const int64_t* indices [[buffer(0)]], + device const int64_t* strides [[buffer(1)]], + device int64_t* flat_indices [[buffer(2)]], + constant uint& sparse_dim [[buffer(3)]], + constant uint& nnz [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + int64_t flat_idx = 0; + for (uint d = 0; d < sparse_dim; d++) { + flat_idx += indices[d * nnz + gid] * strides[d]; + } + flat_indices[gid] = flat_idx; +} + +kernel void compute_output_positions_kernel( + device const bool* is_unique [[buffer(0)]], + device int* positions [[buffer(1)]], + uint gid [[thread_position_in_grid]]) { + int pos = 0; + for (uint i = 0; i < gid; i++) { + if (is_unique[i]) + pos++; + } + positions[gid] = pos; +} + +kernel void mark_unique_positions_and_count_kernel( + device const int64_t* flat_indices [[buffer(0)]], + device bool* is_unique [[buffer(1)]], + device atomic_int* count [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + bool unique = (tid == 0) || (flat_indices[tid] != flat_indices[tid - 1]); + is_unique[tid] = unique; + + if (unique) { + atomic_fetch_add_explicit(count, 1, memory_order_relaxed); + } +} + +// Kogge-Stone parallel prefix sum step +kernel void kogge_stone_step( + device const int* input [[buffer(0)]], + device int* output [[buffer(1)]], + constant uint& stride [[buffer(2)]], + uint gid [[thread_position_in_grid]]) { + int val = input[gid]; + if (gid >= stride) { + val += input[gid - stride]; + } + output[gid] = val; +} + +// Shift right for exclusive scan +kernel void shift_right_kernel( + device const int* input [[buffer(0)]], + device int* output [[buffer(1)]], + uint gid [[thread_position_in_grid]]) { + output[gid] = (gid == 0) ? 0 : input[gid - 1]; +} + +template +kernel void coalesce_with_positions_kernel( + device const int64_t* flat_indices [[buffer(0)]], + device const int64_t* indices [[buffer(1)]], + device const T* in_values [[buffer(2)]], + device const bool* is_unique [[buffer(3)]], + device const int* output_positions [[buffer(4)]], + device int64_t* out_indices [[buffer(5)]], + device T* out_values [[buffer(6)]], + constant uint& nnz [[buffer(7)]], + constant uint& value_size [[buffer(8)]], + constant uint& sparse_dim [[buffer(9)]], + constant uint& total_unique [[buffer(10)]], + uint gid [[thread_position_in_grid]]) { + if (!is_unique[gid]) + return; + + int out_pos = output_positions[gid]; + + for (uint d = 0; d < sparse_dim; d++) { + out_indices[d * total_unique + out_pos] = indices[d * nnz + gid]; + } + + int64_t current_index = flat_indices[gid]; + uint end = gid + 1; + while (end < nnz && flat_indices[end] == current_index) { + end++; + } + + for (uint elem = 0; elem < value_size; elem++) { + T sum = 0; + for (uint j = gid; j < end; j++) { + sum += in_values[j * value_size + elem]; + } + out_values[out_pos * value_size + elem] = sum; + } +} + +#define INSTANTIATE_COALESCE_WITH_POSITIONS(DTYPE) \ + template \ + [[host_name("coalesce_with_positions_kernel_" #DTYPE)]] [[kernel]] void \ + coalesce_with_positions_kernel( \ + device const int64_t* flat_indices [[buffer(0)]], \ + device const int64_t* indices [[buffer(1)]], \ + device const DTYPE* in_values [[buffer(2)]], \ + device const bool* is_unique [[buffer(3)]], \ + device const int* output_positions [[buffer(4)]], \ + device int64_t* out_indices [[buffer(5)]], \ + device DTYPE* out_values [[buffer(6)]], \ + constant uint& nnz [[buffer(7)]], \ + constant uint& value_size [[buffer(8)]], \ + constant uint& sparse_dim [[buffer(9)]], \ + constant uint& total_unique [[buffer(10)]], \ + uint gid [[thread_position_in_grid]]); + +INSTANTIATE_COALESCE_WITH_POSITIONS(float); +INSTANTIATE_COALESCE_WITH_POSITIONS(half); +INSTANTIATE_COALESCE_WITH_POSITIONS(bfloat); +INSTANTIATE_COALESCE_WITH_POSITIONS(bool); +INSTANTIATE_COALESCE_WITH_POSITIONS(long); +INSTANTIATE_COALESCE_WITH_POSITIONS(char); +INSTANTIATE_COALESCE_WITH_POSITIONS(uchar); +INSTANTIATE_COALESCE_WITH_POSITIONS(short); +INSTANTIATE_COALESCE_WITH_POSITIONS(int); \ No newline at end of file diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 80049aa9a832..1a3e2825d4fa 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -849,16 +849,6 @@ std::tuple _efficient_ if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) std::optional out(res); std::optional seqused_k = std::nullopt; std::optional alibi_slopes = std::nullopt; diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 3888df64ad80..55e86e0240db 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -26,6 +26,8 @@ #else #include #include +#include +#include #include #include #include @@ -184,7 +186,7 @@ std::tuple _flash_attention_backward( return std::make_tuple(Tensor(), Tensor(), Tensor()); } -std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( +std::tuple _cudnn_attention_backward( const Tensor& grad_out, const Tensor& query, const Tensor& key, @@ -211,57 +213,117 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ } } - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); + const bool is_nested = cum_seq_q.defined(); const int64_t max_seqlen_batch_q = query.size(2); const int64_t max_seqlen_batch_k = key.size(2); - // This is needed because SaveVariable automatically converts - // std::optional to undefined tensor - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + if (!is_nested) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + } } - } - const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - run_cudnn_SDP_bprop(batch_size /*int64_t b*/, - num_heads /*int64_t h*/, - max_q/*int64_t s_q*/, - max_k/*int64_t s_kv*/, - head_dim_qk /*int64_t d_qk*/, - head_dim_v /*int64_t d_v*/, - softmax_scale /*float scaling_factor*/, - is_causal /*bool is_causal*/, - dropout_p /*float dropout_probability*/, - query /*const Tensor& q*/, - key /*const Tensor& k*/, - value /*const Tensor& v*/, - attn_bias_ /*const std::optional& attn_bias*/, - out /*const Tensor& o*/, - grad_out/*const Tensor& dO*/, - logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, - dq/*Tensor& dQ*/, - dk/*Tensor& dK*/, - dv/*Tensor& dV*/, - philox_seed/*Tensor& dropoutseed*/, - philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + } else { + // BHSD ... + const int64_t batch_size = cum_seq_q.size(0) - 1; + const int64_t num_heads_q = query.size(-2); + const int64_t num_heads_k = key.size(-2); + const int64_t num_heads_v = value.size(-2); + const int64_t head_dim_qk = query.size(-1); + const int64_t head_dim_v = value.size(-1); + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + } + } + + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + run_cudnn_SDP_bprop_nestedtensor( + batch_size, + num_heads_q, + num_heads_k, + num_heads_v, + max_seqlen_batch_q, + max_seqlen_batch_k, + head_dim_qk, + head_dim_v, + softmax_scale, + is_causal, + dropout_p, + cum_seq_q, + cum_seq_k, + query, + key, + value, + attn_bias_, + out, + grad_out, + logsumexp, + dq, + dk, + dv, + philox_seed, + philox_offset); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + } } std::tuple @@ -431,7 +493,7 @@ _efficient_attention_backward( // ROCM Implementation if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float(); // Store grad_bias in optional std::optional opt_grad_bias = grad_bias; @@ -1063,4 +1125,40 @@ std::tuple _scaled_dot_product_e } } +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + std::optional scale) { + return at::_cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp, + philox_seed, + philox_offset, + attn_bias, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + scale); +} + } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 4b198f4d6d2d..4b85b2d28753 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -57,21 +57,28 @@ namespace sdp { namespace { +// tracks whether we've set the default priority order once, to avoid setting +// it redundantly or overwriting a user-specified priority order +// when the priority order context manager is used before the default priority +// order is initialized the following happens: +// (1) the current priority order is queried +// (2) priority_order() is called, which initializes it to the default as init_ is false +// (3) the user-specified priority order is set +// (3.1) we are in the priority context... +// (3.2) we exit the priority context... +// (4) the previous priority order (default) is restored +bool priority_order_init_ = false; + // TODO(eqy): more benchmarking to determine whether this should include sm86/89 // Needs to be kept in-sync with test_fused_chocie in test_transformers.py bool check_prefer_cudnn_attention() { - // TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0 - // see context: https://github.com/pytorch/pytorch/issues/138340 - // return false; -#if defined(CUDNN_VERSION) - -#if CUDNN_VERSION > 90000 + static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true; + if (!prefer_cudnn) { + return false; + } +#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000)) auto dprops = at::cuda::getCurrentDeviceProperties(); - return dprops->major >= 9; -#else - return false; -#endif - + return dprops->major >= 9 && !dprops->minor; #else return false; #endif @@ -79,6 +86,16 @@ bool check_prefer_cudnn_attention() { // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { + if (!priority_order_init_) { + priority_order_init_ = true; + if (check_prefer_cudnn_attention()) { + const std::vector cudnn_order = {static_cast(at::SDPBackend::cudnn_attention), + static_cast(at::SDPBackend::flash_attention), + static_cast(at::SDPBackend::efficient_attention), + static_cast(at::SDPBackend::math)}; + at::globalContext().setSDPPriorityOrder(cudnn_order); + } + } return at::globalContext().sDPPriorityOrder(); } @@ -414,12 +431,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } auto head_dim_limit = 128; - if (cudnn_version >= 90501) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (dprops->major == 9 && !dprops->minor) { - head_dim_limit = 256; - } - } + // TODO(eqy): add head dim >= 256 cases once support is finalized if (d_qk > head_dim_limit || d_v > head_dim_limit) { if (debug) { TORCH_WARN("head_dim should be no more than ", head_dim_limit); @@ -453,9 +465,15 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } } - if (s_q == 1 || s_k == 1) { + if (s_k == 1) { + if (debug) { + TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1."); + } + return false; + } + if (s_q == 1 && params.dropout != 0.0) { if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1."); + TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout."); } return false; } @@ -563,9 +581,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested - if (dprop->major != 9 && has_for_nested_inputs(params)) { + if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) { if (debug) { - TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); + TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0."); } return false; } @@ -589,7 +607,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // sdp kernels if (!at::globalContext().userEnabledCuDNNSDP()) { if (debug) { - TORCH_WARN("CuDNN attention has been runtime disabled."); + TORCH_WARN("cuDNN attention has been runtime disabled."); } return false; } @@ -620,7 +638,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { - TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); } return false; #endif @@ -630,10 +648,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { c10::array_of( check_runtime_disabled_cudnn, check_for_nested_inputs, - check_nonzero_sequence_lengths_dense, check_all_tensors_on_device, check_tensor_shapes, - check_cudnn_tensor_shapes, check_cudnn_deterministic, check_dtypes_low_precision, check_attn_mask_shape, @@ -646,8 +662,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } constexpr auto dense_constraints = c10::array_of( + check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense + check_batch_size_and_num_heads_dense, + check_cudnn_tensor_shapes ); if (has_only_dense_inputs(params)) { @@ -864,7 +882,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); - TORCH_WARN("CuDNN attention kernel not used because:"); + TORCH_WARN("cuDNN attention kernel not used because:"); sdp::can_use_cudnn_attention(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt index b30c39340036..819880cf3bc5 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt @@ -1,7 +1,7 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt + --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) @@ -11,7 +11,27 @@ endif() execute_process( COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt + --api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt + RESULT_VARIABLE ret +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_SPLITKV kernels via Python.") +endif() + +execute_process( + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + --api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt + RESULT_VARIABLE ret +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_APPENDKV kernels via Python.") +endif() + +execute_process( + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) @@ -19,15 +39,29 @@ if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.") endif() -# Generate the files for both fwd and bwd -execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR} +# Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") endif() -execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.") +endif() + +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.") +endif() + +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret ) @@ -44,6 +78,22 @@ if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass") endif() +execute_process( + COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt" + RESULT_VARIABLE ret) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd_splitkv pass") +endif() + +execute_process( + COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt" + RESULT_VARIABLE ret) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass") +endif() + # Change make_kernel to make_kernel_pt for bwd execute_process( COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt" diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh b/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh index 672bea143751..849613f79569 100755 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh @@ -21,6 +21,8 @@ while IFS= read -r file; do if [ -f "$file" ]; then # Use sed to replace "make_kernel" with "make_kernel_pt" in place sed -i 's/make_kernel/make_kernel_pt/g' "$file" + sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file" + sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file" echo "Updated: $file" else echo "Skipping: $file (not found)" diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp deleted file mode 100644 index 8115288fb887..000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp +++ /dev/null @@ -1,100 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -// keep sync with BlockAttentionBiasEnum -enum class bias_enum -{ - no_bias = 0, - elementwise_bias = 1, - alibi = 2, -}; - -struct bias_info -{ - bias_enum type; - /* - * simple dispatch logic - * - * if type == elementwise_bias: - * if rank_info == 0: - * bias is 1*1*s*s - * elif rank_info == 1: - * bias is 1*h*s*s - * elif rank_info == 2: - * bias is b*h*s*s - * - * elif type == alibi: - * if rank_info == 0: - * alibi in 1*h - * elif rank_info == 1: - * alibi in b*h - */ - int rank_info; - - void serialize(std::ostream& os) const - { - if(type == bias_enum::no_bias) - os << "n"; - else if(type == bias_enum::elementwise_bias) - { - os << "e"; - if(rank_info != 0) - { - os << "[" << rank_info << "]"; - } - } - else if(type == bias_enum::alibi) - { - os << "alibi"; - if(rank_info != 0) - { - os << "[" << rank_info << "]"; - } - } - } - - static bias_info decode(std::string str) - { - bias_info info{bias_enum::no_bias, 0}; - if(str == "0" || str == "n") - { - info.type = bias_enum::no_bias; - } - else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || - str.compare(0, 11, "elementwise") == 0) - { - info.type = bias_enum::elementwise_bias; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string e = str.substr(found_0 + 1); - info.rank_info = atoi(e.c_str()); - } - } - else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || - str.compare(0, 5, "alibi") == 0) - { - info.type = bias_enum::alibi; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string e = str.substr(found_0 + 1); - info.rank_info = atoi(e.c_str()); - } - } - return info; - } - - friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) - { - bi.serialize(os); - return os; - } -}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp deleted file mode 100644 index affa40619b59..000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp +++ /dev/null @@ -1,457 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -struct FmhaBwdFp16 -{ -}; - -struct FmhaBwdBf16 -{ -}; - -template -struct FmhaBwdTypeConfig; - -template <> -struct FmhaBwdTypeConfig -{ - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; - using GemmDataType = ck_tile::half_t; - using BiasDataType = ck_tile::half_t; - using LSEDataType = float; - using AccDataType = float; // data type for gemm accumulation - using DDataType = float; - using RandValOutputDataType = uint8_t; - using ODataType = ck_tile::half_t; - using OGradDataType = ck_tile::half_t; - using QGradDataType = ck_tile::half_t; - using KGradDataType = ck_tile::half_t; - using VGradDataType = ck_tile::half_t; - using BiasGradDataType = ck_tile::half_t; -}; - -template <> -struct FmhaBwdTypeConfig -{ - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using GemmDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using LSEDataType = float; - using AccDataType = float; // data type for gemm accumulation - using DDataType = float; - using RandValOutputDataType = uint8_t; - using ODataType = ck_tile::bf16_t; - using OGradDataType = ck_tile::bf16_t; - using QGradDataType = ck_tile::bf16_t; - using KGradDataType = ck_tile::bf16_t; - using VGradDataType = ck_tile::bf16_t; - using BiasGradDataType = ck_tile::bf16_t; -}; - -struct FmhaMasks -{ - using NoMask = ck_tile::GenericAttentionMask; - using GenericMask = ck_tile::GenericAttentionMask; - using CausalMask = ck_tile::GenericAttentionMask; -}; - -// runtime args, some will passed to karg, some will used to compute grids/blocks -struct fmha_bwd_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* bias_ptr; // bias or alibi_slope pointer - const void* o_ptr; - const void* lse_ptr; - const void* do_ptr; - void* d_ptr; - void* rand_val_ptr; - void* dq_ptr; - void* dk_ptr; - void* dv_ptr; - void* dbias_ptr; - void* dq_acc_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_k_ptr; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t max_seqlen_k; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; - float scale; - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 - ck_tile::index_t stride_o; - ck_tile::index_t stride_randval; - ck_tile::index_t stride_do; - ck_tile::index_t stride_dq_acc; - ck_tile::index_t stride_dq; - ck_tile::index_t stride_dk; - ck_tile::index_t stride_dv; - ck_tile::index_t stride_dbias; - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_bias; - ck_tile::index_t nhead_stride_o; - ck_tile::index_t nhead_stride_randval; - ck_tile::index_t nhead_stride_do; - ck_tile::index_t nhead_stride_lsed; - ck_tile::index_t nhead_stride_dq_acc; - ck_tile::index_t nhead_stride_dq; - ck_tile::index_t nhead_stride_dk; - ck_tile::index_t nhead_stride_dv; - ck_tile::index_t nhead_stride_dbias; - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_o; - ck_tile::index_t batch_stride_randval; - ck_tile::index_t batch_stride_do; - ck_tile::index_t batch_stride_lsed; - ck_tile::index_t batch_stride_dq_acc; - ck_tile::index_t batch_stride_dq; - ck_tile::index_t batch_stride_dk; - ck_tile::index_t batch_stride_dv; - ck_tile::index_t batch_stride_dbias; - ck_tile::index_t split_stride_dq_acc; - ck_tile::index_t window_size_left; - ck_tile::index_t window_size_right; - ck_tile::index_t mask_type; - float p_drop; - float p_undrop; - std::variant, std::pair> - drop_seed_offset; -}; - -template -auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) - { - return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); - } - else - { // create batch mode kernel arguments - return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_do, - args.batch_stride_lsed, - args.batch_stride_dq_acc, - args.batch_stride_dk, - args.batch_stride_dv, - args.batch_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); - } - }(); - - dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); - return ck_tile::make_tuple(kargs, grids); -} - -template -auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) -{ - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) - { - return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, - args.do_ptr, - args.d_ptr, - args.p_undrop, - args.seqstart_q_ptr, - args.hdim_v, - args.stride_do, - args.stride_o, - args.nhead_stride_do, - args.nhead_stride_o, - args.nhead_stride_lsed); - } - else - { // create batch mode kernel arguments - return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, - args.do_ptr, - args.d_ptr, - args.p_undrop, - args.seqlen_q, - args.hdim_v, - args.stride_do, - args.stride_o, - args.nhead_stride_do, - args.nhead_stride_o, - args.nhead_stride_lsed, - args.batch_stride_do, - args.batch_stride_o, - args.batch_stride_lsed); - } - }(); - - dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); - return ck_tile::make_tuple(kargs, grids); -} - -template -auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) -{ - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) - { - return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, - args.dq_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.hdim_q, - args.stride_dq, - args.stride_dq_acc, - args.nhead_stride_dq, - args.nhead_stride_dq_acc, - args.split_stride_dq_acc); - } - else - { // create batch mode kernel arguments - return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, - args.dq_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.stride_dq, - args.stride_dq_acc, - args.nhead_stride_dq, - args.nhead_stride_dq_acc, - args.batch_stride_dq, - args.batch_stride_dq_acc, - args.split_stride_dq_acc); - } - }(); - - dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); - return ck_tile::make_tuple(kargs, grids); -} - -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct fmha_bwd_dq_dk_dv_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; - using FmhaMask = ck_tile::remove_cvref_t; - using FmhaDropout = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsDeterministic = kIsDeterministic_; -}; - -template -float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); - -template -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); - -template -std::string fmha_bwd_dq_dk_dv_get_name_(); - -template -struct fmha_bwd_dot_do_o_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadDv = kPadDv_; -}; - -template -float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); - -template -void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); - -template -std::string fmha_bwd_dot_do_o_get_name_(); - -template -struct fmha_bwd_convert_dq_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kIsDeterministic = kIsDeterministic_; -}; - -template -float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); - -template -void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); - -template -std::string fmha_bwd_convert_dq_get_name_(); - -// This is the public API, will be generated by script -struct fmha_bwd_traits -{ - int hdim_q; - int hdim_v; - std::string data_type; - bool is_group_mode; - mask_enum mask_type; - bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum - bool has_dbias; - bool has_dropout; - bool is_store_randval; - bool is_deterministic; - // TODO: padding check is inside this api -}; -template -float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp deleted file mode 100644 index 2de70cd49bbb..000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp +++ /dev/null @@ -1,824 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include - -struct FmhaFwdFp16 -{ -}; - -struct FmhaFwdBf16 -{ -}; - -struct FmhaFwdFp8 -{ -}; - -struct FmhaFwdBf8 -{ -}; - -struct FmhaFwdFp8Fp16 -{ -}; - -struct FmhaFwdFp8Bf16 -{ -}; - -template -struct FmhaFwdTypeConfig; - -template <> -struct FmhaFwdTypeConfig -{ - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; - using BiasDataType = ck_tile::half_t; - using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::half_t; -}; - -template <> -struct FmhaFwdTypeConfig -{ - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; -}; - -template <> -struct FmhaFwdTypeConfig -{ - using QDataType = ck_tile::fp8_t; - using KDataType = ck_tile::fp8_t; - using VDataType = ck_tile::fp8_t; - using BiasDataType = float; - using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp8_t; -}; - -template <> -struct FmhaFwdTypeConfig -{ - using QDataType = ck_tile::bf8_t; - using KDataType = ck_tile::bf8_t; - using VDataType = ck_tile::bf8_t; - using BiasDataType = ck_tile::bf8_t; - using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf8_t; -}; - -struct FmhaMasks -{ - using NoMask = ck_tile::GenericAttentionMask; - using GenericMask = ck_tile::GenericAttentionMask; - using CausalMask = ck_tile::GenericAttentionMask; -}; - -// runtime args, some will passed to karg, some will used to compute grids/blocks -struct fmha_fwd_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* bias_ptr; // bias or alibi_slope pointer - void* rand_val_ptr; - void* lse_ptr; - void* o_ptr; - - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; - - float scale_s; - float scale_p; - float scale_o; - - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 - ck_tile::index_t stride_randval; - ck_tile::index_t stride_o; - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_bias; - ck_tile::index_t nhead_stride_randval; - ck_tile::index_t nhead_stride_lse; - ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_randval; - ck_tile::index_t batch_stride_lse; - ck_tile::index_t batch_stride_o; - - ck_tile::index_t window_size_left; - ck_tile::index_t window_size_right; - ck_tile::index_t mask_type; - - float p_drop; - bool s_randval; - - std::variant, std::pair> - drop_seed_offset; -}; - -struct fmha_fwd_splitkv_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* bias_ptr; // bias or alibi_slope pointer - void* lse_acc_ptr; - void* o_acc_ptr; - void* lse_ptr; - void* o_ptr; - - void* block_table_ptr; - ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr - ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr - bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not - // nullptr. - - const void* cache_batch_idx; - - // the real seqlen_q & seqlen_k are decided by following: - // batch mode: seqlen_q = kargs.seqlen_q - // seqlen_k = kargs.seqlen_k - // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] - // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // or kargs.seqlen_k_ptr[b] - // - // batch mode (kvcache): - // seqlen_q = kargs.seqlen_q - // seqlen_k = kargs.seqlen_k_ptr[b] - // group mode (kvcache): - // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] - // - // when is_gappy=true: - // seqlen_k = kargs.seqlen_k_ptr[b] - // seqstart_k_ptr[b] now store local offset of each batch - // - // when is_gappy=false: - // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // or kargs.seqlen_k_ptr[b] - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_k_ptr; - - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; - ck_tile::index_t num_splits; - - float scale_s; - float scale_p; - float scale_o; - - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 - ck_tile::index_t stride_o_acc; - ck_tile::index_t stride_o; - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_bias; - ck_tile::index_t nhead_stride_lse; - ck_tile::index_t nhead_stride_lse_acc; - ck_tile::index_t nhead_stride_o_acc; - ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_lse; - ck_tile::index_t batch_stride_lse_acc; - ck_tile::index_t batch_stride_o_acc; - ck_tile::index_t batch_stride_o; - ck_tile::index_t split_stride_lse_acc; - ck_tile::index_t split_stride_o_acc; - - ck_tile::index_t window_size_left; - ck_tile::index_t window_size_right; - ck_tile::index_t mask_type; -}; - -struct fmha_fwd_appendkv_args -{ - void* q_ptr; - void* k_ptr; - const void* knew_ptr; - void* v_ptr; - const void* vnew_ptr; - - const void* seqlen_k_ptr; - - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_knew; - ck_tile::index_t batch; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; - - const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 - const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 - ck_tile::index_t rotary_dim; - bool has_mask; - - void* block_table_ptr; - ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr - ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr - - const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) - - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_knew; - ck_tile::index_t stride_v; - ck_tile::index_t stride_vnew; - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_knew; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_vnew; - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_knew; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_vnew; -}; - -template -auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); - } - else - { // create batch mode kernel arguments - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); - } - }(); - - if constexpr(FmhaKernel::kIsGroupMode) - { - dim3 grids = FmhaKernel::GridSize( - args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); - return ck_tile::make_tuple(kargs, grids); - } - else - { - dim3 grids = - FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); - return ck_tile::make_tuple(kargs, grids); - } -} - -template -auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(Kernel::kIsGroupMode) - { - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.is_gappy, - args.scale_s, - args.scale_p, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_o_acc, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.batch_stride_k, // only used for paged-kvcache - args.batch_stride_v, // only used for paged-kvcache - args.split_stride_lse_acc, - args.split_stride_o_acc, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - else - { // create batch mode kernel arguments - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.seqlen_q, - args.seqlen_k, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.cache_batch_idx, - args.scale_s, - args.scale_p, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_o_acc, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, - args.split_stride_lse_acc, - args.split_stride_o_acc, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - }(); - - dim3 grids = Kernel::GridSize( - args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); - - return ck_tile::make_tuple(kargs, grids); -} - -template -auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel argumentszs - if constexpr(Kernel::kIsGroupMode) - { - return Kernel::MakeKargs(args.lse_acc_ptr, - args.o_acc_ptr, - args.lse_ptr, - args.o_ptr, - args.batch, - args.seqstart_q_ptr, - args.hdim_v, - args.num_splits, - args.scale_o, - args.stride_o_acc, - args.stride_o, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.nhead_stride_lse, - args.nhead_stride_o, - args.split_stride_lse_acc, - args.split_stride_o_acc); - } - else - { // create batch mode kernel arguments - return Kernel::MakeKargs(args.lse_acc_ptr, - args.o_acc_ptr, - args.lse_ptr, - args.o_ptr, - args.batch, - args.seqlen_q, - args.hdim_v, - args.num_splits, - args.scale_o, - args.stride_o_acc, - args.stride_o, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, - args.batch_stride_lse, - args.batch_stride_o, - args.split_stride_lse_acc, - args.split_stride_o_acc); - } - }(); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); - - return ck_tile::make_tuple(kargs, grids); -} - -template -auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.knew_ptr, - args.v_ptr, - args.vnew_ptr, - args.seqlen_q, - args.seqlen_k_ptr, - args.seqlen_knew, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.rotary_cos_ptr, - args.rotary_sin_ptr, - args.rotary_dim, - args.has_mask, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.cache_batch_idx, - args.stride_q, - args.stride_k, - args.stride_knew, - args.stride_v, - args.stride_vnew, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_knew, - args.nhead_stride_v, - args.nhead_stride_vnew, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_knew, - args.batch_stride_v, - args.batch_stride_vnew); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); - - return ck_tile::make_tuple(kargs, grids); -} - -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct fmha_fwd_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN0 = kN0_; - static constexpr ck_tile::index_t kK0 = kK0_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr ck_tile::index_t kK1 = kK1_; - static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kHasDropout = kHasDropout_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; -}; - -template -float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); - -template -struct fmha_fwd_splitkv_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN0 = kN0_; - static constexpr ck_tile::index_t kK0 = kK0_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr ck_tile::index_t kK1 = kK1_; - static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsPagedKV = kIsPagedKV_; -}; - -template -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); - -template -std::string fmha_fwd_splitkv_get_name_(); - -template -struct fmha_fwd_splitkv_combine_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadDv = kPadDv_; -}; - -template -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); - -template -std::string fmha_fwd_splitkv_combine_get_name_(); - -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct fmha_fwd_appendkv_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; - static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; - static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; - static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSk = kPadSk_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr auto RotaryEnum = RotaryEnum_; - static constexpr bool kIsPagedKV = kIsPagedKV_; -}; - -template -float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); - -// This is the public API, will be generated by script -struct fmha_fwd_traits -{ - int hdim_q; - int hdim_v; - std::string data_type; - bool is_group_mode; - bool is_v_rowmajor; - mask_enum mask_type; - bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum - bool has_lse; - bool has_dropout; - bool do_fp8_static_quant; - // TODO: padding check is inside this api -}; -float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); - -struct fmha_fwd_splitkv_traits -{ - int hdim_q; - int hdim_v; - std::string data_type; - bool is_group_mode; - bool is_v_rowmajor; - mask_enum mask_type; - bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum - bool has_lse; - bool do_fp8_static_quant; - // TODO: padding check is inside this api -}; -float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, - fmha_fwd_splitkv_args, - const ck_tile::stream_config&); - -struct fmha_fwd_appendkv_traits -{ - int hdim_q; - int hdim_v; - std::string data_type; - bool is_v_rowmajor; - rope_enum rope_type; -}; -float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, - fmha_fwd_appendkv_args, - const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp deleted file mode 100644 index 133049057d78..000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp +++ /dev/null @@ -1,157 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include -#include - -// keep this in sync with ck_tile::GenericAttentionMaskEnum -enum class mask_enum -{ - no_mask = 0, - mask_top_left, - mask_bottom_right, - window_generic, -}; - -struct mask_info -{ - mask_enum type; - ck_tile::index_t y, x; - ck_tile::index_t left, right; // FA style SWA left/right - - void serialize(std::ostream& os) const - { - if(type == mask_enum::no_mask) - os << "n"; - else if(type == mask_enum::mask_top_left) - os << "t(" << left << ":" << right << ")"; - else if(type == mask_enum::mask_bottom_right) - os << "b(" << left << ":" << right << ")"; - else - { - os << "g(" << y << ":" << x << ")"; - } - } - static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) - { - ck_tile::index_t x_total = seqlen_k; - ck_tile::index_t y_total = seqlen_q; - mask_info tmp; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string t = str.substr(0, found_0); - std::string v = str.substr(found_0 + 1); - if(t == "xt" || t == "xb") - { - // xformer style sliding window attn from top-left - ck_tile::index_t window_size = atoi(v.c_str()); - ck_tile::index_t left_size = -1; - ck_tile::index_t right_size = 0; - if(window_size > 0) - { - left_size = window_size / 2; - right_size = window_size - 1 - left_size; - } - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, t == "xt"); - - tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = left_size; - tmp.right = right_size; - } - else - { - auto found_1 = v.find(","); - if(found_1 == std::string::npos) - { - printf("not supported value %s, %s\n", v.c_str(), str.c_str()); - assert(0); - } - tmp.type = mask_enum::window_generic; - ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); - ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation - if(t == "t") - { - tmp.type = mask_enum::mask_top_left; - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = v0; - tmp.right = v1; - } - else if(t == "b") - { - tmp.type = mask_enum::mask_bottom_right; - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = v0; - tmp.right = v1; - } - else if(t == "g") - { - tmp.y = v0; - tmp.x = v1; - tmp.left = v0; // TODO: don't use this? - tmp.right = v1; - } - else - { - printf("not supported type %s, %s\n", t.c_str(), str.c_str()); - assert(0); - } - } - } - else - { - auto set_causal_top_left = [&]() { - tmp.type = mask_enum::mask_top_left; - tmp.y = seqlen_q; - tmp.x = 1; - tmp.left = -1; - tmp.right = 0; - }; - auto set_causal_bottom_right = [&]() { - tmp.type = mask_enum::mask_bottom_right; - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; - tmp.left = -1; - tmp.right = 0; - }; - if(str == "t") - set_causal_top_left(); - else if(str == "b") - set_causal_bottom_right(); - else - { - tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::mask_top_left) - { - set_causal_top_left(); - } - else if(tmp.type == mask_enum::mask_bottom_right) - { - set_causal_bottom_right(); - } - } - } - return tmp; - } - - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) - { - mi.serialize(os); - return os; - } -}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip index 601ffd2d0752..59669afb93d2 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip @@ -1,7 +1,7 @@ #include #include -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) namespace pytorch_flash { std::tuple< at::Tensor, // dQ @@ -117,4 +117,4 @@ mem_eff_backward_ck( } } // namespace pytorch_flash -#endif // USE_CK_FLASH_ATTENTION +#endif // USE_ROCM_CK_SDPA diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h index 6fd46467bc07..e92006ef6315 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h @@ -3,7 +3,7 @@ #include -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) namespace pytorch_flash { std::tuple< @@ -64,4 +64,4 @@ mem_eff_backward_ck( const at::Tensor philox_offset); } // namespace pytorch_flash -#endif // USE_CK_FLASH_ATTENTION +#endif // USE_ROCM_CK_SDPA diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip index fac77821a56c..d15c5105d0b4 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip @@ -1,7 +1,7 @@ #include #include -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) namespace pytorch_flash { std::tuple< at::Tensor, // output @@ -93,4 +93,4 @@ mem_eff_forward_ck( } } // namespace pytorch_flash -#endif // USE_CK_FLASH_ATTENTION +#endif // USE_ROCM_CK_SDPA diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip index b0c2a31df099..05f97414acdd 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip @@ -22,6 +22,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, dtype, false, // is_group_mode true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias, has_lse, @@ -85,6 +86,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, ck_tile::index_t stride_attn_bias = 0; ck_tile::index_t batch_stride_bias = 0; ck_tile::index_t nhead_stride_bias = 0; + if (attn_bias_.has_value()) { auto a_b = attn_bias_.value(); CHECK_DEVICE(a_b); @@ -94,7 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nhead_stride_bias = a_b.stride(1); batch_stride_bias = a_b.stride(0); } - return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -116,6 +117,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, softmax_scale, // scale_s 1, // scale_p 1, // scale_o + 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, @@ -139,6 +141,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), + -1, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip index ece6f29877ab..ee6261df8a91 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -20,6 +20,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, dtype, true, // is_group_mode true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias, has_lse, @@ -117,6 +118,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, softmax_scale, // scale_s 1, // scale_p 1, // scale_o + 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, @@ -140,6 +142,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), + -1, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp deleted file mode 100644 index 85754c037872..000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp +++ /dev/null @@ -1,84 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -// keep sync with RotaryEmbeddingEnum -enum class rope_enum -{ - none = 0, - interleaved = 1, - half_rotated = 2, -}; - -template -std::tuple, ck_tile::HostTensor> -generate_rotary_cos_sin(ck_tile::index_t seqlen, - ck_tile::index_t rotary_dim, - std::optional seed = std::nullopt) -{ - // return dummy tensors if we won't apply RoPE at all - if(rotary_dim <= 0) - { - ck_tile::HostTensor dummy({1, 1}); - return std::make_tuple(dummy, dummy); - } - - std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); - std::uniform_real_distribution generator(0.0f, 1.0f); - - const ck_tile::index_t num_rows = seqlen * 2; - const ck_tile::index_t num_cols = rotary_dim / 2; - - using std::begin, std::end; - - ck_tile::HostTensor angle({num_rows, num_cols}); - std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); - - ck_tile::HostTensor cos({num_rows, num_cols}); - std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { - return ck_tile::type_convert(std::cos(origin_value)); - }); - - ck_tile::HostTensor sin({num_rows, num_cols}); - std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { - return ck_tile::type_convert(std::sin(origin_value)); - }); - - return std::make_tuple(cos, sin); -} - -template -std::tuple, ck_tile::HostTensor> -slice_rotary_cos_sin(const ck_tile::HostTensor& cos, - const ck_tile::HostTensor& sin, - ck_tile::index_t seqlen_offset, - ck_tile::index_t seqlen) -{ - assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); - assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); - - assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); - - const ck_tile::index_t num_rows = seqlen; - const ck_tile::index_t num_cols = cos.get_length(1); - - ck_tile::HostTensor cos_pt({num_rows, num_cols}); - cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); - - ck_tile::HostTensor sin_pt({num_rows, num_cols}); - sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); - - return std::make_tuple(cos_pt, sin_pt); -} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 17298aae9485..f6f2240d4f09 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -147,7 +147,7 @@ std::tuple mha_varlen_bwd_aot( const at::Tensor& philox_seed, const at::Tensor& philox_offset); -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) // CK implementation TORCH_API std::tuple< @@ -295,7 +295,7 @@ mha_fwd( const float softcap, const bool return_softmax, std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { const int non_null_window_left = window_size_left.value_or(-1); @@ -368,7 +368,7 @@ mha_varlen_fwd( const float softcap, const bool return_softmax, std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional dummy_attn_bias = std::nullopt; @@ -441,9 +441,10 @@ inline std::tuple mha_bwd( const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { + +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) std::optional non_null_dbias = std::nullopt; const int non_null_window_left = window_size_left.value_or(-1); const int non_null_window_right = window_size_right.value_or(-1); @@ -474,10 +475,8 @@ inline std::tuple mha_bwd( philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); -#else - TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); -#endif } +#endif return mha_bwd_aot( dout, q, @@ -530,7 +529,7 @@ inline std::tuple mha_varlen_bwd const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional non_null_dbias = std::nullopt; diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 7ad7a18e9c66..60dd52d1dffc 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,7 +1,8 @@ +#include + #include #include #include -#include #include @@ -9,7 +10,7 @@ // numbers of threads set and also whether the scheduler // will throw an exception when multiple threads call // their first parallel construct. -void test(int given_num_threads) { +static void test(int given_num_threads) { auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat)); ASSERT_TRUE(given_num_threads >= 0); ASSERT_EQ(at::get_num_threads(), given_num_threads); @@ -19,7 +20,7 @@ void test(int given_num_threads) { } } -int main() { +TEST(ThreadInitTest, ThreadInit) { at::init_num_threads(); at::set_num_threads(4); @@ -32,13 +33,11 @@ int main() { #if !AT_PARALLEL_NATIVE at::set_num_threads(5); - ASSERT_TRUE(at::get_num_threads() == 5); + ASSERT_EQ(at::get_num_threads(), 5); #endif // test inter-op settings at::set_num_interop_threads(5); ASSERT_EQ(at::get_num_interop_threads(), 5); ASSERT_ANY_THROW(at::set_num_interop_threads(6)); - - return 0; } diff --git a/benchmarks/data/dataloader_benchmark.py b/benchmarks/data/dataloader_benchmark.py new file mode 100644 index 000000000000..7d1dd3afc7e9 --- /dev/null +++ b/benchmarks/data/dataloader_benchmark.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +""" +Benchmark script for PyTorch DataLoader with different worker methods. + +This script measures: +1. Dataloader initialization time +2. Dataloading speed (time per batch) +3. CPU memory utilization + +Usage: + python dataloader_benchmark.py --data_path /path/to/dataset --batch_size 32 --num_workers 4 +""" + +import argparse +import copy +import gc +import time + +import psutil +import torchvision +import torchvision.transforms as transforms +from torchvision.models import resnet18 + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.utils.data.dataset import ConcatDataset + + +def get_memory_usage(): + """ + Get current memory usage in MB. This includes all child processes. + + Returns: + Total memory usage in MB + """ + process = psutil.Process() + + main_memory = process.memory_full_info().pss + + # Add memory usage of all child processes + for child in process.children(recursive=True): + try: + child_mem = child.memory_full_info().pss + main_memory += child_mem + except (psutil.NoSuchProcess, psutil.AccessDenied, AttributeError): + # Process might have terminated or doesn't support PSS, fall back to USS + print(f"Failed to get PSS for {child}, falling back to USS") + child_mem = child.memory_info().uss + main_memory += child_mem + + return main_memory / (1024 * 1024) + + +def print_detailed_memory(): + """Print detailed memory information.""" + process = psutil.Process() + print("\nDetailed memory information:") + try: + print( + f" USS (Unique Set Size): {process.memory_full_info().uss / (1024 * 1024):.2f} MB" + ) + print( + f" PSS (Proportional Set Size): {process.memory_full_info().pss / (1024 * 1024):.2f} MB" + ) + print( + f" RSS (Resident Set Size): {process.memory_info().rss / (1024 * 1024):.2f} MB" + ) + except Exception: + print(" Detailed memory info not available") + + +def create_model(): + """Create a simple model for benchmarking.""" + model = resnet18() + return model + + +def benchmark_dataloader( + dataset, + batch_size, + num_workers, + num_epochs=1, + max_batches=10, + multiprocessing_context=None, + logging_freq=10, +): + """Benchmark a dataloader with specific configuration.""" + print("\n--- Benchmarking DataLoader ---") + + # Clear memory before starting + gc.collect() + torch.cuda.empty_cache() + + # Create model + model = create_model() + + # Measure memory before dataloader creation + memory_before = get_memory_usage() + print(f"Memory before DataLoader creation: {memory_before:.2f} MB") + print_detailed_memory() + + # Measure dataloader initialization time + start = time.perf_counter() + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + prefetch_factor=2 if num_workers > 0 else None, + multiprocessing_context=multiprocessing_context, + ) + it = iter(dataloader) + dataloader_init_time = time.perf_counter() - start + + # Measure memory after dataloader creation + memory_after = get_memory_usage() + print(f"Memory after DataLoader creation: {memory_after:.2f} MB") + print(f"Memory increase: {memory_after - memory_before:.2f} MB") + + # Create model and optimizer + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + # Benchmark dataloading speed + model.train() + total_batches = 0 + total_samples = 0 + total_time = 0 + total_data_load_time = 0 + + # Measure peak memory during training + peak_memory = memory_after + + print( + f"\nStarting training loop with {num_epochs} epochs (max {max_batches} batches per epoch)" + ) + + for epoch in range(num_epochs): + while total_batches < max_batches: + batch_start = time.perf_counter() + + try: + inputs, labels = next(it) + except StopIteration: + break + + # Move data to device + inputs = inputs.to(device) + labels = labels.to(device) + + # Capture data fetch time (including sending to device) + data_load_time = time.perf_counter() - batch_start + + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Capture batch time + batch_time = time.perf_counter() - batch_start + + total_batches += 1 + total_samples += inputs.size(0) + total_data_load_time += data_load_time + total_time += batch_time + + # Update peak memory and log memory usage periodically + if total_batches % 5 == 0: + # Force garbage collection before measuring memory + gc.collect() + current_memory = get_memory_usage() + + if current_memory > peak_memory: + peak_memory = current_memory + + if total_batches % logging_freq == 0: + print( + f"Epoch {epoch + 1}, Batch {total_batches}, " + f"Time: {batch_time:.4f}s, " + f"Memory: {current_memory:.2f} MB" + ) + + # Calculate statistics + avg_data_load_time = ( + total_data_load_time / total_batches if total_batches > 0 else 0 + ) + avg_batch_time = total_time / total_batches if total_batches > 0 else 0 + samples_per_second = total_samples / total_time if total_time > 0 else 0 + + results = { + "dataloader_init_time": dataloader_init_time, + "num_workers": num_workers, + "batch_size": batch_size, + "total_batches": total_batches, + "avg_batch_time": avg_batch_time, + "avg_data_load_time": avg_data_load_time, + "samples_per_second": samples_per_second, + "peak_memory_mb": peak_memory, + "memory_increase_mb": peak_memory - memory_before, + } + + print("\nResults:") + print(f" DataLoader init time: {dataloader_init_time:.4f} seconds") + print(f" Average data loading time: {avg_data_load_time:.4f} seconds") + print(f" Average batch time: {avg_batch_time:.4f} seconds") + print(f" Samples per second: {samples_per_second:.2f}") + print(f" Peak memory usage: {peak_memory:.2f} MB") + print(f" Memory increase: {peak_memory - memory_before:.2f} MB") + + # Clean up + del model, optimizer + del dataloader + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark PyTorch DataLoader with different worker methods" + ) + parser.add_argument("--data_path", required=True, help="Path to dataset") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers") + parser.add_argument( + "--max_batches", + type=int, + default=100, + help="Maximum number of batches per epoch", + ) + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument( + "--multiprocessing_context", + choices=["fork", "spawn", "forkserver"], + default="forkserver", + help="Multiprocessing context to use (fork, spawn, forkserver)", + ) + parser.add_argument( + "--dataset_copies", + type=int, + default=1, + help="Number of copies of the dataset to concatenate (for testing memory usage)", + ) + parser.add_argument( + "--logging_freq", + type=int, + default=10, + help="Frequency of logging memory usage during training", + ) + args = parser.parse_args() + + # Print system info + print("System Information:") + # The following are handy for debugging if building from source worked correctly + print(f" PyTorch version: {torch.__version__}") + print(f" PyTorch location: {torch.__file__}") + print(f" Torchvision version: {torchvision.__version__}") + print(f" Torchvision location: {torchvision.__file__}") + print(f" CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f" CUDA device: {torch.cuda.get_device_name(0)}") + print(f" CPU count: {psutil.cpu_count(logical=True)}") + print(f" Physical CPU cores: {psutil.cpu_count(logical=False)}") + print(f" Total system memory: {psutil.virtual_memory().total / (1024**3):.2f} GB") + + # Define transforms + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + # Load dataset + print(f"\nLoading dataset from {args.data_path} ({args.dataset_copies} copies)") + + # Try to load as ImageFolder + datasets = [] + for _ in range(args.dataset_copies): + base_dataset = torchvision.datasets.ImageFolder( + args.data_path, transform=transform + ) + datasets.append(copy.deepcopy(base_dataset)) + del base_dataset + dataset = ConcatDataset(datasets) + + print(f"Dataset size: {len(dataset)}") + + # Run benchmark with specified worker method + benchmark_dataloader( + dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + multiprocessing_context=args.multiprocessing_context, + num_epochs=args.num_epochs, + max_batches=args.max_batches, + logging_freq=args.logging_freq, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/benchmarks.py b/benchmarks/dynamo/benchmarks.py index 55b03429bcc5..25c1e8203a0a 100755 --- a/benchmarks/dynamo/benchmarks.py +++ b/benchmarks/dynamo/benchmarks.py @@ -5,6 +5,12 @@ import sys +# Run only this selected group of models, leave this empty to run everything +TORCHBENCH_ONLY_MODELS = [ + m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip() +] + + # Note - hf and timm have their own version of this, torchbench does not # TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this... def model_names(filename: str) -> set[str]: @@ -17,6 +23,8 @@ def model_names(filename: str) -> set[str]: if len(line_parts) == 1: line_parts = line.split(",") model_name = line_parts[0] + if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS: + continue names.add(model_name) return names diff --git a/benchmarks/dynamo/check_accuracy.py b/benchmarks/dynamo/check_accuracy.py index 7fa24ae7346b..5cd714fe02e9 100644 --- a/benchmarks/dynamo/check_accuracy.py +++ b/benchmarks/dynamo/check_accuracy.py @@ -14,6 +14,7 @@ "detectron2_maskrcnn_r_101_c4", "timm_efficientnet", # see https://github.com/pytorch/pytorch/issues/148699 "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 + "moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291 } diff --git a/benchmarks/dynamo/check_graph_breaks.py b/benchmarks/dynamo/check_graph_breaks.py index 173f11acb132..57814dacd00b 100644 --- a/benchmarks/dynamo/check_graph_breaks.py +++ b/benchmarks/dynamo/check_graph_breaks.py @@ -13,6 +13,7 @@ "gluon_inception_v3", "detectron2_maskrcnn_r_101_c4", "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 + "detectron2_fcos_r_50_fpn", } diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index af605accecf6..01762c5f5f29 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 33ede2b914b4..54b7d63f3a4b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv index 1cafcbe55675..ce334e22c698 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv @@ -42,14 +42,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -66,7 +58,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -154,10 +146,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv index 1cafcbe55675..ce334e22c698 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv @@ -42,14 +42,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -66,7 +58,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -154,10 +146,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index faafea393ede..9620a79f91a9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index a2b7c1a7b15c..aec659fdcd65 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index 697fe04cd91a..4f2eec149352 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 7f11e1398027..f9874a7a4b90 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index cb8cead2ba03..81ed3080dd3e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index 6f9e9e0ed5a7..c8db4d582320 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -122,7 +122,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -142,7 +142,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv index a0bbb3b62ecc..f4c9ffddd997 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -346,7 +346,7 @@ vgg16,pass,0 -vision_maskrcnn,fail_accuracy,30 +vision_maskrcnn,fail_accuracy,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 7f11e1398027..f9874a7a4b90 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 05eb7e3546ee..188f3dd00cac 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index af605accecf6..01762c5f5f29 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 44983e8ecc21..0985e42fc5cb 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 9a9a68629f87..fbd169539ab7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 33ede2b914b4..54b7d63f3a4b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv index 9fdb41506e3b..08061de428d7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv index b3a3265baa16..6f316b219bb9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv @@ -166,7 +166,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -181,7 +181,7 @@ hf_T5_base,pass,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv index d2300bdac05b..48d0b111788f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -114,7 +114,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv index 1cafcbe55675..ce334e22c698 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv @@ -42,14 +42,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -66,7 +58,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -154,10 +146,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv index 9fdb41506e3b..08061de428d7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv index 624f29562478..4b5138ce9c36 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv @@ -166,7 +166,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -181,7 +181,7 @@ hf_T5_base,pass,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv index 1605a26b7ce5..643a02fdca8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -114,7 +114,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv index 6776cc5f5d7a..a3fc7cf19237 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -174,7 +174,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv index b43e38b7d822..ced88884720b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv index 9fdb41506e3b..08061de428d7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv index b3a3265baa16..6f316b219bb9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv @@ -166,7 +166,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -181,7 +181,7 @@ hf_T5_base,pass,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv index 754f5f718e43..d1606b622639 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -114,7 +114,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv index fd57a3b4cbf3..0f088e7892d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv index 66e088f33407..f65909f3a24e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv index 3e4e9ee702aa..8ccf95da9659 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -174,7 +174,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv index 86ad955b5a2c..e842ac7cb8e1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index b20d82ba9b24..469ece2958df 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -9,6 +9,7 @@ import csv import dataclasses import functools +import gc import importlib import itertools import json @@ -20,6 +21,7 @@ import signal import subprocess import sys +import tempfile import time import weakref from contextlib import contextmanager @@ -40,6 +42,7 @@ import torch.distributed import torch.multiprocessing as mp from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU +from torch._C._nativert import PyModelRunner from torch._dynamo.profiler import fx_insert_profiling, Profiler from torch._dynamo.testing import ( dummy_fx_compile, @@ -201,7 +204,6 @@ class CI(NamedTuple): "PLBartForCausalLM", "PLBartForConditionalGeneration", "PegasusForCausalLM", - "Speech2Text2ForCausalLM", "TrOCRForCausalLM", "XGLMForCausalLM", # TIMM @@ -1099,6 +1101,8 @@ def maybe_mark_profile(*args, **kwargs): frozen_model_iter_fn = export_aot_inductor( model, example_inputs, args.inductor_compile_mode ) + elif args.export_nativert: + frozen_model_iter_fn = export_nativert(model, example_inputs) else: frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) @@ -1445,6 +1449,38 @@ def get_excess_memory(cls, model) -> float: return cls.cache.get(weakref.ref(model), (None, 0.0))[1] +class NativeRTCache: + cache: dict[weakref.ref, Any] = {} + + @classmethod + def load(cls, model, example_inputs): + from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path + + key = weakref.ref(model) + if key not in cls.cache: + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) + example_outputs = model(*example_args, **example_kwargs) + _register_dataclass_output_as_pytree(example_outputs) + + combined_args = _combine_args(model, example_args, example_kwargs) + dynamic_shapes = _tree_map_with_path( + _produce_dynamic_shapes_for_export, combined_args + ) + + ep = torch.export.export( + model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes + ) + ep = ep.run_decompositions({}) + with tempfile.NamedTemporaryFile(delete=False) as f: + torch.export.pt2_archive._package.package_pt2( + f, exported_programs={"forward": ep} + ) + filename = f.name + cls.cache[key] = PyModelRunner(filename, "forward") + + return cls.cache[key] + + def export(model, example_inputs): from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path @@ -1471,6 +1507,16 @@ def opt_export(_, example_inputs): return opt_export +def export_nativert(model, example_inputs): + optimized = NativeRTCache.load(model, example_inputs) + + def opt_nativert(_, example_inputs, collect_outputs=False): + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) + return optimized.run(*example_args, **example_kwargs) + + return opt_nativert + + def export_aot_inductor(model, example_inputs, mode): optimized = AOTInductorModelCache.load(model, example_inputs, mode) @@ -2227,7 +2273,11 @@ def record_status(accuracy_status, dynamo_start_stats): try: model_copy = self.deepcopy_and_maybe_parallelize(model) self.init_optimizer(name, current_device, model_copy.parameters()) - if self.args.export or self.args.export_aot_inductor: + if ( + self.args.export + or self.args.export_aot_inductor + or self.args.export_nativert + ): # apply export on module directly # no need for n iterations # the logic should be the same to self.model_iter_fn (forward_pass) @@ -2387,6 +2437,7 @@ def run_performance_test_non_alternate( ) def warmup(fn, model, example_inputs, mode, niters=10): + gc.collect() peak_mem = 0 start_stats = get_dynamo_stats() try: @@ -2548,6 +2599,7 @@ def run_performance_test( return experiment(*self.maybe_cast(model, example_inputs)) def warmup(fn, model, example_inputs, mode, niters=5): + gc.collect() peak_mem = 0 start_stats = get_dynamo_stats() try: @@ -2621,7 +2673,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): niters=1, ) - if self.args.export_aot_inductor: + if self.args.export_aot_inductor or self.args.export_nativert: optimized_model_iter_fn = optimize_ctx else: optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) @@ -3374,6 +3426,11 @@ def get_example_inputs(self): action="store_true", help="Measure pass rate with Export+AOTInductor", ) + group.add_argument( + "--export-nativert", + action="store_true", + help="Measure pass rate with Export+NativeRT", + ) group.add_argument( "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch" ) @@ -3815,6 +3872,10 @@ def run(runner, args, original_dir=None): optimize_ctx = export experiment = speedup_experiment output_filename = "export.csv" + elif args.export_nativert: + optimize_ctx = export_nativert + experiment = speedup_experiment + output_filename = "export_nativert.csv" elif args.xla: (dev,) = args.devices os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev] diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index e5e9a57f5382..aa81832a8831 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -106,6 +106,11 @@ def process_hf_reformer_output(out): # on A100 GPUs - 40 GB. BATCH_SIZE_KNOWN_MODELS = {} +# Run only this selected group of models, leave this empty to run everything +TORCHBENCH_ONLY_MODELS = [ + m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip() +] + # TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py # Get the list of models and their batch sizes @@ -116,6 +121,8 @@ def process_hf_reformer_output(out): lines = [line.rstrip() for line in lines] for line in lines: model_name, batch_size = line.split(",") + if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS: + continue batch_size = int(batch_size) BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size assert len(BATCH_SIZE_KNOWN_MODELS) @@ -452,6 +459,12 @@ def load_model( else: model.eval() + # Turning off kv cache for torchbench models. This is not the right + # thing to do, but the pt2 dashboard is outdated. Real transformers + # benchmarks will be added soon using a different infra. + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + self.validate_model(model, example_inputs) return device, model_name, model, example_inputs, batch_size diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index f0ee57a58965..564077611709 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -31,8 +31,6 @@ batch_size: BlenderbotSmallForCausalLM: 4 BlenderbotSmallForConditionalGeneration: 2 CamemBert: 2 - DebertaForMaskedLM: 4 - DebertaForQuestionAnswering: 2 DebertaV2ForMaskedLM: 4 DebertaV2ForQuestionAnswering: 8 DistilBertForMaskedLM: 2 @@ -63,7 +61,6 @@ batch_size: PegasusForConditionalGeneration: 2 RobertaForCausalLM: 2 RobertaForQuestionAnswering: 2 - Speech2Text2ForCausalLM: 4 T5ForConditionalGeneration: 2 T5Small: 2 TrOCRForCausalLM: 2 diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt index 6e3cf19a783d..12ceedd5c4cc 100644 --- a/benchmarks/dynamo/huggingface_models_list.txt +++ b/benchmarks/dynamo/huggingface_models_list.txt @@ -10,8 +10,6 @@ BlenderbotForConditionalGeneration,16 BlenderbotSmallForCausalLM,256 BlenderbotSmallForConditionalGeneration,128 CamemBert,32 -DebertaForMaskedLM,32 -DebertaForQuestionAnswering,32 DebertaV2ForMaskedLM,8 DebertaV2ForQuestionAnswering,8 DistilBertForMaskedLM,256 @@ -42,7 +40,6 @@ PegasusForCausalLM,128 PegasusForConditionalGeneration,64 RobertaForCausalLM,32 RobertaForQuestionAnswering,32 -Speech2Text2ForCausalLM,1024 T5ForConditionalGeneration,8 T5Small,8 TrOCRForCausalLM,64 diff --git a/benchmarks/dynamo/huggingface_models_list_cpu.txt b/benchmarks/dynamo/huggingface_models_list_cpu.txt index cabd79ac830f..4078368a69c4 100644 --- a/benchmarks/dynamo/huggingface_models_list_cpu.txt +++ b/benchmarks/dynamo/huggingface_models_list_cpu.txt @@ -10,8 +10,6 @@ BlenderbotForCausalLM,32 BlenderbotSmallForCausalLM,64 BlenderbotSmallForConditionalGeneration,64 CamemBert,16 -DebertaForMaskedLM,32 -DebertaForQuestionAnswering,8 DebertaV2ForMaskedLM,16 DebertaV2ForQuestionAnswering,2 DistilBertForMaskedLM,128 @@ -38,7 +36,6 @@ PLBartForCausalLM,8 PLBartForConditionalGeneration,4 RobertaForCausalLM,16 RobertaForQuestionAnswering,16 -Speech2Text2ForCausalLM,32 T5ForConditionalGeneration,4 T5Small,1 TrOCRForCausalLM,32 diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index a0ef356d29df..debddc5c7fa3 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1 @@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1 -basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1 +basic_NestedModule_eager,compile_time_instruction_count,9199000000,0.1 diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index ecedaf681e41..b63c41947b9a 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -39,13 +39,20 @@ def pip_install(package): from timm.models import create_model TIMM_MODELS = {} -filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") +# Run only this selected group of models, leave this empty to run everything +TORCHBENCH_ONLY_MODELS = [ + m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip() +] + +filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") with open(filename) as fh: lines = fh.readlines() lines = [line.rstrip() for line in lines] for line in lines: model_name, batch_size = line.split(" ") + if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS: + continue TIMM_MODELS[model_name] = int(batch_size) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index c2568aa1daa1..1f10ecc661d8 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -382,6 +382,22 @@ def load_model( if self.args.trace_on_xla: # work around for: https://github.com/pytorch/xla/issues/4174 import torch_xla # noqa: F401 + + # Turning off kv cache for torchbench models. This is not the right + # thing to do, but the torchbench models are way outdated, and since we + # are using torchbench pt2 dashboard to track regressions (rather than + # improving performance), we are just setting the kv cache to false. + # Real transformers benchmarks will be added soon using a different + # infra. + if ( + model_name.startswith("hf") + and hasattr(model, "config") + and hasattr(model.config, "use_cache") + ): + model.config.use_cache = False + if model_name == "hf_T5_generate": + model.model.config.use_cache = False + self.validate_model(model, example_inputs) return device, benchmark.name, model, example_inputs, batch_size diff --git a/build_variables.bzl b/build_variables.bzl index 1dda77b63750..a226249db708 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -599,6 +599,7 @@ libtorch_nativert_sources = [ "torch/nativert/graph/GraphSignature.cpp", "torch/nativert/graph/Serialization.cpp", "torch/nativert/graph/TensorMeta.cpp", + "torch/nativert/graph/GraphUtils.cpp", "torch/nativert/executor/DelegateExecutor.cpp", "torch/nativert/executor/Placement.cpp", "torch/nativert/executor/ExecutionPlanner.cpp", diff --git a/c10/core/AllocatorConfig.cpp b/c10/core/AllocatorConfig.cpp index 5cb0ce273838..e154338d501b 100644 --- a/c10/core/AllocatorConfig.cpp +++ b/c10/core/AllocatorConfig.cpp @@ -224,7 +224,7 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) { // check if the key is unrecognized. if (device_config_parser_hook_) { TORCH_CHECK( - keys_.find(key) != keys_.end(), + getKeys().find(key) != getKeys().end(), "Unrecognized key '", key, "' in Accelerator allocator config."); diff --git a/c10/core/AllocatorConfig.h b/c10/core/AllocatorConfig.h index 14d94d242f59..efde5e3a8ff9 100644 --- a/c10/core/AllocatorConfig.h +++ b/c10/core/AllocatorConfig.h @@ -220,11 +220,24 @@ class C10_API AcceleratorAllocatorConfig { return instance().last_allocator_settings_; } + // Use `Construct On First Use Idiom` to avoid `Static Initialization Order` + // issue. + static std::unordered_set& getMutableKeys() { + static std::unordered_set keys{ + "max_split_size_mb", + "max_non_split_rounding_mb", + "garbage_collection_threshold", + "roundup_power2_divisions", + "expandable_segments", + "pinned_use_background_threads"}; + return keys; + } + // Returns the set of valid keys for the allocator configuration. // This set is used to validate the presence and correctness of keys in // device-specific configuration parsers. static const std::unordered_set& getKeys() { - return keys_; + return getMutableKeys(); } // Registers a device-specific configuration parser hook and its key. This @@ -238,9 +251,10 @@ class C10_API AcceleratorAllocatorConfig { std::function&& hook, const std::unordered_set& keys) { device_config_parser_hook_ = std::move(hook); + auto& mutable_keys = getMutableKeys(); for (auto& key : keys) { TORCH_CHECK( - keys_.insert(key).second, + mutable_keys.insert(key).second, "Duplicated key '", key, "' found in device-specific configuration parser hook registration"); @@ -326,17 +340,6 @@ class C10_API AcceleratorAllocatorConfig { // their own environment configuration extensions. inline static std::function device_config_parser_hook_{nullptr}; - - // A set of valid configuration keys, including both common and - // device-specific options. This set is used to validate the presence and - // correctness of keys during parsing. - inline static std::unordered_set keys_{ - "max_split_size_mb", - "max_non_split_rounding_mb", - "garbage_collection_threshold", - "roundup_power2_divisions", - "expandable_segments", - "pinned_use_background_threads"}; }; C10_API inline void setAllocatorSettings(const std::string& env) { diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 67c9276313bb..0497d72b9570 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -237,8 +237,6 @@ inline DeviceType backendToDeviceType(Backend b) { return DeviceType::CPU; case Backend::CUDA: case Backend::SparseCUDA: - case Backend::SparseMPS: - case Backend::SparseCsrMPS: case Backend::QuantizedCUDA: case Backend::SparseCsrCUDA: return DeviceType::CUDA; @@ -276,6 +274,8 @@ inline DeviceType backendToDeviceType(Backend b) { case Backend::Meta: return DeviceType::Meta; case Backend::MPS: + case Backend::SparseMPS: + case Backend::SparseCsrMPS: return DeviceType::MPS; case Backend::HPU: return DeviceType::HPU; diff --git a/c10/core/CachingDeviceAllocator.cpp b/c10/core/CachingDeviceAllocator.cpp new file mode 100644 index 000000000000..582efd59cf1b --- /dev/null +++ b/c10/core/CachingDeviceAllocator.cpp @@ -0,0 +1,10 @@ +#include + +namespace c10 { + +// Ensures proper DLL export of this pure virtual base class on Windows, +// since it's mainly used in other DLLs outside c10.dll. +DeviceAllocator::DeviceAllocator() = default; +DeviceAllocator::~DeviceAllocator() = default; + +} // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index b23490de693a..0bec03ae417f 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace c10::CachingDeviceAllocator { @@ -59,3 +60,55 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator + +namespace c10 { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by Graph mode capture_begin. +// second is set if the instance is created by Graph mode graph_pool_handle. +using MempoolId_t = std::pair; + +struct C10_API DeviceAllocator : public c10::Allocator { + DeviceAllocator(); + ~DeviceAllocator() override; + + // Returns true if the allocator has been properly initialized and is ready + // for use + virtual bool initialized() = 0; + + // Releases all cached device memory from the specified memory pool back to + // the system + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; + + // Associates a memory allocation with a stream to establish dependency + // tracking. Prevents memory reuse until all operations on the specified + // stream complete + virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; + + // Retrieves comprehensive memory statistics for the specified device, + // including allocation patterns, usage metrics + virtual CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + + // Resets cumulative allocation statistics for the specified device to zero + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + + // Resets peak memory usage statistics for the specified device + virtual void resetPeakStats(c10::DeviceIndex device) = 0; +}; + +// This function is used to get the DeviceAllocator for a specific device type +// and keep backward compatibility with c10::GetAllocator. +C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { + TORCH_CHECK( + t != DeviceType::CPU, + "getDeviceAllocator is not supported for CPU device type."); + auto* allocator = c10::GetAllocator(t); + auto* device_allocator = dynamic_cast(allocator); + TORCH_INTERNAL_ASSERT( + device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); + return device_allocator; +} + +} // namespace c10 diff --git a/c10/core/Layout.h b/c10/core/Layout.h index 0daa129bb5a4..0d09e0ed46f4 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -33,7 +33,6 @@ inline Layout layout_from_backend(Backend backend) { case Backend::SparseCPU: case Backend::SparseCUDA: case Backend::SparseMPS: - case Backend::SparseCsrMPS: case Backend::SparseHIP: case Backend::SparseVE: case Backend::SparseXPU: @@ -43,6 +42,7 @@ inline Layout layout_from_backend(Backend backend) { return Layout::Mkldnn; case Backend::SparseCsrCPU: case Backend::SparseCsrCUDA: + case Backend::SparseCsrMPS: case Backend::SparseCsrHIP: case Backend::SparseCsrVE: case Backend::SparseCsrXPU: diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 3b483c86bc88..646a1dde3994 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -191,11 +191,17 @@ class C10_API Scalar { isIntegral() const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; } + bool isIntegral(bool includeBool) const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag || (includeBool && isBoolean()); } + // See Note [Meaning of HAS_u] + bool isUnsigned() const { + return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0); + } + bool isComplex() const { return Tag::HAS_z == tag; } diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 3d8a2b0074e9..4a15eb23ac63 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -19,25 +19,16 @@ #include #include -#include #include #include #include #include -namespace c10 { - -// dummy struct for uint1 to uint7, actual functionality -// of these dtypes will be implemented in python with Tensor subclass -template -struct dummy_uint1_7_t {}; +#include -// dummy struct for int1 to int7, actual functionality -// of these dtypes will be implemented in python with Tensor subclass -template -struct dummy_int1_7_t {}; +namespace c10 { -// For the macros below: +// [dtype Macros note] For the macros below: // // For users: If you want to macro some code for all non-QInt scalar types // (i.e. types with complete information, you probably want one of the @@ -57,56 +48,6 @@ struct dummy_int1_7_t {}; // some old PRs where we added new dtypes (check history of this file) can // help give you an idea where to start. -// NB: Order matters for this macro; it is relied upon in -// _promoteTypesLookup and the serialization format. -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(c10::complex, ComplexHalf) /* 8 */ \ - _(c10::complex, ComplexFloat) /* 9 */ \ - _(c10::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(c10::qint8, QInt8) /* 12 */ \ - _(c10::quint8, QUInt8) /* 13 */ \ - _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ - _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ \ - _(c10::bits1x8, Bits1x8) /* 18 */ \ - _(c10::bits2x4, Bits2x4) /* 19 */ \ - _(c10::bits4x2, Bits4x2) /* 20 */ \ - _(c10::bits8, Bits8) /* 21 */ \ - _(c10::bits16, Bits16) /* 22 */ \ - _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ - _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ - _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ - _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ - _(uint16_t, UInt16) /* 27 */ \ - _(uint32_t, UInt32) /* 28 */ \ - _(uint64_t, UInt64) /* 29 */ \ - _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ - _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ - _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ - _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ - _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ - _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ - _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ - _(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ - _(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ - _(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ - _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ - _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ - _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ - _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ - _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ - _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ - // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() // doesn't work for all the conversions you need... @@ -152,17 +93,6 @@ struct dummy_int1_7_t {}; _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ _(at::Float8_e8m0fnu, Float8_e8m0fnu) -enum class ScalarType : int8_t { -#define DEFINE_ST_ENUM_VAL_(_1, n) n, - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) -#undef DEFINE_ENUM_ST_ENUM_VAL_ - Undefined, - NumOptions -}; - -constexpr uint16_t NumScalarTypes = - static_cast(ScalarType::NumOptions); - namespace impl { // These are used to map ScalarTypes to C++ types. diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 381bc65b27fb..fcd7b4b4b31d 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2090,6 +2090,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { constexpr auto sparse_backends = DispatchKeySet( {BackendComponent::CPUBit, BackendComponent::CUDABit, + BackendComponent::MPSBit, BackendComponent::HIPBit, BackendComponent::XPUBit}); constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 8c4b613473c0..21d72e4b6831 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -110,8 +110,22 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_use_async_allocator; } + // Use `Construct On First Use Idiom` to avoid `Static Initialization Order` + // issue. static const std::unordered_set& getKeys() { - return keys_; + static std::unordered_set keys{ + "backend", + // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues + // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) + "release_lock_on_cud" + "amalloc", + "pinned_use_cud" + "a_host_register", + // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) + "release_lock_on_hipmalloc", + "pinned_use_hip_host_register", + "pinned_num_register_threads"}; + return keys; } static CUDAAllocatorConfig& instance() { @@ -163,18 +177,6 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_pinned_use_cuda_host_register{false}; std::atomic m_use_async_allocator{false}; std::atomic m_is_allocator_loaded{false}; - inline static std::unordered_set keys_{ - "backend", - // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues - // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_cud" - "amalloc", - "pinned_use_cud" - "a_host_register", - // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_hipmalloc", - "pinned_use_hip_host_register", - "pinned_num_register_threads"}; }; // Keep this for backwards compatibility diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index b0b1be8937a9..59b62dcac07f 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -368,14 +368,12 @@ struct ExpandableSegment { ExpandableSegment( c10::DeviceIndex device, std::optional stream, - size_t address_space_size, size_t segment_size, std::vector peers) : device_(device), stream_(stream), // 2MB for small pool, 20MB for large pool segment_size_(segment_size), - max_handles_(numSegments(address_space_size)), peers_(std::move(peers)) { cudaDeviceProp prop{}; C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_)); @@ -544,11 +542,7 @@ struct ExpandableSegment { ShareHeader header{}; buf.read((char*)&header, sizeof(ShareHeader)); auto segment = std::make_unique( - device, - std::nullopt, - header.num_handles * header.segment_size, - header.segment_size, - std::move(peers)); + device, std::nullopt, header.segment_size, std::move(peers)); // older build setups (e.g. multiwheels) do not have this syscall, added 2020 // but the kernel on the system might still support it. #ifndef SYS_pidfd_open @@ -746,7 +740,6 @@ struct ExpandableSegment { ExpandableSegment( c10::DeviceIndex device, std::optional stream, - size_t address_space_size, size_t segment_size, std::vector peers) { TORCH_INTERNAL_ASSERT(false, "expandable segment not supported"); @@ -2420,19 +2413,8 @@ class DeviceCachingAllocator { } } auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer; - cudaDeviceProp prop{}; - C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - // we allocate enough address space for 1 1/8 the total memory on the GPU. - // This allows for some cases where we have to unmap pages earlier in the - // segment to put them at the end. - size_t address_space_size = prop.totalGlobalMem + prop.totalGlobalMem / 8; - expandable_segments_.emplace_back(new ExpandableSegment( - device, - stream, - address_space_size, - segment_size, - devices_with_peer_access_)); + device, stream, segment_size, devices_with_peer_access_)); ExpandableSegment* es = expandable_segments_.back(); Block* candidate = new Block(device, stream, es->size(), pool, es->ptr()); @@ -4136,7 +4118,18 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); +// Register this HIP allocator as the CUDA allocator to allow it to work +// with both c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) +// APIs. We don't perform this masquerading inside +// HIPAllocatorMasqueradingAsCUDA because it needs to happen during static +// initialization, and doing so there may introduce static initialization +// order (SIOF) issues. +#define HIP_MASQUERADING_AS_CUDA \ + "cud" \ + "a" + at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0); allocator.store(r); +#undef HIP_MASQUERADING_AS_CUDA } }; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 956411fe2282..75a2d4c8e481 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -202,25 +202,24 @@ struct ShareableHandle { std::string handle; }; -class CUDAAllocator : public Allocator { +class CUDAAllocator : public DeviceAllocator { public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; - virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; - virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; - virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) = 0; - virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; - virtual void resetPeakStats(c10::DeviceIndex device) = 0; + // Keep for BC only + virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; + void recordStream(const DataPtr& ptr, c10::Stream stream) override { + CUDAStream cuda_stream = CUDAStream(stream); + recordStream(ptr, cuda_stream); + } virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -525,6 +524,10 @@ inline void enablePeerAccess( namespace c10::cuda { +// Keep BC only +using c10::CaptureId_t; +using c10::MempoolId_t; + // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 0e8cabf61859..683ed9b76845 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -53,13 +53,12 @@ int device_count_impl(bool fail_if_no_driver) { "https://pytorch.org to install a PyTorch version that has been " "compiled with your version of the CUDA driver."); } - } break; + } case cudaErrorInitializationError: TORCH_CHECK( false, "CUDA driver initialization failed, you might not " "have a CUDA gpu."); - break; case cudaErrorUnknown: TORCH_CHECK( false, @@ -67,7 +66,6 @@ int device_count_impl(bool fail_if_no_driver) { "incorrectly set up environment, e.g. changing env " "variable CUDA_VISIBLE_DEVICES after program start. " "Setting the available devices to be zero."); - break; #if C10_ASAN_ENABLED case cudaErrorMemoryAllocation: // In ASAN mode, we know that a cudaErrorMemoryAllocation error will diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index eb29ca8bc9f0..936875fd71d5 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,12 +9,6 @@ namespace c10::cuda { -using CaptureId_t = unsigned long long; - -// first is set if the instance is created by CUDAGraph::capture_begin. -// second is set if the instance is created by at::cuda::graph_pool_handle. -using MempoolId_t = std::pair; - // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 0cde2d9de01c..8eca673cd3a4 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -216,9 +216,6 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) { // Creates the low and high priority stream pools for the specified device // Warning: only call once per device! static void initDeviceStreamState(DeviceIndex device_index) { - // Switches to the requested device so streams are properly associated - // with it. - CUDAGuard device_guard{device_index}; for (const auto i : c10::irange(kStreamsPerPool)) { for (const auto p : c10::irange(max_stream_priorities)) { initSingleStream(p, device_index, i); diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 9800809d1e53..6702cb9b532d 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -53,7 +53,8 @@ #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ - _(cuMulticastCreate, 12030) + _(cuMulticastCreate, 12030) \ + _(cuMulticastUnbind, 12030) #else #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #endif diff --git a/c10/metal/atomic.h b/c10/metal/atomic.h index b2bce2b8b5fd..d0cbc0391698 100644 --- a/c10/metal/atomic.h +++ b/c10/metal/atomic.h @@ -85,7 +85,6 @@ struct AtomicType { } }; -#if __METAL_VERSION__ >= 310 template <> struct AtomicType { using type = ::metal::atomic; @@ -93,7 +92,6 @@ struct AtomicType { atomic_add_helper(data, offset, value); } }; -#endif // Metal supports atomic_store_explicit for bools, but // sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to @@ -126,5 +124,54 @@ struct AtomicType { } }; +// ComplexHalf atomic op +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, half2 value) { + auto ptr = data + offset; + auto old = + ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); + while (!::metal::atomic_compare_exchange_weak_explicit( + ptr, + &old, + as_type(as_type(old) + value), + ::metal::memory_order_relaxed, + ::metal::memory_order_relaxed)) + ; + } +}; + +// There are no atomic 64-bit add in Metal yet, but templates below implements a +// consistent add I.e. if multiple threads are modify the same 64-bit value, +// results stored at the address will eventually be equal to its original value +// plus sum of all operands +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, long value) { + const auto value_bits = as_type(value); + const uint low = static_cast(value_bits); + uint high = static_cast(value_bits >> 32); + auto ptr = data + (offset << 1); + auto old_low = + atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed); + high += (old_low + low < old_low) ? 1 : 0; + atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed); + } +}; + +// ComplexFloat atomic op, which again is not really atomic, but eventually +// consistent +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, float2 value) { + auto ptr = data + (offset << 1); + atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed); + atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed); + } +}; + } // namespace metal } // namespace c10 diff --git a/c10/metal/common.h b/c10/metal/common.h index b953dec90252..e4b4d1a38ca4 100644 --- a/c10/metal/common.h +++ b/c10/metal/common.h @@ -9,7 +9,6 @@ #define C10_METAL_CONSTEXPR constexpr #endif -#if !defined(__METAL__) || __METAL_VERSION__ >= 310 #define C10_METAL_ALL_TYPES_FUNCTOR(_) \ _(Byte, 0) \ _(Char, 1) \ @@ -22,19 +21,6 @@ _(ComplexFloat, 9) \ _(Bool, 11) \ _(BFloat16, 15) -#else -#define C10_METAL_ALL_TYPES_FUNCTOR(_) \ - _(Byte, 0) \ - _(Char, 1) \ - _(Short, 2) \ - _(Int, 3) \ - _(Long, 4) \ - _(Half, 5) \ - _(Float, 6) \ - _(ComplexHalf, 8) \ - _(ComplexFloat, 9) \ - _(Bool, 11) -#endif namespace c10 { namespace metal { diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index cd7de5b54766..9cfe65f6a03a 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -186,10 +186,8 @@ inline T val_at_offs(constant void* ptr, long offs, ScalarType type) { return cast_to(val_at_offs(ptr, offs)); case ScalarType::Half: return cast_to(val_at_offs(ptr, offs)); -#if __METAL_VERSION__ >= 310 case ScalarType::BFloat16: return cast_to(val_at_offs(ptr, offs)); -#endif // Complex case ScalarType::ComplexHalf: return cast_to(val_at_offs(ptr, offs)); diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index 785dc431b57b..2d9782019166 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -15,12 +15,10 @@ struct simd_type { template using simd_type_t = typename simd_type::t; -#if __METAL_VERSION__ >= 310 template <> struct simd_type { using t = float; }; -#endif } // namespace detail template @@ -140,7 +138,7 @@ template < inline ::c10::metal::pair simd_argmin(T val) { const auto rc = simd_min(val); const auto vote = ::metal::simd_ballot(val == rc); - return {rc, ::metal::ctz(static_cast(static_cast(vote)))}; + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; } template < @@ -149,7 +147,7 @@ template < inline ::c10::metal::pair simd_argmin(T val) { const auto rc = simd_min(val); const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val)); - return {rc, ::metal::ctz(static_cast(static_cast(vote)))}; + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; } template < @@ -158,7 +156,7 @@ template < inline ::c10::metal::pair simd_argmax(T val) { const auto rc = simd_max(val); const auto vote = ::metal::simd_ballot(val == rc); - return {rc, ::metal::ctz(static_cast(static_cast(vote)))}; + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; } template < @@ -167,7 +165,7 @@ template < inline ::c10::metal::pair simd_argmax(T val) { const auto rc = simd_max(val); const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val)); - return {rc, ::metal::ctz(static_cast(static_cast(vote)))}; + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; } template @@ -303,30 +301,58 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) { return rc; } -template -int threadgroup_argmax(threadgroup T* data, unsigned size) { - // TODO: This should be moved to the callee +template +IDX_T threadgroup_argmax( + threadgroup ARG_T* arg_data, + threadgroup IDX_T* idx_data, + ARG_T val, + IDX_T idx_val, + unsigned idx, + unsigned size) { + auto rc = simd_argmax(val, idx_val); + if (size <= simdgroup_size) { + return rc.second; + } + if (idx % simdgroup_size == 0) { + arg_data[idx / simdgroup_size] = rc.first; + idx_data[idx / simdgroup_size] = rc.second; + } ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); - int rc = 0; - for (unsigned idx = 1; idx < size; ++idx) { - if (data[idx] > data[rc]) { - rc = idx; + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_argmax(arg_data[idx], idx_data[idx]); + if (idx == 0) { + idx_data[0] = rc1.second; } } - return rc; + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return idx_data[0]; } -template -int threadgroup_argmin(threadgroup T* data, unsigned size) { - // TODO: This should be moved to the callee +template +IDX_T threadgroup_argmin( + threadgroup ARG_T* arg_data, + threadgroup IDX_T* idx_data, + ARG_T val, + IDX_T idx_val, + unsigned idx, + unsigned size) { + auto rc = simd_argmin(val, idx_val); + if (size <= simdgroup_size) { + return rc.second; + } + if (idx % simdgroup_size == 0) { + arg_data[idx / simdgroup_size] = rc.first; + idx_data[idx / simdgroup_size] = rc.second; + } ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); - int rc = 0; - for (unsigned idx = 1; idx < size; ++idx) { - if (data[idx] < data[rc]) { - rc = idx; + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_argmin(arg_data[idx], idx_data[idx]); + if (idx == 0) { + idx_data[0] = rc1.second; } } - return rc; + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return idx_data[0]; } } // namespace metal diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 3e401d36da27..aaa0e1741240 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -24,14 +24,12 @@ struct vectypes { using type2 = half2; }; -#if __METAL_VERSION__ >= 310 template <> struct vectypes { using type4 = bfloat4; using type3 = bfloat3; using type2 = bfloat2; }; -#endif template <> struct vectypes { @@ -79,12 +77,10 @@ struct OpMathType { using type = int; }; -#if __METAL_VERSION__ >= 310 template <> struct OpMathType { using type = float; }; -#endif // Type promotion structure for higher precision accumulation template @@ -98,13 +94,11 @@ struct AccumulationType { using type = float; }; -#if __METAL_VERSION__ >= 310 // Specialization for bfloat - promote to float for accumulation template <> struct AccumulationType { using type = float; }; -#endif } // namespace detail @@ -130,7 +124,6 @@ min(T a, U b) { return ::metal::min(a, static_cast(b)); } -#if __METAL_VERSION__ >= 310 template <> inline bfloat min(bfloat a, bfloat b) { return bfloat( @@ -142,7 +135,6 @@ inline bfloat max(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); } -#endif template using vec2type_t = typename detail::vectypes::type2; diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index 1ed866f78d9a..6d3510cd5be8 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -1,340 +1 @@ -#pragma once - -#include -#include - -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -#if defined(CL_SYCL_LANGUAGE_VERSION) -#include // for SYCL 1.2.1 -#elif defined(SYCL_LANGUAGE_VERSION) -#include // for SYCL 2020 -#endif - -namespace c10 { - -/// Constructors -inline C10_HOST_DEVICE BFloat16::BFloat16(float value) - : -#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 800 - x(__bfloat16_as_ushort(__float2bfloat16(value))) -#elif defined(__SYCL_DEVICE_ONLY__) && \ - defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) - x(c10::bit_cast(sycl::ext::oneapi::bfloat16(value))) -#else - // RNE by default - x(detail::round_to_nearest_even(value)) -#endif -{ -} - -/// Implicit conversions -inline C10_HOST_DEVICE BFloat16::operator float() const { -#if defined(__CUDACC__) && !defined(USE_ROCM) - return __bfloat162float(*reinterpret_cast(&x)); -#elif defined(__SYCL_DEVICE_ONLY__) && \ - defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) - return float(*reinterpret_cast(&x)); -#else - return detail::f32_from_bits(x); -#endif -} - -#if defined(__CUDACC__) && !defined(USE_ROCM) -inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { - x = *reinterpret_cast(&value); -} -inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { - return *reinterpret_cast(&x); -} -#endif - -#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) -inline C10_HOST_DEVICE BFloat16::BFloat16( - const sycl::ext::oneapi::bfloat16& value) { - x = *reinterpret_cast(&value); -} -inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { - return *reinterpret_cast(&x); -} -#endif - -// CUDA intrinsics - -#if defined(__CUDACC__) || defined(__HIPCC__) -inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { -#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __ldg(reinterpret_cast(ptr)); -#else - return *ptr; -#endif -} -#endif - -/// Arithmetic - -inline C10_HOST_DEVICE BFloat16 -operator+(const BFloat16& a, const BFloat16& b) { - return static_cast(a) + static_cast(b); -} - -inline C10_HOST_DEVICE BFloat16 -operator-(const BFloat16& a, const BFloat16& b) { - return static_cast(a) - static_cast(b); -} - -inline C10_HOST_DEVICE BFloat16 -operator*(const BFloat16& a, const BFloat16& b) { - return static_cast(a) * static_cast(b); -} - -inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / static_cast(b); -} - -inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { - return -static_cast(a); -} - -inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) { - a = a + b; - return a; -} - -inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) { - a = a - b; - return a; -} - -inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) { - a = a * b; - return a; -} - -inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { - a = a / b; - return a; -} - -inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { - a.x = a.x | b.x; - return a; -} - -inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { - a.x = a.x ^ b.x; - return a; -} - -inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { - a.x = a.x & b.x; - return a; -} - -/// Arithmetic with floats - -inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { - return a += static_cast(b); -} -inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { - return a -= static_cast(b); -} -inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { - return a *= static_cast(b); -} -inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { - return a /= static_cast(b); -} - -/// Arithmetic with doubles - -inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) { - return a / static_cast(b); -} - -/// Arithmetic with ints - -inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { - return static_cast(a) / b; -} - -//// Arithmetic with int64_t - -inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { - return static_cast(a) / b; -} - -// Overloading < and > operators, because std::max and std::min use them. - -inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { - return float(lhs) > float(rhs); -} - -inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { - return float(lhs) < float(rhs); -} - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_signed = true; - static constexpr bool is_specialized = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = true; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = true; - static constexpr auto has_denorm = numeric_limits::has_denorm; - static constexpr auto has_denorm_loss = - numeric_limits::has_denorm_loss; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 8; - static constexpr int digits10 = 2; - static constexpr int max_digits10 = 4; - static constexpr int radix = 2; - static constexpr int min_exponent = -125; - static constexpr int min_exponent10 = -37; - static constexpr int max_exponent = 128; - static constexpr int max_exponent10 = 38; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = - numeric_limits::tinyness_before; - - static constexpr c10::BFloat16 min() { - return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 lowest() { - return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 max() { - return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 epsilon() { - return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 round_error() { - return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 infinity() { - return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 quiet_NaN() { - return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 signaling_NaN() { - return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 denorm_min() { - return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index 06236df1fc81..6d3510cd5be8 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -1,116 +1 @@ -#pragma once - -// Defines the bloat16 type (brain floating-point). This representation uses -// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. - -#include -#include -#include -#include -#include -#include -#include - -#if defined(__CUDACC__) && !defined(USE_ROCM) -#include -#endif - -#if defined(CL_SYCL_LANGUAGE_VERSION) -#include // for SYCL 1.2.1 -#elif defined(SYCL_LANGUAGE_VERSION) -#include // for SYCL 2020 -#endif - -namespace c10 { - -namespace detail { -inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { - float res = 0; - uint32_t tmp = src; - tmp <<= 16; - -#if defined(USE_ROCM) && defined(__HIPCC__) - float* tempRes; - - // We should be using memcpy in order to respect the strict aliasing rule - // but it fails in the HIP environment. - tempRes = reinterpret_cast(&tmp); - res = *tempRes; -#else - std::memcpy(&res, &tmp, sizeof(tmp)); -#endif - - return res; -} - -inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { - uint32_t res = 0; - -#if defined(USE_ROCM) && defined(__HIPCC__) - // We should be using memcpy in order to respect the strict aliasing rule - // but it fails in the HIP environment. - uint32_t* tempRes = reinterpret_cast(&src); - res = *tempRes; -#else - std::memcpy(&res, &src, sizeof(res)); -#endif - - return res >> 16; -} - -inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { -#if defined(USE_ROCM) && defined(__HIPCC__) - if (src != src) { -#elif defined(_MSC_VER) - if (isnan(src)) { -#else - if (std::isnan(src)) { -#endif - return UINT16_C(0x7FC0); - } else { - const uint32_t U32 = c10::bit_cast(src); - uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); - return static_cast((U32 + rounding_bias) >> 16); - } -} -} // namespace detail - -struct alignas(2) BFloat16 { - uint16_t x; - - // HIP wants __host__ __device__ tag, CUDA does not -#if defined(USE_ROCM) && defined(__HIPCC__) - C10_HOST_DEVICE BFloat16() = default; -#else - BFloat16() = default; -#endif - - struct from_bits_t {}; - static constexpr C10_HOST_DEVICE from_bits_t from_bits() { - return from_bits_t(); - } - - constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) - : x(bits) {} - /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); - inline C10_HOST_DEVICE operator float() const; - -#if defined(__CUDACC__) && !defined(USE_ROCM) - inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); - explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; -#endif - -#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) - inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); - explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; -#endif -}; - -inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep +#include diff --git a/c10/util/Float8_e4m3fn-inl.h b/c10/util/Float8_e4m3fn-inl.h index e2d6a36da179..ef52e38f506d 100644 --- a/c10/util/Float8_e4m3fn-inl.h +++ b/c10/util/Float8_e4m3fn-inl.h @@ -1,274 +1 @@ -#pragma once - -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -namespace c10 { - -/// Constructors - -inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) - : x(detail::fp8e4m3fn_from_fp32_value(value)) {} - -/// Implicit conversions - -inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const { - return detail::fp8e4m3fn_to_fp32_value(x); -} - -/// Special values helper - -inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const { - return (x & 0b01111111) == 0b01111111; -} - -/// Arithmetic - -inline C10_HOST_DEVICE Float8_e4m3fn -operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { - return static_cast(a) + static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fn -operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { - return static_cast(a) - static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fn -operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { - return static_cast(a) * static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fn operator/( - const Float8_e4m3fn& a, - const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) { - return -static_cast(a); -} - -inline C10_HOST_DEVICE Float8_e4m3fn& operator+=( - Float8_e4m3fn& a, - const Float8_e4m3fn& b) { - a = a + b; - return a; -} - -inline C10_HOST_DEVICE Float8_e4m3fn& operator-=( - Float8_e4m3fn& a, - const Float8_e4m3fn& b) { - a = a - b; - return a; -} - -inline C10_HOST_DEVICE Float8_e4m3fn& operator*=( - Float8_e4m3fn& a, - const Float8_e4m3fn& b) { - a = a * b; - return a; -} - -inline C10_HOST_DEVICE Float8_e4m3fn& operator/=( - Float8_e4m3fn& a, - const Float8_e4m3fn& b) { - a = a / b; - return a; -} - -/// Arithmetic with floats - -inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) { - return a += static_cast(b); -} -inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) { - return a -= static_cast(b); -} -inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) { - return a *= static_cast(b); -} -inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) { - return a /= static_cast(b); -} - -/// Arithmetic with doubles - -inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -/// Arithmetic with ints - -inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) { - return static_cast(a) / b; -} - -//// Arithmetic with int64_t - -inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) { - return static_cast(a) / b; -} - -/// NOTE: we do not define comparisons directly and instead rely on the implicit -/// conversion from c10::Float8_e4m3fn to float. - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_specialized = true; - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = false; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = false; - static constexpr auto has_denorm = true; - static constexpr auto has_denorm_loss = true; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 4; - static constexpr int digits10 = 0; - static constexpr int max_digits10 = 3; - static constexpr int radix = 2; - static constexpr int min_exponent = -5; - static constexpr int min_exponent10 = -1; - static constexpr int max_exponent = 8; - static constexpr int max_exponent10 = 2; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = false; - - static constexpr c10::Float8_e4m3fn min() { - return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits()); - } - static constexpr c10::Float8_e4m3fn lowest() { - return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits()); - } - static constexpr c10::Float8_e4m3fn max() { - return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits()); - } - static constexpr c10::Float8_e4m3fn epsilon() { - return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits()); - } - static constexpr c10::Float8_e4m3fn round_error() { - return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits()); - } - static constexpr c10::Float8_e4m3fn quiet_NaN() { - return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits()); - } - static constexpr c10::Float8_e4m3fn denorm_min() { - return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/Float8_e4m3fn.h b/c10/util/Float8_e4m3fn.h index 529a04f24d56..ef52e38f506d 100644 --- a/c10/util/Float8_e4m3fn.h +++ b/c10/util/Float8_e4m3fn.h @@ -1,238 +1 @@ -#pragma once - -/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions -/// to standard C types and basic arithmetic operations. Note that arithmetic -/// operations are implemented by converting to floating point and -/// performing the operation in float32. -/// Binary configuration: -/// s eeee mmm -/// 1 sign bit -/// 4 exponent bits -/// 3 mantissa bits -/// bias = 7 -/// -/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf -/// and inspired by Half implementation from pytorch/c10/util/Half.h - -#include -#include - -#if defined(__cplusplus) -#include -#include -#elif !defined(__OPENCL_VERSION__) -#include -#include -#endif - -#ifdef _MSC_VER -#include -#endif - -#include -#include - -namespace c10 { - -namespace detail { - -/* - * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit - * representation, to a 32-bit floating-point number in IEEE single-precision - * format, in bit representation. - * - * @note The implementation doesn't use any floating-point operations. - */ -inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { - /* - * Extend the fp8 E4M3FN number to 32 bits and shift to the - * upper part of the 32-bit word: - * +---+----+---+-----------------------------+ - * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| - * +---+----+---+-----------------------------+ - * Bits 31 27-30 24-26 0-23 - * - * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - * - zero bits. - */ - const uint32_t w = (uint32_t)input << 24; - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = w & UINT32_C(0x80000000); - /* - * Extract mantissa and biased exponent of the input number into the bits 0-30 - * of the 32-bit word: - * - * +---+----+---+-----------------------------+ - * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| - * +---+----+---+-----------------------------+ - * Bits 31 27-30 24-26 0-23 - */ - const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); - /* - * Renorm shift is the number of bits to shift mantissa left to make the - * half-precision number normalized. If the initial number is normalized, some - * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case - * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note - * that if we shift denormalized nonsign by renorm_shift, the unit bit of - * mantissa will shift into exponent, turning the biased exponent into 1, and - * making mantissa normalized (i.e. without leading 1). - */ -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) - uint32_t renorm_shift = __clz(nonsign); -#elif defined(__SYCL_DEVICE_ONLY__) - // Note: zero is not a supported input into `__builtin_clz` - uint32_t renorm_shift = - nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; -#elif defined(_MSC_VER) && !defined(__clang__) - unsigned long nonsign_bsr; - _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); - uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; -#else - // Note: zero is not a supported input into `__builtin_clz` - uint32_t renorm_shift = - nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; -#endif - renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; - /* - * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1, - * the addition overflows it into bit 31, and the subsequent shift turns the - * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number - * is Nan, 0x00000000 otherwise - */ - const int32_t inf_nan_mask = - ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); - /* - * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 - * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 - * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == - * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) - * 0x00000000 otherwise - */ - const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; - /* - * 1. Shift nonsign left by renorm_shift to normalize it (if the input - * was denormal) - * 2. Shift nonsign right by 4 so the exponent (4 bits originally) - * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high - * bits of the 23-bit mantissa of IEEE single-precision number. - * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the - * different in exponent bias (0x7F for single-precision number less 0x07 - * for fp8e4m3fn number). - * 4. Subtract renorm_shift from the exponent (starting at bit 23) to - * account for renormalization. As renorm_shift is less than 0x78, this - * can be combined with step 3. - * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the - * input was NaN or infinity. - * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent - * into zero if the input was zero. - * 7. Combine with the sign of the input number. - */ - uint32_t result = sign | - ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | - inf_nan_mask) & - ~zero_mask); - return fp32_from_bits(result); -} - -/* - * Convert a 32-bit floating-point number in IEEE single-precision format to a - * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. - */ -inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { - /* - * Binary representation of 480.0f, which is the first value - * not representable in fp8e4m3fn range: - * 0 1111 111 - fp8e4m3fn - * 0 10000111 11100000000000000000000 - fp32 - */ - constexpr uint32_t fp8_max = UINT32_C(1087) << 20; - - /* - * A mask for converting fp32 numbers lower than fp8e4m3fn normal range - * into denorm representation - * magic number: ((127 - 7) + (23 - 3) + 1) - */ - constexpr uint32_t denorm_mask = UINT32_C(141) << 23; - - uint32_t f_bits = fp32_to_bits(f); - - uint8_t result = 0u; - - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = f_bits & UINT32_C(0x80000000); - - /* - * Set sign bit to 0 - */ - f_bits ^= sign; - - if (f_bits >= fp8_max) { - // NaN - all exponent and mantissa bits set to 1 - result = 0x7f; - } else { - if (f_bits < (UINT32_C(121) << 23)) { - // Input number is smaller than 2^(-6), which is the smallest - // fp8e4m3fn normal number - f_bits = - fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); - result = static_cast(f_bits - denorm_mask); - } else { - // resulting mantissa is odd - uint8_t mant_odd = (f_bits >> 20) & 1; - - // update exponent, rounding bias part 1 - f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; - - // rounding bias part 2 - f_bits += mant_odd; - - // take the bits! - result = static_cast(f_bits >> 20); - } - } - - result |= static_cast(sign >> 24); - return result; -} - -} // namespace detail - -struct alignas(1) Float8_e4m3fn { - uint8_t x; - - struct from_bits_t {}; - C10_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - - Float8_e4m3fn() = default; - - constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) - : x(bits) {} - inline C10_HOST_DEVICE Float8_e4m3fn(float value); - inline C10_HOST_DEVICE operator float() const; - inline C10_HOST_DEVICE bool isnan() const; -}; - -inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep +#include diff --git a/c10/util/Float8_e4m3fnuz-inl.h b/c10/util/Float8_e4m3fnuz-inl.h index e89eaeadd47b..f8fab7180e1e 100644 --- a/c10/util/Float8_e4m3fnuz-inl.h +++ b/c10/util/Float8_e4m3fnuz-inl.h @@ -1,279 +1 @@ -#pragma once - -#include -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -namespace c10 { - -/// Constructors - -inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) - : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} - -/// Implicit conversions - -inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const { - return detail::fp8_fnuz_to_fp32_value<4, 3>(x); -} - -/// Special values helper - -inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { - return x == 0b10000000; -} - -/// Arithmetic - -inline C10_HOST_DEVICE Float8_e4m3fnuz -operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { - return static_cast(a) + static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz -operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { - return static_cast(a) - static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz -operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { - return static_cast(a) * static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz operator/( - const Float8_e4m3fnuz& a, - const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) { - return -static_cast(a); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=( - Float8_e4m3fnuz& a, - const Float8_e4m3fnuz& b) { - a = a + b; - return a; -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=( - Float8_e4m3fnuz& a, - const Float8_e4m3fnuz& b) { - a = a - b; - return a; -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=( - Float8_e4m3fnuz& a, - const Float8_e4m3fnuz& b) { - a = a * b; - return a; -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=( - Float8_e4m3fnuz& a, - const Float8_e4m3fnuz& b) { - a = a / b; - return a; -} - -/// Arithmetic with floats - -inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) { - return a += static_cast(b); -} -inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) { - return a -= static_cast(b); -} -inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) { - return a *= static_cast(b); -} -inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) { - return a /= static_cast(b); -} - -/// Arithmetic with doubles - -inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -/// Arithmetic with ints - -inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) { - return static_cast(a) / b; -} - -//// Arithmetic with int64_t - -inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) { - return static_cast(a) / b; -} - -/// NOTE: we do not define comparisons directly and instead rely on the implicit -/// conversion from c10::Float8_e4m3fnuz to float. - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_specialized = true; - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = false; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = false; - static constexpr auto has_denorm = true; - static constexpr auto has_denorm_loss = true; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 4; - static constexpr int digits10 = 0; - static constexpr int max_digits10 = 3; - static constexpr int radix = 2; - static constexpr int min_exponent = -6; - static constexpr int min_exponent10 = -1; - static constexpr int max_exponent = 8; - static constexpr int max_exponent10 = 2; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = false; - - static constexpr c10::Float8_e4m3fnuz min() { - return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz lowest() { - return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz max() { - return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz epsilon() { - return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz round_error() { - return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz infinity() { - // NaN (no infinities) - return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz quiet_NaN() { - return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); - } - static constexpr c10::Float8_e4m3fnuz denorm_min() { - return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/Float8_e4m3fnuz.h b/c10/util/Float8_e4m3fnuz.h index f5de58f12a11..f8fab7180e1e 100644 --- a/c10/util/Float8_e4m3fnuz.h +++ b/c10/util/Float8_e4m3fnuz.h @@ -1,139 +1 @@ -#pragma once - -/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including -/// conversions to standard C types and basic arithmetic operations. Note that -/// arithmetic operations are implemented by converting to floating point and -/// performing the operation in float32. -/// Binary configuration remains the same as Float8_e4m3fn: -/// s eeee mmm -/// 1 sign bit -/// 4 exponent bits -/// 3 mantissa bits -/// The key differences versus Float8_e4m3fn are: -/// bias = 8 -/// no infinities or negative zero -/// NaN only when sign bit is 1, rest all 0s -/// -/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and -/// the existing Float8_e4m3fn implementation. - -#include -#include -#include -#include - -#if defined(__cplusplus) -#include -#elif !defined(__OPENCL_VERSION__) -#include -#include -#endif - -#include -#include - -namespace c10 { - -namespace detail { - -/* - * Convert a 32-bit floating-point number in IEEE single-precision format to a - * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. - */ -inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { - /* - * Binary representation of 256.0f, which is the first value not representable - * (i.e. the first value which would overflow in to the sign bit, resulting in - * a NaN) in fp8e4m3fnuz range: - * 1 0000 000 - fp8e4m3fnuz - * 0 10000111 00000000000000000000000 - fp32 - */ - constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; - - /* - * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range - * into denorm representation - * magic number: ((127 - 8) + (23 - 3) + 1) - */ - constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; - - uint32_t f_bits = fp32_to_bits(f); - - uint32_t result = 0u; - - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = f_bits & UINT32_C(0x80000000); - - /* - * Set sign bit to 0 - */ - f_bits ^= sign; - - if (f_bits >= fnuz_max) { - // NaN -- sign bit set to 1, rest 0s. - return 0x80; - } - - if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { - // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the - // number will become subnormal. - f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); - result = static_cast(f_bits - denorm_mask); - if (result == 0) { - // fnuz types don't have negative zero. - return 0; - } - } else { - // resulting mantissa is odd - uint8_t mant_odd = (f_bits >> 20) & 1; - - // update exponent, rounding bias part 1 - f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; - - // rounding bias part 2 - f_bits += mant_odd; - - // take the bits! - result = static_cast(f_bits >> 20); - } - - result |= sign >> 24; - return result; -} - -} // namespace detail - -struct alignas(1) Float8_e4m3fnuz { - uint8_t x; - - struct from_bits_t {}; - C10_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - - Float8_e4m3fnuz() = default; - - constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) - : x(bits) {} - inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); - inline C10_HOST_DEVICE operator float() const; - inline C10_HOST_DEVICE bool isnan() const; -}; - -inline std::ostream& operator<<( - std::ostream& out, - const Float8_e4m3fnuz& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep +#include diff --git a/c10/util/Float8_e5m2-inl.h b/c10/util/Float8_e5m2-inl.h index 5a5c1a5fc9b5..2e21840fba37 100644 --- a/c10/util/Float8_e5m2-inl.h +++ b/c10/util/Float8_e5m2-inl.h @@ -1,286 +1 @@ -#pragma once - -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -#define EXP_WIDTH_FP8 5 -#define MAN_WIDTH_FP8 2 -#define EXP_BIAS_FP8 15 - -namespace c10 { - -/// Constructors - -inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) - : x(detail::fp8e5m2_from_fp32_value(value)) {} - -/// Implicit conversions - -inline C10_HOST_DEVICE Float8_e5m2::operator float() const { - return detail::fp8e5m2_to_fp32_value(x); -} - -/// Special values helpers - -inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const { - return (x & 0b01111111) > 0b01111100; -} - -inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const { - return (x & 0b01111111) == 0b01111100; -} - -/// Arithmetic - -inline C10_HOST_DEVICE Float8_e5m2 -operator+(const Float8_e5m2& a, const Float8_e5m2& b) { - return static_cast(a) + static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2 -operator-(const Float8_e5m2& a, const Float8_e5m2& b) { - return static_cast(a) - static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2 -operator*(const Float8_e5m2& a, const Float8_e5m2& b) { - return static_cast(a) * static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2 operator/( - const Float8_e5m2& a, - const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) { - return -static_cast(a); -} - -inline C10_HOST_DEVICE Float8_e5m2& operator+=( - Float8_e5m2& a, - const Float8_e5m2& b) { - a = a + b; - return a; -} - -inline C10_HOST_DEVICE Float8_e5m2& operator-=( - Float8_e5m2& a, - const Float8_e5m2& b) { - a = a - b; - return a; -} - -inline C10_HOST_DEVICE Float8_e5m2& operator*=( - Float8_e5m2& a, - const Float8_e5m2& b) { - a = a * b; - return a; -} - -inline C10_HOST_DEVICE Float8_e5m2& operator/=( - Float8_e5m2& a, - const Float8_e5m2& b) { - a = a / b; - return a; -} - -/// Arithmetic with floats - -inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) { - return a += static_cast(b); -} -inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) { - return a -= static_cast(b); -} -inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) { - return a *= static_cast(b); -} -inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) { - return a /= static_cast(b); -} - -/// Arithmetic with doubles - -inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -/// Arithmetic with ints - -inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) { - return static_cast(a) / b; -} - -//// Arithmetic with int64_t - -inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) { - return static_cast(a) / b; -} - -/// NOTE: we do not define comparisons directly and instead rely on the implicit -/// conversion from c10::Float8_e5m2 to float. - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_specialized = true; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = true; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = false; - static constexpr auto has_denorm = true; - static constexpr auto has_denorm_loss = true; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 3; - static constexpr int digits10 = 0; - static constexpr int max_digits10 = 2; - static constexpr int radix = 2; - static constexpr int min_exponent = -13; - static constexpr int min_exponent10 = -4; - static constexpr int max_exponent = 16; - static constexpr int max_exponent10 = 4; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = - numeric_limits::tinyness_before; - - static constexpr c10::Float8_e5m2 min() { - return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 max() { - return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 lowest() { - return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 epsilon() { - return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 round_error() { - return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 infinity() { - return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 quiet_NaN() { - return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits()); - } - static constexpr c10::Float8_e5m2 denorm_min() { - return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/Float8_e5m2.h b/c10/util/Float8_e5m2.h index 8f70b77bcd6e..2e21840fba37 100644 --- a/c10/util/Float8_e5m2.h +++ b/c10/util/Float8_e5m2.h @@ -1,146 +1 @@ -#pragma once - -/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions -/// to standard C types and basic arithmetic operations. Note that arithmetic -/// operations are implemented by converting to floating point and -/// performing the operation in float32. -/// Binary configuration: -/// s eeeee mm -/// 1 sign bit -/// 5 exponent bits -/// 2 mantissa bits -/// bias = 15 -/// -/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf -/// and inspired by Half implementation from pytorch/c10/util/Half.h - -#include - -namespace c10 { - -namespace detail { - -/* - * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit - * representation, to a 32-bit floating-point number in IEEE single-precision - * format, in bit representation. - * - * @note The implementation doesn't use any floating-point operations. - */ -inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) { - /* - * Extend the fp8 E5M2 number to 32 bits and shift to the - * upper part of the 32-bit word: - * +---+----+---+-----------------------------+ - * | S |EEEEE|MM|0000 0000 0000 0000 0000 0000| - * +---+----+---+-----------------------------+ - * Bits 31 26-30 24-25 0-23 - * - * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - * - zero bits. - */ - uint16_t half_representation = input; - half_representation <<= 8; - return fp16_ieee_to_fp32_value(half_representation); -} - -/* - * Convert a 32-bit floating-point number in IEEE single-precision format to a - * 8-bit floating-point number in fp8 E5M2 format, in bit representation. - */ -inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) { - /* - * Binary representation of fp32 infinity - * 0 11111111 00000000000000000000000 - */ - constexpr uint32_t fp32_inf = UINT32_C(255) << 23; - - /* - * Binary representation of 65536.0f, which is the first value - * not representable in fp8e5m2 range: - * 0 11111 00 - fp8e5m2 - * 0 10001111 00000000000000000000000 - fp32 - */ - constexpr uint32_t fp8_max = UINT32_C(143) << 23; - - /* - * A mask for converting fp32 numbers lower than fp8e5m2 normal range - * into denorm representation - * magic number: ((127 - 15) + (23 - 2) + 1) - */ - constexpr uint32_t denorm_mask = UINT32_C(134) << 23; - - uint32_t f_bits = fp32_to_bits(f); - uint8_t result = 0u; - - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = f_bits & UINT32_C(0x80000000); - - /* - * Set sign bit to 0 - */ - f_bits ^= sign; - - if (f_bits >= fp8_max) { - // NaN - all exponent and mantissa bits set to 1 - result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); - } else { - if (f_bits < (UINT32_C(113) << 23)) { - // Input number is smaller than 2^(-14), which is the smallest - // fp8e5m2 normal number - f_bits = - fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); - result = static_cast(f_bits - denorm_mask); - } else { - // resulting mantissa is odd - uint32_t mant_odd = (f_bits >> 21) & 1; - - // update exponent, rounding bias part 1 - f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; - - // rounding bias part 2 - f_bits += mant_odd; - - // take the bits! - result = static_cast(f_bits >> 21); - } - } - - result |= static_cast(sign >> 24); - return result; -} - -} // namespace detail - -struct alignas(1) Float8_e5m2 { - uint8_t x; - - struct from_bits_t {}; - C10_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - - Float8_e5m2() = default; - - constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {} - inline C10_HOST_DEVICE Float8_e5m2(float value); - inline C10_HOST_DEVICE operator float() const; - inline C10_HOST_DEVICE bool isnan() const; - inline C10_HOST_DEVICE bool isinf() const; -}; - -inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep +#include diff --git a/c10/util/Float8_e5m2fnuz-inl.h b/c10/util/Float8_e5m2fnuz-inl.h index d81054cbee35..1f2d3db723d0 100644 --- a/c10/util/Float8_e5m2fnuz-inl.h +++ b/c10/util/Float8_e5m2fnuz-inl.h @@ -1,285 +1 @@ -#pragma once - -#include -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -namespace c10 { - -/// Constructors - -inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) - : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} - -/// Implicit conversions - -inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const { - return detail::fp8_fnuz_to_fp32_value<5, 2>(x); -} - -/// Special values helpers - -inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { - return x == 0b10000000; -} - -inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { - return false; -} - -/// Arithmetic - -inline C10_HOST_DEVICE Float8_e5m2fnuz -operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { - return static_cast(a) + static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz -operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { - return static_cast(a) - static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz -operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { - return static_cast(a) * static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz operator/( - const Float8_e5m2fnuz& a, - const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) { - return -static_cast(a); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=( - Float8_e5m2fnuz& a, - const Float8_e5m2fnuz& b) { - a = a + b; - return a; -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=( - Float8_e5m2fnuz& a, - const Float8_e5m2fnuz& b) { - a = a - b; - return a; -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=( - Float8_e5m2fnuz& a, - const Float8_e5m2fnuz& b) { - a = a * b; - return a; -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=( - Float8_e5m2fnuz& a, - const Float8_e5m2fnuz& b) { - a = a / b; - return a; -} - -/// Arithmetic with floats - -inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) { - return a += static_cast(b); -} -inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) { - return a -= static_cast(b); -} -inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) { - return a *= static_cast(b); -} -inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) { - return a /= static_cast(b); -} - -/// Arithmetic with doubles - -inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -/// Arithmetic with ints - -inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) { - return static_cast(a) / b; -} - -//// Arithmetic with int64_t - -inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) { - return static_cast(a) / b; -} - -/// NOTE: we do not define comparisons directly and instead rely on the implicit -/// conversion from c10::Float8_e5m2fnuz to float. - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_specialized = true; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = false; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = false; - static constexpr auto has_denorm = true; - static constexpr auto has_denorm_loss = true; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 3; - static constexpr int digits10 = 0; - static constexpr int max_digits10 = 2; - static constexpr int radix = 2; - static constexpr int min_exponent = -14; - static constexpr int min_exponent10 = -4; - static constexpr int max_exponent = 16; - static constexpr int max_exponent10 = 4; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = - numeric_limits::tinyness_before; - - static constexpr c10::Float8_e5m2fnuz min() { - return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits()); - } - static constexpr c10::Float8_e5m2fnuz max() { - return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits()); - } - static constexpr c10::Float8_e5m2fnuz lowest() { - return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits()); - } - static constexpr c10::Float8_e5m2fnuz epsilon() { - return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits()); - } - static constexpr c10::Float8_e5m2fnuz round_error() { - return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits()); - } - static constexpr c10::Float8_e5m2fnuz infinity() { - return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); - } - // TODO(future): we are mapping neg_zero to both inf and NaN, this is - // surprising and we should figure out what to do about it. - static constexpr c10::Float8_e5m2fnuz quiet_NaN() { - return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); - } - static constexpr c10::Float8_e5m2fnuz denorm_min() { - return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/Float8_e5m2fnuz.h b/c10/util/Float8_e5m2fnuz.h index 9b8c2505ab1f..1f2d3db723d0 100644 --- a/c10/util/Float8_e5m2fnuz.h +++ b/c10/util/Float8_e5m2fnuz.h @@ -1,138 +1 @@ -#pragma once - -/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including -/// conversions to standard C types and basic arithmetic operations. Note that -/// arithmetic operations are implemented by converting to floating point and -/// performing the operation in float32. -/// Binary configuration remains the same as e5m2: -/// s eeeee mm -/// 1 sign bit -/// 5 exponent bits -/// 2 mantissa bits -/// The key differences that e5m2fnuz brings are: -/// bias = 16 -/// no infinities or negative zero -/// NaN only when sign bit is 1, rest all 0s -/// -/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and -/// the existing Float8_e4m3fn implementation. - -#include -#include -#include - -#if defined(__cplusplus) -#include -#elif !defined(__OPENCL_VERSION__) -#include -#include -#endif - -#include -#include - -namespace c10 { - -namespace detail { - -/* - * Convert a 32-bit floating-point number in IEEE single-precision format to a - * 8-bit floating-point number in fp8 E5M2 format, in bit representation. - */ -inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { - /* - * Binary representation of 65536.0f, which is the first value not - * representable (i.e. the first value which would overflow in to the sign - * bit, resulting in a NaN) in fp8e4m3fnuz range: - * 1 00000 00 - fp8e5m2fnuz - * 0 10001111 00000000000000000000000 - fp32 - */ - constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; - - /* - * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range - * into denormalized representation. - * magic number: ((127 - 16) + (23 - 2) + 1) - */ - constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; - - uint32_t f_bits = fp32_to_bits(f); - uint32_t result = 0u; - - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = f_bits & UINT32_C(0x80000000); - - /* - * Set sign bit to 0 - */ - f_bits ^= sign; - - if (f_bits >= fnuz_max) { - // NaN -- sign bit set to 1, rest 0s - return 0x80; - } - - if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { - // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the - // number will become subnormal. - f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); - result = static_cast(f_bits - denorm_mask); - if (result == 0) { - // fnuz types don't have negative zero. - return 0; - } - } else { - // resulting mantissa is odd - uint8_t mant_odd = (f_bits >> 21) & 1; - - // update exponent, rounding bias part 1 - f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; - - // rounding bias part 2 - f_bits += mant_odd; - - // take the bits! - result = static_cast(f_bits >> 21); - } - - result |= sign >> 24; - return result; -} - -} // namespace detail - -struct alignas(1) Float8_e5m2fnuz { - uint8_t x; - - struct from_bits_t {}; - C10_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - - Float8_e5m2fnuz() = default; - - constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) - : x(bits) {} - inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); - inline C10_HOST_DEVICE operator float() const; - inline C10_HOST_DEVICE bool isnan() const; - inline C10_HOST_DEVICE bool isinf() const; -}; - -inline std::ostream& operator<<( - std::ostream& out, - const Float8_e5m2fnuz& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep +#include diff --git a/c10/util/Float8_e8m0fnu-inl.h b/c10/util/Float8_e8m0fnu-inl.h index 7d67934abd14..9982faa07976 100644 --- a/c10/util/Float8_e8m0fnu-inl.h +++ b/c10/util/Float8_e8m0fnu-inl.h @@ -1,112 +1 @@ -#pragma once - -#include -#include -#include -#include - -// TODO(#146647): Can we remove the below warning? -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -namespace c10 { - -/// Constructors - -inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value) - : x(detail::fp8e8m0fnu_from_fp32_value(value)) {} - -/// Implicit conversions - -inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const { - // TODO(#146647): maybe rewrite without control flow - - // if exponent is zero, need to special case to return 2^-127 instead of zero - if (x == 0) { - return c10::detail::fp32_from_bits(0x00400000); - } - - // if exponent is NaN, need to special case to return properly encoded NaN - if (isnan()) { - return c10::detail::fp32_from_bits(0x7f800001); - } - - // leave sign at 0, set the exponent bits, leave stored mantissa at 0 - uint32_t res = x << 23; - - return c10::detail::fp32_from_bits(res); -} - -/// Special values helper - -inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const { - return x == 0b11111111; -} - -/// NOTE: we do not define comparisons directly and instead rely on the implicit -/// conversion from c10::Float8_e8m0fnu to float. - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_specialized = true; - static constexpr bool is_signed = false; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = false; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = false; - static constexpr auto has_denorm = false; - static constexpr auto has_denorm_loss = false; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 1; - static constexpr int digits10 = 0; - static constexpr int max_digits10 = 1; // just a 2! - static constexpr int radix = 2; - static constexpr int min_exponent = -126; - static constexpr int min_exponent10 = -38; - static constexpr int max_exponent = 128; - static constexpr int max_exponent10 = 38; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = false; - - static constexpr c10::Float8_e8m0fnu min() { - // 2^-127 - return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); - } - static constexpr c10::Float8_e8m0fnu lowest() { - // 2^-127 - return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); - } - static constexpr c10::Float8_e8m0fnu max() { - // 254 biased, which is 127 unbiased, so 2^127 - return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits()); - } - static constexpr c10::Float8_e8m0fnu epsilon() { - // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this - // is "the difference between 1.0 and the next representable value of the - // given floating-point type". The next representable value is 2.0, so the - // difference is 1.0 which is 2^0. 0 unbiased is 127 biased. - return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits()); - } - static constexpr c10::Float8_e8m0fnu round_error() { - // 0.5 in float, which is 2^-1, and -1 + 127 = 126 - return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits()); - } - static constexpr c10::Float8_e8m0fnu quiet_NaN() { - return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/Float8_e8m0fnu.h b/c10/util/Float8_e8m0fnu.h index 0ae3b5012f8f..9982faa07976 100644 --- a/c10/util/Float8_e8m0fnu.h +++ b/c10/util/Float8_e8m0fnu.h @@ -1,120 +1 @@ -#pragma once - -/// Defines the Float8_e8m0fnu type (8-bit floating-point) including -/// conversions to standard C types -/// Binary configuration : -/// eeeeeeee -/// no sign bits -/// 8 exponent bits -/// no mantissa bits -/// -/// This is the E8M0 dtype from the OCP MX format spec -/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, -/// Section 5.4.1) - -#include -#include -#include -#include - -// TODO(#146647): do we need to special case OPENCL? -#if defined(__cplusplus) -#include -#elif !defined(__OPENCL_VERSION__) -#include -#include -#endif - -#include -#include - -namespace c10 { - -namespace detail { - -/* - * Convert a 32-bit floating-point number in IEEE single-precision format to a - * 8-bit floating-point number in fp8 e8m0fnu format, in bit representation. - */ -inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) { - // TODO(#146647): maybe rewrite without control flow - - uint32_t f_bits = c10::detail::fp32_to_bits(f); - - // extract the exponent - uint32_t exponent = (f_bits >> 23) & 0b11111111; - - // special case float32 NaN and +-inf to map to e8m0 nan - if (exponent == 0b11111111) { - return exponent; - } - - // next, we use guard, round, sticky bits and the LSB to implement round to - // nearest, with ties to even - - // guard bit - bit 23, or 22 zero-indexed - uint8_t g = (f_bits & 0x400000) > 0; - // round bit - bit 22, or 21 zero-indexed - uint8_t r = (f_bits & 0x200000) > 0; - // sticky bit - bits 21 to 1, or 20 to 0 zero-indexed - uint8_t s = (f_bits & 0x1FFFFF) > 0; - // in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the - // original float32 is denormal, and to 1 if the original float32 is normal. - uint8_t lsb = exponent > 0; - - // implement the RNE logic - bool round_up = false; - - // if g == 0, round down (no-op) - if (g == 1) { - if ((r == 1) || (s == 1)) { - // round up - round_up = true; - } else { - if (lsb == 1) { - // round up - round_up = true; - } - // if lsb == 0, round down (no-op) - } - } - - if (round_up) { - // adjust exponent - // note that if exponent was 255 we would have already returned earlier, so - // we know we can add one safely without running out of bounds - exponent++; - } - - return exponent; -} - -} // namespace detail - -struct alignas(1) Float8_e8m0fnu { - uint8_t x; - - struct from_bits_t {}; - C10_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - - Float8_e8m0fnu() = default; - - constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) - : x(bits) {} - inline C10_HOST_DEVICE Float8_e8m0fnu(float value); - inline C10_HOST_DEVICE operator float() const; - inline C10_HOST_DEVICE bool isnan() const; -}; - -inline std::ostream& operator<<( - std::ostream& out, - const Float8_e8m0fnu& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep +#include diff --git a/c10/util/TypeSafeSignMath.h b/c10/util/TypeSafeSignMath.h index 58c050678302..28520225d4b2 100644 --- a/c10/util/TypeSafeSignMath.h +++ b/c10/util/TypeSafeSignMath.h @@ -1,140 +1 @@ -#pragma once - -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wstring-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion") -#endif -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -namespace c10 { - -/// Returns false since we cannot have x < 0 if x is unsigned. -template -inline constexpr bool is_negative( - const T& /*x*/, - std::true_type /*is_unsigned*/) { - return false; -} - -/// Returns true if a signed variable x < 0 -template -inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { - return x < T(0); -} - -/// Returns true if x < 0 -/// NOTE: Will fail on an unsigned custom type -/// For the most part it's possible to fix this if -/// the custom type has a constexpr constructor. -/// However, notably, c10::Half does not :-( -template -inline constexpr bool is_negative(const T& x) { - return is_negative(x, std::is_unsigned()); -} - -/// Returns the sign of an unsigned variable x as 0, 1 -template -inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { - return T(0) < x; -} - -/// Returns the sign of a signed variable x as -1, 0, 1 -template -inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { - return (T(0) < x) - (x < T(0)); -} - -/// Returns the sign of x as -1, 0, 1 -/// NOTE: Will fail on an unsigned custom type -/// For the most part it's possible to fix this if -/// the custom type has a constexpr constructor. -/// However, notably, c10::Half does not :-( -template -inline constexpr int signum(const T& x) { - return signum(x, std::is_unsigned()); -} - -/// Returns true if a and b are not both negative -template -inline constexpr bool signs_differ(const T& a, const U& b) { - return is_negative(a) != is_negative(b); -} - -// Suppress sign compare warning when compiling with GCC -// as later does not account for short-circuit rule before -// raising the warning, see https://godbolt.org/z/Tr3Msnz99 -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -#endif - -/// Returns true if x is greater than the greatest value of the type Limit -template -inline constexpr bool greater_than_max(const T& x) { - constexpr bool can_overflow = - std::numeric_limits::digits > std::numeric_limits::digits; - return can_overflow && x > (std::numeric_limits::max)(); -} - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif - -/// Returns true if x < lowest(Limit). Standard comparison -template -inline constexpr bool less_than_lowest( - const T& x, - std::false_type /*limit_is_unsigned*/, - std::false_type /*x_is_unsigned*/) { - return x < std::numeric_limits::lowest(); -} - -/// Returns false since all the limit is signed and therefore includes -/// negative values but x cannot be negative because it is unsigned -template -inline constexpr bool less_than_lowest( - const T& /*x*/, - std::false_type /*limit_is_unsigned*/, - std::true_type /*x_is_unsigned*/) { - return false; -} - -/// Returns true if x < 0, where 0 is constructed from T. -/// Limit is not signed, so its lower value is zero -template -inline constexpr bool less_than_lowest( - const T& x, - std::true_type /*limit_is_unsigned*/, - std::false_type /*x_is_unsigned*/) { - return x < T(0); -} - -/// Returns false sign both types are unsigned -template -inline constexpr bool less_than_lowest( - const T& /*x*/, - std::true_type /*limit_is_unsigned*/, - std::true_type /*x_is_unsigned*/) { - return false; -} - -/// Returns true if x is less than the lowest value of type T -/// NOTE: Will fail on an unsigned custom type -/// For the most part it's possible to fix this if -/// the custom type has a constexpr constructor. -/// However, notably, c10::Half does not : -template -inline constexpr bool less_than_lowest(const T& x) { - return less_than_lowest( - x, std::is_unsigned(), std::is_unsigned()); -} - -} // namespace c10 - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/complex.h b/c10/util/complex.h index b63710d9458f..4e699684bc38 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -4,531 +4,7 @@ #include #include - -#if defined(__CUDACC__) || defined(__HIPCC__) -#include -#endif - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") -#endif -#if C10_CLANG_HAS_WARNING("-Wfloat-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") -#endif - -namespace c10 { - -// c10::complex is an implementation of complex numbers that aims -// to work on all devices supported by PyTorch -// -// Most of the APIs duplicates std::complex -// Reference: https://en.cppreference.com/w/cpp/numeric/complex -// -// [NOTE: Complex Operator Unification] -// Operators currently use a mix of std::complex, thrust::complex, and -// c10::complex internally. The end state is that all operators will use -// c10::complex internally. Until then, there may be some hacks to support all -// variants. -// -// -// [Note on Constructors] -// -// The APIs of constructors are mostly copied from C++ standard: -// https://en.cppreference.com/w/cpp/numeric/complex/complex -// -// Since C++14, all constructors are constexpr in std::complex -// -// There are three types of constructors: -// - initializing from real and imag: -// `constexpr complex( const T& re = T(), const T& im = T() );` -// - implicitly-declared copy constructor -// - converting constructors -// -// Converting constructors: -// - std::complex defines converting constructor between float/double/long -// double, -// while we define converting constructor between float/double. -// - For these converting constructors, upcasting is implicit, downcasting is -// explicit. -// - We also define explicit casting from std::complex/thrust::complex -// - Note that the conversion from thrust is not constexpr, because -// thrust does not define them as constexpr ???? -// -// -// [Operator =] -// -// The APIs of operator = are mostly copied from C++ standard: -// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D -// -// Since C++20, all operator= are constexpr. Although we are not building with -// C++20, we also obey this behavior. -// -// There are three types of assign operator: -// - Assign a real value from the same scalar type -// - In std, this is templated as complex& operator=(const T& x) -// with specialization `complex& operator=(T x)` for float/double/long -// double Since we only support float and double, on will use `complex& -// operator=(T x)` -// - Copy assignment operator and converting assignment operator -// - There is no specialization of converting assignment operators, which type -// is -// convertible is solely dependent on whether the scalar type is convertible -// -// In addition to the standard assignment, we also provide assignment operators -// with std and thrust -// -// -// [Casting operators] -// -// std::complex does not have casting operators. We define casting operators -// casting to std::complex and thrust::complex -// -// -// [Operator ""] -// -// std::complex has custom literals `i`, `if` and `il` defined in namespace -// `std::literals::complex_literals`. We define our own custom literals in the -// namespace `c10::complex_literals`. Our custom literals does not follow the -// same behavior as in std::complex, instead, we define _if, _id to construct -// float/double complex literals. -// -// -// [real() and imag()] -// -// In C++20, there are two overload of these functions, one it to return the -// real/imag, another is to set real/imag, they are both constexpr. We follow -// this design. -// -// -// [Operator +=,-=,*=,/=] -// -// Since C++20, these operators become constexpr. In our implementation, they -// are also constexpr. -// -// There are two types of such operators: operating with a real number, or -// operating with another complex number. For the operating with a real number, -// the generic template form has argument type `const T &`, while the overload -// for float/double/long double has `T`. We will follow the same type as -// float/double/long double in std. -// -// [Unary operator +-] -// -// Since C++20, they are constexpr. We also make them expr -// -// [Binary operators +-*/] -// -// Each operator has three versions (taking + as example): -// - complex + complex -// - complex + real -// - real + complex -// -// [Operator ==, !=] -// -// Each operator has three versions (taking == as example): -// - complex == complex -// - complex == real -// - real == complex -// -// Some of them are removed on C++20, but we decide to keep them -// -// [Operator <<, >>] -// -// These are implemented by casting to std::complex -// -// -// -// TODO(@zasdfgbnm): c10::complex is not currently supported, -// because: -// - lots of members and functions of c10::Half are not constexpr -// - thrust::complex only support float and double - -template -struct alignas(sizeof(T) * 2) complex { - using value_type = T; - - T real_ = T(0); - T imag_ = T(0); - - constexpr complex() = default; - C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) - : real_(re), imag_(im) {} - template - explicit constexpr complex(const std::complex& other) - : complex(other.real(), other.imag()) {} -#if defined(__CUDACC__) || defined(__HIPCC__) - template - explicit C10_HOST_DEVICE complex(const thrust::complex& other) - : real_(other.real()), imag_(other.imag()) {} -// NOTE can not be implemented as follow due to ROCm bug: -// explicit C10_HOST_DEVICE complex(const thrust::complex &other): -// complex(other.real(), other.imag()) {} -#endif - - // Use SFINAE to specialize casting constructor for c10::complex and - // c10::complex - template - C10_HOST_DEVICE explicit constexpr complex( - const std::enable_if_t, complex>& other) - : real_(other.real_), imag_(other.imag_) {} - template - C10_HOST_DEVICE constexpr complex( - const std::enable_if_t, complex>& other) - : real_(other.real_), imag_(other.imag_) {} - - constexpr complex& operator=(T re) { - real_ = re; - imag_ = 0; - return *this; - } - - constexpr complex& operator+=(T re) { - real_ += re; - return *this; - } - - constexpr complex& operator-=(T re) { - real_ -= re; - return *this; - } - - constexpr complex& operator*=(T re) { - real_ *= re; - imag_ *= re; - return *this; - } - - constexpr complex& operator/=(T re) { - real_ /= re; - imag_ /= re; - return *this; - } - - template - constexpr complex& operator=(const complex& rhs) { - real_ = rhs.real(); - imag_ = rhs.imag(); - return *this; - } - - template - constexpr complex& operator+=(const complex& rhs) { - real_ += rhs.real(); - imag_ += rhs.imag(); - return *this; - } - - template - constexpr complex& operator-=(const complex& rhs) { - real_ -= rhs.real(); - imag_ -= rhs.imag(); - return *this; - } - - template - constexpr complex& operator*=(const complex& rhs) { - // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i - T a = real_; - T b = imag_; - U c = rhs.real(); - U d = rhs.imag(); - real_ = a * c - b * d; - imag_ = a * d + b * c; - return *this; - } - -#ifdef __APPLE__ -#define FORCE_INLINE_APPLE __attribute__((always_inline)) -#else -#define FORCE_INLINE_APPLE -#endif - template - constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) - __ubsan_ignore_float_divide_by_zero__ { - // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i - // the calculation below follows numpy's complex division - T a = real_; - T b = imag_; - U c = rhs.real(); - U d = rhs.imag(); - -#if defined(__GNUC__) && !defined(__clang__) - // std::abs is already constexpr by gcc - auto abs_c = std::abs(c); - auto abs_d = std::abs(d); -#else - auto abs_c = c < 0 ? -c : c; - auto abs_d = d < 0 ? -d : d; -#endif - - if (abs_c >= abs_d) { - if (abs_c == U(0) && abs_d == U(0)) { - /* divide by zeros should yield a complex inf or nan */ - real_ = a / abs_c; - imag_ = b / abs_d; - } else { - auto rat = d / c; - auto scl = U(1.0) / (c + d * rat); - real_ = (a + b * rat) * scl; - imag_ = (b - a * rat) * scl; - } - } else { - auto rat = c / d; - auto scl = U(1.0) / (d + c * rat); - real_ = (a * rat + b) * scl; - imag_ = (b * rat - a) * scl; - } - return *this; - } -#undef FORCE_INLINE_APPLE - - template - constexpr complex& operator=(const std::complex& rhs) { - real_ = rhs.real(); - imag_ = rhs.imag(); - return *this; - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - template - C10_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { - real_ = rhs.real(); - imag_ = rhs.imag(); - return *this; - } -#endif - - template - explicit constexpr operator std::complex() const { - return std::complex(std::complex(real(), imag())); - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - template - C10_HOST_DEVICE explicit operator thrust::complex() const { - return static_cast>(thrust::complex(real(), imag())); - } -#endif - - // consistent with NumPy behavior - explicit constexpr operator bool() const { - return real() || imag(); - } - - C10_HOST_DEVICE constexpr T real() const { - return real_; - } - constexpr void real(T value) { - real_ = value; - } - C10_HOST_DEVICE constexpr T imag() const { - return imag_; - } - constexpr void imag(T value) { - imag_ = value; - } -}; - -namespace complex_literals { - -constexpr complex operator""_if(long double imag) { - return complex(0.0f, static_cast(imag)); -} - -constexpr complex operator""_id(long double imag) { - return complex(0.0, static_cast(imag)); -} - -constexpr complex operator""_if(unsigned long long imag) { - return complex(0.0f, static_cast(imag)); -} - -constexpr complex operator""_id(unsigned long long imag) { - return complex(0.0, static_cast(imag)); -} - -} // namespace complex_literals - -template -constexpr complex operator+(const complex& val) { - return val; -} - -template -constexpr complex operator-(const complex& val) { - return complex(-val.real(), -val.imag()); -} - -template -constexpr complex operator+(const complex& lhs, const complex& rhs) { - complex result = lhs; - return result += rhs; -} - -template -constexpr complex operator+(const complex& lhs, const T& rhs) { - complex result = lhs; - return result += rhs; -} - -template -constexpr complex operator+(const T& lhs, const complex& rhs) { - return complex(lhs + rhs.real(), rhs.imag()); -} - -template -constexpr complex operator-(const complex& lhs, const complex& rhs) { - complex result = lhs; - return result -= rhs; -} - -template -constexpr complex operator-(const complex& lhs, const T& rhs) { - complex result = lhs; - return result -= rhs; -} - -template -constexpr complex operator-(const T& lhs, const complex& rhs) { - complex result = -rhs; - return result += lhs; -} - -template -constexpr complex operator*(const complex& lhs, const complex& rhs) { - complex result = lhs; - return result *= rhs; -} - -template -constexpr complex operator*(const complex& lhs, const T& rhs) { - complex result = lhs; - return result *= rhs; -} - -template -constexpr complex operator*(const T& lhs, const complex& rhs) { - complex result = rhs; - return result *= lhs; -} - -template -constexpr complex operator/(const complex& lhs, const complex& rhs) { - complex result = lhs; - return result /= rhs; -} - -template -constexpr complex operator/(const complex& lhs, const T& rhs) { - complex result = lhs; - return result /= rhs; -} - -template -constexpr complex operator/(const T& lhs, const complex& rhs) { - complex result(lhs, T()); - return result /= rhs; -} - -// Define operators between integral scalars and c10::complex. std::complex does -// not support this when T is a floating-point number. This is useful because it -// saves a lot of "static_cast" when operate a complex and an integer. This -// makes the code both less verbose and potentially more efficient. -#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ - typename std::enable_if_t< \ - std::is_floating_point_v && std::is_integral_v, \ - int> = 0 - -template -constexpr c10::complex operator+(const c10::complex& a, const iT& b) { - return a + static_cast(b); -} - -template -constexpr c10::complex operator+(const iT& a, const c10::complex& b) { - return static_cast(a) + b; -} - -template -constexpr c10::complex operator-(const c10::complex& a, const iT& b) { - return a - static_cast(b); -} - -template -constexpr c10::complex operator-(const iT& a, const c10::complex& b) { - return static_cast(a) - b; -} - -template -constexpr c10::complex operator*(const c10::complex& a, const iT& b) { - return a * static_cast(b); -} - -template -constexpr c10::complex operator*(const iT& a, const c10::complex& b) { - return static_cast(a) * b; -} - -template -constexpr c10::complex operator/(const c10::complex& a, const iT& b) { - return a / static_cast(b); -} - -template -constexpr c10::complex operator/(const iT& a, const c10::complex& b) { - return static_cast(a) / b; -} - -#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION - -template -constexpr bool operator==(const complex& lhs, const complex& rhs) { - return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); -} - -template -constexpr bool operator==(const complex& lhs, const T& rhs) { - return (lhs.real() == rhs) && (lhs.imag() == T()); -} - -template -constexpr bool operator==(const T& lhs, const complex& rhs) { - return (lhs == rhs.real()) && (T() == rhs.imag()); -} - -template -constexpr bool operator!=(const complex& lhs, const complex& rhs) { - return !(lhs == rhs); -} - -template -constexpr bool operator!=(const complex& lhs, const T& rhs) { - return !(lhs == rhs); -} - -template -constexpr bool operator!=(const T& lhs, const complex& rhs) { - return !(lhs == rhs); -} - -template -std::basic_ostream& operator<<( - std::basic_ostream& os, - const complex& x) { - return (os << static_cast>(x)); -} - -template -std::basic_istream& operator>>( - std::basic_istream& is, - complex& x) { - std::complex tmp; - is >> tmp; - x = tmp; - return is; -} - -} // namespace c10 +#include // std functions // @@ -594,72 +70,6 @@ constexpr c10::complex conj(const c10::complex& z) { } // namespace std -namespace c10 { - -template -C10_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { -#if defined(__CUDACC__) || defined(__HIPCC__) - return static_cast>(thrust::polar(r, theta)); -#else - // std::polar() requires r >= 0, so spell out the explicit implementation to - // avoid a branch. - return complex(r * std::cos(theta), r * std::sin(theta)); -#endif -} - -template <> -struct alignas(4) complex { - Half real_; - Half imag_; - - // Constructors - complex() = default; - // Half constructor is not constexpr so the following constructor can't - // be constexpr - C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) - : real_(real), imag_(imag) {} - C10_HOST_DEVICE inline complex(const c10::complex& value) - : real_(value.real()), imag_(value.imag()) {} - - // Conversion operator - inline C10_HOST_DEVICE operator c10::complex() const { - return {real_, imag_}; - } - - constexpr C10_HOST_DEVICE Half real() const { - return real_; - } - constexpr C10_HOST_DEVICE Half imag() const { - return imag_; - } - - C10_HOST_DEVICE complex& operator+=(const complex& other) { - real_ = static_cast(real_) + static_cast(other.real_); - imag_ = static_cast(imag_) + static_cast(other.imag_); - return *this; - } - - C10_HOST_DEVICE complex& operator-=(const complex& other) { - real_ = static_cast(real_) - static_cast(other.real_); - imag_ = static_cast(imag_) - static_cast(other.imag_); - return *this; - } - - C10_HOST_DEVICE complex& operator*=(const complex& other) { - auto a = static_cast(real_); - auto b = static_cast(imag_); - auto c = static_cast(other.real()); - auto d = static_cast(other.imag()); - real_ = a * c - b * d; - imag_ = a * d + b * c; - return *this; - } -}; - -} // namespace c10 - -C10_CLANG_DIAGNOSTIC_POP() - #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H // math functions are included in a separate file #include // IWYU pragma: keep diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index afae32d92a4b..04ab3cabcbc2 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -539,7 +539,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public Allocator { +class XPUAllocator : public DeviceAllocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -575,6 +575,10 @@ class XPUAllocator : public Allocator { } } + bool initialized() override { + return !device_allocators.empty(); + } + void malloc( void** devPtr, DeviceIndex device, @@ -609,13 +613,13 @@ class XPUAllocator : public Allocator { } } - void emptyCache() { + void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { for (auto& da : device_allocators) { da->emptyCache(); } } - void recordStream(const DataPtr& ptr, XPUStream stream) { + void recordStream(const DataPtr& ptr, c10::Stream stream) override { if (!ptr.get()) { return; } @@ -625,7 +629,8 @@ class XPUAllocator : public Allocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); - device_allocators[block->device]->recordStream(block, stream); + c10::xpu::XPUStream xpu_stream{stream}; + device_allocators[block->device]->recordStream(block, xpu_stream); } DataPtr allocate(size_t size) override { @@ -678,17 +683,17 @@ class XPUAllocator : public Allocator { ": did you call init?"); } - DeviceStats getDeviceStats(DeviceIndex device) { + DeviceStats getDeviceStats(DeviceIndex device) override { assertValidDevice(device); return device_allocators[device]->getStats(); } - void resetPeakStats(DeviceIndex device) { + void resetPeakStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } - void resetAccumulatedStats(DeviceIndex device) { + void resetAccumulatedStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index db10db0ea7c0..96ed0c3b918e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -825,7 +825,6 @@ if(USE_MPS) if(CAN_COMPILE_METAL) add_dependencies(torch_cpu metallibs) target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_basic,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_basic.metallib) - target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_bfloat,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_bfloat.metallib) else() target_compile_definitions(torch_cpu PRIVATE PYTORCH_JIT_COMPILE_SHADERS) endif() @@ -1346,10 +1345,6 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) - add_subdirectory( - ${TORCH_ROOT}/test/cpp/tensorexpr - ${CMAKE_BINARY_DIR}/test_tensorexpr - ) if(USE_DISTRIBUTED) add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) if(NOT WIN32) @@ -1447,8 +1442,8 @@ if(USE_ROCM) if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) endif() - if(USE_CK_FLASH_ATTENTION) - target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION) + if(USE_ROCM_CK_SDPA) + target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA) endif() endif() diff --git a/cmake/BLAS_ABI.cmake b/cmake/BLAS_ABI.cmake index bb0b5949d73d..45a15af1027a 100644 --- a/cmake/BLAS_ABI.cmake +++ b/cmake/BLAS_ABI.cmake @@ -1,3 +1,4 @@ +include(CMakePushCheckState) # Push host architecture when cross-compiling otherwise check would fail # when cross-compiling for arm64 on x86_64 cmake_push_check_state(RESET) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 16ee19a91d48..e4973c849a18 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -91,30 +91,28 @@ if(INTERN_BUILD_ATEN_OPS) torch_cuda_get_nvcc_gencode_flag(_existing_arch_flags) set(_file_compile_flags "") - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0) - foreach(_arch ${archs}) - if("${_arch}" STREQUAL "89") - if(_existing_arch_flags MATCHES ".*compute_86.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89") - endif() + foreach(_arch ${archs}) + if("${_arch}" STREQUAL "89") + if(_existing_arch_flags MATCHES ".*compute_86.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89") endif() - if("${_arch}" STREQUAL "90a") - if(_existing_arch_flags MATCHES ".*compute_90.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a") - endif() + endif() + if("${_arch}" STREQUAL "90a") + if(_existing_arch_flags MATCHES ".*compute_90.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a") endif() - if("${_arch}" STREQUAL "100a") - if(_existing_arch_flags MATCHES ".*compute_100.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a") - endif() + endif() + if("${_arch}" STREQUAL "100a") + if(_existing_arch_flags MATCHES ".*compute_100.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a") endif() - if("${_arch}" STREQUAL "120a") - if(_existing_arch_flags MATCHES ".*compute_120.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") - endif() + endif() + if("${_arch}" STREQUAL "120a") + if(_existing_arch_flags MATCHES ".*compute_120.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") endif() - endforeach() - endif() + endif() + endforeach() list(JOIN _file_compile_flags " " _file_compile_flags) set_source_files_properties(${file} PROPERTIES COMPILE_FLAGS "${_file_compile_flags}") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 6208ab77286b..26d882f2f7f1 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -260,7 +260,7 @@ endif() # Determine if blas was compiled with the f2c conventions if(BLAS_LIBRARIES AND BLAS_CHECK_F2C) include(cmake/BLAS_ABI.cmake) -endif(BLAS_LIBRARIES) +endif() if(NOT INTERN_BUILD_MOBILE) set(AT_MKL_SEQUENTIAL 0) @@ -576,7 +576,7 @@ elseif(NOT TARGET XNNPACK AND USE_SYSTEM_XNNPACK) find_library(microkernels-prod_LIBRARY microkernels-prod) set_property(TARGET XNNPACK PROPERTY IMPORTED_LOCATION "${XNNPACK_LIBRARY}") set_property(TARGET microkernels-prod PROPERTY IMPORTED_LOCATION "${microkernels-prod_LIBRARY}") - if(NOT XNNPACK_LIBRARY or NOT microkernels-prod_LIBRARY) + if(NOT XNNPACK_LIBRARY OR NOT microkernels-prod_LIBRARY) message(FATAL_ERROR "Cannot find XNNPACK") endif() message("-- Found XNNPACK: ${XNNPACK_LIBRARY}") @@ -664,55 +664,20 @@ if(USE_FBGEMM) if(NOT DEFINED FBGEMM_SOURCE_DIR) set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory") endif() - if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) - message(WARNING - "A compiler with AVX512 support is required for FBGEMM. " - "Not compiling with FBGEMM. " - "Turn this warning off by USE_FBGEMM=OFF.") - set(USE_FBGEMM OFF) - endif() - if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - message(WARNING - "x64 operating system is required for FBGEMM. " - "Not compiling with FBGEMM. " - "Turn this warning off by USE_FBGEMM=OFF.") - set(USE_FBGEMM OFF) - endif() if(USE_FBGEMM AND NOT TARGET fbgemm) set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "") set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "") - if(MSVC AND BUILD_SHARED_LIBS) - set(FBGEMM_LIBRARY_TYPE "shared" CACHE STRING "") - else() - set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "") - endif() - if(USE_ASAN) - set(USE_SANITIZER "address,undefined" CACHE STRING "-fsanitize options for FBGEMM") - endif() + set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "") add_subdirectory("${FBGEMM_SOURCE_DIR}") - set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET fbgemm PROPERTY POSITION_INDEPENDENT_CODE ON) - - # Disabling autovec in fbgemm due to large library size causing symbol relocation issues, which is only allowed in static builds. - # Long-term solution involves modularizing fbgemm targets. - target_compile_definitions(fbgemm_generic PUBLIC DISABLE_FBGEMM_AUTOVEC) - target_compile_definitions(fbgemm_avx2 PUBLIC DISABLE_FBGEMM_AUTOVEC) - target_compile_definitions(fbgemm_avx512 PUBLIC DISABLE_FBGEMM_AUTOVEC) - - if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0) - # See https://github.com/pytorch/pytorch/issues/74352 - target_compile_options_if_supported(asmjit -Wno-deprecated-copy) - target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable) - endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") target_compile_options_if_supported(asmjit -Wno-extra-semi) target_compile_options_if_supported(fbgemm -Wno-extra-semi) endif() + target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable) + target_compile_options_if_supported(asmjit -Wno-unused-variable) endif() if(USE_FBGEMM) - target_compile_definitions(fbgemm PUBLIC DISABLE_FBGEMM_AUTOVEC) list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm) endif() endif() @@ -721,9 +686,6 @@ if(USE_FBGEMM) caffe2_update_option(USE_FBGEMM ON) else() caffe2_update_option(USE_FBGEMM OFF) - message(WARNING - "Turning USE_FAKELOWP off as it depends on USE_FBGEMM.") - caffe2_update_option(USE_FAKELOWP OFF) endif() if(USE_OPENCL) @@ -1045,6 +1007,9 @@ if(USE_ROCM) if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() + if(USE_ROCM_CK_GEMM) + list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM) + endif() list(APPEND HIP_HIPCC_FLAGS --offload-compress) if(WIN32) add_definitions(-DROCM_ON_WINDOWS) @@ -1143,7 +1108,7 @@ if(USE_UCC) endif() # ---[ CUB -if(USE_CUDA) +if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0) find_package(CUB) if(NOT CUB_FOUND) message(FATAL_ERROR "Cannot find CUB.") @@ -1166,17 +1131,10 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE) # Tensorpipe uses cuda_add_library torch_update_find_cuda_flags() - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") - message(WARNING "Archived TensorPipe forces CMake compatibility mode") - set(CMAKE_POLICY_VERSION_MINIMUM 3.5) - endif() add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/tensorpipe) # Suppress warning to unblock libnop compilation by clang-17 # See https://github.com/pytorch/pytorch/issues/151316 target_compile_options_if_supported(tensorpipe -Wno-missing-template-arg-list-after-template-kw) - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") - unset(CMAKE_POLICY_VERSION_MINIMUM) - endif() list(APPEND Caffe2_DEPENDENCY_LIBS tensorpipe) list(APPEND Caffe2_DEPENDENCY_LIBS nlohmann) @@ -1242,10 +1200,17 @@ if(USE_GLOO) if(NOT Gloo_FOUND) message(FATAL_ERROR "Cannot find gloo") endif() - message("Found gloo: ${Gloo_LIBRARY}") + message("Found gloo: ${Gloo_NATIVE_LIBRARY}, cuda lib: ${Gloo_CUDA_LIBRARY}, hip lib: ${Gloo_HIP_LIBRARY}") message("Found gloo include directories: ${Gloo_INCLUDE_DIRS}") add_library(gloo SHARED IMPORTED) - set_target_properties(gloo PROPERTIES IMPORTED_LOCATION ${Gloo_LIBRARY}) + set_target_properties(gloo PROPERTIES IMPORTED_LOCATION ${Gloo_NATIVE_LIBRARY}) + if(USE_CUDA) + add_library(gloo_cuda SHARED IMPORTED) + set_target_properties(gloo_cuda PROPERTIES IMPORTED_LOCATION ${Gloo_CUDA_LIBRARY}) + elseif(USE_ROCM) + add_library(gloo_hip SHARED IMPORTED) + set_target_properties(gloo_hip PROPERTIES IMPORTED_LOCATION ${Gloo_HIP_LIBRARY}) + endif() # need to use Gloo_INCLUDE_DIRS over third_party/gloo to find Gloo's auto-generated config.h include_directories(BEFORE SYSTEM ${Gloo_INCLUDE_DIRS}) endif() diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 871a23487f29..54126b1f130d 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -2,24 +2,6 @@ include(CheckCXXSourceCompiles) include(CheckCXXCompilerFlag) include(CMakePushCheckState) -# ---[ Check if we want to turn off deprecated warning due to glog. -if(USE_GLOG) - cmake_push_check_state(RESET) - set(CMAKE_REQUIRED_FLAGS "-std=c++17") - CHECK_CXX_SOURCE_COMPILES( - "#include - int main(int argc, char** argv) { - return 0; - }" CAFFE2_NEED_TO_TURN_OFF_DEPRECATION_WARNING - FAIL_REGEX ".*-Wno-deprecated.*") - - if(NOT CAFFE2_NEED_TO_TURN_OFF_DEPRECATION_WARNING AND NOT MSVC) - message(STATUS "Turning off deprecation warning due to glog.") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated") - endif() - cmake_pop_check_state() -endif() - # ---[ Check if the compiler has AVX/AVX2 support. We only check AVX2. if(NOT INTERN_BUILD_MOBILE) find_package(AVX) # checks AVX and AVX2 @@ -30,46 +12,6 @@ if(NOT INTERN_BUILD_MOBILE) set(CAFFE2_PERF_WITH_AVX2 1) endif() endif() -# ---[ Check if the compiler has AVX512 support. -cmake_push_check_state(RESET) -if(MSVC AND NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - # We could've used MSVC's hidden option /arch:AVX512 that defines __AVX512F__, - # __AVX512DQ__, and __AVX512VL__, and /arch:AVX512F that defines __AVX512F__. - # But, we chose not to do that not to rely on hidden options. - set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__ /D__AVX512DQ__ /D__AVX512VL__") -else() - # We only consider the case where all of avx512f, avx512dq, and avx512vl are - # supported. - # Platforms where avx512f is supported by not avx512dq and avx512vl as of - # Jan 15 2019 : linux_manywheel_2.7mu_cpu_build and - # linux_conda_3.7_cu100_build - set(CMAKE_REQUIRED_FLAGS "-mavx512f -mavx512dq -mavx512vl") -endif() -CHECK_CXX_SOURCE_COMPILES( - "#if defined(_MSC_VER) - #include - #else - #include - #endif - // check avx512f - __m512 addConstant(__m512 arg) { - return _mm512_add_ps(arg, _mm512_set1_ps(1.f)); - } - // check avx512dq - __m512 andConstant(__m512 arg) { - return _mm512_and_ps(arg, _mm512_set1_ps(1.f)); - } - int main() { - __m512i a = _mm512_set1_epi32(1); - __m256i ymm = _mm512_extracti64x4_epi64(a, 0); - ymm = _mm256_abs_epi64(ymm); // check avx512vl - __mmask16 m = _mm512_cmp_epi32_mask(a, a, _MM_CMPINT_EQ); - __m512i r = _mm512_andnot_si512(a, a); - }" CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) -if(CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) - message(STATUS "Current compiler supports avx512f extension. Will build fbgemm.") -endif() -cmake_pop_check_state() # ---[ Checks if compiler supports -fvisibility=hidden check_cxx_compiler_flag("-fvisibility=hidden" COMPILER_SUPPORTS_HIDDEN_VISIBILITY) diff --git a/cmake/Modules/FindGloo.cmake b/cmake/Modules/FindGloo.cmake index e965326e2e8a..944cd4d8d257 100644 --- a/cmake/Modules/FindGloo.cmake +++ b/cmake/Modules/FindGloo.cmake @@ -1,7 +1,8 @@ # Try to find the Gloo library and headers. # Gloo_FOUND - system has Gloo lib # Gloo_INCLUDE_DIRS - the Gloo include directory -# Gloo_LIBRARY/Gloo_NATIVE_LIBRARY - libraries needed to use Gloo +# Gloo_NATIVE_LIBRARY - base gloo library, needs to be linked +# Gloo_CUDA_LIBRARY/Gloo_HIP_LIBRARY - CUDA/HIP support library in Gloo find_path(Gloo_INCLUDE_DIR NAMES gloo/common/common.h @@ -10,40 +11,32 @@ find_path(Gloo_INCLUDE_DIR find_library(Gloo_NATIVE_LIBRARY NAMES gloo - DOC "The Gloo library (without CUDA)" + DOC "The Gloo library" ) +# Gloo has optional CUDA support +# if Gloo + CUDA is desired, Gloo_CUDA_LIBRARY +# needs to be linked into desired target find_library(Gloo_CUDA_LIBRARY NAMES gloo_cuda - DOC "The Gloo library (with CUDA)" + DOC "Gloo's CUDA support/code" +) + +# Gloo has optional HIP support +# if Gloo + HIP is desired, Gloo_HIP_LIBRARY +# needs to be linked to desired target +find_library(Gloo_HIP_LIBRARY + NAMES gloo_hiop + DOC "Gloo's HIP support/code" ) set(Gloo_INCLUDE_DIRS ${Gloo_INCLUDE_DIR}) -# use the CUDA library depending on the Gloo_USE_CUDA variable -if (DEFINED Gloo_USE_CUDA) - if (${Gloo_USE_CUDA}) - set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - else() - set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - endif() -else() - # else try to use the CUDA library if found - if (${Gloo_CUDA_LIBRARY} STREQUAL "Gloo_CUDA_LIBRARY-NOTFOUND") - set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - else() - set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - endif() -endif() include(FindPackageHandleStandardArgs) find_package_handle_standard_args(Gloo FOUND_VAR Gloo_FOUND - REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_LIBRARY + REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_NATIVE_LIBRARY ) mark_as_advanced(Gloo_FOUND) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 3c2ec74f14d1..63e501bcb5ab 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -127,15 +127,15 @@ function(caffe2_print_configuration_summary) endif() message(STATUS " USE_ROCM : ${USE_ROCM}") if(${USE_ROCM}) - message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") - message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") - message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}") + message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") + message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") + message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}") + message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}") endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") message(STATUS " USE_FBGEMM : ${USE_FBGEMM}") - message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}") message(STATUS " USE_KINETO : ${USE_KINETO}") message(STATUS " USE_GFLAGS : ${USE_GFLAGS}") message(STATUS " USE_GLOG : ${USE_GLOG}") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 132f9670ff34..018bca837a5a 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -6,7 +6,7 @@ set(PYTORCH_FOUND_HIP FALSE) # In the latter case, if /opt/rocm does not exist emit status # message and return. if(DEFINED ENV{ROCM_PATH}) - set(ROCM_PATH $ENV{ROCM_PATH}) + file(TO_CMAKE_PATH "$ENV{ROCM_PATH}" ROCM_PATH) if(NOT EXISTS ${ROCM_PATH}) message(FATAL_ERROR "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n" @@ -31,7 +31,7 @@ if(NOT DEFINED ENV{MAGMA_HOME}) set(MAGMA_HOME ${ROCM_PATH}/magma) set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma) else() - set(MAGMA_HOME $ENV{MAGMA_HOME}) + file(TO_CMAKE_PATH "$ENV{MAGMA_HOME}" MAGMA_HOME) endif() # MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different diff --git a/codex_setup.sh b/codex_setup.sh new file mode 100755 index 000000000000..85c7b93e8979 --- /dev/null +++ b/codex_setup.sh @@ -0,0 +1,14 @@ +set -ex +uv venv +source .venv/bin/activate +uv pip install -r requirements.txt +uv pip install numpy +lintrunner init +NIGHTLY_PATCH=$(curl -s https://github.com/pytorch/pytorch/commit/nightly.patch | head -n20) +COMMIT=$(grep -oE '[0-9a-f]{40}' <<< "$NIGHTLY_PATCH" | head -1) +COMMIT_DATE=$(echo "$NIGHTLY_PATCH" | grep '^Date:' | sed -E 's/Date: .*, ([0-9]+) ([A-Za-z]+) ([0-9]+) .*/\3 \2 \1/' | awk 'BEGIN{split("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec", months, " "); for(i=1;i<=12;i++) month[months[i]]=sprintf("%02d",i)} {print $1 month[$2] sprintf("%02d",$3)}') +VERSION_STRING="2.9.0.dev${COMMIT_DATE}+cpu" +git rev-parse HEAD > /tmp/orig_work.txt +git reset --hard $COMMIT +USE_NIGHTLY=$VERSION_STRING python setup.py develop +echo "source $PWD/.venv/bin/activate" >> ~/.bashrc diff --git a/docs/source/_static/js/runllm-widget.js b/docs/source/_static/js/runllm-widget.js new file mode 100644 index 000000000000..45632613722c --- /dev/null +++ b/docs/source/_static/js/runllm-widget.js @@ -0,0 +1,17 @@ +document.addEventListener("DOMContentLoaded", function () { + var script = document.createElement("script"); + script.type = "module"; + script.id = "runllm-widget-script" + + script.src = "https://widget.runllm.com"; + + script.setAttribute("version", "stable"); + script.setAttribute("crossorigin", "true"); + script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); + script.setAttribute("runllm-name", "PyTorch"); + script.setAttribute("runllm-position", "BOTTOM_RIGHT"); + script.setAttribute("runllm-assistant-id", "834"); + + script.async = true; + document.head.appendChild(script); +}); diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index c6f2fb108040..ce593a9acf51 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,3 +25,26 @@ synchronize device_index ``` + +```{eval-rst} +.. automodule:: torch.accelerator.memory +``` +```{eval-rst} +.. currentmodule:: torch.accelerator.memory +``` + +## Memory management +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + reset_accumulated_memory_stats + reset_peak_memory_stats +``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 3268f54de47d..4f47652e88d2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1793,12 +1793,6 @@ # torch.optim.optimizer "register_optimizer_step_post_hook", "register_optimizer_step_pre_hook", - # torch.optim.swa_utils - "get_ema_avg_fn", - "get_ema_multi_avg_fn", - "get_swa_avg_fn", - "get_swa_multi_avg_fn", - "update_bn", # torch.overrides "enable_reentrant_dispatch", # torch.package.analyze.find_first_use_of_broken_modules @@ -2909,31 +2903,6 @@ # torch.onnx.verification "OnnxBackend", "OnnxTestCaseRepro", - # torch.optim.adamax - "Adamax", - # torch.optim.adamw - "AdamW", - # torch.optim.asgd - "ASGD", - # torch.optim.lbfgs - "LBFGS", - # torch.optim.lr_scheduler - "ChainedScheduler", - "ConstantLR", - "CosineAnnealingLR", - "CosineAnnealingWarmRestarts", - "CyclicLR", - "ExponentialLR", - "LRScheduler", - "LambdaLR", - "LinearLR", - "MultiStepLR", - "MultiplicativeLR", - "OneCycleLR", - "PolynomialLR", - "ReduceLROnPlateau", - "SequentialLR", - "StepLR", # torch.optim.optimizer "Optimizer", # torch.overrides @@ -3339,6 +3308,8 @@ def linkcode_resolve(domain, info): "https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css", ] +html_js_files = ["js/runllm-widget.js"] + from sphinx.ext.coverage import CoverageBuilder diff --git a/docs/source/elastic/numa.rst b/docs/source/elastic/numa.rst index b6caa8a94c0e..d56c99cf422e 100644 --- a/docs/source/elastic/numa.rst +++ b/docs/source/elastic/numa.rst @@ -3,8 +3,8 @@ NUMA Binding Utilities ====================== -.. automodule:: torch.distributed.numa +.. automodule:: torch.numa :members: -.. automodule:: torch.distributed.numa.binding +.. automodule:: torch.numa.binding :members: diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 5210eb4ad149..8ad4c87a7139 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -896,6 +896,130 @@ APIs can be used for debugging purposes: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#memory-allocator +Tuning NVLink Performance with Custom Memory Allocator on H100/H200 GPUs +------------------------------------------------------------------------ +In rare cases, performance of NVLink on H100/H200 GPUs can be influenced by the physical memory +layout of data, creating an opportunity for developers to tune their applications for optimal +throughput. + +An example of how physical memory layout of data affects performance is when communication +kernels issue unbalanced NVLink read/write operations. In the following figure, we can see +that each warp accesses memory addresses with a consistent strided pattern in each single wave. +We can have a more balanced load by tuning the stride size in the workload or we can implement +a custom CUDA allocator. + +.. code:: + + _______________________________ _______________________________ _______________________________ + | Warp 0 Reading | No-reading | | Warp 1 Reading | No-reading | ... Warp N Reading | No-reading | + _______________________________ _______________________________ _______________________________ + <-----------------------------> + Stride size + +Such an allocator can maintain contiguous virtual memory addresses for the kernel while strategically +arranging the mapping to physical memory addresses (e.g., through shuffling). This technique allows +developers to explore different physical access patterns to find the most efficient one, unlocking +higher performance without modifying the kernel's logic. A practical implementation of such an allocator +can be achieved using PyTorch’s custom allocator support as mentioned before, where the malloc and free +functions are: + +.. code:: C++ + + // assuming a system with 8 GPUs + struct CustomAllocInfo { + void** devPtr; // This will be the usable virtual memory address + CUdeviceptr dptr; + size_t totalSize; // Total size of the allocated memory + size_t padded_size; + int device_id; + std::vector handles; // Handles to physical memory allocations + }; + + // loop over pages + cudaError_t customCudaMalloc(CustomAllocInfo* info) { + if (!info) return cudaErrorInvalidValue; + + CUdeviceptr dptr; + + // Handles to redundant physical memory allocations which help truncate stride pattern in physical memory + std::vector handles_redundant; + + size_t granularity = 0; + CUmemAllocationProp prop = {}; + + int currentDev = info->device_id; + size_t totalSize = info->totalSize; + + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = currentDev; + cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); + size_t padded_size = ROUND_UP(totalSize, granularity); + + info->padded_size = padded_size; + + // loop over pages + size_t iter_granularity = granularity * 64; // 64 * granularity with shift_size = 2 works + uint32_t iteration_count = (totalSize + iter_granularity - 1) / iter_granularity; + + cuMemAddressReserve(&dptr, padded_size, 0ULL, 0ULL, 0ULL); + + const int shift_size = 2; + for (size_t i = 0; i < iteration_count; i+=shift_size) { + + CUmemGenericAllocationHandle allocHandle[shift_size]; + for (int shift = 0; (shift < shift_size)&&(i+shift < iteration_count); shift++){ + CHECK_CUDA(cuMemCreate(&allocHandle[shift], iter_granularity, &prop, 0)); + info->handles.push_back(allocHandle[shift]); + } + + for (int shift = 0; (shift < shift_size)&&(i+shift < iteration_count); shift++){ + + // mapping makes the shift (shift -> (shift+1)%shift_size ) + CHECK_CUDA(cuMemMap(dptr + (i+shift) * iter_granularity, iter_granularity, 0, allocHandle[(shift+1)%shift_size], 0)); + + setupMultiGPUAccess(dptr + (i+shift) * iter_granularity, iter_granularity, {0, 1, 2, 3, 4, 5, 6, 7}); // Enable access for all 8 GPUs + } + + // std::cout << "Here we allocate one redundant page (2MB)..." << std::endl; + // this is an extra optimization on top of the swizzling. It helps "break" + // the physical access pattern even more. It can be left out if workload is already + // performing at SOL with just swizzling. + CUmemGenericAllocationHandle allocHandle_redundant; + CHECK_CUDA(cuMemCreate(&allocHandle_redundant, granularity, &prop, 0)); + handles_redundant.push_back(allocHandle_redundant); + } + + *info->devPtr = (void*)dptr; + info->dptr = dptr; + + // Release each redundant allocation + for (auto handle : handles_redundant) { + // std::cout << "Here we release one redundant page (2MB)..." << std::endl; + CHECK_CUDA(cuMemRelease(handle)); + } + + return cudaSuccess; + } + + void customCudaFree(CustomAllocInfo* info) { + if (!info) return; + + // CHECK_CUDA(cudaSetDevice(info->device_id)); + + CHECK_CUDA(cuMemUnmap(info->dptr, info->padded_size)); + + // Unmap and release each allocation + for (auto handle : info->handles) { + CHECK_CUDA(cuMemRelease(handle)); + } + + // Unreserve the virtual address space + // CHECK_CUDA(cuMemAddressFree((CUdeviceptr)*info->devPtr, info->padded_size)); + CHECK_CUDA(cuMemAddressFree(info->dptr, info->padded_size)); + } + + cuBLAS workspaces ----------------- diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index 5ca51833f025..6414730c28d4 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -107,7 +107,7 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. -#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. +#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to use torch.compile on Windows CPU/XPU `_. #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples diff --git a/docs/source/notes/hip.rst b/docs/source/notes/hip.rst index a34535d67fc9..7ee596b53f9c 100644 --- a/docs/source/notes/hip.rst +++ b/docs/source/notes/hip.rst @@ -179,3 +179,30 @@ by recompiling the PyTorch from source. Please add below line as an argument to cmake command parameters:: -DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON + +Enabling/Disabling ROCm Composable Kernel +----------------------------------------- + +Enabling composable_kernel (CK) for both SDPA and GEMMs is a two-part process. First the user must have built +pytorch while setting the corresponding environment variable to '1' + +SDPA: +``USE_ROCM_CK_SDPA=1`` + +GEMMs: +``USE_ROCM_CK_GEMM=1`` + +Second, the user must explicitly request that CK be used as the backend library via the corresponding python +call + +SDPA: +``setROCmFAPreferredBackend('')`` + +GEMMs: +``setBlasPreferredBackend('')`` + +To enable CK in either scenario, simply pass 'ck' to those functions. + +In order to set the backend to CK, the user MUST have built with the correct environment variable. If not, +PyTorch will print a warning and use the "default" backend. For GEMMs, this will route to hipblas and +for SDPA it routes to aotriton. diff --git a/docs/source/torch.compiler_dynamo_deepdive.md b/docs/source/torch.compiler_dynamo_deepdive.md index 6bbb03170e54..9fa7654023ca 100644 --- a/docs/source/torch.compiler_dynamo_deepdive.md +++ b/docs/source/torch.compiler_dynamo_deepdive.md @@ -285,7 +285,7 @@ appear in the errors, and the `VariableTracker` method that throws the exception when you encounter a Dynamo error. In particular, sometimes we find that an object is tracked as a `UserDefinedObjectVariable` (this is Dynamo’s catch-all class), when it should have been tracked as -something more specific. In these cases, the `SourceBuilder.__call__` +something more specific. In these cases, the `VariableBuilder` logic is often to blame. **Debugging tip**. When running a program with `TORCH_LOGS=dynamo`, diff --git a/docs/source/torch.compiler_troubleshooting_old.md b/docs/source/torch.compiler_troubleshooting_old.md index 03555d74e817..ef13fc177237 100644 --- a/docs/source/torch.compiler_troubleshooting_old.md +++ b/docs/source/torch.compiler_troubleshooting_old.md @@ -717,5 +717,5 @@ backtrace is slow and very spammy so it is not included by default with extended In order to measure the cold start compilation time or debug a cache corruption, it is possible pass `TORCHINDUCTOR_FORCE_DISABLE_CACHES=1` or set -`torch._inductor.config.force_disable_caches = True` which will override any +`torch.compiler.config.force_disable_caches = True` which will override any other caching config option and disable all compile time caching. diff --git a/docs/source/torch_cuda_memory.md b/docs/source/torch_cuda_memory.md index bb50e5fd5751..e5fa147ee785 100644 --- a/docs/source/torch_cuda_memory.md +++ b/docs/source/torch_cuda_memory.md @@ -32,7 +32,7 @@ torch.cuda.memory._dump_snapshot("my_snapshot.pickle") ## Using the visualizer -Open [pytorch.org/memory_viz](https://pytorch.org/memory_viz>) and drag/drop the pickled snapshot file into the visualizer. +Open [pytorch.org/memory_viz](https://pytorch.org/memory_viz) and drag/drop the pickled snapshot file into the visualizer. The visualizer is a javascript application that runs locally on your computer. It does not upload any snapshot data. diff --git a/functorch/csrc/dim/arena.h b/functorch/csrc/dim/arena.h index aaaf7e772a3a..ec2cfef66895 100644 --- a/functorch/csrc/dim/arena.h +++ b/functorch/csrc/dim/arena.h @@ -8,7 +8,7 @@ #include #include "minpybind.h" -#ifdef _WIN32 +#if defined(_MSC_VER) && !defined(__clang__) #include // https://stackoverflow.com/questions/355967/how-to-use-msvc-intrinsics-to-get-the-equivalent-of-this-gcc-code inline unsigned int __builtin_clz(unsigned int x) { diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index f52d417d2ba2..95747181e848 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -24,10 +24,6 @@ class DimensionBindError(Exception): # use dict to avoid writing C++ bindings for set pointwise = dict.fromkeys(op_properties.pointwise, True) -use_c = True -if not use_c: - from . import reference - class _Tensor: # fast path around slow wrapping/unwrapping logic for simply queries used @@ -40,12 +36,8 @@ def dims(self): def dim(self): return self.ndim - if use_c: - __torch_function__ = classmethod(_C.__torch_function__) - expand = _C._instancemethod(_C.expand) - else: - __torch_function__ = reference.__torch_function__ - expand = reference.expand + __torch_function__ = classmethod(_C.__torch_function__) + expand = _C._instancemethod(_C.expand) index = _C._instancemethod(_C.index) @@ -64,8 +56,6 @@ class Dim(_C.Dim, _Tensor): class Tensor(_Tensor, _C.Tensor): - if not use_c: - from_batched = staticmethod(_C.Tensor_from_batched) from_positional = staticmethod(_C.Tensor_from_positional) sum = _C._instancemethod(_C.Tensor_sum) @@ -75,21 +65,17 @@ def cat(tensors, dim, new_dim): return stack(tensors, n, dim).index([n, dim], new_dim) -if use_c: - _wrap = _C._wrap +_wrap = _C._wrap + + +def _def(name, *args, **kwargs): + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) - def _def(name, *args, **kwargs): - orig = getattr(torch.Tensor, name) - setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) - t__getitem__ = _C._instancemethod(_C.__getitem__) - stack = _C.stack - split = _C._instancemethod(_C.split) -else: - _wrap, _def = reference._wrap, reference._def - t__getitem__ = reference.t__getitem__ - stack = reference.stack - split = reference.split +t__getitem__ = _C._instancemethod(_C.__getitem__) +stack = _C.stack +split = _C._instancemethod(_C.split) # note: there is no python reference t__setitem__ = _C._instancemethod(_C.__setitem__) @@ -105,13 +91,10 @@ def _def(name, *args, **kwargs): _Tensor.split = split torch.Tensor.expand = _C._instancemethod(_C.expand) torch.Tensor.index = _C._instancemethod(_C.index) -wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) +wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__) del _Tensor.ndim -if use_c: - _Tensor.order = _C._instancemethod(_C.order) -else: - _Tensor.order = reference.positional +_Tensor.order = _C._instancemethod(_C.order) _def("mean") _def("sum") diff --git a/functorch/dim/batch_tensor.py b/functorch/dim/batch_tensor.py deleted file mode 100644 index dae9b270896e..000000000000 --- a/functorch/dim/batch_tensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from contextlib import contextmanager - -from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers - - -_enabled = False - - -@contextmanager -def _enable_layers(dims): - global _enabled - assert not _enabled - input = sorted((d._level, d.size) for d in dims if not isinstance(d, int)) - n = len(input) - try: - _vmap_add_layers(input) - _enabled = True - yield - finally: - _enabled = False - _vmap_remove_layers(n) diff --git a/functorch/dim/delayed_mul_tensor.py b/functorch/dim/delayed_mul_tensor.py deleted file mode 100644 index 3c136cfe1247..000000000000 --- a/functorch/dim/delayed_mul_tensor.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from . import _Tensor, Tensor -from .reference import _dims, _enable_layers, llist, ltuple - - -class DelayedMulTensor(_Tensor): - def __init__(self, lhs, rhs): - self._lhs, self._rhs = lhs, rhs - self._data = None - self._levels_data = None - self._has_device = lhs._has_device or rhs._has_device - self._batchtensor_data = None - self._tensor_data = None - - @property - def _levels(self): - if self._levels_data is None: - levels = llist(self._lhs._levels) - for l in self._rhs._levels: - if l not in levels: - levels.append(l) - self._levels_data = ltuple(levels) - return self._levels_data - - @property - def _batchtensor(self): - if self._batchtensor_data is None: - with _enable_layers(self._levels): - print("bt multiply fallback") - self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor - return self._batchtensor_data - - @property - def _tensor(self): - if self._tensor_data is None: - self._tensor_data = Tensor.from_batched( - self._batchtensor, self._has_device - )._tensor - return self._tensor_data - - @property - def ndim(self): - return self._batchtensor.ndim - - @property - def dims(self): - return ltuple(super().dims) - - def sum(self, dim): - dims = _dims(dim, 0, False, False) - n = ord("a") - all_levels = self._levels - - def to_char(d): - return chr(n + all_levels.index(d)) - - plhs, levelslhs = self._lhs._tensor, self._lhs._levels - prhs, levelsrhs = self._rhs._tensor, self._rhs._levels - new_levels = [l for l in self._levels if l not in dims] - fmt = "".join( - [ - *(to_char(d) for d in levelslhs), - ",", - *(to_char(d) for d in levelsrhs), - "->", - *(to_char(d) for d in new_levels), - ] - ) - result_data = torch.einsum(fmt, (plhs, prhs)) - return Tensor.from_positional(result_data, new_levels, True) diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py deleted file mode 100644 index 9a4b56866484..000000000000 --- a/functorch/dim/dim.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import dis -import inspect -from dataclasses import dataclass -from typing import Union - -from . import DimList - - -_vmap_levels = [] - - -@dataclass -class LevelInfo: - level: int - alive: bool = True - - -class Dim: - def __init__(self, name: str, size: Union[None, int] = None): - self.name = name - self._size = None - self._vmap_level = None - if size is not None: - self.size = size - - def __del__(self): - if self._vmap_level is not None: - _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 - while ( - not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821 - ): - _vmap_decrement_nesting() # noqa: F821 - _vmap_levels.pop() - - @property - def size(self): - assert self.is_bound - return self._size - - @size.setter - def size(self, size: int): - from . import DimensionBindError - - if self._size is None: - self._size = size - self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 - self._vmap_stack = len(_vmap_levels) - _vmap_levels.append(LevelInfo(self._vmap_level)) - - elif self._size != size: - raise DimensionBindError( - f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" - ) - - @property - def is_bound(self): - return self._size is not None - - def __repr__(self): - return self.name - - -def extract_name(inst): - assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" - return inst.argval - - -_cache = {} - - -def dims(lists=0): - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - code, lasti = calling_frame.f_code, calling_frame.f_lasti - key = (code, lasti) - if key not in _cache: - first = lasti // 2 + 1 - instructions = list(dis.get_instructions(calling_frame.f_code)) - unpack = instructions[first] - - if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": - # just a single dim, not a list - name = unpack.argval - ctor = Dim if lists == 0 else DimList - _cache[key] = lambda: ctor(name=name) - else: - assert unpack.opname == "UNPACK_SEQUENCE" - ndims = unpack.argval - names = tuple( - extract_name(instructions[first + 1 + i]) for i in range(ndims) - ) - first_list = len(names) - lists - _cache[key] = lambda: tuple( - Dim(n) if i < first_list else DimList(name=n) - for i, n in enumerate(names) - ) - return _cache[key]() - - -def _dim_set(positional, arg): - def convert(a): - if isinstance(a, Dim): - return a - else: - assert isinstance(a, int) - return positional[a] - - if arg is None: - return positional - elif not isinstance(arg, (Dim, int)): - return tuple(convert(a) for a in arg) - else: - return (convert(arg),) diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py deleted file mode 100644 index fd934011d823..000000000000 --- a/functorch/dim/reference.py +++ /dev/null @@ -1,645 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# reference python implementations for C ops -import torch -from functorch._C import dim as _C - -from . import op_properties -from .batch_tensor import _enable_layers -from .tree_map import tree_flatten, tree_map - - -DimList = _C.DimList -import operator -from functools import reduce - - -# use dict to avoid writing C++ bindings for set -pointwise = set(op_properties.pointwise) - - -def prod(x): - return reduce(operator.mul, x, 1) - - -def _wrap_dim(d, N, keepdim): - from . import Dim - - if isinstance(d, Dim): - assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" - return d - elif d >= 0: - return d - N - else: - return d - - -def _dims(d, N, keepdim, single_dim): - from . import Dim - - if isinstance(d, (Dim, int)): - return ltuple((_wrap_dim(d, N, keepdim),)) - assert not single_dim, f"expected a single dimension or int but found: {d}" - return ltuple(_wrap_dim(x, N, keepdim) for x in d) - - -def _bind_dims_to_size(lhs_size, rhs, lhs_debug): - from . import DimensionMismatchError - - not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) - if len(not_bound) == 1: - idx, d = not_bound[0] - rhs_so_far = prod(r.size for r in rhs if r.is_bound) - if lhs_size % rhs_so_far != 0: - rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError( - f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" - ) - new_size = lhs_size // rhs_so_far - d.size = new_size - elif len(not_bound) > 1: - rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError( - f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" - ) - else: - rhs_size = prod(r.size for r in rhs) - if lhs_size != rhs_size: - raise DimensionMismatchError( - f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" - ) - - -def _tensor_levels(inp): - from . import _Tensor - - if isinstance(inp, _Tensor): - return inp._tensor, llist(inp._levels), inp._has_device - else: - return inp, llist(range(-inp.ndim, 0)), True - - -def _match_levels(v, from_levels, to_levels): - view = [] - permute = [] - requires_view = False - size = v.size() - for t in to_levels: - try: - idx = from_levels.index(t) - permute.append(idx) - view.append(size[idx]) - except ValueError: - view.append(1) - requires_view = True - if permute != list(range(len(permute))): - v = v.permute(*permute) - if requires_view: - v = v.view(*view) - return v - - -# make a single dimension positional but do not permute it, -# used to do multi-tensor operators where the dim being acted on -# should not physically move if possible -def _positional_no_permute(self, dim, expand_dim=False): - from . import Tensor - - ptensor, levels = self._tensor, llist(self._levels) - try: - idx = levels.index(dim) - except ValueError: - if not expand_dim: - raise - idx = 0 - ptensor = ptensor.expand(dim.size, *ptensor.size()) - levels.insert(0, 0) - idx_batched = 0 - for i in range(idx): - if isinstance(levels[i], int): - levels[i] -= 1 - idx_batched += 1 - levels[idx] = -idx_batched - 1 - return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched - - -def seq(a, b): - from . import Dim - - if isinstance(a, Dim) != isinstance(b, Dim): - return False - if isinstance(a, Dim): - return a is b - else: - return a == b - - -class isin: - __slots__ = () - - def __contains__(self, item): - for x in self: - if seq(item, x): - return True - return False - - def index(self, item): - for i, x in enumerate(self): - if seq(item, x): - return i - raise ValueError - - -class llist(isin, list): - __slots__ = () - - -class ltuple(isin, tuple): - __slots__ = () - - -empty_dict = {} - - -@classmethod -def __torch_function__(self, orig, cls, args, kwargs=empty_dict): - from . import _Tensor, Tensor, TensorLike - from .delayed_mul_tensor import DelayedMulTensor - - if orig is torch.Tensor.__mul__: - lhs, rhs = args - if ( - isinstance(lhs, _Tensor) - and isinstance(rhs, _Tensor) - and lhs.ndim == 0 - and rhs.ndim == 0 - ): - return DelayedMulTensor(lhs, rhs) - all_dims = llist() - flat_args, unflatten = tree_flatten((args, kwargs)) - device_holding_tensor = None - for f in flat_args: - if isinstance(f, _Tensor): - if f._has_device: - device_holding_tensor = f._batchtensor - for d in f.dims: - if d not in all_dims: - all_dims.append(d) - - def unwrap(t): - if isinstance(t, _Tensor): - r = t._batchtensor - if device_holding_tensor is not None and not t._has_device: - r = r.to(device=device_holding_tensor.device) - return r - return t - - if orig in pointwise: - result_levels = llist() - to_expand = [] - for i, f in enumerate(flat_args): - if isinstance(f, TensorLike): - ptensor, levels, _ = _tensor_levels(f) - if ( - isinstance(f, _Tensor) - and not f._has_device - and device_holding_tensor is not None - ): - ptensor = ptensor.to(device=device_holding_tensor.device) - flat_args[i] = ptensor - for l in levels: - if l not in result_levels: - result_levels.append(l) - to_expand.append((i, levels)) - - for i, levels in to_expand: - flat_args[i] = _match_levels(flat_args[i], levels, result_levels) - args, kwargs = unflatten(flat_args) - result = orig(*args, **kwargs) - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_positional( - t, result_levels, device_holding_tensor is not None - ) - return t - - return tree_map(wrap, result) - else: - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_batched(t, device_holding_tensor is not None) - return t - - with _enable_layers(all_dims): - print(f"batch_tensor for {orig}") - args, kwargs = unflatten(unwrap(f) for f in flat_args) - result = orig(*args, **kwargs) - # print("END", orig) - return tree_map(wrap, result) - - -def positional(self, *dims): - from . import Dim, DimensionBindError, Tensor - - ptensor, levels = self._tensor, llist(self._levels) - flat_dims = llist() - view = [] - needs_view = False - ndim = self.ndim - for d in dims: - if isinstance(d, DimList): - flat_dims.extend(d) - view.extend(e.size for e in d) - elif isinstance(d, Dim): - flat_dims.append(d) - view.append(d.size) - elif isinstance(d, int): - d = _wrap_dim(d, ndim, False) - flat_dims.append(d) - view.append(ptensor.size(d)) - else: - flat_dims.extend(d) - view.append(prod(e.size for e in d)) - needs_view = True - - permute = list(range(len(levels))) - for i, d in enumerate(flat_dims): - try: - idx = levels.index(d) - except ValueError as e: - raise DimensionBindError( - f"tensor of dimensions {self.dims} does not contain dim {d}" - ) from e - p = permute[idx] - del levels[idx] - del permute[idx] - levels.insert(i, 0) - permute.insert(i, p) - ptensor = ptensor.permute(*permute) - seen = 0 - for i in range(len(levels) - 1, -1, -1): - if isinstance(levels[i], int): - seen += 1 - levels[i] = -seen - result = Tensor.from_positional(ptensor, levels, self._has_device) - if needs_view: - result = result.reshape(*view, *result.size()[len(flat_dims) :]) - return result - - -def _contains_dim(input): - from . import Dim - - for i in input: - if isinstance(i, Dim): - return True - - -def expand(self, *sizes): - if not _contains_dim(sizes): - return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) - dims = sizes - sizes = [d.size for d in dims] + [-1] * self.ndim - self = self.expand(*sizes) - return self[dims] - - -_not_present = object() - - -def _getarg(name, offset, args, kwargs, default): - if len(args) > offset: - return args[offset] - return kwargs.get(name, default) - - -def _patcharg(name, offset, args, kwargs, value): - if len(args) > offset: - args[offset] = value - else: - kwargs[name] = value - - -def _wrap( - orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True -): - from . import Dim, Tensor, TensorLike - - def fn(self, *args, **kwargs): - dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) - if dim is _not_present or (single_dim and not isinstance(dim, Dim)): - with _enable_layers(self.dims): - print(f"dim fallback batch_tensor for {orig}") - return Tensor.from_batched( - orig(self._batchtensor, *args, **kwargs), self._has_device - ) - keepdim = ( - _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False - ) - t, levels = self._tensor, llist(self._levels) - dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) - dim_indices = tuple(levels.index(d) for d in dims) - if reduce and not keepdim: - new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] - else: - new_levels = levels - - if len(dim_indices) == 1: - dim_indices = dim_indices[ - 0 - ] # so that dims that really only take a single argument work... - args = list(args) - _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_positional(t, new_levels, self._has_device) - return t - - with _enable_layers(new_levels): - print(f"dim used batch_tensor for {orig}") - r = orig(t, *args, **kwargs) - return tree_map(wrap, r) - - return fn - - -def _def(name, *args, **kwargs): - from . import _Tensor - - orig = getattr(torch.Tensor, name) - setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) - - -no_slice = slice(None) - -_orig_getitem = torch.Tensor.__getitem__ - - -class dim_tracker: - def __init__(self) -> None: - self.dims = llist() - self.count = [] - - def record(self, d): - if d not in self.dims: - self.dims.append(d) - self.count.append(1) - - def __getitem__(self, d): - return self.count[self.dims.index(d)] - - -def t__getitem__(self, input): - from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike - - # * bail to original example if we have a single non-Dim tensor, or a non-tensor - # * locate ... or an unbound tensor list, and determine its size, bind dim list - # (remember that None does not count to the total dim count) - # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, - # produce the re-view if needed - # * for each single-use dim index, replace with no_slice and mark that it will be added - # (keep track of whether we have to call super) - # * call super if needed - # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) - # this handles bool indexing handling, as well as some other simple cases. - - is_simple = ( - not isinstance(input, Dim) - and not isinstance(input, (tuple, list)) - and - # WAR for functorch bug where zero time tensors in getitem are not handled correctly. - not (isinstance(input, TensorLike) and input.ndim == 0) - ) - - if is_simple: - if isinstance(self, _Tensor): - return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) - else: - return _orig_getitem(self, input) - - # can further optimize this case - if not isinstance(input, tuple): - input = [input] - else: - input = list(input) - - dims_indexed = 0 - expanding_object = None - dimlists = [] - for i, s in enumerate(input): - if s is ... or isinstance(s, DimList) and not s.is_bound: - if expanding_object is not None: - msg = ( - "at most one ... or unbound dimension list can exist in indexing list but" - f" found 2 at offsets {i} and {expanding_object}" - ) - raise DimensionBindError(msg) - expanding_object = i - - if isinstance(s, DimList): - dims_indexed += len(s) if s.is_bound else 0 - dimlists.append(i) - elif s is not None and s is not ...: - dims_indexed += 1 - - ndim = self.ndim - if dims_indexed > ndim: - raise IndexError( - f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." - ) - if expanding_object is not None: - expanding_ndims = ndim - dims_indexed - obj = input[expanding_object] - if obj is ...: - input[expanding_object : expanding_object + 1] = [ - no_slice - ] * expanding_ndims - else: - obj.bind_len(expanding_ndims) - # flatten the dimslists into the indexing - for i in reversed(dimlists): - input[i : i + 1] = input[i] - dims_indexed = 0 - requires_view = False - size = self.size() - view_sizes = [] - dims_seen = dim_tracker() - - def add_dims(t): - if not isinstance(t, _Tensor): - return - for d in t.dims: - dims_seen.record(d) - - add_dims(self) - dim_packs = [] - for i, idx in enumerate(input): - if idx is None: - input[i] = no_slice - view_sizes.append(1) - requires_view = True - else: - sz = size[dims_indexed] - if isinstance(idx, Dim): - idx.size = sz - dims_seen.record(idx) - view_sizes.append(sz) - elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): - for d in idx: - dims_seen.record(idx) - _bind_dims_to_size(sz, idx, f"offset {i}") - view_sizes.extend(d.size for d in idx) - requires_view = True - dim_packs.append(i) - else: - add_dims(idx) - view_sizes.append(sz) - dims_indexed += 1 - if requires_view: - self = self.view(*view_sizes) - for i in reversed(dim_packs): - input[i : i + 1] = input[i] - - # currently: - # input is flat, containing either Dim, or Tensor, or something valid for standard indexing - # self may have first-class dims as well. - - # to index: - # drop the first class dims from self, they just become direct indices of their positions - - # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. - # these dimensions will appear and need to be bound at the first place tensor occurs - - if isinstance(self, _Tensor): - ptensor_self, levels = self._tensor, list(self._levels) - # indices to ptensor rather than self which has first-class dimensions - input_it = iter(input) - flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] - has_device = self._has_device - to_pad = 0 - else: - ptensor_self, flat_inputs = self, input - to_pad = ptensor_self.ndim - len(flat_inputs) - has_device = True - - result_levels = [] - index_levels = [] - tensor_insert_point = None - to_expand = {} - requires_getindex = False - for i, inp in enumerate(flat_inputs): - if isinstance(inp, Dim) and dims_seen[inp] == 1: - flat_inputs[i] = no_slice - result_levels.append(inp) - elif isinstance(inp, TensorLike): - requires_getindex = True - if tensor_insert_point is None: - tensor_insert_point = len(result_levels) - ptensor, levels, _ = _tensor_levels(inp) - to_expand[i] = levels - flat_inputs[i] = ptensor - for l in levels: - if l not in index_levels: - index_levels.append(l) - else: - requires_getindex = True - result_levels.append(0) - - if tensor_insert_point is not None: - result_levels[tensor_insert_point:tensor_insert_point] = index_levels - - for i, levels in to_expand.items(): - flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) - - if requires_getindex: - result = _orig_getitem(ptensor_self, flat_inputs) - else: - result = ptensor_self - - next_positional = -1 - if to_pad > 0: - result_levels.extend([0] * to_pad) - for i, r in enumerate(reversed(result_levels)): - if isinstance(r, int): - result_levels[-1 - i] = next_positional - next_positional -= 1 - - return Tensor.from_positional(result, result_levels, has_device) - - -# XXX - dim is optional and can be the outer-most dimension... -def stack(tensors, new_dim, dim=0, out=None): - if isinstance(dim, int): - return torch.stack(tensors, dim, out).index(dim, new_dim) - index = None - if out is not None: - out, index = _positional_no_permute(out, dim, expand_dim=True) - ptensors = [] - for t in tensors: - pt, pi = _positional_no_permute(t, dim, expand_dim=True) - if index is not None and pi != index: - pt = pt.move_dim(pi, index) - else: - index = pi - ptensors.append(pt) - pr = torch.stack(ptensors, index, out=out) - return pr.index((index, index + 1), (new_dim, dim)) - - -_orig_split = torch.Tensor.split - - -def split(self, split_size_or_sections, dim=0): - from . import _Tensor, Dim - - if isinstance(split_size_or_sections, int) or any( - isinstance(t, int) for t in split_size_or_sections - ): - if isinstance(dim, Dim): - raise ValueError( - "when dim is specified as a Dim object, split sizes must also be dimensions." - ) - return _orig_split(self, split_size_or_sections, dim=dim) - - if isinstance(dim, Dim): - assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" - self, dim = _positional_no_permute(self, dim) - - size = self.size(dim) - total_bound_size = 0 - unbound = [] - sizes = [] - for i, d in enumerate(split_size_or_sections): - if d.is_bound: - sizes.append(d.size) - total_bound_size += d.size - else: - sizes.append(0) - unbound.append(i) - - if unbound: - assert total_bound_size <= size, ( - f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" - ) - remaining_size = size - total_bound_size - chunk_size = -(-remaining_size // len(unbound)) - for u in unbound: - sz = min(chunk_size, remaining_size) - split_size_or_sections[u].size = sz - sizes[u] = sz - remaining_size -= sz - else: - assert total_bound_size == size, ( - f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" - ) - return tuple( - t.index(dim, d) - for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) - ) diff --git a/functorch/dim/wrap_type.py b/functorch/dim/wrap_type.py index aae543b91a89..b9ebda47c4cf 100644 --- a/functorch/dim/wrap_type.py +++ b/functorch/dim/wrap_type.py @@ -26,18 +26,8 @@ PROPERTY_TYPES = (GetSetDescriptorType, property) -def _py_wrap_method(orig, __torch_function__): - def impl(*args, **kwargs): - return __torch_function__(orig, None, args, kwargs) - - return impl - - -def wrap_type(use_c, to_patch, pattern, __torch_function__): - if use_c: - wrap_method = _wrap_method - else: - wrap_method = _py_wrap_method +def wrap_type(to_patch, pattern, __torch_function__): + wrap_method = _wrap_method all = {} for t in reversed(pattern.mro()[:-1]): # skip object diff --git a/pyproject.toml b/pyproject.toml index c42aa782407f..a911a2a723b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,9 +69,6 @@ pyyaml = ["pyyaml"] # Linter tools ################################################################# -[tool.black] -line-length = 88 - [tool.isort] src_paths = ["caffe2", "torch", "torchgen", "functorch", "test"] extra_standard_library = ["typing_extensions"] diff --git a/scripts/lintrunner.py b/scripts/lintrunner.py new file mode 100644 index 000000000000..2e3ad2bc219a --- /dev/null +++ b/scripts/lintrunner.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +Wrapper script to run the isolated hook version of lintrunner. + +This allows developers to easily run lintrunner (including with -a for auto-fixes) +using the same isolated environment that the pre-push hook uses, without having +to manually activate/deactivate virtual environments. + +Usage: + python scripts/lintrunner.py # Check mode (same as git push) + python scripts/lintrunner.py -a # Auto-fix mode + python scripts/lintrunner.py --help # Show lintrunner help + +This module also provides shared functionality for lintrunner hash management. +""" + +from __future__ import annotations + +import hashlib +import os +import shlex +import shutil +import subprocess +import sys +from pathlib import Path + + +def find_repo_root() -> Path: + """Find repository root using git.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + capture_output=True, + text=True, + check=True, + ) + return Path(result.stdout.strip()) + except subprocess.CalledProcessError: + sys.exit("āŒ Not in a git repository") + + +def compute_file_hash(path: Path) -> str: + """Returns SHA256 hash of a file's contents.""" + hasher = hashlib.sha256() + with path.open("rb") as f: + while chunk := f.read(8192): + hasher.update(chunk) + return hasher.hexdigest() + + +def read_stored_hash(path: Path) -> str | None: + if not path.exists(): + return None + try: + return path.read_text().strip() + except Exception: + return None + + +# Venv location - change this if the path changes +HOOK_VENV_PATH = ".git/hooks/linter/.venv" + + +def get_hook_venv_path() -> Path: + """Get the path to the hook virtual environment.""" + repo_root = find_repo_root() + return repo_root / HOOK_VENV_PATH + + +def find_hook_venv() -> Path: + """Locate the isolated hook virtual environment.""" + venv_dir = get_hook_venv_path() + + if not venv_dir.exists(): + sys.exit( + f"āŒ Hook virtual environment not found at {venv_dir}\n" + " Please set this up by running: python scripts/setup_hooks.py" + ) + + return venv_dir + + +def check_lintrunner_installed(venv_dir: Path) -> None: + """Check if lintrunner is installed in the given venv, exit if not.""" + result = subprocess.run( + [ + "uv", + "pip", + "show", + "--python", + str(venv_dir / "bin" / "python"), + "lintrunner", + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if result.returncode != 0: + sys.exit( + "āŒ lintrunner is required but was not found in the hook environment. " + "Please run `python scripts/setup_hooks.py` to reinstall." + ) + print("āœ… lintrunner is already installed") + + +def run_lintrunner(venv_dir: Path, args: list[str]) -> int: + """Run lintrunner command in the specified venv and return exit code.""" + # Run lintrunner directly from the venv's bin directory with environment setup + lintrunner_exe = venv_dir / "bin" / "lintrunner" + cmd = [str(lintrunner_exe)] + args + env = os.environ.copy() + + # PATH: Ensures lintrunner can find other tools in the venv (like python, pip, etc.) + env["PATH"] = str(venv_dir / "bin") + os.pathsep + env.get("PATH", "") + # VIRTUAL_ENV: Tells tools like pip_init.py that we're in a venv (prevents --user flag issues) + env["VIRTUAL_ENV"] = str(venv_dir) + + # Note: Progress tends to be slightly garbled due to terminal control sequences, + # but functionality and final results will be correct + return subprocess.call(cmd, env=env) + + +def initialize_lintrunner_if_needed(venv_dir: Path) -> None: + """Check if lintrunner needs initialization and run init if needed.""" + repo_root = find_repo_root() + lintrunner_toml_path = repo_root / ".lintrunner.toml" + initialized_hash_path = venv_dir / ".lintrunner_plugins_hash" + + if not lintrunner_toml_path.exists(): + print("āš ļø No .lintrunner.toml found. Skipping init.") + return + + current_hash = compute_file_hash(lintrunner_toml_path) + stored_hash = read_stored_hash(initialized_hash_path) + + if current_hash != stored_hash: + print("šŸ” Running `lintrunner init` …", file=sys.stderr) + result = run_lintrunner(venv_dir, ["init"]) + if result != 0: + sys.exit(f"āŒ lintrunner init failed") + initialized_hash_path.write_text(current_hash) + else: + print("āœ… Lintrunner plugins already initialized and up to date.") + + +def main() -> None: + """Run lintrunner in the isolated hook environment.""" + venv_dir = find_hook_venv() + python_exe = venv_dir / "bin" / "python" + + if not python_exe.exists(): + sys.exit(f"āŒ Python executable not found at {python_exe}") + + try: + print(f"šŸ Virtual env being used: {venv_dir}", file=sys.stderr) + + # 1. Ensure lintrunner binary is available in the venv + check_lintrunner_installed(venv_dir) + + # 2. Check for plugin updates and re-init if needed + initialize_lintrunner_if_needed(venv_dir) + + # 3. Run lintrunner with any passed arguments and propagate its exit code + args = sys.argv[1:] + result = run_lintrunner(venv_dir, args) + + # If lintrunner failed and we're not already in auto-fix mode, suggest the wrapper + if result != 0 and "-a" not in args: + print( + "\nšŸ’” To auto-fix these issues, run: python scripts/lintrunner.py -a", + file=sys.stderr, + ) + + sys.exit(result) + + except KeyboardInterrupt: + print("\n Lintrunner interrupted by user (KeyboardInterrupt)", file=sys.stderr) + sys.exit(1) # Tell git push to fail + + +if __name__ == "__main__": + main() diff --git a/scripts/run_lintrunner.py b/scripts/run_lintrunner.py deleted file mode 100644 index 60d5b545cf91..000000000000 --- a/scripts/run_lintrunner.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -""" -Pre‑push hook wrapper for Lintrunner. - -āœ“ Stores a hash of .lintrunner.toml in the venv -āœ“ Re-runs `lintrunner init` if that file's hash changes -""" - -from __future__ import annotations - -import hashlib -import os -import shutil -import subprocess -import sys -from pathlib import Path - - -REPO_ROOT = Path(__file__).resolve().parents[1] -LINTRUNNER_TOML_PATH = REPO_ROOT / ".lintrunner.toml" - -# This is the path to the pre-commit-managed venv -VENV_ROOT = Path(sys.executable).parent.parent -# Stores the hash of .lintrunner.toml from the last time we ran `lintrunner init` -INITIALIZED_LINTRUNNER_TOML_HASH_PATH = VENV_ROOT / ".lintrunner_plugins_hash" - - -def ensure_lintrunner() -> None: - """Fail if Lintrunner is not on PATH.""" - if shutil.which("lintrunner"): - print("āœ… lintrunner is already installed") - return - sys.exit( - "āŒ lintrunner is required but was not found on your PATH. Please run the `python scripts/setup_hooks.py` to install to configure lintrunner before using this script. If `git push` still fails, you may need to open an new terminal" - ) - - -def ensure_virtual_environment() -> None: - """Fail if not running within a virtual environment.""" - in_venv = ( - os.environ.get("VIRTUAL_ENV") is not None - or hasattr(sys, "real_prefix") - or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix) - ) - - if not in_venv: - sys.exit( - "āŒ This script must be run from within a virtual environment. " - "Please activate your virtual environment before running this script." - ) - - -def compute_file_hash(path: Path) -> str: - """Returns SHA256 hash of a file's contents.""" - hasher = hashlib.sha256() - with path.open("rb") as f: - while chunk := f.read(8192): - hasher.update(chunk) - return hasher.hexdigest() - - -def read_stored_hash(path: Path) -> str | None: - if not path.exists(): - return None - try: - return path.read_text().strip() - except Exception: - return None - - -def initialize_lintrunner_if_needed() -> None: - """Runs lintrunner init if .lintrunner.toml changed since last run.""" - if not LINTRUNNER_TOML_PATH.exists(): - print("āš ļø No .lintrunner.toml found. Skipping init.") - return - - print( - f"INITIALIZED_LINTRUNNER_TOML_HASH_PATH = {INITIALIZED_LINTRUNNER_TOML_HASH_PATH}" - ) - current_hash = compute_file_hash(LINTRUNNER_TOML_PATH) - stored_hash = read_stored_hash(INITIALIZED_LINTRUNNER_TOML_HASH_PATH) - - if current_hash == stored_hash: - print("āœ… Lintrunner plugins already initialized and up to date.") - return - - print("šŸ” Running `lintrunner init` …", file=sys.stderr) - subprocess.check_call(["lintrunner", "init"]) - INITIALIZED_LINTRUNNER_TOML_HASH_PATH.write_text(current_hash) - - -def main() -> None: - # 0. Ensure we're running in a virtual environment - ensure_virtual_environment() - print(f"šŸ Virtual env being used: {VENV_ROOT}", file=sys.stderr) - - # 1. Ensure lintrunner binary is available - ensure_lintrunner() - - # 2. Check for plugin updates and re-init if needed - initialize_lintrunner_if_needed() - - # 3. Run lintrunner with any passed arguments and propagate its exit code - args = sys.argv[1:] # Forward all arguments to lintrunner - result = subprocess.call(["lintrunner"] + args) - sys.exit(result) - - -if __name__ == "__main__": - main() diff --git a/scripts/setup_hooks.py b/scripts/setup_hooks.py index 41f08d45e98b..e8effe7f8232 100644 --- a/scripts/setup_hooks.py +++ b/scripts/setup_hooks.py @@ -1,31 +1,51 @@ #!/usr/bin/env python3 """ -Bootstrap Git pre‑push hook. +Bootstrap Git pre‑push hook with isolated virtual environment. āœ“ Requires uv to be installed (fails if not available) -āœ“ Installs/updates pre‑commit with uv (global, venv‑proof) -āœ“ Registers the repo's pre‑push hook and freezes hook versions +āœ“ Creates isolated venv in .git/hooks/linter/.venv/ for hook dependencies +āœ“ Installs lintrunner only in the isolated environment +āœ“ Creates direct git hook that bypasses pre-commit Run this from the repo root (inside or outside any project venv): python scripts/setup_hooks.py + +IMPORTANT: The generated git hook references scripts/lintrunner.py. If users checkout +branches that don't have this file, git push will fail with "No such file or directory". +Users would need to either: +1. Re-run the old setup_hooks.py from that branch, or +2. Manually delete .git/hooks/pre-push to disable hooks temporarily, or +3. Switch back to a branch with the new scripts/lintrunner.py """ from __future__ import annotations +import shlex import shutil import subprocess import sys from pathlib import Path -from typing import Tuple + + +# Add scripts directory to Python path so we can import lintrunner module +scripts_dir = Path(__file__).parent +sys.path.insert(0, str(scripts_dir)) + +# Import shared functions from lintrunner module +from lintrunner import find_repo_root, get_hook_venv_path + + +# Restore sys.path to avoid affecting other imports +sys.path.pop(0) # ─────────────────────────────────────────── # Helper utilities # ─────────────────────────────────────────── -def run(cmd: list[str]) -> None: +def run(cmd: list[str], cwd: Path = None) -> None: print(f"$ {' '.join(cmd)}") - subprocess.check_call(cmd) + subprocess.check_call(cmd, cwd=cwd) def which(cmd: str) -> bool: @@ -34,28 +54,7 @@ def which(cmd: str) -> bool: def ensure_uv() -> None: if which("uv"): - # Ensure the path uv installs binaries to is part of the system path - print("$ uv tool update-shell") - result = subprocess.run( - ["uv", "tool", "update-shell"], capture_output=True, text=True - ) - if result.returncode == 0: - # Check if the output indicates changes were made - if ( - "Updated" in result.stdout - or "Added" in result.stdout - or "Modified" in result.stdout - ): - print( - "āš ļø Shell configuration updated. You may need to restart your terminal for changes to take effect." - ) - elif result.stdout.strip(): - print(result.stdout) - return - else: - sys.exit( - f"āŒ Warning: uv tool update-shell failed: {result.stderr}. uv installed tools may not be available." - ) + return sys.exit( "\nāŒ uv is required but was not found on your PATH.\n" @@ -65,29 +64,6 @@ def ensure_uv() -> None: ) -def ensure_tool_installed( - tool: str, force_update: bool = False, python_ver: Tuple[int, int] = None -) -> None: - """ - Checks to see if the tool is available and if not (or if force update requested) then - it reinstalls it. - - Returns: Whether or not the tool is available on PATH. If it's not, a new terminal - needs to be opened before git pushes work as expected. - """ - if force_update or not which(tool): - print(f"Ensuring latest {tool} via uv …") - command = ["uv", "tool", "install", "--force", tool] - if python_ver: - # Add the Python version to the command if specified - command.extend(["--python", f"{python_ver[0]}.{python_ver[1]}"]) - run(command) - if not which(tool): - print( - f"\nāš ļø {tool} installation succeed, but it's not on PATH. Launch a new terminal if your git pushes don't work.\n" - ) - - if sys.platform.startswith("win"): print( "\nāš ļø Lintrunner is not supported on Windows, so there are no pre-push hooks to add. Exiting setup.\n" @@ -95,52 +71,61 @@ def ensure_tool_installed( sys.exit(0) # ─────────────────────────────────────────── -# 1. Install dependencies +# 1. Setup isolated hook environment # ─────────────────────────────────────────── ensure_uv() -# Ensure pre-commit is installed globally via uv -ensure_tool_installed("pre-commit", force_update=True, python_ver=(3, 9)) +# Find repo root and setup hook directory +repo_root = find_repo_root() +venv_dir = get_hook_venv_path() +hooks_dir = venv_dir.parent.parent # Go from .git/hooks/linter/.venv to .git/hooks + -# Don't force a lintrunner update because it might break folks -# who already have it installed in a different way -ensure_tool_installed("lintrunner") +print(f"Setting up isolated hook environment in {venv_dir}") + +# Create isolated virtual environment for hooks +if venv_dir.exists(): + print("Removing existing hook venv...") + shutil.rmtree(venv_dir) + +run(["uv", "venv", str(venv_dir), "--python", "3.9"]) + +# Install lintrunner in the isolated environment +print("Installing lintrunner in isolated environment...") +run( + ["uv", "pip", "install", "--python", str(venv_dir / "bin" / "python"), "lintrunner"] +) # ─────────────────────────────────────────── -# 2. Activate (or refresh) the pre‑push hook +# 2. Create direct git pre-push hook # ─────────────────────────────────────────── -# ── Activate (or refresh) the repo’s pre‑push hook ────────────────────────── -# Creates/overwrites .git/hooks/pre‑push with a tiny shim that will call -# `pre-commit run --hook-stage pre-push` on every `git push`. -# This is why we need to install pre-commit globally. -# -# The --allow-missing-config flag lets pre-commit succeed if someone changes to -# a branch that doesn't have pre-commit installed -run( - [ - "uv", - "tool", - "run", - "pre-commit", - "install", - "--hook-type", - "pre-push", - "--allow-missing-config", - ] +pre_push_hook = hooks_dir / "pre-push" +python_exe = venv_dir / "bin" / "python" +lintrunner_script_path_quoted = shlex.quote( + str(repo_root / "scripts" / "lintrunner.py") ) -# ── Pin remote‑hook versions for reproducibility ──────────────────────────── -# (Note: we don't have remote hooks right now, but it future-proofs this script) -# 1. `autoupdate` bumps every remote hook’s `rev:` in .pre-commit-config.yaml -# to the latest commit on its default branch. -# 2. `--freeze` immediately rewrites each `rev:` to the exact commit SHA, -# ensuring all contributors and CI run identical hook code. -run(["uv", "tool", "run", "pre-commit", "autoupdate", "--freeze"]) +hook_script = f"""#!/bin/bash +set -e + +# Check if lintrunner script exists (user might be on older commit) +if [ ! -f {lintrunner_script_path_quoted} ]; then + echo "āš ļø {lintrunner_script_path_quoted} not found - skipping linting (likely on an older commit)" + exit 0 +fi + +# Run lintrunner wrapper using the isolated venv's Python +{shlex.quote(str(python_exe))} {lintrunner_script_path_quoted} +""" +print(f"Creating git pre-push hook at {pre_push_hook}") +pre_push_hook.write_text(hook_script) +pre_push_hook.chmod(0o755) # Make executable print( - "\nāœ… pre‑commit is installed globally via uv and the pre‑push hook is active.\n" + "\nāœ… Isolated hook environment created and pre‑push hook is active.\n" " Lintrunner will now run automatically on every `git push`.\n" + f" Hook dependencies are isolated in {venv_dir}\n" ) diff --git a/setup.py b/setup.py index 189a78c23bbb..23ef58124139 100644 --- a/setup.py +++ b/setup.py @@ -156,6 +156,12 @@ # USE_ROCM_KERNEL_ASSERT=1 # Enable kernel assert in ROCm platform # +# USE_ROCM_CK_GEMM=1 +# Enable building CK GEMM backend in ROCm platform +# +# USE_ROCM_CK_SDPA=1 +# Enable building CK SDPA backend in ROCm platform +# # Environment variables we respect (these environment variables are # conventional and are often understood/set by other software.) # @@ -229,6 +235,11 @@ # # BUILD_PYTHON_ONLY # Builds pytorch as a wheel using libtorch.so from a separate wheel +# +# USE_NIGHTLY=VERSION +# Skip cmake build and instead download and extract nightly PyTorch wheel +# matching the specified version (e.g., USE_NIGHTLY="2.8.0.dev20250608+cpu") +# into the local directory for development use from __future__ import annotations @@ -266,8 +277,10 @@ import shutil import subprocess import sysconfig +import tempfile import textwrap import time +import zipfile from collections import defaultdict from pathlib import Path from typing import Any, ClassVar, IO @@ -588,9 +601,372 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +# ATTENTION: THIS IS AI SLOP +def extract_variant_from_version(version: str) -> str: + """Extract variant from version string, defaulting to 'cpu'.""" + import re + + variant_match = re.search(r"\+([^-\s,)]+)", version) + return variant_match.group(1) if variant_match else "cpu" + + +# ATTENTION: THIS IS AI SLOP +def get_nightly_git_hash(version: str) -> str: + """Download a nightly wheel and extract the git hash from its version.py file.""" + # Extract variant from version to construct correct URL + variant = extract_variant_from_version(version) + nightly_index_url = f"https://download.pytorch.org/whl/nightly/{variant}/" + + torch_version_spec = f"torch=={version}" + + # Create a temporary directory for downloading + with tempfile.TemporaryDirectory(prefix="pytorch-hash-extract-") as temp_dir: + temp_path = Path(temp_dir) + + # Download the wheel + report(f"-- Downloading {version} wheel to extract git hash...") + download_cmd = [ + "uvx", + "pip", + "download", + "--index-url", + nightly_index_url, + "--pre", + "--no-deps", + "--dest", + str(temp_path), + torch_version_spec, + ] + + result = subprocess.run(download_cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Failed to download {version} wheel for git hash extraction: {result.stderr}" + ) + + # Find the downloaded wheel file + wheel_files = list(temp_path.glob("torch-*.whl")) + if not wheel_files: + raise RuntimeError(f"No torch wheel found after downloading {version}") + + wheel_file = wheel_files[0] + + # Extract the wheel and look for version.py + with tempfile.TemporaryDirectory( + prefix="pytorch-wheel-extract-" + ) as extract_dir: + extract_path = Path(extract_dir) + + with zipfile.ZipFile(wheel_file, "r") as zip_ref: + zip_ref.extractall(extract_path) + + # Find torch directory and version.py + torch_dirs = list(extract_path.glob("torch")) + if not torch_dirs: + torch_dirs = list(extract_path.glob("*/torch")) + + if not torch_dirs: + raise RuntimeError(f"Could not find torch directory in {version} wheel") + + version_file = torch_dirs[0] / "version.py" + if not version_file.exists(): + raise RuntimeError(f"Could not find version.py in {version} wheel") + + # Read and parse version.py to extract git_version (nightly branch commit) + from ast import literal_eval + + nightly_commit = None + with version_file.open(encoding="utf-8") as f: + for line in f: + if line.strip().startswith("git_version"): + try: + # Parse the git_version assignment, e.g., git_version = "abc123def456" + nightly_commit = literal_eval( + line.partition("=")[2].strip() + ) + break + except (ValueError, SyntaxError): + continue + + if not nightly_commit: + raise RuntimeError( + f"Could not parse git_version from {version} wheel's version.py" + ) + + # Now fetch the nightly branch and extract the real source commit from the message + report("-- Fetching nightly branch to extract source commit...") + + # Fetch only the nightly branch + subprocess.check_call(["git", "fetch", "origin", "nightly"], cwd=str(CWD)) + + # Get the commit message from the nightly commit + commit_message = subprocess.check_output( + ["git", "show", "--no-patch", "--format=%s", nightly_commit], + cwd=str(CWD), + text=True, + ).strip() + + # Parse the commit message to extract the real hash + # Format: "2025-08-06 nightly release (74a754aae98aabc2aca67e5edb41cc684fae9a82)" + import re + + hash_match = re.search(r"\(([0-9a-fA-F]{40})\)", commit_message) + if hash_match: + real_commit = hash_match.group(1) + report(f"-- Extracted source commit: {real_commit[:12]}...") + return real_commit + else: + raise RuntimeError( + f"Could not parse commit hash from nightly commit message: {commit_message}" + ) + + +# ATTENTION: THIS IS AI SLOP +def get_latest_nightly_version(variant: str = "cpu") -> str: + """Get the latest available nightly version using pip to query the PyTorch nightly index.""" + # Get the latest available nightly version for the specified variant + nightly_index_url = f"https://download.pytorch.org/whl/nightly/{variant}/" + + # Run pip index to get available versions + output = subprocess.check_output( + [ + "uvx", + "pip", + "index", + "versions", + "--index-url", + nightly_index_url, + "--pre", + "torch", + ], + text=True, + timeout=30, + ) + + # Parse the first line to get the latest version + # Format: "torch (2.9.0.dev20250806)" or "torch (2.9.0.dev20250806+cpu)" + first_line = output.strip().split("\n")[0] + if "(" in first_line and ")" in first_line: + # Extract version from parentheses exactly as reported + version = first_line.split("(")[1].split(")")[0] + return version + + raise RuntimeError(f"Could not parse version from pip index output: {first_line}") + + +# ATTENTION: THIS IS AI SLOP +def download_and_extract_nightly_wheel(version: str) -> None: + """Download and extract nightly PyTorch wheel for USE_NIGHTLY=VERSION builds.""" + + # Extract variant from version (e.g., cpu, cu121, cu118, rocm5.7) + variant = extract_variant_from_version(version) + nightly_index_url = f"https://download.pytorch.org/whl/nightly/{variant}/" + + # Construct the full torch version spec + torch_version_spec = f"torch=={version}" + + # Create a temporary directory for downloading + with tempfile.TemporaryDirectory(prefix="pytorch-nightly-") as temp_dir: + temp_path = Path(temp_dir) + + # Use pip to download the specific nightly wheel + download_cmd = [ + "uvx", + "pip", + "download", + "--index-url", + nightly_index_url, + "--pre", + "--no-deps", + "--dest", + str(temp_path), + torch_version_spec, + ] + + report("-- Downloading nightly PyTorch wheel...") + result = subprocess.run(download_cmd, capture_output=True, text=True) + if result.returncode != 0: + # Try to get the latest nightly version for the same variant to help the user + variant = extract_variant_from_version(version) + try: + report(f"-- Detecting latest {variant} nightly version...") + latest_version = get_latest_nightly_version(variant) + error_msg = f"Failed to download nightly wheel for version {version}: {result.stderr.strip()}" + error_msg += ( + f"\n\nLatest available {variant} nightly version: {latest_version}" + ) + error_msg += f'\nTry: USE_NIGHTLY="{latest_version}"' + + # Also get the git hash for the latest version + git_hash = get_nightly_git_hash(latest_version) + error_msg += f"\n\nIMPORTANT: You must checkout the matching source commit:\ngit checkout {git_hash}" + except Exception: + # If we can't get latest for this variant, try CPU as fallback + try: + report("-- Detecting latest CPU nightly version...") + latest_version = get_latest_nightly_version("cpu") + error_msg = f"Failed to download nightly wheel for version {version}: {result.stderr.strip()}" + error_msg += f"\n\nCould not find {variant} nightlies. Latest available CPU nightly version: {latest_version}" + error_msg += f'\nTry: USE_NIGHTLY="{latest_version}"' + except Exception: + error_msg = f"Failed to download nightly wheel for version {version}: {result.stderr.strip()}" + error_msg += "\n\nCould not determine latest nightly version. " + error_msg += "Check https://download.pytorch.org/whl/nightly/ for available versions." + + raise RuntimeError(error_msg) + + # Find the downloaded wheel file + wheel_files = list(temp_path.glob("torch-*.whl")) + if not wheel_files: + raise RuntimeError("No torch wheel found after download") + elif len(wheel_files) > 1: + raise RuntimeError(f"Multiple torch wheels found: {wheel_files}") + + wheel_file = wheel_files[0] + report(f"-- Downloaded wheel: {wheel_file.name}") + + # Extract the wheel + with tempfile.TemporaryDirectory( + prefix="pytorch-wheel-extract-" + ) as extract_dir: + extract_path = Path(extract_dir) + + # Use Python's zipfile to extract the wheel + with zipfile.ZipFile(wheel_file, "r") as zip_ref: + zip_ref.extractall(extract_path) + + # Find the torch directory in the extracted wheel + torch_dirs = list(extract_path.glob("torch")) + if not torch_dirs: + # Sometimes the torch directory might be nested + torch_dirs = list(extract_path.glob("*/torch")) + + if not torch_dirs: + raise RuntimeError("Could not find torch directory in extracted wheel") + + source_torch_dir = torch_dirs[0] + target_torch_dir = TORCH_DIR + + report( + f"-- Extracting wheel contents from {source_torch_dir} to {target_torch_dir}" + ) + + # Copy the essential files from the wheel to our local directory + # Based on the file listing logic from tools/nightly.py + files_to_copy: list[Path] = [] + + # Get platform-specific binary files + if IS_LINUX: + files_to_copy.extend(source_torch_dir.glob("*.so")) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.so*") + if (source_torch_dir / "lib").exists() + else [] + ) + elif IS_DARWIN: + files_to_copy.extend(source_torch_dir.glob("*.so")) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.dylib") + if (source_torch_dir / "lib").exists() + else [] + ) + elif IS_WINDOWS: + files_to_copy.extend(source_torch_dir.glob("*.pyd")) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.lib") + if (source_torch_dir / "lib").exists() + else [] + ) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.dll") + if (source_torch_dir / "lib").exists() + else [] + ) + + # Add essential directories and files + essential_items = ["version.py", "bin", "include", "lib"] + for item_name in essential_items: + item_path = source_torch_dir / item_name + if item_path.exists(): + files_to_copy.append(item_path) + + # Add testing internal generated files + testing_generated = source_torch_dir / "testing" / "_internal" / "generated" + if testing_generated.exists(): + files_to_copy.append(testing_generated) + + # Copy all the files and directories + for src_path in files_to_copy: + rel_path = src_path.relative_to(source_torch_dir) + dst_path = target_torch_dir / rel_path + + # Copy files and directories, preserving existing subdirectories + if src_path.is_dir(): + # Create destination directory if it doesn't exist + dst_path.mkdir(parents=True, exist_ok=True) + # Copy individual entries from source directory + for src_item in src_path.iterdir(): + dst_item = dst_path / src_item.name + if src_item.is_dir(): + # Recursively copy subdirectories (this will preserve existing ones) + shutil.copytree(src_item, dst_item, dirs_exist_ok=True) + else: + # Copy individual files, overwriting existing ones + shutil.copy2(src_item, dst_item) + else: + # For files, remove existing and copy new + if dst_path.exists(): + dst_path.unlink() + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_path, dst_path) + + report(f" Copied {rel_path}") + + report("-- Nightly wheel extraction completed") + + # all the work we need to do _before_ setup runs def build_deps() -> None: report(f"-- Building version {TORCH_VERSION}") + + # ATTENTION: THIS IS AI SLOP + # Check for USE_NIGHTLY=VERSION to bypass normal build and download nightly wheel + nightly_version = os.getenv("USE_NIGHTLY") + if nightly_version is not None: + import re + + if ( + nightly_version == "" + or nightly_version == "cpu" + or re.match(r"^cu\d+$", nightly_version) + or re.match(r"^rocm\d+\.\d+$", nightly_version) + ): + # Empty string or variant-only specification, show error with latest version + variant = "cpu" if nightly_version == "" else nightly_version + report(f"-- Detecting latest {variant} nightly version...") + latest_version = get_latest_nightly_version(variant) + # Also get the git hash to tell user which commit to checkout + git_hash = get_nightly_git_hash(latest_version) + + if nightly_version == "": + error_msg = f"USE_NIGHTLY cannot be empty. Latest available version: {latest_version}\n" + else: + error_msg = ( + "USE_NIGHTLY requires a specific version, not just a variant. " + "Latest available {nightly_version} version: {latest_version}\n" + ) + + error_msg += f'Try: USE_NIGHTLY="{latest_version}"' + error_msg += f"\n\nIMPORTANT: You must checkout the matching source commit for this binary:\ngit checkout {git_hash}" + raise RuntimeError(error_msg) + else: + # Full version specification + report( + f"-- USE_NIGHTLY={nightly_version} detected, downloading nightly wheel" + ) + download_and_extract_nightly_wheel(nightly_version) + return + check_submodules() check_pydep("yaml", "pyyaml") build_pytorch( @@ -750,7 +1126,7 @@ def _embed_libomp(self) -> None: def run(self) -> None: # Report build options. This is run after the build completes so # `CMakeCache.txt` exists # and we can get an accurate report on what is used and what is not. - cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) + cmake_cache_vars = get_cmake_cache_vars() if cmake_cache_vars["USE_NUMPY"]: report("-- Building with NumPy bindings") else: @@ -850,23 +1226,6 @@ def run(self) -> None: target_dir.mkdir(parents=True, exist_ok=True) self.copy_file(export_lib, target_lib) - # In ROCm on Windows case copy rocblas and hipblaslt files into - # torch/lib/rocblas/library and torch/lib/hipblaslt/library - if str2bool(os.getenv("USE_ROCM")): - rocm_dir_path = Path(os.environ["ROCM_DIR"]) - rocm_bin_path = rocm_dir_path / "bin" - rocblas_dir = rocm_bin_path / "rocblas" - target_rocblas_dir = target_dir / "rocblas" - target_rocblas_dir.mkdir(parents=True, exist_ok=True) - self.copy_tree(rocblas_dir, str(target_rocblas_dir)) - - hipblaslt_dir = rocm_bin_path / "hipblaslt" - target_hipblaslt_dir = target_dir / "hipblaslt" - target_hipblaslt_dir.mkdir(parents=True, exist_ok=True) - self.copy_tree(hipblaslt_dir, str(target_hipblaslt_dir)) - else: - report("The specified environment variable does not exist.") - def build_extensions(self) -> None: self.create_compile_commands() @@ -1310,6 +1669,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.h", "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", + "_inductor/kernel/flex/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index 05a0587b8398..e81fb555c848 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -90,7 +90,7 @@ def f(t): return reduction_func(t, dim=0) f.__name__ = reduction_func.__name__ - f_c = torch.compile(f, dynamic=False) + f_c = torch.compile(f, dynamic=False, fullgraph=True) for size in (512, 1024, 2048, 4096): x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) @@ -116,7 +116,7 @@ def bench_scan( def f(t): return scan_func(t, dim=dim) - f_c = torch.compile(f, dynamic=False) + f_c = torch.compile(f, dynamic=False, fullgraph=True) for size in (32, 128, 512, 1024): f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}" @@ -135,7 +135,7 @@ def f(t): def f_1d(t): return scan_func(t, dim=0) - f_1d_c = torch.compile(f_1d, dynamic=False) + f_1d_c = torch.compile(f_1d, dynamic=False, fullgraph=True) for size in (100, 10000, 1000000): f_1d.__name__ = f"{scan_func.__name__}-1d-{size}" @@ -154,9 +154,7 @@ def f_1d(t): def main() -> None: - dtypes = [torch.float16, torch.float32] - if torch.backends.mps.is_macos_or_newer(14, 0): - dtypes.append(torch.bfloat16) + dtypes = [torch.float16, torch.float32, torch.bfloat16] # Profile index ops B = 11 @@ -204,4 +202,5 @@ def main() -> None: if __name__ == "__main__": + torch._dynamo.config.cache_size_limit = 2**16 main() diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 34f88ca7e3e1..6898e406fb3b 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -5,6 +5,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp diff --git a/test/cpp/aoti_abi_check/test_dtype.cpp b/test/cpp/aoti_abi_check/test_dtype.cpp index 25385de5b103..e6e7e75867c8 100644 --- a/test/cpp/aoti_abi_check/test_dtype.cpp +++ b/test/cpp/aoti_abi_check/test_dtype.cpp @@ -1,16 +1,16 @@ #include -#include -#include -#include -#include -#include -#include -#include +#include +#include #include - +#include +#include +#include +#include +#include #include #include +#include #include #include #include @@ -18,12 +18,12 @@ #include TEST(TestDtype, TestBFloat16) { - c10::BFloat16 a = 1.0f; - c10::BFloat16 b = 2.0f; - c10::BFloat16 add = 3.0f; - c10::BFloat16 sub = -1.0f; - c10::BFloat16 mul = 2.0f; - c10::BFloat16 div = 0.5f; + torch::headeronly::BFloat16 a = 1.0f; + torch::headeronly::BFloat16 b = 2.0f; + torch::headeronly::BFloat16 add = 3.0f; + torch::headeronly::BFloat16 sub = -1.0f; + torch::headeronly::BFloat16 mul = 2.0f; + torch::headeronly::BFloat16 div = 0.5f; EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -32,12 +32,12 @@ TEST(TestDtype, TestBFloat16) { } TEST(TestDtype, TestFloat8_e4m3fn) { - c10::Float8_e4m3fn a = 1.0f; - c10::Float8_e4m3fn b = 2.0f; - c10::Float8_e4m3fn add = 3.0f; - c10::Float8_e4m3fn sub = -1.0f; - c10::Float8_e4m3fn mul = 2.0f; - c10::Float8_e4m3fn div = 0.5f; + torch::headeronly::Float8_e4m3fn a = 1.0f; + torch::headeronly::Float8_e4m3fn b = 2.0f; + torch::headeronly::Float8_e4m3fn add = 3.0f; + torch::headeronly::Float8_e4m3fn sub = -1.0f; + torch::headeronly::Float8_e4m3fn mul = 2.0f; + torch::headeronly::Float8_e4m3fn div = 0.5f; EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -46,12 +46,12 @@ TEST(TestDtype, TestFloat8_e4m3fn) { } TEST(TestDtype, TestFloat8_e4m3fuz) { - c10::Float8_e4m3fnuz a = 1.0f; - c10::Float8_e4m3fnuz b = 2.0f; - c10::Float8_e4m3fnuz add = 3.0f; - c10::Float8_e4m3fnuz sub = -1.0f; - c10::Float8_e4m3fnuz mul = 2.0f; - c10::Float8_e4m3fnuz div = 0.5f; + torch::headeronly::Float8_e4m3fnuz a = 1.0f; + torch::headeronly::Float8_e4m3fnuz b = 2.0f; + torch::headeronly::Float8_e4m3fnuz add = 3.0f; + torch::headeronly::Float8_e4m3fnuz sub = -1.0f; + torch::headeronly::Float8_e4m3fnuz mul = 2.0f; + torch::headeronly::Float8_e4m3fnuz div = 0.5f; EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -60,12 +60,12 @@ TEST(TestDtype, TestFloat8_e4m3fuz) { } TEST(TestDtype, TestFloat8_e5m2) { - c10::Float8_e5m2 a = 1.0f; - c10::Float8_e5m2 b = 2.0f; - c10::Float8_e5m2 add = 3.0f; - c10::Float8_e5m2 sub = -1.0f; - c10::Float8_e5m2 mul = 2.0f; - c10::Float8_e5m2 div = 0.5f; + torch::headeronly::Float8_e5m2 a = 1.0f; + torch::headeronly::Float8_e5m2 b = 2.0f; + torch::headeronly::Float8_e5m2 add = 3.0f; + torch::headeronly::Float8_e5m2 sub = -1.0f; + torch::headeronly::Float8_e5m2 mul = 2.0f; + torch::headeronly::Float8_e5m2 div = 0.5f; EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -74,12 +74,12 @@ TEST(TestDtype, TestFloat8_e5m2) { } TEST(TestDtype, TestFloat8_e5m2fnuz) { - c10::Float8_e5m2fnuz a = 1.0f; - c10::Float8_e5m2fnuz b = 2.0f; - c10::Float8_e5m2fnuz add = 3.0f; - c10::Float8_e5m2fnuz sub = -1.0f; - c10::Float8_e5m2fnuz mul = 2.0f; - c10::Float8_e5m2fnuz div = 0.5f; + torch::headeronly::Float8_e5m2fnuz a = 1.0f; + torch::headeronly::Float8_e5m2fnuz b = 2.0f; + torch::headeronly::Float8_e5m2fnuz add = 3.0f; + torch::headeronly::Float8_e5m2fnuz sub = -1.0f; + torch::headeronly::Float8_e5m2fnuz mul = 2.0f; + torch::headeronly::Float8_e5m2fnuz div = 0.5f; EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -87,6 +87,11 @@ TEST(TestDtype, TestFloat8_e5m2fnuz) { EXPECT_EQ(a / b, div); } +TEST(TestDtype, TestFloat8_e8m0fnu) { + torch::headeronly::Float8_e8m0fnu a = 1.0f; + ASSERT_FALSE(a.isnan()); +} + TEST(TestDtype, TestFloat4) { // not much you can do with this type, just make sure it compiles torch::headeronly::Float4_e2m1fn_x2 a(5); @@ -118,12 +123,12 @@ TEST(TestDtype, TestHalf) { } TEST(TestDtype, TestComplexFloat) { - c10::complex a(std::complex(1.0f, 2.0f)); - c10::complex b(std::complex(3.0f, 4.0f)); - c10::complex add(std::complex(4.0f, 6.0f)); - c10::complex sub(std::complex(-2.0f, -2.0f)); - c10::complex mul(std::complex(-5.0f, 10.0f)); - c10::complex div(std::complex(0.44f, 0.08f)); + torch::headeronly::complex a(std::complex(1.0f, 2.0f)); + torch::headeronly::complex b(std::complex(3.0f, 4.0f)); + torch::headeronly::complex add(std::complex(4.0f, 6.0f)); + torch::headeronly::complex sub(std::complex(-2.0f, -2.0f)); + torch::headeronly::complex mul(std::complex(-5.0f, 10.0f)); + torch::headeronly::complex div(std::complex(0.44f, 0.08f)); EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -145,3 +150,60 @@ TEST(TestDtype, TestQuintsQintsAndBits) { auto i = torch::headeronly::bits8(2); auto j = torch::headeronly::bits16(6); } + +TEST(TestDtype, TestScalarType) { + using torch::headeronly::ScalarType; + constexpr ScalarType expected_scalar_types[] = { + ScalarType::Byte, + ScalarType::Char, + ScalarType::Short, + ScalarType::Int, + ScalarType::Long, + ScalarType::Half, + ScalarType::Float, + ScalarType::Double, + ScalarType::ComplexHalf, + ScalarType::ComplexFloat, + ScalarType::ComplexDouble, + ScalarType::Bool, + ScalarType::QInt8, + ScalarType::QUInt8, + ScalarType::QInt32, + ScalarType::BFloat16, + ScalarType::QUInt4x2, + ScalarType::QUInt2x4, + ScalarType::Bits1x8, + ScalarType::Bits2x4, + ScalarType::Bits4x2, + ScalarType::Bits8, + ScalarType::Bits16, + ScalarType::Float8_e5m2, + ScalarType::Float8_e4m3fn, + ScalarType::Float8_e5m2fnuz, + ScalarType::Float8_e4m3fnuz, + ScalarType::UInt16, + ScalarType::UInt32, + ScalarType::UInt64, + ScalarType::UInt1, + ScalarType::UInt2, + ScalarType::UInt3, + ScalarType::UInt4, + ScalarType::UInt5, + ScalarType::UInt6, + ScalarType::UInt7, + ScalarType::Int1, + ScalarType::Int2, + ScalarType::Int3, + ScalarType::Int4, + ScalarType::Int5, + ScalarType::Int6, + ScalarType::Int7, + ScalarType::Float8_e8m0fnu, + ScalarType::Float4_e2m1fn_x2, + ScalarType::Undefined, + }; + for (int8_t i = 0; i < static_cast(torch::headeronly::NumScalarTypes); + i++) { + EXPECT_EQ(static_cast(i), expected_scalar_types[i]); + } +} diff --git a/test/cpp/aoti_abi_check/test_exception.cpp b/test/cpp/aoti_abi_check/test_exception.cpp index 74a9fee5d986..26f809293244 100644 --- a/test/cpp/aoti_abi_check/test_exception.cpp +++ b/test/cpp/aoti_abi_check/test_exception.cpp @@ -1,6 +1,7 @@ #include #include +#include namespace torch { namespace aot_inductor { @@ -15,5 +16,10 @@ TEST(TestExceptions, TestStdTorchCheck) { std::runtime_error); } +TEST(TestExceptions, TestTorchErrorCodeCheck) { + EXPECT_NO_THROW(TORCH_ERROR_CODE_CHECK(0)); + EXPECT_THROW(TORCH_ERROR_CODE_CHECK(1), std::runtime_error); +} + } // namespace aot_inductor } // namespace torch diff --git a/test/cpp/api/tensor_cuda.cpp b/test/cpp/api/tensor_cuda.cpp index 7a89f0ae5367..1c48a33fb7c0 100644 --- a/test/cpp/api/tensor_cuda.cpp +++ b/test/cpp/api/tensor_cuda.cpp @@ -1,6 +1,8 @@ #include #include +#include +#include #include @@ -124,3 +126,63 @@ TEST(TensorTest, MagmaInitializesCorrectly_CUDA) { at::inverse(tensor); } } + +#ifdef USE_CUDA +#include +#if AT_CUDNN_ENABLED() +TEST(CuDNNBatchNormTest, OutVariantMatchesFunctional) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA is not available"; + } + if (!at::Context::hasCuDNN()) { + GTEST_SKIP() << "cuDNN is not available"; + } + + auto device = torch::device(torch::kCUDA); + + auto input = torch::rand({2, 3, 4, 4}, device); + auto weight = torch::randn({3}, device); + auto bias = torch::randn({3}, device); + auto running_mean = torch::zeros({3}, device); + auto running_var = torch::ones({3}, device); + + bool training = true; + double exponential_average_factor = 0.1; + double epsilon = 1e-5; + + auto output = torch::empty_like(input); + auto save_mean = torch::empty({3}, device); + auto save_var = torch::empty({3}, device); + auto reserve = torch::empty({0}, device.dtype(torch::kByte)); + + at::native::cudnn_batch_norm_out( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + output, + save_mean, + save_var, + reserve); + + auto ref_outputs = at::native::cudnn_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon); + + ASSERT_TRUE(torch::allclose(output, std::get<0>(ref_outputs))); + ASSERT_TRUE(torch::allclose(save_mean, std::get<1>(ref_outputs))); + ASSERT_TRUE(torch::allclose(save_var, std::get<2>(ref_outputs))); + ASSERT_TRUE(torch::equal(reserve, std::get<3>(ref_outputs))); +} +#endif // AT_CUDNN_ENABLED() +#endif // USE_CUDA diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index c05416ce0eef..0675357861f9 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -10,6 +10,7 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp + ${TORCH_ROOT}/torch/nativert/graph/GraphUtils.cpp ${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp ${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp ${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt deleted file mode 100644 index 8fe6ffd525e9..000000000000 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ /dev/null @@ -1,83 +0,0 @@ -set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) - -set(TENSOREXPR_TEST_SRCS - ${TENSOREXPR_TEST_ROOT}/test_approx.cpp - ${TENSOREXPR_TEST_ROOT}/test_aten.cpp - ${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp - ${TENSOREXPR_TEST_ROOT}/test_conv.cpp - ${TENSOREXPR_TEST_ROOT}/test_cpp_codegen.cpp - ${TENSOREXPR_TEST_ROOT}/test_dynamic_shapes.cpp - ${TENSOREXPR_TEST_ROOT}/test_expr.cpp - ${TENSOREXPR_TEST_ROOT}/test_external_calls.cpp - ${TENSOREXPR_TEST_ROOT}/test_graph_opt.cpp - ${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp - ${TENSOREXPR_TEST_ROOT}/test_ir_verifier.cpp - ${TENSOREXPR_TEST_ROOT}/test_kernel.cpp - ${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp - ${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp - ${TENSOREXPR_TEST_ROOT}/test_ops.cpp - ${TENSOREXPR_TEST_ROOT}/test_quantization.cpp - ${TENSOREXPR_TEST_ROOT}/test_memplanning.cpp - ${TENSOREXPR_TEST_ROOT}/test_reductions.cpp - ${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp - ${TENSOREXPR_TEST_ROOT}/test_simplify.cpp - ${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp - ${TENSOREXPR_TEST_ROOT}/test_type.cpp - ${TENSOREXPR_TEST_ROOT}/test_type_specializations.cpp -) - -if(USE_CUDA) - list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp) -endif() - -if(USE_LLVM AND LLVM_FOUND) - list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp) -endif() - -add_executable(test_tensorexpr - ${TORCH_ROOT}/test/cpp/common/main.cpp - ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp - ${TENSOREXPR_TEST_SRCS}) - -target_link_libraries(test_tensorexpr PRIVATE torch gtest_main) -target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) -target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) - -add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) -target_link_libraries(tutorial_tensorexpr PRIVATE torch) -target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) - -# The test case depends on the xnnpack header which in turn depends on the -# pthreadpool header. For some build environment we need add the dependency -# explicitly. -if(USE_PTHREADPOOL) - target_link_libraries(test_tensorexpr PRIVATE pthreadpool_interface) -endif() -if(USE_CUDA) - target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) - target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) -elseif(USE_ROCM) - target_link_libraries(test_tensorexpr PRIVATE - hiprtc::hiprtc - hip::amdhip64 - ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) - - target_link_libraries(tutorial_tensorexpr PRIVATE - hiprtc::hiprtc - hip::amdhip64 - ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) -endif() - -if(INSTALL_TEST) - set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS test_tensorexpr DESTINATION bin) - set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS tutorial_tensorexpr DESTINATION bin) - # Install PDB files for MSVC builds - if(MSVC AND BUILD_SHARED_LIBS) - install(FILES $ DESTINATION bin OPTIONAL) - install(FILES $ DESTINATION bin OPTIONAL) - endif() -endif() diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md deleted file mode 100644 index f86a50a65e80..000000000000 --- a/test/cpp/tensorexpr/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# TensorExpr C++ Tests - -## How to add a new test -First, create a new test file. Test files should have be placed in this -directory, with a name that starts with `test_`, like `test_foo.cpp`. - -Here is an example test file you can copy-paste. -```cpp -#include - -// Tests go in torch::jit -namespace torch { -namespace jit { - -// 1. Test cases are void() functions. -// 2. They start with the prefix `test` -void testCaseOne() { - // ... -} - -void testCaseTwo() { - // ... -} -} -} -``` - -Then, register your test in `tests.h`: -```cpp -// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests -#define TH_FORALL_TESTS(_) \ - _(ADFormulas) \ - _(Attributes) \ - ... - _(CaseOne) // note that the `test` prefix is omitted. - _(CaseTwo) -``` - -We glob all the test files together in `CMakeLists.txt` so that you don't -have to edit it every time you add a test. Unfortunately, this means that in -order to get the build to pick up your new test file, you need to re-run -cmake: -```bash -CMAKE_FRESH=1 python setup.py build -``` - -## How do I run the tests? -The following commands assume you are in PyTorch root. - - ```bash - # (re)build the test binary - ninja build/bin/test_tensorexpr - # run - build/bin/test_tensorexpr --gtest_filter='glob_style_filter*' - ``` diff --git a/test/cpp/tensorexpr/gtest_assert_float_eq.h b/test/cpp/tensorexpr/gtest_assert_float_eq.h deleted file mode 100644 index f85264a8f5d3..000000000000 --- a/test/cpp/tensorexpr/gtest_assert_float_eq.h +++ /dev/null @@ -1,119 +0,0 @@ -#pragma once - -#include -// Copyright 2005, Google Inc. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// -// The Google C++ Testing and Mocking Framework (Google Test) -// -// This header file declares functions and macros used internally by -// Google Test. They are subject to change without notice. - -using Bits = uint32_t; - -// this avoids the "dereferencing type-punned pointer -// will break strict-aliasing rules" error -union Float { - float float_; - Bits bits_; -}; - -// # of bits in a number. -static const size_t kBitCount = 8 * sizeof(Bits); -// The mask for the sign bit. -static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); - -// GOOGLETEST_CM0001 DO NOT DELETE - -// Converts an integer from the sign-and-magnitude representation to -// the biased representation. More precisely, let N be 2 to the -// power of (kBitCount - 1), an integer x is represented by the -// unsigned number x + N. -// -// For instance, -// -// -N + 1 (the most negative number representable using -// sign-and-magnitude) is represented by 1; -// 0 is represented by N; and -// N - 1 (the biggest number representable using -// sign-and-magnitude) is represented by 2N - 1. -// -// Read http://en.wikipedia.org/wiki/Signed_number_representations -// for more details on signed number representations. -static Bits SignAndMagnitudeToBiased(const Bits& sam) { - if (kSignBitMask & sam) { - // sam represents a negative number. - return ~sam + 1; - } else { - // sam represents a positive number. - return kSignBitMask | sam; - } -} - -// Given two numbers in the sign-and-magnitude representation, -// returns the distance between them as an unsigned number. -static Bits DistanceBetweenSignAndMagnitudeNumbers( - const Bits& sam1, - const Bits& sam2) { - const Bits biased1 = SignAndMagnitudeToBiased(sam1); - const Bits biased2 = SignAndMagnitudeToBiased(sam2); - return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); -} - -// How many ULP's (Units in the Last Place) we want to tolerate when -// comparing two numbers. The larger the value, the more error we -// allow. A 0 value means that two numbers must be exactly the same -// to be considered equal. -// -// The maximum error of a single floating-point operation is 0.5 -// units in the last place. On Intel CPU's, all floating-point -// calculations are done with 80-bit precision, while double has 64 -// bits. Therefore, 4 should be enough for ordinary use. -// -// See the following article for more details on ULP: -// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ -static const size_t kMaxUlps = 4; - -// Returns true if and only if this number is at most kMaxUlps ULP's away -// from rhs. In particular, this function: -// -// - returns false if either number is (or both are) NAN. -// - treats really large numbers as almost equal to infinity. -// - thinks +0.0 and -0.0 are 0 DLP's apart. -inline bool AlmostEquals(float lhs, float rhs) { - // The IEEE standard says that any comparison operation involving - // a NAN must return false. - if (std::isnan(lhs) || std::isnan(rhs)) - return false; - - Float l = {lhs}; - Float r = {rhs}; - - return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps; -} diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp deleted file mode 100644 index 424d82c77453..000000000000 --- a/test/cpp/tensorexpr/padded_buffer.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "test/cpp/tensorexpr/padded_buffer.h" - -#include -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -int PaddedBufferBase::Index(const std::vector& indices) const { - TORCH_DCHECK_EQ(dims_.size(), indices.size()); - int total_index = 0; - for (const auto i : c10::irange(dims_.size())) { - total_index += indices[i] * strides_[i]; - } - return total_index; -} - -PaddedBufferBase::PaddedBufferBase( - const std::vector& dims, - // NOLINTNEXTLINE(modernize-pass-by-value) - const std::string& name) - : dims_(dims), name_(name), strides_(dims.size()) { - for (int i = (int)dims.size() - 1; i >= 0; --i) { - if (i == (int)dims.size() - 1) { - strides_[i] = 1; - } else { - strides_[i] = strides_[i + 1] * dims[i + 1]; - } - } - total_size_ = strides_[0] * dims[0]; -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.h b/test/cpp/tensorexpr/padded_buffer.h deleted file mode 100644 index b3e5227ae7e6..000000000000 --- a/test/cpp/tensorexpr/padded_buffer.h +++ /dev/null @@ -1,242 +0,0 @@ -#pragma once - -#include -#include - -#include -#include "torch/csrc/jit/tensorexpr/eval.h" - -namespace torch { -namespace jit { -namespace tensorexpr { - -template -struct DefaultPaddedValue; - -template <> -struct DefaultPaddedValue { - static const int kValue = static_cast(0xDEADBEEF); -}; - -template <> -struct DefaultPaddedValue { - static const int8_t kValue = static_cast(0xBE); -}; - -template <> -struct DefaultPaddedValue { - static const uint8_t kValue = static_cast(0xBE); -}; - -template <> -struct DefaultPaddedValue { - static const int16_t kValue = static_cast(0xBEEF); -}; - -template <> -struct DefaultPaddedValue { - static const int64_t kValue = static_cast(0xDEADBEEF); -}; - -template <> -struct DefaultPaddedValue { - static constexpr float kValue = 0.1357; -}; - -template <> -struct DefaultPaddedValue { - // at::Half ctor isn't constexpr, so just fill it with bits. - static constexpr uint16_t kValue = 1357; -}; - -template <> -struct DefaultPaddedValue { - static constexpr double kValue = 0.1357; -}; - -// A concrete base to be used in PaddedBase. -class PaddedBufferBase { - public: - const std::string& name() const { - return name_; - } - - int size() const { - return total_size_; - } - - int raw_size() const { - return total_size_ + 2 * kPaddingSize; - } - - virtual ~PaddedBufferBase() {} - - protected: - explicit PaddedBufferBase( - const std::vector& dims, - const std::string& name); - int Index(const std::vector& indices) const; - - std::vector dims_; - std::string name_; - std::vector strides_; - int total_size_; // total number of useful element, does not include the - // paddings - static constexpr int kPaddingSize = 64; -}; - -// A padded buffer with wartermarks for testing. -// The buffer carries padded watermarks on both sides to catch potential -// out-of-bounds writes. For read-only data that are not supposed to change, it -// can also make a backup and be compared later. -template -class PaddedBuffer : public PaddedBufferBase { - public: - PaddedBuffer(int d0, const std::string& name = "") - : PaddedBuffer(std::vector({d0}), name) {} - PaddedBuffer(int d0, int d1, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1}), name) {} - PaddedBuffer(int d0, int d1, int d2, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1, d2}), name) {} - PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1, d2, d3}), name) {} - PaddedBuffer(const std::vector& dims, const std::string& name = "") - : PaddedBufferBase(dims, name) { - data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); - } - PaddedBuffer(const PaddedBuffer& other, const std::string& name) - : PaddedBuffer(other) { - this->name_ = name; - } - - T* data() { - return data_.data() + kPaddingSize; - } - const T* data() const { - return const_cast(this)->data(); - } - T* raw_data() { - return data_.data(); - } - const T* raw_data() const { - return const_cast(this)->raw_data(); - } - T& operator()(int i0) { - // There is a bit performance impact with forming a vector here. But this - // data structure is for testing only, and not performance critical. - return this->operator()(std::vector({i0})); - } - const T& operator()(int i0) const { - return const_cast(this)->operator()(i0); - } - T& operator()(int i0, int i1) { - return this->operator()(std::vector({i0, i1})); - } - const T& operator()(int i0, int i1) const { - return const_cast(this)->operator()(i0, i1); - } - T& operator()(int i0, int i1, int i2) { - return this->operator()(std::vector({i0, i1, i2})); - } - const T& operator()(int i0, int i1, int i2) const { - return const_cast(this)->operator()(i0, i1, i2); - } - T& operator()(int i0, int i1, int i2, int i3) { - return this->operator()(std::vector({i0, i1, i2, i3})); - } - const T& operator()(int i0, int i1, int i2, int i3) const { - return const_cast(this)->operator()(i0, i1, i2, i3); - } - T& operator()(const std::vector& indices) { - return data_[kPaddingSize + Index(indices)]; - } - const T& operator()(const std::vector& indices) const { - return const_cast(this)->operator()(indices); - } - - template - friend void ExpectAllNear( - const PaddedBuffer& v1, - const PaddedBuffer& v2, - float abs_error); - template - friend void ExpectAllEqual( - const PaddedBuffer& v1, - const PaddedBuffer& v2); - void Backup() { - backup_data_ = data_; - } - - // Verify the watermarks in the paddings are intact. - void ValidateWatermark() const { - for (const auto i : c10::irange(kPaddingSize)) { - ASSERT_EQ(data_[i], kPaddingValue); - ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue); - } - } - - void CheckBackup() const { - ValidateWatermark(); - DCHECK(backup_data_.size() == data_.size()) - << "Please make sure you have call Backup() before calling CheckBackup()"; - for (const auto i : c10::irange(total_size_)) { - ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]); - } - } - - private: - std::vector data_; - std::vector backup_data_; - T kPaddingValue = DefaultPaddedValue::kValue; -}; - -template -inline CodeGen::CallArg::CallArg(const PaddedBuffer& buffer) - : data_(const_cast(buffer.data())) {} - -template -std::string CompareErrorMsg( - const PaddedBuffer& v1, - const PaddedBuffer& v2, - int index) { - std::ostringstream oss; - oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index) - << ")" - << ", v2: (" << v2.name() << ", " << v2(index) << ")"; - return oss.str(); -} - -template -void ExpectAllEqual(const PaddedBuffer& f1, const PaddedBuffer& f2) { - const std::vector& v1 = f1.data_; - const std::vector& v2 = f2.data_; - const int kPaddingSize = f1.kPaddingSize; - const int total_size = f1.total_size_; - ASSERT_EQ(v1.size(), v2.size()); - f1.ValidateWatermark(); - f2.ValidateWatermark(); - for (const auto i : c10::irange(total_size)) { - ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]); - } -} - -template -void ExpectAllNear( - const PaddedBuffer& f1, - const PaddedBuffer& f2, - float abs_error) { - const std::vector& v1 = f1.data_; - const std::vector& v2 = f2.data_; - const int kPaddingSize = f1.kPaddingSize; - const int total_size = f1.total_size_; - ASSERT_EQ(v1.size(), v2.size()); - f1.ValidateWatermark(); - f2.ValidateWatermark(); - for (const auto i : c10::irange(total_size)) { - ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); - } -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp deleted file mode 100644 index e1a576aecf52..000000000000 --- a/test/cpp/tensorexpr/test_approx.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#ifdef TORCH_ENABLE_LLVM - -#include -#include -#include -#include -#include -#include -#include - -using namespace torch::indexing; -namespace te = torch::jit::tensorexpr; - -static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { - auto loops = ln->getLoopStmtsFor(target); - te::ForPtr inner, tail; - ln->splitWithTail(loops[0], width, &inner, &tail); - ASSERT_TRUE(te::LoopNest::vectorize(inner)); -} - -std::string diffs(const at::Tensor& a, const at::Tensor& b) { - auto diff = torch::abs(a.flatten() - b.flatten()); - auto count_diffs = torch::sum(diff > 0.f); - auto greatest_diff_index = torch::argmax(diff); - std::stringstream ss; - ss << "Found " << count_diffs << " unequal element(s). " - << "The greatest difference was " << diff.index({greatest_diff_index}) - << " at index " << greatest_diff_index; - return ss.str(); -} - -TEST(Approx, log_vml) { - te::VarHandle N("N", te::kInt); - te::BufHandle A("A", {N}, te::kFloat); - te::Tensor B = te::Compute( - "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); - - te::LoopNest ln({B}); - ln.prepareForCodegen(); - vectorize(&ln, B, 8); - te::StmtPtr s = ln.root_stmt(); - s = te::IRSimplifier::simplify(s); - te::LLVMCodeGen cg(s, {A, B, N}); - - auto eps = std::numeric_limits::epsilon(); - auto test = [&](const at::Tensor& A_t) { - at::Tensor B_ref = at::log(A_t); - at::Tensor B_t = at::empty_like(A_t); - auto ap = A_t.data_ptr(); - auto bp = B_t.data_ptr(); - cg.call({ap, bp, A_t.numel()}); - // Results should be bit-identical. - ASSERT_TRUE(torch::allclose( - B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) - << "Input[:8]\n" - << A_t.index({Slice(0, 8)}) << "\n" - << "Test[:8]\n" - << B_t.index({Slice(0, 8)}) << "\n" - << "Ref[:8]\n" - << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); - }; - - // Generate every single-precision FP value in [1.0, 2.0). - at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); - ASSERT_EQ(A_t.numel(), 1 << 23); - - test(A_t); - - test(A_t * 2.0f); - test(A_t * 0.5f); - - test(A_t * 4.0f); - test(A_t * 0.25f); - - test(A_t * powf(2.0f, 16)); - test(A_t * powf(2.0f, -16)); - - test(A_t * powf(2.0f, 126)); - test(A_t * powf(2.0f, -126)); - - test(torch::full({32}, INFINITY)); - test(torch::full({32}, NAN)); - - auto min = std::numeric_limits::min(); - auto denorm_min = std::numeric_limits::denorm_min(); - - // Denormals aren't bit precise, because sleef isn't bit-precise either. - A_t = torch::arange(0.0f, min, denorm_min); - ASSERT_EQ(A_t.numel(), 1 << 23); - auto B_ref = at::log(A_t); - auto B_t = at::empty_like(B_ref); - cg.call({A_t.data_ptr(), B_t.data_ptr(), A_t.numel()}); - ASSERT_TRUE(torch::allclose(B_t, B_ref)); -} - -#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp deleted file mode 100644 index 34ce2bd069d5..000000000000 --- a/test/cpp/tensorexpr/test_aten.cpp +++ /dev/null @@ -1,1068 +0,0 @@ -#include -#include -#include - -#include - -#include -#include -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "test/cpp/tensorexpr/test_base.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(ATen, _cast_Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Cast::make(kFloat, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), static_cast(i)); - } -} - -TEST(ATen, negInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Sub::make(0, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), -static_cast(i)); - } -} - -TEST(ATen, negFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Sub::make(0, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), -i); - } -} - -TEST(ATen, addInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); - } -} - -TEST(ATen, addFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); - } -} - -TEST(ATen, subInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); - } -} - -TEST(ATen, subFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); - } -} - -TEST(ATen, lerp) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))); - } -} - -TEST(ATen, addcmulInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - ExprHandle load_d = d_buf.load(index); - StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - PaddedBuffer e_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - d_v(i) = 5 * i + 3; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); - ir_eval(a_v, b_v, c_v, d_v, e_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), 5 * i + 3); - ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); - } -} - -TEST(ATen, addcmulFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - ExprHandle load_d = d_buf.load(index); - StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - PaddedBuffer e_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - d_v(i) = 5 * i + 3; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); - ir_eval(a_v, b_v, c_v, d_v, e_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), 5 * i + 3); - ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); - } -} - -TEST(ATen, mulInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a * load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); - } -} - -TEST(ATen, mulFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a * load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); - } -} - -TEST(ATen, divInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a / load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = 2 * i + 1; - b_v(i) = i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), 2 * i + 1); - ASSERT_EQ(b_v(i), i + 1); - ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); - } -} - -TEST(ATen, divFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a / load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = 2 * i + 1; - b_v(i) = i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), 2 * i + 1); - ASSERT_EQ(b_v(i), i + 1); - ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); - } -} - -TEST(ATen, maxInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i))); - } -} - -TEST(ATen, maxFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))); - } -} - -TEST(ATen, minInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i))); - } -} - -TEST(ATen, minFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))); - } -} - -void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 1.0f / i); - } -} - -TEST(ATen, reluInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i - 64; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i - 64); - ASSERT_EQ(b_v(i), std::max(a_v(i), 0)); - } -} - -TEST(ATen, reluFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store( - {index}, Max::make(load_a, 0, false) // relu does not propagate nans - ); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i - 64; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i - 64); - ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0)); - } -} - -TEST(ATen, logFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log(a_v(i))); - } -} - -TEST(ATen, fastLogFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::log(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_FLOAT_EQ(test, ref); - } - } -} - -TEST(ATen, fastTanhFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::tanh(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_NEAR(test, ref, 1e-6); - } - } -} - -TEST(ATen, fastSigmoidFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - at::Tensor t = at::ones({1}) * a_v(i); - float ref = at::sigmoid(t).item().to(); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_NEAR(test, ref, 1e-6); - } - } -} - -TEST(ATen, log10Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log10(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log10(a_v(i))); - } -} - -TEST(ATen, log2Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log2(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log2(a_v(i))); - } -} - -TEST(ATen, expFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, exp(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::exp(a_v(i))); - } -} - -TEST(ATen, erfFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, erf(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::erf(a_v(i))); - } -} - -TEST(ATen, cosFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, cos(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::cos(a_v(i))); - } -} - -TEST(ATen, eqInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, geInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGE))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, gtInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 6); - std::vector b_buffer(N, 3); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGT))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, leInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLE))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, ltInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLT))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 0); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h deleted file mode 100644 index 68b96fe6c90f..000000000000 --- a/test/cpp/tensorexpr/test_base.h +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once - -#if defined(USE_GTEST) -#include -#include -#else -#include -#include "c10/util/Exception.h" -#include "test/cpp/tensorexpr/gtest_assert_float_eq.h" -#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__) -#define ASSERT_FLOAT_EQ(x, y, ...) \ - TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__) -#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__) -#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__) -#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__) -#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__) -#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__) - -#define ASSERT_NEAR(x, y, a, ...) \ - TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__) - -#define ASSERT_TRUE TORCH_INTERNAL_ASSERT -#define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) -#define ASSERT_THROWS_WITH(statement, substring) \ - try { \ - (void)statement; \ - ASSERT_TRUE(false); \ - } catch (const std::exception& e) { \ - ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ - } -#define ASSERT_ANY_THROW(statement) \ - { \ - bool threw = false; \ - try { \ - (void)statement; \ - } catch (const std::exception& e) { \ - threw = true; \ - } \ - ASSERT_TRUE(threw); \ - } - -#endif // defined(USE_GTEST) -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -template -void ExpectAllNear( - const std::vector& v1, - const std::vector& v2, - V threshold, - const std::string& name = "") { - ASSERT_EQ(v1.size(), v2.size()); - for (size_t i = 0; i < v1.size(); i++) { - ASSERT_NEAR(v1[i], v2[i], threshold); - } -} - -template -void ExpectAllNear( - const std::vector& vec, - const U& val, - V threshold, - const std::string& name = "") { - for (size_t i = 0; i < vec.size(); i++) { - ASSERT_NEAR(vec[i], val, threshold); - } -} - -template -static void assertAllEqual(const std::vector& vec, const T& val) { - for (auto const& elt : vec) { - ASSERT_EQ(elt, val); - } -} - -template -static void assertAllEqual(const std::vector& v1, const std::vector& v2) { - ASSERT_EQ(v1.size(), v2.size()); - for (size_t i = 0; i < v1.size(); ++i) { - ASSERT_EQ(v1[i], v2[i]); - } -} -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp deleted file mode 100644 index 2605842d6e74..000000000000 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ /dev/null @@ -1,1019 +0,0 @@ -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -static void verifyConstBounds( - const TensorAccessBoundsInfo& access_info, - const std::vector>& ref) { - size_t ndim = ref.size(); - ASSERT_EQ(access_info.start.size(), ndim); - ASSERT_EQ(access_info.stop.size(), ndim); - for (const auto i : c10::irange(ndim)) { - if (ref[i].first >= 0) { // Negative values are used to skip the check - ASSERT_TRUE(access_info.start[i]->isConstant()); - int start_i = immediateAs(access_info.start[i]); - ASSERT_EQ(start_i, ref[i].first); - } - if (ref[i].second >= 0) { - ASSERT_TRUE(access_info.stop[i]->isConstant()); - int stop_i = immediateAs(access_info.stop[i]); - ASSERT_EQ(stop_i, ref[i].second); - } - } -} - -TEST(BoundsInference, _1) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); -} - -TEST(BoundsInference, _2) { - // Verify that bounds inference works for the following example: - // for i in 0..n: - // b[i] = a[i] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); -} - -TEST(BoundsInference, _3) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] * a[i+10] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} - ExprHandle n(100); - BufHandle a("a", {n + 10}, kFloat); - Tensor b = Compute( - "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); -} - -TEST(BoundsInference, _4) { - // Verify that bounds inference works for the following example: - // - // for y in 0..200: - // for x in 0..320: - // b[y,x] = x*y - // for y in 0..200: - // for x in 0..320: - // c[y,x] = a[y,x] * b[y,x] - ExprHandle W(320); - ExprHandle H(200); - BufHandle a("a", {H, W}, kFloat); - Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return x * y; - }); - Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y, x) * b.load(y, x); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - StmtPtr body = l.getLoopBodyFor(c); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); - } - { - // Infer bounds on the inner loop body's scope - auto bounds_info = inferBounds(body); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); - } -} - -TEST(BoundsInference, _5) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] - // - // ==> split ==> - // - // for i_outer in 0..100/16: - // for i_inner in 0..16: - // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] - // for i_tail in 0..100%16: - // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - ForPtr inner; - ForPtr tail; - std::vector loops = l.getLoopStmtsFor(b); - LoopNest::splitWithTail(loops[0], 16, &inner, &tail); - ForPtr outer = loops[0]; - - { - // Verify inferred bounds for the outer loop - auto bounds_info = inferBounds(outer); - ASSERT_EQ(bounds_info.size(), 2); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); - } - { - // Verify inferred bounds for the tail loop - auto bounds_info = inferBounds(tail); - ASSERT_EQ(bounds_info.size(), 2); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); - } -} - -TEST(BoundsInference, _6) { - // Verify that bounds inference works for the following example: - // - // for y in 0..200: - // for x in 0..320: - // b[y,x] = x*y - // for y in 0..20: - // for x in 0..32: - // c[y,x] = a[y+100,x+100] * b[y*2,x*5] - ExprHandle W(320); - ExprHandle H(200); - ExprHandle CW(32); - ExprHandle CH(20); - BufHandle a("a", {H, W}, kFloat); - Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return x * y; - }); - Tensor c = - Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - StmtPtr body = l.getLoopBodyFor(c); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); - } - { - // Infer bounds on the inner loop body's scope - auto bounds_info = inferBounds(body); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); - } -} - -TEST(BoundsInference, Adjacent) { - ExprHandle H(6); - BufHandle a("a", {20}, kFloat); - Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); }); - Tensor c = - Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); }); - LoopNest l({b, c}); - std::vector loops = NodeFinder::find(l.root_stmt()); - - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0:5], writes to b[0:5] - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0+6:5+6], writes to c[0:5] - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); - } - { - // Infer bounds on the high level program. - auto bounds_info = inferBounds(l.root_stmt()); - ASSERT_EQ(bounds_info.size(), 3); - - // Should be union of above 2 bounds, but this time the bounds of A can be - // merged. - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); - } -} - -TEST(BoundsInference, MultipleTopLoopLoad) { - BufHandle a("a", {100}, kFloat); - Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); }); - Tensor c = - Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); }); - Tensor d = - Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); }); - LoopNest l({b, c, d}); - - auto bounds_info = inferBounds(l.root_stmt()); - - ASSERT_EQ(bounds_info.size(), 4); - - // a only read. - { - auto bounds = bounds_info[a.node()]; - ASSERT_EQ(bounds.size(), 1); - // One dimension. - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); - // Bounds: - // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). - // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = - // 96 + 2 - 1 (d). - verifyConstBounds(bound, {{0, 97}}); - } - - // b, c, d only written. - { - auto bounds = bounds_info[b.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for b. - verifyConstBounds(bound, {{0, 63}}); - } - { - auto bounds = bounds_info[c.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for c. - verifyConstBounds(bound, {{0, 31}}); - } - { - auto bounds = bounds_info[d.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for d. - verifyConstBounds(bound, {{0, 95}}); - } -} - -TEST(BoundsInference, MultipleTopLoopStore) { - BufHandle a("a", {100}, kFloat); - BufHandle b("b", {100}, kFloat); - BufHandle c("c", {100}, kFloat); - BufHandle d("d", {100}, kFloat); - VarHandle x("x", kInt); - - // Same as above but the offsets are on the Store now. - // Can't do this through ComputeAPI without transforms we don't have yet. - StmtPtr stmt = Block::make( - {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), - For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), - For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); - - auto bounds_info = inferBounds(stmt); - - ASSERT_EQ(bounds_info.size(), 4); - - // a only read. - { - auto bounds = bounds_info[a.node()]; - ASSERT_EQ(bounds.size(), 1); - // One dimension. - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); - // Bounds: there are no offsets, so this is just the max loop bounds. - verifyConstBounds(bound, {{0, 95}}); - } - - // b, c, d only written. - { - auto bounds = bounds_info[b.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the b loop. - // b loop has no offset, so just the loop extents. - verifyConstBounds(bound, {{0, 63}}); - } - { - auto bounds = bounds_info[c.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the c loop. - // Offset is 10, extent is 32-1. - verifyConstBounds(bound, {{10, 41}}); - } - { - auto bounds = bounds_info[d.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the d loop. - // Offset is 2, extent is 96-1. - verifyConstBounds(bound, {{2, 97}}); - } -} - -TEST(BoundsInference, CacheReads) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 3); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}); - auto bounds_info_before = inferBounds(l.root_stmt()); - - StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - - auto bounds_info_after = inferBounds(l.root_stmt()); - - // CacheAccesses should not change existing bounds, but add a new one for the - // cache. - for (auto& pair : bounds_info_after) { - auto beforeIt = bounds_info_before.find(pair.first); - if (beforeIt != bounds_info_before.end()) { - // Same number of TensorAccessBoundInfos. - ASSERT_EQ(pair.second.size(), beforeIt->second.size()); - - for (const auto i : c10::irange(pair.second.size())) { - TensorAccessBoundsInfo& after = pair.second[i]; - TensorAccessBoundsInfo& before = beforeIt->second[i]; - // Same number of dimensions. - ASSERT_EQ(before.start.size(), after.start.size()); - - // Bounds are equal. - for (const auto j : c10::irange(before.start.size())) { - ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); - ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); - } - } - } else { - // This should be the cache. - ASSERT_EQ(pair.first->name_hint(), "A_local"); - // Should have both a load and a store. - ASSERT_EQ(pair.second.size(), 2); - TensorAccessBoundsInfo& first = pair.second[0]; - TensorAccessBoundsInfo& second = pair.second[1]; - - ASSERT_NE(first.kind, second.kind); - // 2 dimensions. - ASSERT_EQ(first.start.size(), second.start.size()); - ASSERT_EQ(first.start.size(), 2); - - // bounds for load and store are equal. - for (const auto j : c10::irange(first.start.size())) { - ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); - ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); - } - } - } -} - -TEST(BoundsInference, Flattened) { - Tensor b = Compute( - "b", - {3, 4, 5}, - [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { - return x * y + z; - }); - - LoopNest l({b}); - // Flatten indices. - l.prepareForCodegen(); - auto bounds_info = inferBounds(l.root_stmt()); - - // There's only one buffer. - ASSERT_EQ(bounds_info.size(), 1); - auto& TABI = bounds_info[b.buf()][0]; - ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); - // Flattened bounds should have a single dimension. - ASSERT_EQ(TABI.start.size(), 1); - ASSERT_EQ(TABI.stop.size(), 1); - - // Bounds should be 0 -> (3*4*5)-1 - ASSERT_TRUE(exprEquals(TABI.start[0], alloc(0))); - ASSERT_TRUE(exprEquals(TABI.stop[0], alloc(3 * 4 * 5 - 1))); -} - -TEST(BoundsInference, GetPotentialHazards) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* - * A[0] = B[0]; - * B[0] = 3; WAR on B - * A[0] = B[0]; WAW on A, RAW on B - * C[0] = 5; - */ - - StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); - StorePtr store2 = Store::make(b, {0}, 3); - StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); - StorePtr store4 = Store::make(c, {0}, 5); - StmtPtr stmt = Block::make({store1, store2, store3, store4}); - - MemDependencyChecker analyzer; - stmt->accept(&analyzer); - - ASSERT_EQ( - HazardKind::WriteAfterRead, - getPotentialHazards(analyzer, store1, store2)); - - ASSERT_EQ( - HazardKind::ReadAfterWrite, - getPotentialHazards(analyzer, store2, store3)); - - ASSERT_EQ( - HazardKind::WriteAfterWrite, - getPotentialHazards(analyzer, store1, store3)); - - // Fourth store has no dependencies - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store1, store4)); - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store2, store4)); - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store3, store4)); - } -} - -TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return (i + 1) * (j + 1); - }); - - LoopNest l({A, B}); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; - ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; - - // No dependencies between loops. - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, loopRootA, loopRootB)); -} - -TEST(BoundsInference, GetPotentialHazardsLoopCall) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + 5; - }); - - LoopNest l({A, B}); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; - ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; - - ASSERT_EQ( - HazardKind::ReadAfterWrite, - getPotentialHazards(analyzer, loopRootA, loopRootB)); -} - -TEST(BoundsInference, GetPotentialHazardsLoopSplit) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - - LoopNest l({A}); - ForPtr inner, tail; - - // Splitting with tail by something offset creates a tail which also writes to - // A. - ForPtr outer = l.getLoopStmtsFor(A)[0]; - // `outer` loop get transformed to the outer loop after splitting. - LoopNest::splitWithTail(outer, 5, &inner, &tail); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ASSERT_EQ( - HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k-1] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // B[k] = A[k]; - // } - BufHandle a_buf("A", {200}, kInt); - BufHandle b_buf("B", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k}))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k+100] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // A[m+1,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // A[m+20,n+100] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // B[m,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { - // Input IR: - // for (const auto j : c10::irange(100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(100)) { - // B[k] = 20 * A[99-k]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { - // Input IR: - // for (const auto k : c10::irange(100)) { - // B[k] = 20 * A[99-k]; - // } - // for (const auto j : c10::irange(100)) { - // A[j] = 10 * j; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto par = Block::make({forK, forJ}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapWithLoads) { - // Input IR: - // for (const auto k : c10::irange(10, 100)) { - // B[k] = 20 * A[99-k]; - // } - // for (const auto j : c10::irange(10, 100)) { - // C[j] = 10 * A[j]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - BufHandle c_buf("C", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 10, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make( - j, - 10, - 100, - Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j})))); - auto par = Block::make({forK, forJ}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, IsOverlapping) { - // Input IR: - // for (const auto i : c10::irange(100)) { - // A[i] = i * 10; // storeA1 - // B[i] = A[99-i] * 20; // loadA1 - // C[i] = A[i + 100] * 10; // loadA2 - // A[i + 50] = i * 50; // storeA2 - // A[i + 150] = i * 150; // storeA3 - // } - BufHandle a_buf("A", {300}, kInt); - BufHandle b_buf("B", {100}, kInt); - BufHandle c_buf("C", {100}, kInt); - VarHandle i("i", kInt); - auto storeA1 = Store::make(a_buf, {i}, i * 10); - auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i}); - auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20)); - auto loadA2 = Load::make(a_buf, {i + 100}); - auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10)); - auto storeA2 = Store::make(a_buf, {i + 50}, i * 50); - auto storeA3 = Store::make(a_buf, {i + 150}, i * 150); - auto forI = For::make( - i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); - tensorexpr::analysis::MemDependencyChecker analyzer; - forI->accept(&analyzer); - ASSERT_TRUE(isOverlapping(analyzer, storeA1, to(loadA1.node()))); - ASSERT_FALSE(isOverlapping(analyzer, storeA1, to(loadA2.node()))); - ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); - ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp deleted file mode 100644 index e72303873a6c..000000000000 --- a/test/cpp/tensorexpr/test_conv.cpp +++ /dev/null @@ -1,234 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -namespace te = torch::jit::tensorexpr; -namespace F = torch::nn::functional; - -#ifdef TORCH_ENABLE_LLVM - -// Generate test data with few bits of precision, to minimize error -// accumulation from floating-point reordering. -static at::Tensor genTestData(c10::IntArrayRef args) { - return at::trunc(at::randn(args) * 256.0f) / 256.0f; -} - -TEST(Conv, DepthwiseConv2D) { - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - te::BufHandle input("input", {N, C, H, W}, te::kFloat); - te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); - te::BufHandle bias("bias", {K}, te::kFloat); - te::Tensor output = - te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output}); - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto bt = genTestData({K}); - auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - cg.call( - {it.data_ptr(), - wt.data_ptr(), - bt.data_ptr(), - ot.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -TEST(Conv, DepthwiseConv2DNoBias) { - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - te::BufHandle input("input", {N, C, H, W}, te::kFloat); - te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); - te::Tensor output = - te::conv2d_depthwise(input, weight, kStride, kPad, kGroups); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output}); - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto ref = - at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - cg.call({it.data_ptr(), wt.data_ptr(), ot.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -TEST(Conv, DepthwiseConv2DDynamicShapes) { - te::VarHandle N_var("N", te::kInt); - te::VarHandle C_var("C", te::kInt); - te::VarHandle H_var("H", te::kInt); - te::VarHandle W_var("W", te::kInt); - te::VarHandle K_var("K", te::kInt); - te::VarHandle CperG_var("CperG", te::kInt); - te::VarHandle R_var("R", te::kInt); - te::VarHandle S_var("S", te::kInt); - te::VarHandle kPad_var("kPad", te::kInt); - te::VarHandle kStride_var("kStride", te::kInt); - te::VarHandle kGroups_var("kGroups", te::kInt); - - te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat); - te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat); - te::Tensor output = te::conv2d_depthwise( - input, - weight, - N_var, - C_var, - H_var, - W_var, - K_var, - CperG_var, - R_var, - S_var, - kStride_var, - kPad_var, - kGroups_var); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - std::vector buffer_args = { - input, - weight, - N_var, - C_var, - H_var, - W_var, - K_var, - CperG_var, - R_var, - S_var, - kPad_var, - kStride_var, - kGroups_var, - output}; - te::LLVMCodeGen cg(loop.root_stmt(), buffer_args); - - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto ref = - at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - std::vector call_args = { - it.data_ptr(), - wt.data_ptr(), - N, - C, - H, - W, - K, - CperG, - R, - S, - kPad, - kStride, - kGroups, - ot.data_ptr()}; - cg.call(call_args); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -#endif - -TEST(Conv, Conv2D) { - // Input dimensions. - constexpr int N = 1; - constexpr int C = 3; - constexpr int H = 11; - constexpr int W = 11; - - // Filter dimensions. - constexpr int K = 8; - constexpr int R = 3; - constexpr int S = 3; - - // Output dims. - constexpr int OH = H - R + 1; - constexpr int OW = W - S + 1; - - // Compute reference result. - at::Tensor input = torch::randn({N, C, H, W}); - at::Tensor filter = torch::randn({K, C, R, S}); - at::Tensor ref = F::conv2d(input, filter); - - // Double check the output size is as expected. - ASSERT_EQ(ref.size(0), N); - ASSERT_EQ(ref.size(1), K); - ASSERT_EQ(ref.size(2), OH); - ASSERT_EQ(ref.size(3), OW); - - te::BufHandle inputB("input", {N, C, H, W}, te::kFloat); - te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat); - - te::Tensor conv = te::Reduce( - "conv", - {N, K, OH, OW}, - te::Sum(), - // FIXME: We have to use a `std::vector` parameter here and then unpack - // it, because we don't have an overload allowing for an arbitrary number - // of ExprHandle/VarHandle parameters. - [&](const std::vector& v) { - auto const& n = v[0]; - auto const& k = v[1]; - auto const& oh = v[2]; - auto const& ow = v[3]; - auto const& c = v[4]; - auto const& r = v[5]; - auto const& s = v[6]; - // FIXME: We have to use `call` and construct a `std::vector` here - // because the `operator()` overload is only specialized for a small - // number of arguments. - return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); - }, - // FIXME: If you forget one of the reduction dims, you get a segfault. - // Could that be caught by a verifier? - {C, R, S}); - - // FIXME: It'd be nice to have a single header that pulls in things like - // LoopNest, IRSimplifier, etc. - te::LoopNest loop({conv}); - loop.prepareForCodegen(); - te::StmtPtr s = loop.root_stmt(); - s = te::IRSimplifier::simplify(s); - - at::Tensor result = at::empty_like(ref); - te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); - cg.call( - {input.data_ptr(), - filter.data_ptr(), - result.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp deleted file mode 100644 index ed7679053637..000000000000 --- a/test/cpp/tensorexpr/test_cpp_codegen.cpp +++ /dev/null @@ -1,259 +0,0 @@ -#include - -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -#define STR_CHECK(node, expected) \ - std::stringstream ss; \ - CppPrinter printer(&ss); \ - printer.visit(node); \ - ASSERT_EQ(ss.str(), expected) - -#define FILE_CHECK(node, pattern) \ - std::stringstream ss; \ - CppPrinter printer(&ss); \ - printer.visit(node); \ - torch::jit::testing::FileCheck().run(pattern, ss.str()) - -TEST(CppPrinter, IntImm) { - auto i = alloc(10); - STR_CHECK(i, "10"); -} - -TEST(CppPrinter, FloatImm) { - auto f = alloc(10); - STR_CHECK(f, "10.f"); -} - -TEST(CppPrinter, FloatImm1) { - auto f = alloc(10); - STR_CHECK(f, "10.f"); -} - -TEST(CppPrinter, DoubleImm) { - auto d = alloc(10); - STR_CHECK(d, "10.0"); -} - -TEST(CppPrinter, DoubleImm1) { - auto d = alloc(10.1); - STR_CHECK(d, "10.1"); -} - -TEST(CppPrinter, HalfImm) { - auto h = alloc(10); - STR_CHECK(h, "10"); -} - -TEST(CppPrinter, Add) { - auto add = alloc(alloc(1), alloc(2)); - STR_CHECK(add, "1 + 2"); -} - -TEST(CppPrinter, AddExpr1) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc(alloc(2), alloc(3))); - STR_CHECK(add, "(0 + 1) + (2 - 3)"); -} - -TEST(CppPrinter, AddExpr2) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc(alloc(2), alloc(3))); - STR_CHECK(add, "0 * 1 + (2 - 3)"); -} - -TEST(CppPrinter, AddExpr3) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc
(alloc(2), alloc(3))); - STR_CHECK(add, "(0 + 1) + 2 / 3"); -} - -TEST(CppPrinter, Mod) { - auto mod = alloc(alloc(1), alloc(2)); - STR_CHECK(mod, "1 % 2"); -} - -TEST(CppPrinter, ModFloat) { - auto mod = alloc(alloc(1), alloc(2)); - STR_CHECK(mod, "std::fmod(1.f, 2.f)"); -} - -TEST(CppPrinter, Max) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "std::max(1, 2)"); -} - -TEST(CppPrinter, MaxFloat) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "std::max(1.f, 2.f)"); -} - -TEST(CppPrinter, MaxHalf) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "(1 < 2) ? 2 : 1"); -} - -TEST(CppPrinter, And) { - auto v = alloc(alloc(1), alloc(2)); - STR_CHECK(v, "1 & 2"); -} - -TEST(CppPrinter, CompareSelect) { - auto cs = alloc( - alloc(1), - alloc(2), - alloc(1), - alloc(2), - CompareSelectOperation::kLE); - STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); -} - -TEST(CppPrinter, IfThenElse) { - auto cond = alloc(alloc(1), alloc(2)); - auto true_value = alloc(alloc(0), alloc(1)); - auto false_value = alloc(alloc(2), alloc(3)); - auto v = alloc(cond, true_value, false_value); - STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); -} - -TEST(CppPrinter, AllocateFree) { - BufHandle buf("x", {2, 3}, kInt); - AllocatePtr alloc = Allocate::make(buf); - FreePtr free = Free::make(buf); - BlockPtr block = Block::make({alloc, free}); - - const std::string pattern = R"( - # CHECK: { - # CHECK: int* x = static_cast(malloc(24)); - # CHECK: free(x); - # CHECK: } - )"; - FILE_CHECK(block, pattern); -} - -TEST(CppPrinter, LoadStore) { - BufHandle a("A", {2, 3}, kInt); - BufHandle b("B", {3, 4}, kInt); - auto store = b.store({2, 2}, a.load(1, 1)); - STR_CHECK( - store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); -} - -TEST(CppPrinter, Var) { - auto var = alloc("x", kInt); - STR_CHECK(var, "x"); -} - -TEST(CppPrinter, Cast) { - auto cast = alloc(kFloat, alloc(1)); - STR_CHECK(cast, "static_cast(1)"); -} - -TEST(CppPrinter, BitCast) { - auto cast = alloc(kInt, alloc(20)); - STR_CHECK(cast, "std::bitcast(20.f)"); -} - -TEST(CppPrinter, Let) { - auto var = alloc("x", kFloat); - auto val = alloc(2); - auto let = alloc(var, val); - STR_CHECK(let, "float x = 2.f;\n"); -} - -TEST(CppPrinter, For) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); - const std::string pattern = R"( - # CHECK: for (int i = 0; i < 1024; i++) { - # CHECK: C[i] = (A[i]) + (B[i]); - # CHECK: } - )"; - FILE_CHECK(f, pattern); -} - -TEST(CppPrinter, Cond) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = - Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - const std::string pattern = R"( - # CHECK: if (((X[0] < 10) ? 1 : 0)) { - # CHECK: X[0] = (X[0]) + 1; - # CHECK: } else { - # CHECK: X[0] = (X[0]) - 1; - # CHECK: } - )"; - FILE_CHECK(cond, pattern); -} - -TEST(CppPrinter, Intrinsics) { - const std::unordered_set> unsupported_ops{ - kRand, kSigmoid}; - for (const auto i : c10::irange(static_cast(kMaxIntrinsicsOp))) { - IntrinsicsOp op = static_cast(i); - if (unsupported_ops.count(op)) { - continue; - } - - if (Intrinsics::OpArgCount(op) == 1) { - auto v = alloc(op, alloc(2.0f)); - STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); - } else { - auto v = - alloc(op, alloc(1.0f), alloc(2.0f)); - STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); - } - } -} - -TEST(CppPrinter, ExternalCall) { - std::vector dims{alloc(2), alloc(2)}; - auto output = alloc("out", dims, kFloat); - auto buf_arg1 = alloc("a", dims, kFloat); - auto buf_arg2 = alloc("b", dims, kFloat); - auto scalar_arg = alloc(alloc(1), alloc(2)); - std::vector buf_args{buf_arg1, buf_arg2}; - std::vector scalar_args{scalar_arg}; - auto call = - alloc(output, "nnc_aten_matmul", buf_args, scalar_args); - const std::string pattern = R"( - # CHECK: { - # CHECK: void* buf_ptrs[]{out, a, b}; - # CHECK: int64_t buf_ranks[]{2, 2, 2}; - # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; - # CHECK: int8_t buf_dtypes[]{6, 6, 6}; - # CHECK: int64_t extra_args[]{1 + 2}; - # CHECK: nnc_aten_matmul( - # CHECK: 3, - # CHECK: buf_ptrs, - # CHECK: buf_ranks, - # CHECK: buf_dims, - # CHECK: buf_dtypes, - # CHECK: 1, - # CHECK: extra_args); - # CHECK: } - )"; - FILE_CHECK(call, pattern); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp deleted file mode 100644 index 2e1e84e758db..000000000000 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ /dev/null @@ -1,2344 +0,0 @@ -#ifdef USE_CUDA - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; -using namespace torch::jit::tensorexpr; - -template -static void testCudaTestVectorAdd01_impl() { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Dtype dtype = ToDtype(); - BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); - BufHandle b_buf("b", {num_iter, block_count, block_size}, dtype); - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); - const int N = block_count * block_size * num_iter; - PaddedBuffer a_v(N); - PaddedBuffer b_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = ctype(i); - b_v(i) = ctype(i * 3 + 7); - c_ref(i) = a_v(i) + b_v(i); - } - - // TODO: move gpu support into PaddedBuffer - ctype* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(ctype))); - ctype* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(ctype))); - ctype* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(ctype))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -float sigmoid(float x) { - return 1.0f / (1.0f + expf(-0.0f - x)); -} - -TEST(Cuda, Sigmoid_CUDA) { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Dtype dtype = ToDtype(); - BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf); - const int N = block_count * block_size * num_iter; - PaddedBuffer a_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = float(i); - c_ref(i) = sigmoid(sigmoid(a_v(i))); - } - - // TODO: move gpu support into PaddedBuffer - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, TestVectorAdd01_CUDA) { - // floating types. - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - - // integer types. - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); -} - -static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { - BufHandle a_buf("a", {N}, kFloat); - BufHandle b_buf("b", {N}, kFloat); - Tensor c = Compute("c", {N}, [&](const VarHandle& n) { - return a_buf.load(n) + b_buf.load(n); - }); - LoopNest l({c}); - ForPtr n_inner; - std::vector loops = l.getLoopStmtsFor(c); - l.splitWithMask(loops[0], block_size, &n_inner); - loops[0]->set_gpu_block_index(0); - n_inner->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); - PaddedBuffer a_v(N); - PaddedBuffer b_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = i; - b_v(i) = i * 3 + 7; - c_ref(i) = a_v(i) + b_v(i); - } - - // TODO: move gpu support into PaddedBuffer - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, TestVectorAdd02_CUDA) { - testCudaTestVectorAdd02_impl(1024, 128); - testCudaTestVectorAdd02_impl(1030, 128); -} - -TEST(Cuda, HalfCast_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor b = Compute("b", {4}, [&](const VarHandle& i) { - return Cast::make(kFloat, a.load(i)); - }); - - LoopNest l({b}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b}); - - std::vector aData(4, 2.0f); - std::vector bData(4, 0.0f); - at::Half* aDev = nullptr; - float* bDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = bData.size() * sizeof(bData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(bData.data(), bDev, bSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(bData, 2.0f); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); -} - -TEST(Cuda, DynamicShape2D_CUDA) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, c, m, n}); - - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - float* aDev = nullptr; - float* bDev = nullptr; - float* cDev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); - C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); - C10_CUDA_CHECK(cudaMalloc(&cDev, cData.size() * sizeof(cData[0]))); - C10_CUDA_CHECK(cudaMemcpy( - aDev, - aData.data(), - aData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - bDev, - bData.data(), - bData.size() * sizeof(bData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - cDev, - cData.data(), - cData.size() * sizeof(cData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, cDev, M, N}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy( - cData.data(), - cDev, - cData.size() * sizeof(cData[0]), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(cDev)); - }; - testWithSize(32, 32); - testWithSize(1, 16); - testWithSize(27, 13); -} - -TEST(Cuda, TestRand01_CUDA) { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return Intrinsics::make(IntrinsicsOp::kRand, kFloat); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c); - const int N = block_count * block_size * num_iter; - PaddedBuffer c_v(N); - - // TODO: move gpu support into PaddedBuffer - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - float sum1 = 0; - float sum2 = 0; - float sum3 = 0; - for (const auto i : c10::irange(N)) { - float v = c_v.data()[i]; - sum1 += v; - sum2 += v * v; - sum3 += v * v * v; - ASSERT_TRUE(v >= 0 && v < 1); - } - sum1 /= N; - sum2 /= N; - sum3 /= N; - float sum1_mean = 1.f / 2; - float sum2_mean = 1.f / 3; - float sum3_mean = 1.f / 4; - - ASSERT_NEAR(sum1, sum1_mean, 2e-2); - ASSERT_NEAR(sum2, sum2_mean, 2e-2); - ASSERT_NEAR(sum3, sum3_mean, 2e-2); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, DynamicShapeSplit_CUDA) { - constexpr int64_t N = 4096; - VarHandle n("n", kLong); - BufHandle a("a", {n}, kFloat); - Tensor b = - Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); - LoopNest l({b}); - ForPtr inner; - std::vector loops = l.getLoopStmtsFor(b); - l.splitWithMask(loops[0], 1024, &inner); - loops[0]->set_gpu_block_index(0); - inner->set_gpu_thread_index(0); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, n}); - - std::vector aData(N, 1.0f); - std::vector bData(N, 1.0f); - float* aDev = nullptr; - float* bDev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); - C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); - C10_CUDA_CHECK(cudaMemcpy( - aDev, - aData.data(), - aData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - bDev, - bData.data(), - bData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, N}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy( - bData.data(), - bDev, - bData.size() * sizeof(aData[0]), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); -} - -TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { - const static int N = 1024; - BufHandle data_buf("data", {N}, kFloat); - BufHandle output_buf("output", {1}, kFloat); - - // The test adds the following code for trivial reduction: - // for (const auto bidx : c10::irange(1)) { // blockIdx.x - // for (const auto tidx : c10::irange(1)) { // threadIdx.x - // output[0] = 0.f; - // for (const auto i1 : c10::irange(1024)) { - // output[0] = output[0] + data[i1]; - // } - // } - // } - - StorePtr init_store = output_buf.store({0}, 0.f); - VarHandle i1("i1", kInt); - ExprHandle load_data = Load::make(data_buf, {i1}); - ExprHandle load_output = Load::make(output_buf, {0}); - ExprHandle add_value = load_output + load_data; - StorePtr store_output = output_buf.store({0}, add_value); - ForPtr for_output = For::make(i1, 0, N, store_output); - StmtPtr reduce_block = Block::make({init_store, for_output}); - VarHandle thread_idx("tidx", kInt); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr thread_idx_loop = - For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); - PaddedBuffer data_v(N); - PaddedBuffer output_v(1, "output_v"); - PaddedBuffer output_ref(1, "output_ref"); - - output_ref(0) = 0; - for (const auto i : c10::irange(N)) { - data_v(i) = i; - output_ref(0) += data_v(i); - } - - float* data_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&data_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - data_dev, data_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - float* output_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&output_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(data_dev, output_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - output_v.data(), output_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(output_v, output_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(data_dev)); - C10_CUDA_CHECK(cudaFree(output_dev)); -} - -TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { - const static int N = 1024; - - // This test does the following reduction: - // clang-format off - // for b in 0..1 // block-idx - // for t in 0..1024: // thread-idx - // if t < 1: - // b[0] = 0 - // // implied sync_threads - // for t in 0..1024: // thread-idx - // b[0] = b[0] + a[t] // implied atomic - // clang-format on - - BufHandle a_buf("a", {N}, kFloat); - BufHandle b_buf("b", {1}, kFloat); - - StorePtr init_store = b_buf.store({0}, 0.f); - VarHandle t("t", kInt); - VarHandle b("b", kInt); - - // for t in 0..1024: // thread-idx - // if t < 1: - // b[0] = 0 - ExprHandle cond_t_lt_1 = - CompareSelect::make(t, 1, CompareSelectOperation::kLT); - CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); - - // for t in 0..1024: // thread-idx - // b[0] = b[0] + a[t] // implied atomic - ExprHandle load_a = Load::make(a_buf, {t}); - ExprHandle load_b = Load::make(b_buf, {0}); - ExprHandle add_value = load_b + load_a; - StorePtr store_b = b_buf.store({0}, add_value); - ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); - - StmtPtr reduce_block = Block::make({for_init, for_b}); - - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, reduce_block, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); - PaddedBuffer a_v(N); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(N)) { - a_v(i) = i; - b_ref(0) += a_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, NoThreadIdxWrite_1_CUDA) { - // This test does the following reduction: - // - // for k in 0..1: // block-idx - // a[0] = 0 - // for n in 0..2: - // a[0] = a[0] + n - // for m in 0..1024: // thread-idx - // b[m] = m - // a[1] = 1 - // for l in 0..2: - // a[1] = a[1] + n - // - // note that the statements not covered by thread-idx are supposed to be - // covered by its own thread-idx - - const static int N = 1024; - BufHandle a_buf("a", {2}, kFloat); - BufHandle b_buf("b", {N}, kFloat); - - VarHandle k("k", kInt); - VarHandle l("l", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - // a[0] = 0 - // for n in 0..2: - // a[0] = a[0] + n - StorePtr store_a0_0 = a_buf.store({0}, 0.f); - ExprHandle load_a0 = Load::make(a_buf, {0}); - ExprHandle v1 = load_a0 + n; - StorePtr store_a0_v1 = a_buf.store({0}, v1); - ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); - - // for m in 0..1024: // thread-idx - // b[m] = m - StorePtr store_bm_m = b_buf.store({m}, m + 0.f); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); - - // a[1] = 1 - // for l in 0..2: - // a[1] = a[1] + l - StorePtr store_a1_1 = a_buf.store({1}, 1.f); - ExprHandle load_a1 = a_buf.load(1); - ExprHandle v2 = load_a1 + l; - StorePtr store_a1_v2 = a_buf.store({1}, v2); - ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); - - StmtPtr reduce_block = - Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); - - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, reduce_block, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); - PaddedBuffer a_v(2); - PaddedBuffer b_v(N, "b_v"); - PaddedBuffer a_ref(2, "a_ref"); - PaddedBuffer b_ref(N, "b_ref"); - - a_ref(0) = 0; - for (const auto i : c10::irange(2)) { - a_ref(0) += i; - } - a_ref(1) = a_ref(0) + 1; - for (const auto i : c10::irange(N)) { - b_ref(i) = i; - } - - // TODO: add check of the generated code. - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, 2 * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(a_v.data(), a_dev, 2 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(a_v, a_ref, 1e-5); - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, SharedMemReduce_1_CUDA) { - // FIXME: this test is flaky in CI. - // This test does the following: - // for k in 0..1: // block-idx - // alloc(c, 64) - // for n in 0..64: // thread-idx - // c(n) = 0 - // for m in 0..128: - // for n in 0..64: // thread_idx - // c(n) = c(n) + a(k, m, n) - // b(k) = 0 - // for n in 0..64: // thread_idx - // b(k) = b(k) + c(n) - // free(c) - - const int M = 128; - const int N = 64; - const int kTotalSize = M * N; - LoopOptions thread_idx_opt; - thread_idx_opt.set_gpu_thread_index(0); - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - BufHandle a("a", {1, M, N}, kFloat); - BufHandle b("b", {1}, kFloat); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - std::vector block; - std::vector dims; - dims.push_back(ExprHandle(N).node()); - BufHandle c{alloc("c", dims, kFloat)}; - { - // alloc(c, 64); - AllocatePtr alloc = Allocate::make(c); - block.push_back(alloc); - } - - { - // for n in 0..64: // thread-idx - // c(n) = 0 - StorePtr store_cn_0 = Store::make(c, {n}, 0.f); - ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); - block.push_back(loop_n1); - } - - { - // for m in 0..128: - // for n in 0..64: // thread_idx - // c(n) = c(n) + a(k, m, n) - ExprHandle load_cn = Load::make(kFloat, c, {n}); - ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}); - ExprHandle v_add = load_cn + a_kmn; - StorePtr store_cn_v = Store::make(c, {n}, v_add); - ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); - ForPtr loop_m1 = For::make(m, 0, M, loop_n2); - block.push_back(loop_m1); - } - - { - // b(k) = 0 - // for n in 0..64: // thread_idx - // b(k) = b(k) + c(n) - StorePtr store_bk_0 = b.store({k}, 0.f); - block.push_back(store_bk_0); - ExprHandle load_bk = b.load(k); - ExprHandle load_cn = Load::make(kFloat, c, {n}); - ExprHandle v_add = load_bk + load_cn; - StorePtr store_bk = b.store({k}, v_add); - ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); - block.push_back(loop_n3); - } - - { - // free(c) - FreePtr free_stmt = Free::make(c); - block.push_back(free_stmt); - } - - BlockPtr reduce_body = Block::make(block); - ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); - - // TODO: check the generated code for correctness. - CudaCodeGen cuda_cg(loop_k1, a, b); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is not masked, but the d write is. - const std::string& verification_pattern = - R"IR( -# CHECK: c_1 = 0 -# CHECK: for (int m = 0; m < 128 -# CHECK: c_1 = c_1 + -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<1 -# CHECK: b[blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: atomicAdd(&b[blockIdx.x], c_1) -)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, LocalMemReduce_1_CUDA) { - // This test does the following: - // for k in 0..1: // block-idx - // b(k) = 0 - // for n in 0..64: // thread-idx - // alloc(c, 1) - // c(0) = 0 - // for m in 0..128: - // c(0) = c(0) + a(k, m, n) - // b(k) = b(k) + c(0) - // free(c) - - const int M = 128; - const int N = 64; - const int kTotalSize = M * N; - LoopOptions thread_idx_opt; - thread_idx_opt.set_gpu_thread_index(0); - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - BufHandle a("a", {1, M, N}, kFloat); - BufHandle b("b", {1}, kFloat); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle c{ - alloc("c", std::vector({alloc(1)}), kFloat)}; - std::vector block_k; - { - // b(k) = 0 - StorePtr store_bk_0 = b.store({k}, 0.f); - block_k.push_back(store_bk_0); - } - std::vector block_n; - { - // alloc(c, 1); - AllocatePtr alloc = Allocate::make(c); - block_n.push_back(alloc); - } - { - // c(0) = 0 - StorePtr store_c0_0 = Store::make(c, {0}, 0.f); - block_n.push_back(store_c0_0); - } - { - // for m in 0..128: - // c(0) = c(0) + a(k, m, n) - ExprHandle load_c0 = Load::make(kFloat, c, {0}); - ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); - ExprHandle v_add = load_c0 + a_kmn; - StorePtr store_c0_v = Store::make(c, {0}, v_add); - ForPtr loop_m = For::make(m, 0, M, store_c0_v); - block_n.push_back(loop_m); - } - { - // b(k) = b(k) + c(0) - ExprHandle load_bk = b.load(k); - ExprHandle load_c0 = Load::make(kFloat, c, {0}); - ExprHandle v_add = load_bk + load_c0; - StorePtr store_bk = b.store({k}, v_add); - block_n.push_back(store_bk); - } - { - // free(c) - FreePtr free_stmt = Free::make(c); - block_n.push_back(free_stmt); - } - { - BlockPtr block_n_stmt = Block::make(block_n); - ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); - block_k.push_back(for_n); - } - BlockPtr block_k_stmt = Block::make(block_k); - ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); - - CudaCodeGen cuda_cg(loop_k, a, b); - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, HalfSupport_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor b = Compute("b", {4}, [&](const VarHandle& i) { - return Cast::make(half, ExprHandle(2.0f) * a.load(i)); - }); - - Tensor c = Compute("c", {4}, [&](const VarHandle& i) { - return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); - }); - - Tensor d = Compute("d", {4}, [&](const VarHandle& i) { - return Cast::make(half, c.load(i)); - }); - - LoopNest l({b, c, d}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, c, d}); - - std::vector aData(4, 2.0f); - std::vector cData(4, 0.0f); - std::vector dData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* bDev = nullptr; - at::Half* cDev = nullptr; - at::Half* dDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = aData.size() * sizeof(aData[0]); - auto cSize = cData.size() * sizeof(float); - auto dSize = dData.size() * sizeof(dData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMalloc(&cDev, cSize)); - C10_CUDA_CHECK(cudaMalloc(&dDev, dSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(cDev, cData.data(), cSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(dDev, dData.data(), dSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, cDev, dDev}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(cData.data(), cDev, cSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(dData.data(), dDev, dSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(cData, 46.0f); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(cDev)); - C10_CUDA_CHECK(cudaFree(dDev)); -} - -TEST(Cuda, HalfPropagation_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(alloc(0)), true); - }); - - LoopNest l({relu}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, relu}); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the types used by the Max are Float. - const std::string& verification_pattern = - R"IR( -# CHECK: for ( -# CHECK: float v = float(a[i]); -# CHECK: relu[i] = half(Max(v, 0.f -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector aData(4, 2.0f); - std::vector reluData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* reluDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto reluSize = reluData.size() * sizeof(reluData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, reluDev}); - C10_CUDA_CHECK( - cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(aData, reluData); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(reluDev)); -} - -TEST(Cuda, UnusedHalfArgument_CUDA) { - BufHandle a("a", {4}, kFloat); - auto half = ToDtype(); - BufHandle b("b", {4}, half); - Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(alloc(0)), true); - }); - - LoopNest l({relu}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, relu}); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the types used by the Max are Float. - const std::string& verification_pattern = - R"IR( -# CHECK: for ( -# CHECK: float v = a[i]; -# CHECK: relu[i] = Max(v, 0.f -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // Sanity Cbeck; - std::vector aData(4, 2.0f); - std::vector bData(4, 2.0f); - std::vector reluData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* bDev = nullptr; - at::Half* reluDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = bData.size() * sizeof(bData[0]); - auto reluSize = reluData.size() * sizeof(reluData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, reluDev}); - C10_CUDA_CHECK( - cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(aData, reluData); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(reluDev)); -} - -TEST(Cuda, PrioritizeDependents_CUDA) { - BufHandle a("a", {10}, kFloat); - BufHandle b("b", {12}, kFloat); - BufHandle c("c", {12}, kFloat); - - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - /* - * for (const auto i : c10::irange(12)) { - * c[i] = (i < 10 ? a[i] + b[i] : b[i]); - * } - */ - ExprHandle load_a = a.load({i}); - ExprHandle load_b = b.load({i}); - ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); - ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); - - ForPtr loop = - For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); - - CudaCodeGen cuda_cg(loop, a, b, c); - - PaddedBuffer a_v(10, "a_v"); - PaddedBuffer b_v(12, "b_v"); - PaddedBuffer c_v(12, "c_v"); - PaddedBuffer c_ref(12, "c_ref"); - - for (const auto i : c10::irange(10)) { - a_v(i) = i * 100; - b_v(i) = i; - c_v(i) = 0; - } - - for (const auto i : c10::irange(10, 12)) { - b_v(i) = i; - c_v(i) = 0; - } - - float* a_dev = nullptr; - float* b_dev = nullptr; - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, 10 * sizeof(float))); - C10_CUDA_CHECK(cudaMalloc(&b_dev, 12 * sizeof(float))); - C10_CUDA_CHECK(cudaMalloc(&c_dev, 12 * sizeof(float))); - - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), 10 * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), 12 * sizeof(float), cudaMemcpyHostToDevice)); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, 12 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - for (const auto i : c10::irange(12)) { - if (i < 10) { - c_ref(i) = i + i * 100; - } else { - c_ref(i) = i; - } - } - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -/// Tests the case where there are two loops which have different extents bound -/// to the same block dimension. We must mask the smaller extent loop body. -TEST(Cuda, MaskBlockDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is not masked, but the d write is. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if (blockIdx -# CHECK: c[blockIdx.x] = -# CHECK: if (blockIdx.x<50 -# CHECK: d[blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(1))); - - // Sanity check that the kernel works. - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case with two loops, which have different extents that are bound -/// to the same thread dimension. This is the same as the above - the smaller -/// rank write should be masked. But this time we also need to syncthreads. -TEST(Cuda, MaskThreadDim_CUDA) { - int A_SIZE = 50; - int B_SIZE = 100; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i / 2) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is masked, but the d write is not. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x<50 -# CHECK: c[threadIdx.x] = -# CHECK: __syncthreads(); -# CHECK-NOT: if (threadIdx.x -# CHECK: d[threadIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i / 2) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where there are two loops, and each is bound to a different -/// block dimension. In this case all writes should be masked since they occur -/// in distinct dimensions. -// Note: this is an extremely dumb pattern which we should never see, but is a -// useful edge case to make sure we've got things covered. -TEST(Cuda, MaskMultiBlockDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(1); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Write to c should be masked against y, write to d against x. - const std::string& verification_pattern = - R"IR( -# CHECK: if (blockIdx.y<1 -# CHECK: c[blockIdx.x] = -# CHECK: if (blockIdx.x<1 -# CHECK: d[blockIdx.y] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(blockExtents[1], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where both the blockDim and threadDim are bound to different -/// loops. In this instance both stores should be masked since they are -/// distinct. -// Note: this is an extremely dumb pattern which we should never see, but is a -// useful edge case to make sure we've got things covered. -TEST(Cuda, MaskBlockAndThreadDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x<1 -# CHECK: c[blockIdx.x] = -# CHECK: } -# CHECK: if (blockIdx.x<1 -# CHECK: d[threadIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where the loopnest has two loops of depth two: each with the -/// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In -/// this case all writes with a rank smaller than the max should be masked. -TEST(Cuda, MaskMultiDim_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: C[threadIdx.x + 100 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case where loop extents are symbolic and not known at compile time. -// In this case both stores must be masked against the extent of the other loop, -// in case it is larger. -TEST(Cuda, MaskMultiDimSymbolic_CUDA) { - VarHandle OUTER_SIZE("OUTER_SIZE", kLong); - VarHandle A_SIZE("A_SIZE", kLong); - VarHandle B_SIZE("B_SIZE", kLong); - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Since we don't know which is bigger (A_SIZE or B_SIZE) we must mask both. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x(A_SIZE.node(), B_SIZE.node(), true))); - - int64_t OUTER_EXTENT = 10; - int64_t A_EXTENT = 100; - int64_t B_EXTENT = 50; - - PaddedBuffer a_v(OUTER_EXTENT, A_EXTENT); - PaddedBuffer b_v(OUTER_EXTENT, B_EXTENT); - PaddedBuffer c_v(OUTER_EXTENT, A_EXTENT); - PaddedBuffer d_v(OUTER_EXTENT, B_EXTENT); - - PaddedBuffer c_ref(OUTER_EXTENT, A_EXTENT); - PaddedBuffer d_ref(OUTER_EXTENT, B_EXTENT); - - for (const auto o : c10::irange(OUTER_EXTENT)) { - for (const auto i : c10::irange(A_EXTENT)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_EXTENT)) { - for (const auto i : c10::irange(B_EXTENT)) { - b_v(o, i) = (float)(B_EXTENT - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, OUTER_EXTENT, A_EXTENT, B_EXTENT, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case where two loops are fused at a common parent loop, which is -// bound to the block dimension. Internally the inner loops have different -// extents but are bound to the same thread dimension. The smaller loop should -// be masked. -TEST(Cuda, MaskCompoundInnerLoop_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); - - // Can't build this using Compute and transforms yet. - LoopOptions blockBound; - blockBound.set_gpu_block_index(0); - LoopOptions threadBound; - threadBound.set_gpu_thread_index(0); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - - StmtPtr stmt = For::make( - i, - 0, - OUTER_SIZE, - Block::make( - {For::make( - j, - 0, - A_SIZE, - c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), - threadBound), - For::make( - k, - 0, - B_SIZE, - d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), - threadBound)}), - blockBound); - - stmt = FlattenIndexes(stmt); - stmt = IRSimplifier::simplify(stmt); - - CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: c[threadIdx.x + 100 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev, d_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loops fused into a common parent, which is not bound -// to any block or thread dimension - however it's two inner loops are bound to -// the first thread dimensions. This should work just like the MaskThreadDim -// test where the bigger loop is unmasked but the smaller is masked. -TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); - - // Can't build this using Compute and transforms yet. - LoopOptions blockBound; - blockBound.set_gpu_block_index(0); - LoopOptions threadBound; - threadBound.set_gpu_thread_index(0); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - - StmtPtr stmt = For::make( - i, - 0, - OUTER_SIZE, - Block::make( - {For::make( - j, - 0, - A_SIZE, - c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), - threadBound), - For::make( - k, - 0, - B_SIZE, - d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), - threadBound)})); - - stmt = FlattenIndexes(stmt); - stmt = IRSimplifier::simplify(stmt); - - CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The other loop remains the D write is masked. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 10 -# CHECK-NOT: if ( -# CHECK: c[threadIdx.x + 100 * i] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: d[threadIdx.x + 50 * i] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev, d_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loop nests, each of which bound to the same block -// size, but with internal loops bound to different thread rank (ie x and y). In -// this case both bodies must be masked against the other dimension being > 0. -// Note: this is a bit degenerate no one would actually write this for perf. -TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 30; - int B_SIZE = 15; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(1); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Both stores masked against the other thread dim < 1. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.y<1 -# CHECK: C[threadIdx.x + 30 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<1 -# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loop nests, each bound to both Block and Thread but -// the second loop is smaller in both cases - the second store must be masked -// for both the block and thread dimension. -TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { - int OUTER_A_SIZE = 10; - int OUTER_B_SIZE = 5; - int A_SIZE = 30; - int B_SIZE = 15; - BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked twice, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: C[threadIdx.x + 30 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (blockIdx.x<5 -# CHECK: if (threadIdx.x<15 -# CHECK: D[threadIdx.x + 15 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_A_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_B_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_A_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_B_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_A_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_B_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_A_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_B_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -} // namespace jit -} // namespace torch - -#endif diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp deleted file mode 100644 index 07b9872fb832..000000000000 --- a/test/cpp/tensorexpr/test_dynamic_shapes.cpp +++ /dev/null @@ -1,701 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -TEST(DynamicShapes, SimpleGraph) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Tensor, - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - return (%4))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto x_type = TensorType::create(at::rand({10, 5})); - std::vector x_sym_dims( - {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); - auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); - graph->inputs().at(0)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-2), SS(-3)), - // %SS_2 : int, - // %SS_3 : int): - // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x) - // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3) - // return (%4) - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - std::vector symbolic_shape_inputs = c10::fmap( - x_sym_dims, - [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::erf(at::tanh(a)); - - std::vector stack = fmap(std::vector({a})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::erf(at::tanh(a)); - - std::vector stack = fmap(std::vector({a})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWith2InputsSameDims) { -#ifdef TORCH_ENABLE_LLVM - // The two inputs in this graph must have the same dims. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Tensor, - %y : Tensor, - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({10, 5})); - std::vector x_sym_dims( - {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); - auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-4), SS(-5)), - // %y : Float(SS(-4), SS(-5)), - // %SS_2 : int, - // %SS_3 : int): - // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x) - // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4) - // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y) - // return (%6) - - std::vector symbolic_shape_inputs = c10::fmap( - x_sym_dims, - [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWith2InputsAndBroadcast) { -#ifdef TORCH_ENABLE_LLVM - // The second input to the graph has a dim of size 1 which should be - // broadcasted in the at::mul op. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(10, 5, requires_grad=0, device=cpu), - %y : Float(1, 5, requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({10, 5})); - auto y_type = TensorType::create(at::rand({1, 5})); - auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes( - std::vector({x_dim0_sym, x_dim1_sym})); - auto y_sym_type = y_type->withSymbolicShapes(std::vector( - {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(y_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-6), SS(-7)), - // %y : Float(1, SS(-7)), - // %SS_2 : int, - // %SS_3 : int): - // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x) - // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4) - // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y) - // return (%6) - - std::vector symbolic_shape_inputs( - {x_dim0_sym.value(), x_dim1_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) { -#ifdef TORCH_ENABLE_LLVM - // The second input to the graph has a dim of size 1 which should be - // broadcasted in the at::mul op. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(1, 5, requires_grad=0, device=cpu), - %y : Float(1, 5, requires_grad=0, device=cpu), - %SS_2 : int): - %4 : Tensor = aten::tanh(%x) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({1, 5})); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes(std::vector( - {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(1, SS(-2)), - // %y : Float(1, SS(-2)), - // %SS_2 : int): - // %3 : Float(1, SS(-2)) = aten::tanh(%x) - // %4 : Float(1, SS(-2)) = aten::mul(%3, %y) - // return (%4) - - std::vector symbolic_shape_inputs({x_dim1_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::tanh(a), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::tanh(a), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithSymbolicStrides) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %SS_3 : int, - %SS_2 : int): - %15 : int = prim::Constant[value=1]() - %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15) - %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0) - return (%22))IR"; - parseIR(graph_string, &*graph); - - std::vector input_desc = { - torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE}; - std::vector output_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = output_desc; - std::vector symbolic_shape_inputs = {-3, -2}; - TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - { - auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(x0, x1, 1), x0); - - std::vector inputs = {x0, x1}; - std::vector stack = at::fmap(inputs); - stack.push_back(32); - stack.push_back(10); - k.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - { - auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto out = - at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(x0, x1, 1), x0); - - std::vector inputs = {out, x0, x1}; - std::vector stack = at::fmap(inputs); - stack.push_back(32); - stack.push_back(10); - k.runWithAllocatedOutputs(stack); - - ASSERT_TRUE(at::allclose(out, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithCatAndBroadcast) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(10, 5, requires_grad=0, device=cpu), - %y : Float(4, 5, requires_grad=0, device=cpu), - %z : Float(1, 1, requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int): - %11 : int = prim::Constant[value=0]() - %3 : Tensor = aten::tanh(%x) - %out1 : Tensor = aten::erf(%3) - %out2 : Tensor = aten::relu(%y) - %10 : Tensor[] = prim::ListConstruct(%out1, %out2) - %25 : Tensor = aten::cat(%10, %11) - %28 : Tensor = aten::hardswish(%25) - %29 : Tensor = aten::mul(%28, %z) - return (%29))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto z_inp = graph->inputs()[2]; - auto x_type = TensorType::create(at::rand({10, 5})); - auto y_type = TensorType::create(at::rand({4, 5})); - auto z_type = TensorType::create(at::rand({1, 1})); - auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes( - std::vector({x_dim0_sym, x_dim1_sym})); - auto y_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto y_sym_type = y_type->withSymbolicShapes( - std::vector({y_dim0_sym, x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(y_sym_type); - auto cat_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto cat_out_type = x_type->withSymbolicShapes( - std::vector({cat_dim0_sym, x_dim1_sym})); - auto nodeIt = graph->nodes().begin(); - ++nodeIt; - nodeIt->output()->setType(x_sym_type); // aten::tanh - ++nodeIt; - nodeIt->output()->setType(x_sym_type); // aten::erf - ++nodeIt; - nodeIt->output()->setType(y_sym_type); // aten::relu - ++nodeIt; - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::cat - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::hardswish - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::mul - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-2), SS(-3)), - // %y : Float(SS(-4), SS(-3)), - // %z : Float(1, 1), - // %SS_2 : int, - // %SS_3 : int, - // %SS_4 : int, - // %SS_5 : int): - // %7 : int = prim::Constant[value=0]() - // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x) - // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8) - // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y) - // %11 : Tensor[] = prim::ListConstruct(%9, %10) - // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7) - // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12) - // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z) - // return (%14) - - std::vector symbolic_shape_inputs( - {x_dim0_sym.value(), - x_dim1_sym.value(), - y_dim0_sym.value(), - cat_dim0_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[z_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul( - at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c); - - std::vector stack = fmap(std::vector({a, b, c})); - stack.push_back(10); - stack.push_back(5); - stack.push_back(4); - stack.push_back(14); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -#endif -} - -TEST(DynamicShapes, GraphFromModel) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu), - %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu), - %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu), - %4 : Float(SS(-7), requires_grad=0, device=cpu), - %5 : Float(SS(-7), requires_grad=0, device=cpu), - %SS_10 : int, - %SS_9 : int, - %SS_8 : int, - %SS_7 : int, - %SS_6 : int, - %SS_5 : int, - %SS_4 : int, - %SS_3 : int, - %SS_2 : int): - %15 : int = prim::Constant[value=1]() - %16 : bool = prim::Constant[value=0]() - %17 : int = prim::Constant[value=6]() - %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16) - %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2) - %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15) - %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15) - %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4) - return (%22))IR"; - parseIR(graph_string, &*graph); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->inputs().at(3)] = input_desc; - symbolic_strides[graph->inputs().at(4)] = input_desc; - symbolic_strides[graph->inputs().at(5)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - std::vector symbolic_shape_inputs = { - -10, -9, -8, -7, -6, -5, -4, -3, -2}; - TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - int64_t i2 = 10; - int64_t i3 = 32; - int64_t i4 = 19; - int64_t i5 = 71; - int64_t i6 = 139; - int64_t i7 = 261; - int64_t i8 = 261; - int64_t i9 = 261; - int64_t i10 = 261; - auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong)); - auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4); - - { - std::vector inputs = {x0, x1, x2, x3, x4, x5}; - std::vector stack = at::fmap(inputs); - stack.emplace_back(i10); - stack.emplace_back(i9); - stack.emplace_back(i8); - stack.emplace_back(i7); - stack.emplace_back(i6); - stack.emplace_back(i5); - stack.emplace_back(i4); - stack.emplace_back(i3); - stack.emplace_back(i2); - k.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - { - auto out = - at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - std::vector inputs = {out, x0, x1, x2, x3, x4, x5}; - std::vector stack = at::fmap(inputs); - stack.emplace_back(i10); - stack.emplace_back(i9); - stack.emplace_back(i8); - stack.emplace_back(i7); - stack.emplace_back(i6); - stack.emplace_back(i5); - stack.emplace_back(i4); - stack.emplace_back(i3); - stack.emplace_back(i2); - k.runWithAllocatedOutputs(stack); - - ASSERT_TRUE(at::allclose(out, ref)); - } -#endif -} - -TEST(DynamicShapes, MultiThreadedExecution) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_template = R"IR( - graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), - %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), - %SS_2 : int, - %SS_3 : int): - %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) - %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) - %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) - return (%5))IR"; - for (bool use_cuda : {false, true}) { - if (!torch::cuda::is_available() && use_cuda) { - continue; - } - auto device = use_cuda ? at::kCUDA : at::kCPU; - at::jit::TemplateEnv env; - env.s("device", use_cuda ? "cuda:0" : "cpu"); - const auto graph_string = format(graph_template, env); - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - auto run_kernel = [&](int dim1, int dim2) { - auto a = - at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); - auto b = - at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); - - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - // Run the kernel in parallel to ensure that the run() method calls in - // TensorExprKernel are not changing any state. - constexpr size_t kNumThreads = 4; - std::vector threads; - for (size_t id = 0; id < kNumThreads; ++id) { - threads.emplace_back(run_kernel, id + 5, id + 20); - } - for (auto& t : threads) { - t.join(); - } - } -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp deleted file mode 100644 index eb2d6296b229..000000000000 --- a/test/cpp/tensorexpr/test_expr.cpp +++ /dev/null @@ -1,836 +0,0 @@ -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -using SimpleIRExprEval = ExprEval; - -TEST(Expr, BasicValueTest) { - ExprHandle a = IntImm::make(2), b = IntImm::make(3); - ExprHandle c = Add::make(a, b); - SimpleIRExprEval eval(c); - ASSERT_EQ(eval.value(), 5); -} - -TEST(Expr, BasicValueTest02) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - SimpleIRExprEval eval(f); - ASSERT_EQ(eval.value(), -4.0f); -} - -TEST(Expr, IsChannelsLastContiguous) { - std::vector vars = { - VarHandle("var1", kLong), - VarHandle("var2", kLong), - VarHandle("var3", kLong), - VarHandle("var4", kLong), - VarHandle("var5", kLong)}; - - // { - // key: ndims, - // value: [ - // ... - // [dim_2, dim_1, ..., dim_n] - // ] - // } - using shapGenInfo = std::unordered_map>>; - - // { - // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n], - // strides: [ - // ... - // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z] - // ] - // } - using shapeInfo = - std::pair, std::vector>>; - - std::vector dims = {3, 4, 5}; - - std::unordered_map> dims_expr_vec_conf = { - {3, std::vector(vars.begin(), vars.begin() + 2)}, - {4, std::vector(vars.begin(), vars.begin() + 3)}, - {5, std::vector(vars.begin(), vars.begin() + 4)}, - }; - - shapGenInfo channels_last_cont_shape_conf = { - {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}}; - shapGenInfo channels_last_non_cont_shape_conf = { - {3, {{2, 1, 0}, {1, 0, 2}}}, - {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}}, - {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}}; - - shapGenInfo cont_shape_conf = { - {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}}; - - auto shape_gen_fn = [dims_expr_vec_conf]( - int ndims, shapGenInfo shape_gen_info) -> shapeInfo { - auto dims_expr_vec = dims_expr_vec_conf.at(ndims); - std::vector> strides_expr_vec; - for (size_t i = 0; i < strides_expr_vec.size(); i++) { - strides_expr_vec[i].resize(ndims); - } - - auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) { - if (indicator % 2 == 0) { - return a * b; - } else { - return b * a; - } - }; - - auto stride_order_vec = shape_gen_info.at(ndims); - for (size_t i = 0; i < strides_expr_vec.size(); i++) { - auto stride_order = stride_order_vec[i]; - - strides_expr_vec[i][stride_order[0]] = 1; - for (size_t j = 1; j < stride_order.size(); j++) { - auto cur_dim_idx = stride_order[j]; - auto adjacent_dim_idx = stride_order[j - 1]; - - strides_expr_vec[i][cur_dim_idx] = stride_gen_fn( - i, - dims_expr_vec[adjacent_dim_idx], - strides_expr_vec[i][adjacent_dim_idx]); - } - } - - return {dims_expr_vec, strides_expr_vec}; - }; - - auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool { - if (ndims == 3) { - return buf_handle.is_channels_last_1d_contiguous(); - } else if (ndims == 4) { - return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast); - } else { - return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d); - } - }; - - // channels-last contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true); - } - } - - // channels-last non-contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false); - } - } - - // contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(buf_handle.is_contiguous(), true); - } - } - - // non-contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(buf_handle.is_contiguous(), false); - } - } -} - -TEST(Expr, LetTest01) { - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, LetTest02) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = - ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - eval.bindVar(y, ExprHandle(6.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); -} - -TEST(Expr, LetStmtTest01) { - BufHandle a_buf("a", {1}, kFloat); - BufHandle b_buf("b", {1}, kFloat); - - ExprHandle load_a = a_buf.load(0); - VarHandle var = VarHandle("v", kFloat); - StmtPtr let_store = Let::make(var, load_a); - StmtPtr store_b = b_buf.store({0}, var); - BlockPtr block = Block::make({let_store, store_b}); - - SimpleIREvaluator eval(block, {a_buf, b_buf}); - - PaddedBuffer a_v(1); - PaddedBuffer b_v(1); - PaddedBuffer b_ref(1); - - a_v(0) = 23; - b_ref(0) = a_v(0); - eval(a_v, b_v); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(Expr, IntTest) { - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, FloatTest) { - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, ByteTest) { - VarHandle x("x", kByte); - ExprHandle body = ExprHandle((uint8_t)2) + - (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((uint8_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, CharTest) { - VarHandle x("x", kChar); - ExprHandle body = ExprHandle((int8_t)2) + - (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int8_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, ShortTest) { - VarHandle x("x", kShort); - ExprHandle body = ExprHandle((int16_t)2) + - (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int16_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, LongTest) { - VarHandle x("x", kLong); - ExprHandle body = ExprHandle((int64_t)2) + - (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int64_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, HalfTest) { - VarHandle x("x", kHalf); - ExprHandle body = ExprHandle((at::Half)2) + - (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((at::Half)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, DoubleTest) { - VarHandle x("x", kDouble); - ExprHandle body = ExprHandle((double)2) + - (x * ExprHandle((double)3) + ExprHandle((double)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((double)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, VectorAdd01) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a_buf("A", {kTotalSize}, kFloat); - BufHandle b_buf("B", {kTotalSize}, kFloat); - BufHandle c_buf("C", {kTotalSize}, kFloat); - - /* - Build the following: - for (const auto index : c10::irange(kVectorCount)) { - store(c_buf, ramp(index * 8, 1, 8), - load(a_buf, ramp(index * 8, 1, 8) + - load(b_buf, ramp(index * 8, 1, 8)))) - } - */ - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = - a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); - ExprHandle load_b = - b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); - ExprHandle value = load_a + load_b; - StmtPtr store_c = - c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); - StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); - - ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); - ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); - ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer c_ref(kTotalSize); - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i * i; - b_v(i) = i * i * 4; - c_ref(i) = a_v(i) + b_v(i); - } - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(Expr, CompareSelectEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 1); -} - -TEST(Expr, CompareSelectDtypes) { - // LHS and RHS expressions should have the same dtype, but this dtype could - // differ from the dtype of the return values (but dtypes of true and false - // return values should be the same). - // This test constructs a CompareSelect expression where the input dtype is - // different from the output dtype and verifies that it works correctly: - // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0.0f); - std::vector c_ref(N, 3.14f); - - VarHandle i("i", kInt); - // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f - // A and B are int, C is float. - auto select_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), - b.load(i), - FloatImm::make(3.14f), - FloatImm::make(2.78f), - CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(select_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - assertAllEqual(b_buffer, 1); - ExpectAllNear(c_buffer, c_ref, 1e-7); -} - -TEST(Expr, IntrinsicsDtypes) { - constexpr int N = 256; - BufHandle a("A", {N}, kDouble); - BufHandle b("B", {N}, kDouble); - std::vector a_buffer(N, -10.0); - std::vector b_buffer(N, 0.0); - std::vector b_ref(N, 10.0); - - VarHandle i("i", kInt); - auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); - - SimpleIREvaluator ir_eval(abs_expr, {a, b}); - ir_eval(a_buffer, b_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - - assertAllEqual(a_buffer, -10.0); - ExpectAllNear(b_buffer, b_ref, 1e-7); -} - -TEST(Expr, Substitute01) { - VarPtr x = alloc("x", kFloat); - VarPtr y = alloc("y", kFloat); - ExprPtr e = - alloc(alloc(x, alloc(1.0f)), alloc(x, y)); - - VarPtr z = alloc("z", kFloat); - ExprPtr e2 = Substitute(e, {{x, alloc(z, alloc(5.0f))}}); - ExprPtr e2_ref = alloc( - alloc(alloc(z, alloc(5.0f)), alloc(1.0f)), - alloc(alloc(z, alloc(5.0f)), y)); - std::ostringstream oss; - oss << *e2; - std::string e2_str = oss.str(); - - oss.str(""); - oss << *e2_ref; - std::string e2_ref_str = oss.str(); - ASSERT_EQ(e2_str, e2_ref_str); -} - -TEST(Expr, Math01) { - ExprHandle v = sin(ExprHandle(1.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "sin(1.f)"); - - SimpleIRExprEval eval(v); - float v_ref = std::sin(1.0f); - float res = eval.value(); - ASSERT_NEAR(res, v_ref, 1e-6); -} - -TEST(Expr, UnaryMath01) { - struct TestConfig { - std::function func; - std::function ref_func; - }; - - std::vector test_configs = { - {[](const ExprHandle& v) { return sin(v); }, - [](float v) { return std::sin(v); }}, - {[](const ExprHandle& v) { return sin(v); }, - [](float v) { return std::sin(v); }}, - {[](const ExprHandle& v) { return tan(v); }, - [](float v) { return std::tan(v); }}, - {[](const ExprHandle& v) { return asin(v); }, - [](float v) { return std::asin(v); }}, - {[](const ExprHandle& v) { return acos(v); }, - [](float v) { return std::acos(v); }}, - {[](const ExprHandle& v) { return atan(v); }, - [](float v) { return std::atan(v); }}, - {[](const ExprHandle& v) { return sinh(v); }, - [](float v) { return std::sinh(v); }}, - {[](const ExprHandle& v) { return cosh(v); }, - [](float v) { return std::cosh(v); }}, - {[](const ExprHandle& v) { return tanh(v); }, - [](float v) { return std::tanh(v); }}, - {[](const ExprHandle& v) { return exp(v); }, - [](float v) { return std::exp(v); }}, - {[](const ExprHandle& v) { return tensorexpr::abs(v); }, - [](float v) { return std::fabs(v); }}, - {[](const ExprHandle& v) { return log(v); }, - [](float v) { return std::log(v); }}, - {[](const ExprHandle& v) { return log2(v); }, - [](float v) { return std::log2(v); }}, - {[](const ExprHandle& v) { return log10(v); }, - [](float v) { return std::log10(v); }}, - {[](const ExprHandle& v) { return erf(v); }, - [](float v) { return std::erf(v); }}, - {[](const ExprHandle& v) { return sqrt(v); }, - [](float v) { return std::sqrt(v); }}, - {[](const ExprHandle& v) { return rsqrt(v); }, - [](float v) { return 1.0f / std::sqrt(v); }}, - {[](const ExprHandle& v) { return ceil(v); }, - [](float v) { return std::ceil(v); }}, - {[](const ExprHandle& v) { return floor(v); }, - [](float v) { return std::floor(v); }}, - {[](const ExprHandle& v) { return round(v); }, - [](float v) { return std::round(v); }}, - {[](const ExprHandle& v) { return trunc(v); }, - [](float v) { return std::trunc(v); }}, - }; - - for (const TestConfig& test_config : test_configs) { - const float input_v = 0.8765f; - ExprHandle v = test_config.func(ExprHandle(input_v)); - float v_ref = test_config.ref_func(input_v); - SimpleIRExprEval eval(v); - ASSERT_NEAR(eval.value(), v_ref, 1e-6); - } - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - for (float input_v : {std::nan("1"), 0., .5}) { - ExprHandle v = FloatImm::make(input_v); - SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); - ASSERT_NEAR(eval.value(), std::isnan(input_v), 0); - } -} - -TEST(Expr, BinaryMath01) { - struct TestConfig { - std::function func; - std::function ref_func; - }; - - std::vector test_configs = { - {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, - [](float v1, float v2) { return std::pow(v1, v2); }}, - {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, - [](float v1, float v2) { return std::fmod(v1, v2); }}, - }; - - for (const TestConfig& test_config : test_configs) { - const float v1 = 0.8765f; - float v2 = 1.2345f; - ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); - float v_ref = test_config.ref_func(v1, v2); - SimpleIRExprEval eval(v_expr); - ASSERT_NEAR(eval.value(), v_ref, 1e-6); - } -} - -TEST(Expr, LogicalOps01) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.69f); - ExprHandle f1 = (a > b) && (c > d); - ExprHandle f2 = (a > b) && (c < d); - ExprHandle f3 = (a < b) && (c > d); - ExprHandle f4 = (a < b) && (c < d); - ExprHandle f5 = (a < b) || (c > d); - ExprHandle f6 = (a < b) || (c < d); - ExprHandle f7 = (a > b) || (c < d); - ExprHandle f8 = (a > b) || (c > d); - - SimpleIRExprEval eval1(f1); - SimpleIRExprEval eval2(f2); - SimpleIRExprEval eval3(f3); - SimpleIRExprEval eval4(f4); - SimpleIRExprEval eval5(f5); - SimpleIRExprEval eval6(f6); - SimpleIRExprEval eval7(f7); - SimpleIRExprEval eval8(f8); - ASSERT_EQ(eval1.value(), 1); - ASSERT_EQ(eval2.value(), 0); - ASSERT_EQ(eval3.value(), 0); - ASSERT_EQ(eval4.value(), 0); - ASSERT_EQ(eval5.value(), 1); - ASSERT_EQ(eval6.value(), 0); - ASSERT_EQ(eval7.value(), 1); - ASSERT_EQ(eval8.value(), 1); -} - -TEST(Expr, LogicalOps02) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.72f); - - ExprHandle f1 = (a > b) || (c > d); - ExprHandle f2 = (a > b) && (c <= d); - ExprHandle f3 = (a > b) && (c > d); - ExprHandle ff1 = f1 && f2; - ExprHandle ff2 = f2 || f3; - - SimpleIRExprEval eval1(ff1); - SimpleIRExprEval eval2(ff2); - ASSERT_EQ(eval1.value(), 1); - ASSERT_EQ(eval2.value(), 1); -} - -TEST(Expr, LogicalOps03) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.69f); - - // Bool types - ExprHandle bool_f1 = (a > b) && BoolImm::make(true); - ExprHandle bool_f2 = (c <= d) || BoolImm::make(true); - - // Int types - ExprHandle int_f1 = (a > b) && IntImm::make(1); - ExprHandle int_f2 = (c <= d) || IntImm::make(1); - - // Short types - ExprHandle short_f1 = (a > b) && ShortImm::make(1); - ExprHandle short_f2 = (c <= d) || ShortImm::make(1); - - // Long types - ExprHandle long_f1 = (a > b) && LongImm::make(1); - ExprHandle long_f2 = (c <= d) || LongImm::make(1); - - // Char types - ExprHandle char_f1 = (a > b) && CharImm::make(1); - ExprHandle char_f2 = (c <= d) || CharImm::make(1); - - // Byte types - ExprHandle byte_f1 = (a > b) && ByteImm::make(1); - ExprHandle byte_f2 = (c <= d) || ByteImm::make(1); - - SimpleIRExprEval eval1(bool_f1); - SimpleIRExprEval eval2(bool_f2); - SimpleIRExprEval eval3(int_f1); - SimpleIRExprEval eval4(int_f2); - SimpleIRExprEval eval5(short_f1); - SimpleIRExprEval eval6(short_f2); - SimpleIRExprEval eval7(long_f1); - SimpleIRExprEval eval8(long_f2); - SimpleIRExprEval eval9(char_f1); - SimpleIRExprEval eval10(char_f2); - SimpleIRExprEval eval11(byte_f1); - SimpleIRExprEval eval12(byte_f2); - - ASSERT_EQ(eval1.value(), true); - ASSERT_EQ(eval2.value(), true); - ASSERT_EQ(eval3.value(), 1); - ASSERT_EQ(eval4.value(), 1); - ASSERT_EQ(eval5.value(), 1); - ASSERT_EQ(eval6.value(), 1); - ASSERT_EQ(eval7.value(), 1); - ASSERT_EQ(eval8.value(), 1); - ASSERT_EQ(eval9.value(), 1); - ASSERT_EQ(eval10.value(), 1); - ASSERT_EQ(eval11.value(), 1); - ASSERT_EQ(eval12.value(), 1); -} - -TEST(Expr, BitwiseOps) { - ExprHandle a(59); - ExprHandle b(11); - ExprHandle c(101); - ExprHandle d(2); - ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; - - SimpleIRExprEval eval(f); - ASSERT_EQ(eval.value(), 11); -} - -TEST(Expr, DynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(Expr, OutOfBounds) { - ExprHandle N(10); - ExprHandle start(0); - ExprHandle stop(15); - VarHandle i("i", kInt); - - BufHandle X("X", {N}, kInt); - - auto body = Store::make(X, {i}, i); - auto stmt = For::make(i, start, stop, body); - - PaddedBuffer data(20); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); -} - -TEST(Expr, OutOfBounds2d) { - std::vector> size_options = {{10, 15}, {15, 10}}; - for (auto sizes : size_options) { - ExprHandle N(sizes.first); - ExprHandle M(sizes.second); - ExprHandle start(0); - ExprHandle stopInner(15); - ExprHandle stopOuter(15); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - BufHandle X("X", {N, M}, kInt); - - auto body = Store::make(X, {i, j}, i); - auto inner = For::make(j, start, stopInner, body); - auto stmt = For::make(i, start, stopOuter, inner); - - PaddedBuffer data(400); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); - } -} - -TEST(Expr, OutOfBounds2dFlattenedIndex) { - ExprHandle buf_size(149); - ExprHandle start(0); - ExprHandle stopInner(15); - ExprHandle stopOuter(10); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - BufHandle X("X", {buf_size}, kInt); - - auto idx = Add::make(Mul::make(i, stopInner), j); - auto body = Store::make(X, {idx}, i); - auto inner = For::make(j, start, stopInner, body); - auto stmt = For::make(i, start, stopOuter, inner); - - PaddedBuffer data(400); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); -} - -void testCond01() { - const int N = 16; - PaddedBuffer a_v(N); - BufHandle a_buf("a", {N}, kFloat); - VarHandle index = VarHandle("index", kInt); - StmtPtr assign_x2 = a_buf.store({index}, cast(index) * 2); - StmtPtr assign_x3 = a_buf.store({index}, cast(index) * 3); - ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); - StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); - StmtPtr for_stmt = For::make(index, 0, N, assign); - SimpleIREvaluator(for_stmt, {a_buf})(a_v); - - PaddedBuffer a_ref(N); - for (const auto i : c10::irange(N)) { - if (i % 2 == 0) { - a_ref(i) = i * 2; - } else { - a_ref(i) = i * 3; - } - } - ExpectAllNear(a_v, a_ref, 1e-5); -} - -void testIfThenElse01() { - ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 1.0f); -} - -void testIfThenElse02() { - ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 2.0f); -} - -void testIfThenElse03() { - ExprHandle v = - ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 2.0f); -} - -void testStmtClone() { - const int N = 16; - - BufHandle a_buf("a", {N}, kInt); - VarHandle index = VarHandle("index", kInt); - StmtPtr body = a_buf.store({index}, 5); - StmtPtr loop = For::make(index, 0, N, body); - - StmtPtr cloned_loop = Stmt::clone(loop); - std::vector orig_loop_results(N); - std::vector cloned_loop_results(N); - SimpleIREvaluator(loop, {a_buf})(orig_loop_results); - SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); - - assertAllEqual(orig_loop_results, 5); - assertAllEqual(cloned_loop_results, 5); - - // Let's add another assign to the body in the cloned loop and verify that the - // original statement hasn't changed while the cloned one has. - StmtPtr body_addition = a_buf.store({index}, 33); - BlockPtr cloned_body = static_to(static_to(cloned_loop)->body()); - cloned_body->append_stmt(body_addition); - - std::vector orig_loop_results_after_mutation(N); - std::vector cloned_loop_results_after_mutation(N); - SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); - SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); - - assertAllEqual(orig_loop_results_after_mutation, 5); - assertAllEqual(cloned_loop_results_after_mutation, 33); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp deleted file mode 100644 index 49f43d16b499..000000000000 --- a/test/cpp/tensorexpr/test_external_calls.cpp +++ /dev/null @@ -1,1061 +0,0 @@ -#include - -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -TEST(ExternalCall, Conv1d_float) { - BufHandle Input("Input", {1, 100, 115}, kFloat); - BufHandle Weight("Weight", {100, 1, 7}, kFloat); - BufHandle Bias("Bias", {100}, kFloat); - BufHandle ResultBuf("Result", {1, 100, 115}, kFloat); - int64_t stride = 1; - int64_t pad = 3; - int64_t dilation = 1; - int64_t groups = 100; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv1d", - {Input, Weight, Bias}, - {stride, pad, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 100, 115}, options) * 5.f; - at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f; - at::Tensor bias = at::ones({100}, options) * 11.f; - at::Tensor ref = - at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 100 * 115, 5.f); - std::vector weight_buf(100 * 1 * 7, 6.f); - std::vector bias_buf(100, 11.f); - std::vector result_buf(1 * 100 * 115, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv1d_int) { - // A similar test, but now using kInt tensors - BufHandle Input("Input", {1, 100, 115}, kInt); - BufHandle Weight("Weight", {100, 1, 7}, kInt); - BufHandle Bias("Bias", {100}, kInt); - BufHandle ResultBuf("Result", {1, 100, 115}, kInt); - int64_t stride = 1; - int64_t pad = 3; - int64_t dilation = 1; - int64_t groups = 100; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv1d", - {Input, Weight, Bias}, - {stride, pad, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kInt) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 100, 115}, options) * 5; - at::Tensor weight = at::ones({100, 1, 7}, options) * 6; - at::Tensor bias = at::ones({100}, options) * 11; - at::Tensor ref = - at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 100 * 115, 5); - std::vector weight_buf(100 * 1 * 7, 6); - std::vector bias_buf(100, 11); - std::vector result_buf(1 * 100 * 115, -1); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv1d_nobias_noargs) { - BufHandle Input("Input", {1, 1, 115}, kFloat); - BufHandle Weight("Weight", {10, 1, 7}, kFloat); - BufHandle ResultBuf("Result", {1, 10, 109}, kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 1, 115}, options) * 5.f; - at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f; - at::Tensor ref = at::conv1d(input, weight); - - at::Tensor nnc_result; - std::vector input_buf(1 * 1 * 115, 5.f); - std::vector weight_buf(10 * 1 * 7, 6.f); - std::vector result_buf(1 * 10 * 109, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); - - llvm_codegen.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); - - ir_eval.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_float) { - BufHandle Input("Input", {1, 3, 224, 224}, kFloat); - BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat); - BufHandle Bias("Bias", {16}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv2d", - {Input, Weight, Bias}, - {stride, stride, pad, pad, dilation, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f; - at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f; - at::Tensor bias = at::ones({16}, options) * 11.f; - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 3 * 224 * 224, 5.f); - std::vector weight_buf(16 * 3 * 3 * 3, 6.f); - std::vector bias_buf(16, 11.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_int) { - // A similar test, but now using kInt tensors - - BufHandle Input("Input", {1, 3, 224, 224}, kInt); - BufHandle Weight("Weight", {16, 3, 3, 3}, kInt); - BufHandle Bias("Bias", {16}, kInt); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv2d", - {Input, Weight, Bias}, - {stride, stride, pad, pad, dilation, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kInt) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5; - at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6; - at::Tensor bias = at::ones({16}, options) * 11; - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 3 * 224 * 224, 5); - std::vector weight_buf(16 * 3 * 3 * 3, 6); - std::vector bias_buf(16, 11); - std::vector result_buf(1 * 16 * 112 * 112, -1); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_nobias_noargs) { - BufHandle Input("Input", {1, 16, 112, 112}, kFloat); - BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; - at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; - at::Tensor ref = at::conv2d(input, weight); - - at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 112 * 112, 5.f); - std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); - - llvm_codegen.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); - - ir_eval.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Addmm_float) { - BufHandle Input("Input", {100, 300}, kFloat); - BufHandle Mat1("Mat1", {100, 200}, kFloat); - BufHandle Mat2("Mat2", {200, 300}, kFloat); - BufHandle ResultBuf("Result", {100, 300}, kFloat); - int64_t beta = 2; - int64_t alpha = 2; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({100, 300}, options) * 5.f; - at::Tensor mat1 = at::ones({100, 200}, options) * 6.f; - at::Tensor mat2 = at::ones({200, 300}, options) * 11.f; - at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha); - - at::Tensor nnc_result; - std::vector input_buf(100 * 300, 5.f); - std::vector mat1_buf(100 * 200, 6.f); - std::vector mat2_buf(200 * 300, 11.f); - std::vector result_buf(100 * 300, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result}); - - llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result}); - - ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Embedding) { - BufHandle Weight("Weight", {256, 100}, kFloat); - BufHandle Indices("Indices", {1, 115}, kLong); - BufHandle ResultBuf("Result", {1, 115, 100}, kFloat); - int64_t padding_idx = -1; - bool scale_grad_by_freq = false; - bool sparse = false; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_embedding", - {Weight, Indices}, - {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - - at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f; - at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6; - at::Tensor ref = - at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); - - at::Tensor nnc_result; - std::vector weight_buf(256 * 100, 5.f); - std::vector indices_buf(1 * 115, 6); - std::vector result_buf(1 * 115 * 100, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result}); - - llvm_codegen.call({weight_buf, indices_buf, result_buf}); - nnc_result = at::from_blob( - result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result}); - - ir_eval.call({weight_buf, indices_buf, result_buf}); - nnc_result = at::from_blob( - result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, MaxReduction) { - BufHandle Input("Input", {1, 115, 152}, kFloat); - BufHandle ResultBuf("Result", {1, 152}, kFloat); - int64_t dim = 1; - bool keep_dim = false; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - - at::Tensor input = at::ones({1, 115, 152}, options) * 5.f; - at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim)); - - at::Tensor nnc_result; - std::vector input_buf(1 * 115 * 152, 5.f); - std::vector result_buf(1 * 152, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result}); - - llvm_codegen.call({input_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result}); - - ir_eval.call({input_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -#ifdef USE_XNNPACK - -TEST(ExternalCall, Prepacked_Linear_float) { - using namespace at::native::xnnpack; - - BufHandle Input("Input", {100, 200}, kFloat); - BufHandle ResultBuf("Result", {100, 300}, kFloat); - - // Calculate reference result using at::linear. - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = - at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200}); - at::Tensor weight = - at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200}); - at::Tensor bias = at::linspace(-10.0, 10.0, 300, options); - at::Tensor ref = at::linear(input, weight, bias); - - // Create prepacked xnnpack context object. - auto linear_clamp_prepack_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("prepacked::linear_clamp_prepack", "") - .typed( - at::Tensor, - std::optional, - const std::optional&, - const std::optional&)>(); - auto prepacked = linear_clamp_prepack_op.call( - weight, bias, std::optional(), std::optional()); - - BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_prepacked_linear_clamp_run", - {Input, DummyPrepacked}, - {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - at::Tensor nnc_result; - std::vector input_buf( - input.data_ptr(), input.data_ptr() + 100 * 200); - std::vector result_buf(100 * 300, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); - - llvm_codegen.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); - - ir_eval.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Prepacked_Conv2d_float) { - using namespace at::native::xnnpack; - - BufHandle Input("Input", {1, 3, 224, 224}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - // Calculate reference result using at::conv2d. - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options) - .resize_({1, 3, 224, 224}); - at::Tensor weight = - at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3}); - at::Tensor bias = at::linspace(-10.0, 10.0, 16, options); - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - // Create prepacked xnnpack context object. - auto conv2d_clamp_prepack_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "") - .typed( - at::Tensor, - std::optional, - std::vector, - std::vector, - std::vector, - int64_t, - const std::optional&, - const std::optional&)>(); - auto prepacked = conv2d_clamp_prepack_op.call( - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups, - std::optional(), - std::optional()); - - BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_prepacked_conv2d_clamp_run", - {Input, DummyPrepacked}, - {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - at::Tensor nnc_result; - std::vector input_buf( - input.data_ptr(), input.data_ptr() + 1 * 3 * 224 * 224); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); - - llvm_codegen.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); - - ir_eval.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); -} - -#endif // USE_XNNPACK - -TEST(ExternalCall, BinaryFloat) { - using TensorFunc = std::function; - using Test = std::tuple< - std::vector, - std::vector, - std::vector, - TensorFunc, - std::string>; - std::vector tests = {}; - tests.push_back( - Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"}); - tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"}); - tests.push_back(Test{ - {100, 200}, - {200, 300}, - {100, 300}, - [&](const at::Tensor& a, const at::Tensor& b) { return at::mm(a, b); }, - "nnc_aten_mm"}); - for (auto curTest : tests) { - auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest; - auto toExprHandleVec = [](std::vector v) { - auto intV = std::vector(v.begin(), v.end()); - return std::vector(intV.begin(), intV.end()); - }; - BufHandle A("A", toExprHandleVec(aShape), kFloat); - BufHandle B("B", toExprHandleVec(bShape), kFloat); - BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, externCallName, {A, B}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; - at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f; - at::Tensor ref = torchFunc(a, b); - - auto prod = [](std::vector v) { - // NOLINTNEXTLINE(modernize-use-transparent-functors) - return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); - }; - - at::Tensor nnc_result; - std::vector a_buf(prod(aShape), 5.f); - std::vector b_buf(prod(bShape), 6.f); - std::vector result_buf(prod(resShape), -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result}); - - llvm_codegen.call({a_buf, b_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result}); - ir_eval.call({a_buf, b_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); - } -} - -TEST(ExternalCall, UnaryFloat) { - using TensorFunc = std::function; - auto toExprHandleVec = [](std::vector v) { - auto intV = std::vector(v.begin(), v.end()); - return std::vector(intV.begin(), intV.end()); - }; - using Test = std::tuple< - std::vector, - std::vector, - TensorFunc, - std::string, - std::vector>; - std::vector tests = {}; - tests.push_back(Test{ - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - {1, 64, 8, 9}, - {1, 64, 5, 7}, - [](at::Tensor x) { return at::adaptive_avg_pool2d(x, {5, 7}); }, - "nnc_aten_adaptive_avg_pool2d", - toExprHandleVec({5, 7})}); - tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - {100, 200}, - {100}, - [](at::Tensor x) { return at::mean(x, {1}); }, - "nnc_aten_mean", - toExprHandleVec({1, /*keepdim=*/0})}); - for (auto curTest : tests) { - auto [aShape, resShape, torchFunc, externCallName, externCallArgs] = - curTest; - BufHandle A("A", toExprHandleVec(aShape), kFloat); - BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs)); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; - at::Tensor ref = torchFunc(a); - - auto prod = [](std::vector v) { - // NOLINTNEXTLINE(modernize-use-transparent-functors) - return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); - }; - - at::Tensor nnc_result; - std::vector a_buf(prod(aShape), 5.f); - std::vector result_buf(prod(resShape), -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result}); - - llvm_codegen.call({a_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result}); - ir_eval.call({a_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); - } -} - -TEST(ExternalCall, ComputeInterop) { - // This test verifies that Tensors using external calls can be used by and can - // use Tensors built with Compute API. - - BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat); - BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat); - - Tensor Input = Compute( - "Input", - {1, 16, 32, 32}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { return FloatImm::make(5.0f); }); - Tensor Weight = Compute( - "Weight", - {16, 16, 1, 1}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { return FloatImm::make(6.0f); }); - - Tensor ConvResult = Tensor( - ConvResultBuf.node(), - ExternalCall::make( - ConvResultBuf, - "nnc_aten_conv2d", - {BufHandle(Input.buf()), BufHandle(Weight.buf())}, - {})); - Tensor MatmulResult = Tensor( - MatmulResultBuf.node(), - ExternalCall::make( - MatmulResultBuf, - "nnc_aten_matmul", - {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, - {})); - Tensor Result = Compute( - "Result", - {1, 16, 32, 32}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { - return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); - }); - - LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); - - // Inlining should not inline anything here since all Bufs are either defined - // or used in ExternalCalls - we run it just for testing - l.inlineIntermediateBufs(true); - - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; - at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; - at::Tensor t = at::conv2d(input, weight); - at::Tensor t2 = at::matmul(t, t); - at::Tensor ref = t + t2; - - at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 32 * 32, 5.f); - std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector conv_result_buf(1 * 16 * 32 * 32, -1.f); - std::vector matmul_result_buf(1 * 16 * 32 * 32, -1.f); - std::vector result_buf(1 * 16 * 32 * 32, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen( - l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); - - llvm_codegen.call( - {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval( - l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); - - ir_eval.call( - {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Inlining) { - // This test verifies that Tensors using external calls can be used by and - // can use Tensors built with Compute API. - - BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat); - - Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return FloatImm::make(5.0f); - }); - Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return FloatImm::make(4.0f); - }); - Tensor MatmulResult = Tensor( - MatmulResultBuf.node(), - ExternalCall::make( - MatmulResultBuf, - "nnc_aten_matmul", - {BufHandle(A.buf()), BufHandle(B.buf())}, - {})); - Tensor Result = - Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return MatmulResult.load(i, j) + FloatImm::make(3.0f); - }); - - StmtPtr root_stmt = alloc(std::vector( - {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); - LoopNest l(root_stmt, {Result.buf()}); - - // Inlining should not inline anything here since all Bufs are either - // defined or used in ExternalCalls - l.inlineIntermediateBufs(false); - - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones({8, 8}, options) * 5.f; - at::Tensor b = at::ones({8, 8}, options) * 4.f; - at::Tensor t = at::matmul(a, b); - at::Tensor ref = t + 3.f; - - at::Tensor nnc_result; - std::vector result_buf(8 * 8); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Result}); - - llvm_codegen.call({result_buf}); - nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Result}); - - ir_eval.call({result_buf}); - nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, JitCustomFusionOp) { - const char* custom_op_schema_literal = - "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor"; - const char* external_func_name = "nnc_add_mul"; - - auto add_mul_lowering_func = - [external_func_name]( - const std::vector& inputs, - const std::vector& output_shape, - const std::vector& output_strides, - const std::optional& output_type, - at::Device device) { - auto output_dtype = Dtype(*output_type); - torch::jit::tensorexpr::BufHandle result_buf( - "nnc_add_mul_res_buf", output_shape, output_dtype); - const torch::jit::tensorexpr::BufHandle& a = - std::get(inputs[0]); - const torch::jit::tensorexpr::BufHandle& b = - std::get(inputs[1]); - const torch::jit::tensorexpr::BufHandle& c = - std::get(inputs[1]); - torch::jit::tensorexpr::StmtPtr s = - torch::jit::tensorexpr::ExternalCall::make( - result_buf, external_func_name, {a, b, c}, {}); - return Tensor(result_buf.node(), s); - }; - - auto add_mul_external_func = [](int64_t bufs_num, - void** buf_data, - int64_t* buf_ranks, - int64_t* buf_dims, - int64_t* buf_strides, - int8_t* buf_dtypes, - int64_t args_num, - int64_t* extra_args) {}; - - torch::jit::RegisterOperators reg({Operator( - custom_op_schema_literal, - [](const Node* node) -> Operation { - return [](Stack& _stack) { - auto a = std::move(peek(_stack, 0, 3)).toTensor(); - auto b = std::move(peek(_stack, 1, 3)).toTensor(); - auto c = std::move(peek(_stack, 2, 3)).toTensor(); - drop(_stack, 3); - auto result = (a + b) * c; - pack(_stack, std::move(result)); - return 0; - }; - }, - c10::AliasAnalysisKind::FROM_SCHEMA)}); - - auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet(); - custom_operator_set.insert({custom_op_schema_literal}); - - auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); - te_lowering_registry.insert( - parseSchema(custom_op_schema_literal), add_mul_lowering_func); - - auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); - te_nnc_func_registry[external_func_name] = add_mul_external_func; - - std::string graph_string = R"IR( - graph(%a : Float(10, 20, strides=[20, 1], device=cpu), - %b : Float(10, 20, strides=[20, 1], device=cpu), - %c : Float(10, 20, strides=[20, 1], device=cpu)): - %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c) - return (%res))IR"; - - auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::string shape_compute_python_string = R"PY( - def computOutput(a: List[int], b: List[int], c: List[int]): - expandedSizes: List[int] = [] - dimsA = len(a) - dimsB = len(b) - dimsC = len(c) - ndim = max(dimsA, dimsB, dimsC) - for i in range(ndim): - offset = ndim - 1 - i - dimA = dimsA - 1 - offset - dimB = dimsB - 1 - offset - dimC = dimsC - 1 - offset - sizeA = a[dimA] if (dimA >= 0) else 1 - sizeB = b[dimB] if (dimB >= 0) else 1 - sizeC = a[dimC] if (dimC >= 0) else 1 - - if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1: - # TODO: only assertion error is bound in C++ compilation right now - raise AssertionError( - "The size of tensor a {} must match the size of tensor b (" - "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i) - ) - - expandedSizes.append(max(sizeA, sizeB, sizeC)) - - return expandedSizes - )PY"; - auto cu_ptr = torch::jit::compile(shape_compute_python_string); - torch::jit::GraphFunction* gf = - (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput"); - ASSERT_TRUE(gf); - -#ifdef TORCH_ENABLE_LLVM - auto static_graph_case = graph->copy(); - FuseTensorExprs(static_graph_case, 1); - torch::jit::testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("nnc_custom::add_mul") - ->run(*static_graph_case); - - auto dynamic_graph_case = graph->copy(); - auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal); - ASSERT_TRUE(custom_op); - torch::jit::RegisterShapeComputeGraphForSchema( - custom_op->schema(), gf->graph()); - FuseTensorExprs(dynamic_graph_case, 1, false, true); - torch::jit::testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("nnc_custom::add_mul") - ->run(*dynamic_graph_case); -#else - torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph); -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp deleted file mode 100644 index aed73d09d14d..000000000000 --- a/test/cpp/tensorexpr/test_graph_opt.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -class GraphOpt : public ::testing::Test { - public: - void SetUp() override { - old_cat_wo_conditionals_ = getCatWoConditionals(); - getCatWoConditionals() = true; - } - - void TearDown() override { - getCatWoConditionals() = old_cat_wo_conditionals_; - } - - private: - bool old_cat_wo_conditionals_; -}; - -TEST_F(GraphOpt, OptimizeCat) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::log` op must be moved to the inputs of `aten::cat`. - testing::FileCheck() - .check("aten::log") - ->check("aten::log") - ->check("aten::log") - ->check("aten::cat") - ->check_not("aten::log") - ->run(*kernel.graph()); - - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::log(at::cat({x, y, z}, 0)); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCat2) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) - %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::log` and `aten::tanh` ops must be moved to the inputs of - // `aten::cat`. - testing::FileCheck() - .check("aten::log") - ->check("aten::log") - ->check("aten::log") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check_not("aten::log") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCat3) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%a : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) - %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::tanh` op must be moved to the inputs of `aten::cat`. - // But the `aten::mul` op must not be moved since it is not a single-tensor - // op (it has 2 tensor inputs). - testing::FileCheck() - .check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check("aten::mul") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto a = at::rand({60}, at::kFloat); - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; - - std::vector inputs = {a, x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Int(10, strides=[1], device=cpu), - %y : Int(20, strides=[1], device=cpu), - %z : Int(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::tanh` op must be moved to the inputs of `aten::cat`. - // The scalar type of the inputs to `cat` should now be `Float` since they - // are the result of `tanh` which does the type promotion. - testing::FileCheck() - .check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto x = at::randint(std::numeric_limits::max(), {10}, at::kInt); - auto y = at::randint(std::numeric_limits::max(), {20}, at::kInt); - auto z = at::randint(std::numeric_limits::max(), {30}, at::kInt); - auto ref = at::tanh(at::cat({x, y, z}, 0)); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Double(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation should have happened because the `aten::cat` op performs - // type promotion. This case is currently not handled. - testing::FileCheck() - .check("aten::cat") - ->check("aten::log") - ->check_not("aten::cat") - ->check_not("aten::log") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation is expected since the consumers of cat are not - // single-tensor element-wise ops. - testing::FileCheck() - .check("aten::cat") - ->check("aten::mul") - ->check_not("aten::cat") - ->check_not("aten::mul") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(60, strides=[1], device=cpu), - %1 : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %one : int = prim::Constant[value=1]() - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) - %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation is expected since the consumers of cat are not - // single-tensor element-wise ops. - testing::FileCheck() - .check("aten::cat") - ->check("aten::mul") - ->check("aten::add") - ->check_not("aten::cat") - ->check_not("aten::mul") - ->check_not("aten::add") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, AOTGraphPrepPasses) { - const auto graph_string = R"IR( - graph(%x, %y, %z, %i : int): - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - return (%xyz_list, %i))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - removeGraphOutput(g, 1); - replaceListOutputWithTuple(g); - LowerAllTuples(g); - - testing::FileCheck().check("return (%x, %y, %z)")->run(*g); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp deleted file mode 100644 index 4d2f8c6e906e..000000000000 --- a/test/cpp/tensorexpr/test_ir_printer.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include - -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -#include -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(IRPrinter, BasicValueTest) { - ExprHandle a = IntImm::make(2), b = IntImm::make(3); - ExprHandle c = Add::make(a, b); - - std::stringstream ss; - ss << c; - ASSERT_EQ(ss.str(), "2 + 3"); -} - -TEST(IRPrinter, BasicValueTest02) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - - std::stringstream ss; - ss << f; - ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)"); -} - -TEST(IRPrinter, BasicValueTest03) { - ExprHandle a(3.402823466385289e+38f); - ExprHandle b(-3.402823466385289e+38f); - std::stringstream ss; - ss << a << ", " << b; - ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f"); -} - -TEST(IRPrinter, CastTest) { - VarHandle x("x", kHalf); - VarHandle y("y", kFloat); - ExprHandle body = ExprHandle(2.f) + - (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y); - - std::stringstream ss; - ss << body; - ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)"); -} - -TEST(IRPrinter, FunctionName) { - int M = 4; - int N = 20; - - Tensor producer = Compute( - "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return m * n; - }); - - Tensor chunk_0 = Compute( - "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer.load(m, n); - }); - - Tensor chunk_1 = Compute( - "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer.load(m, n + ExprHandle(N / 2)); - }); - - Tensor consumer = Compute( - "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) { - return i * chunk_1.load(i, j); - }); - - LoopNest l({chunk_0, chunk_1, consumer}); - auto body = LoopNest::sanitizeNames(l.root_stmt()); - - std::stringstream ss; - ss << *body; - - const std::string& verification_pattern = - R"IR( - # CHECK: for (int i_2 - # CHECK: for (int j_2 - # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, ss.str()); -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp deleted file mode 100644 index 886213ea9c76..000000000000 --- a/test/cpp/tensorexpr/test_ir_verifier.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include - -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -#include -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(IRVerifier, BitwiseOps) { - VarPtr X = alloc("x", kInt); - VarPtr Y = alloc("y", kFloat); - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, CompareSelect) { - ExprPtr X = alloc(1); - ExprPtr Y = alloc(3.14f); - { - auto a = alloc(X, X, X, Y, kEQ); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y, X, X, kEQ); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Ramp) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kFloat); - { - auto a = alloc(I, J, 4); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Load) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - BufPtr B = alloc( - "b", - std::vector({alloc(10), alloc(20)}), - kFloat); - { - // Indices with different int dtypes (kInt, kLong) are ok - auto a = alloc(B, std::vector({I, J})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_NO_THROW(verify(a)); - } - { - // Float index - auto a = alloc(B, std::vector({K, K})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Multilanes are only allowed in flattened indices - auto multilane_index = alloc(I, alloc(1), 4); - auto a = alloc(B, std::vector({I, multilane_index})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, IfThenElse) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - { - // Condition must be integral - auto a = alloc(K, I, I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Dtypes of true and false exprs must match - auto a = alloc(I, I, J); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Can't have multiple lanes in condition expr - auto a = alloc(alloc(I, 4), I, I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, For) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kInt); - StmtPtr body = alloc(std::vector({})); - { - // Can't have nullptr as a Var - auto a = alloc(nullptr, I, J, body); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Block) { - VarPtr I = alloc("i", kInt); - BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); - { - StmtPtr store = alloc(B, std::vector({I}), I); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - StmtPtr block1 = alloc(std::vector({store})); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - StmtPtr block2 = alloc(std::vector({store})); - // Stmt can't have multiple parents, thus inserting it into several blocks - // is illegal - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(block2)); - } -} - -TEST(IRVerifier, Store) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - BufPtr B = alloc( - "b", - std::vector({alloc(10), alloc(20)}), - kFloat); - { - // Indices with different int dtypes (kInt, kLong) are ok - auto a = alloc(B, std::vector({I, J}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_NO_THROW(verify(a)); - } - { - // Float index - auto a = alloc(B, std::vector({K, K}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Multilanes are only allowed in flattened indices - auto multilane_index = alloc(I, alloc(1), 4); - auto a = alloc(B, std::vector({I, multilane_index}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Value and buf dtypes mismatch - auto a = alloc(B, std::vector({I}), I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp deleted file mode 100644 index dc67928b111a..000000000000 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ /dev/null @@ -1,2133 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -class Kernel : public ::testing::Test { - public: - void SetUp() override { - getTEMustUseLLVMOnCPU() = false; - } -}; - -TEST_F(Kernel, ParallelExternalCallBuf) { - const auto graph_string = R"IR( - graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu), - %1 : Float(1000, 5000, strides=[5000, 1], device=cpu), - %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)): - %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1) - %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) - return (%4))IR"; - auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); -#ifdef TORCH_ENABLE_LLVM - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -#endif -} - -TEST_F(Kernel, InliningIntermediates) { - // here, each mul has only one use, so it should be completely inlined - { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %one : int = prim::Constant[value=1]() - %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) - return (%5))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); - } - { - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), - %1 : Float(5, 3, strides=[3, 1], device=${device})): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %one : int = prim::Constant[value=1]() - %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) - %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) - %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) - return (%4, %5))IR"; - for (bool use_cuda : {false, true}) { - if (!torch::cuda::is_available() && use_cuda) { - continue; - } - - at::jit::TemplateEnv env; - env.s("device", use_cuda ? "cuda:0" : "cpu"); - const auto graph_string = format(graph_template, env); - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - // aten_mul only has one use, inlined completely - torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); - - // aten_sub should be removed by the CUDA backend by metavar rewriting - // and by the CPU backend by horizontal fusion. - torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str()); - } - } -} - -TEST_F(Kernel, PreAllocIntermediateBufs) { - const auto graph_string = R"IR( -graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), - %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): - %2 : int = prim::Constant[value=1]() - %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 - %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::matmul(a, b) + a; - TensorExprKernel k(graph, {}, {}, true); - - std::vector inputs = {a, b}; - auto stmt = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *stmt; - - // Check whether the intermediate buffer has been added to constants - auto constants = k.getConstantDescriptors(); - ASSERT_EQ(constants.size(), 1); - - // Check the IR we produced - torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str()); - torch::jit::testing::FileCheck().check_not("Free")->run(oss.str()); - - // Check correctness - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, _1) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, _2) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, _3) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[12, 2], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2), Slice(None, None, 2)}); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, Huge) { - const auto graph_string = R"IR( - graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): - %1 : int = prim::Constant[value=0]() - %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) - %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - std::ostringstream oss; - oss << *k.getCodeGenStmt(); - // The 4000000000 iterations loop will be split into 500000000 x 8 and the - // outer loop will be parallel. If LLVM is not present, it will not be split, - // and to cover both of these cases we're looking for 00000000ll; in the - // output. - const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST_F(Kernel, ParallelStrided) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), - %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): - %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) - .index( - {Slice(None, None, 2), - Slice(None, None, 2), - Slice(None, None, 2)}); - auto ref = a * (a * b); - auto o = at::zeros_like(ref); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, DISABLED_Shape_Inference) { - // disabled: doesn't do stride propagation, and isn't being used currently - - // Test TensorExpr shape inference capabilities: it should only require shapes - // for the inputs - { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[12, 2], device=cpu)): - %2 : Tensor = aten::mul(%0, %1) - %3 : Tensor = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2), Slice(None, None, 2)}); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - const auto graph_string = R"IR( - graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), - %1 : Float(8, 8, strides=[8, 1], device=cpu)): - %2 : Tensor = aten::mul(%0, %1) - %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) - %r : Tensor = aten::mul(%3, %4) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat)); - auto t = torch::chunk(a * b, 2, 1); - auto ref = t[0] * t[1]; - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - TORCH_CHECK_EQ(o.sizes()[0], 8); - TORCH_CHECK_EQ(o.sizes()[1], 4); - for (size_t i = 0; i < 8 * 4; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that shape inference handles aten::unsqueeze - - const auto graph_string = R"IR( - graph(%a : Float(4, 2, strides=[2, 1], device=cpu), - %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), - %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): - %one : int = prim::Constant[value=1]() - %minus_one : int = prim::Constant[value=-1]() - %three : int = prim::Constant[value=3]() - %minus_four : int = prim::Constant[value=-4]() - %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2] - %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1] - %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1] - %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2] - %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1] - %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2] - return (%abc))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) * - at::unsqueeze(c, -4); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_mul)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that shape inference handles aten::cat - - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({a, b, c}, 1); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that we throw an error when input list for aten::cat is empty - - const auto graph_string = R"IR( - graph(): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct() - %r : Tensor = aten::cat(%inputs, %dim) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - auto compile = [&]() { - TensorExprKernel k(graph); - k.getCodeGenStmt(); - }; - ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat"); - } - { - // Test that we throw an error when 'dim' passed to aten::cat is invalid - - const auto ir_dim_99 = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=99]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) - return (%r))IR"; - const auto ir_dim_minus_6 = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=-6]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) - return (%r))IR"; - - auto compile = [](const std::string& graph_string) { - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - k.getCodeGenStmt(); - }; - ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index"); - ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index"); - } -} - -TEST_F(Kernel, CatInputTypesPromotion) { - { - // Test that we properly promote input types for aten::cat - - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); - auto ref = at::cat({a, b, c}, 1); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); - } - } -} - -TEST_F(Kernel, ToDType) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): - %1 : NoneType = prim::Constant() - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=6]() - %4 : int = prim::Constant[value=15]() - %5 : int = prim::Constant[value=5]() - %6 : bool = prim::Constant[value=1]() - %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1) - %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4) - %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6) - %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1) - %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1) - %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1) - return (%k.3))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_to -# CHECK-NEXT: } -# CHECK-NEXT: })IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16)); - auto ref = - at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat)); - - std::vector inputs = {a}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); -#endif -} - -TEST_F(Kernel, CatAndInlineWithAConstantDim) { - const auto graph_string = R"IR( - graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), - %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = prim::ListConstruct(%0, %1) - %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) - %6 : Tensor[] = prim::ListConstruct(%5) - %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) - %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) - return (%8, %7))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - - auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); - - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, CatWithEmptyInputs) { - bool curr_cat_wo_conditionals = getCatWoConditionals(); - for (auto cat_wo_conditionals : {true, false}) { - getCatWoConditionals() = cat_wo_conditionals; - const auto graph_string = R"IR( - graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu), - %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)): - %3 : int = prim::Constant[value=0]() - %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0) - %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1) - %10 : Tensor[] = prim::ListConstruct(%6, %7) - %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3) - return (%11))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - - auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0); - - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } - getCatWoConditionals() = curr_cat_wo_conditionals; -} - -TEST_F(Kernel, CatWoConditionals) { - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) - return (%r))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK: for -# CHECK: for -# CHECK: aten_cat -# CHECK: for -# CHECK: for -# CHECK: aten_cat -# CHECK: for -# CHECK: for -# CHECK: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({a, b, c}, 1); - - std::vector inputs = {a, b, c}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - getCatWoConditionals() = old_cat_wo_conditionals; -} - -TEST_F(Kernel, OptimizeConditionals) { - bool old_cat_wo_conditionals = getCatWoConditionals(); - bool old_opt_conditionals = getOptConditionals(); - getCatWoConditionals() = false; - getOptConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(5, 3, strides=[3, 1], device=cpu), - %b : Float(5, 7, strides=[7, 1], device=cpu), - %c : Float(5, 9, strides=[9, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim) - %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r) - return (%t))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_relu -# CHECK: for -# CHECK-NEXT: aten_relu -# CHECK: for -# CHECK-NEXT: aten_relu -# CHECK-NOT: Allocate -# CHECK-NOT: Free)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat)); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::relu(at::cat({a, b, c}, 1)); - - std::vector inputs = {a, b, c}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - getOptConditionals() = old_opt_conditionals; - getCatWoConditionals() = old_cat_wo_conditionals; -} - -namespace { - -std::string dtypeConstant(ScalarType scalar_type) { - if (scalar_type == ScalarType::Undefined) { - return "None = prim::Constant()"; - } else { - at::jit::TemplateEnv env_dtype; - env_dtype.d("scalar_type", static_cast(scalar_type)); - return format("int = prim::Constant[value=${scalar_type}]()", env_dtype); - } -} - -at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { - int64_t numel = std::accumulate( - sizes.begin(), - sizes.end(), - 1, - // NOLINTNEXTLINE(modernize-use-transparent-functors) - std::multiplies()); - std::vector values(numel); - std::iota(values.begin(), values.end(), 0); - auto a = at::tensor(values, options); - return a.reshape(sizes); -} - -} // namespace - -TEST_F(Kernel, SumAllAxes) { - // Test lowering of sum on all axes. - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : ${dtype} - %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1) - return (%2))IR"; - auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - at::jit::TemplateEnv env; - env.s("dtype", dtypeConstant(scalar_type)); - if (scalar_type == ScalarType::Undefined) { - env.s("out_dtype", "Float"); - } else { - env.s("out_dtype", "Double"); - } - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); - std::optional dtype; - if (scalar_type != ScalarType::Undefined) { - dtype = static_cast(scalar_type); - } - auto ref = a.sum(/*dtype=*/dtype); - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } -} - -std::string li_to_str(at::ArrayRef li) { - std::stringstream out; - bool first = true; - for (auto elem : li) { - if (!first) { - out << ", "; - } - out << elem; - first = false; - } - return out.str(); -} - -TEST_F(Kernel, SumOneAxis) { - // Test lowering of sum on one axis. - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : int[] = prim::Constant[value=[${dim}]]() - %2 : bool = prim::Constant[value=${keepdim}]() - %3 : ${dtype} - %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) - return (%4))IR"; - auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - for (int dim = -a.dim(); dim < a.dim(); ++dim) { - for (bool keepdim : {false, true}) { - for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - at::jit::TemplateEnv env; - env.d("dim", dim); - env.d("keepdim", keepdim); - env.s("dtype", dtypeConstant(scalar_type)); - std::optional dtype; - if (scalar_type != ScalarType::Undefined) { - dtype = static_cast(scalar_type); - } - auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); - if (scalar_type == ScalarType::Undefined) { - env.s("out_dtype", "Float"); - } else { - env.s("out_dtype", "Double"); - } - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - const auto graph_string = format(graph_template, env); - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t -# CHECK-NEXT: sum -# CHECK-NEXT: for (int64_t -# CHECK-NEXT: sum)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); - } - } - } -} - -TEST_F(Kernel, SumMultipleAxes) { - // Test lowering of sum on multiple axes. - const auto graph_template = R"IR( - graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)): - %1 : int = prim::Constant[value=${dim1}]() - %2 : int = prim::Constant[value=${dim2}]() - %3 : int[] = prim::ListConstruct(%1, %2) - %4 : bool = prim::Constant[value=${keepdim}]() - %5 : ${dtype} - %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5) - return (%6))IR"; - auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - // Only iterate over positive values of axes to keep the running time - // reasonable, since the number of pairs is quadratic. - for (const auto dim1 : c10::irange(a.dim())) { - for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { - for (bool keepdim : {false, true}) { - at::jit::TemplateEnv env; - env.d("dim1", dim1); - env.d("dim2", dim2); - env.d("keepdim", keepdim); - env.s("dtype", dtypeConstant(ScalarType::Undefined)); - auto o = at::empty({}, TensorOptions(kCPU)); - auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); - - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: sum)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } - } - } -} - -// This test and the following ones testing Softmax only tests with dim set -// to one of the valid input dimensions. It does not test with dim=None -// because that is supposed to be deprecated. -TEST_F(Kernel, Softmax2D) { - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %dt_float : int = prim::Constant[value=7]() - %dt_none : NoneType = prim::Constant() - %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt}) - return (%4))IR"; - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 5 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: aten_softmax)IR"; - - for (bool empty_dtype : {false, true}) { - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - auto other_dim = (softmax_dim + 1) % a.dim(); - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - env.s("dt", empty_dtype ? "dt_none" : "dt_float"); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("other_dim", other_dim); - ver_env.d("other_dim_size", a.sizes()[other_dim]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = - format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, - // oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } - } -} - -TEST_F(Kernel, Softmax3D) { - const auto graph_template = R"IR( - graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int = prim::Constant[value=7]() - %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) - return (%3))IR"; - - auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} - # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 3 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 - # CHECK-NEXT: aten_softmax)IR"; - - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - std::vector other_dims; - for (const auto i : c10::irange(a.dim())) { - if (i != softmax_dim) { - other_dims.push_back(i); - } - } - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("dim1", other_dims[0]); - ver_env.d("dim1_size", a.sizes()[other_dims[0]]); - ver_env.d("dim2", other_dims[1]); - ver_env.d("dim2_size", a.sizes()[other_dims[1]]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } -} - -TEST_F(Kernel, Softmax4D) { - const auto graph_template = R"IR( - graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int = prim::Constant[value=7]() - %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) - return (%3))IR"; - - auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} - # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} - # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 2 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 - # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 - # CHECK-NEXT: aten_softmax)IR"; - - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - std::vector other_dims; - for (const auto i : c10::irange(a.dim())) { - if (i != softmax_dim) { - other_dims.push_back(i); - } - } - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("dim1", other_dims[0]); - ver_env.d("dim1_size", a.sizes()[other_dims[0]]); - ver_env.d("dim2", other_dims[1]); - ver_env.d("dim2_size", a.sizes()[other_dims[1]]); - ver_env.d("dim3", other_dims[2]); - ver_env.d("dim3_size", a.sizes()[other_dims[2]]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } -} - -TEST_F(Kernel, SignTest) { - const auto graph_template = R"IR( - graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): - %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) - return (%2))IR"; - - auto run_test = [](const std::string& graph_string, const at::Tensor& input) { - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - - std::vector inputs = {input}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = at::sign(input); - ASSERT_TRUE(at::allclose(o, ref)); - }; - auto common_options = at::TensorOptions() - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - int default_input_size = 100; - for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { - at::Tensor corner_case_inputs; - at::jit::TemplateEnv env; - auto options = common_options; - switch (scalar_type) { - case ScalarType::Float: { - env.s("dtype", "Float"); - options = options.dtype(at::kFloat); - std::vector input_float = { - 0.0f, - -0.0f, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity(), - std::nanf("1"), - -std::nanf("1")}; - corner_case_inputs = at::from_blob( - input_float.data(), - {static_cast(input_float.size())}, - options); - auto rand_input = at::rand({default_input_size}, options); - auto input = at::cat({rand_input, corner_case_inputs}); - env.d("size", at::numel(input)); - const auto graph_string = format(graph_template, env); - run_test(graph_string, input); - break; - } - case ScalarType::Double: { - env.s("dtype", "Double"); - options = options.dtype(at::kDouble); - std::vector input_double = { - 0.0, - -0.0, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity(), - std::nan("1"), - -std::nan("1")}; - corner_case_inputs = at::from_blob( - input_double.data(), - {static_cast(input_double.size())}, - options); - auto rand_input = at::rand({default_input_size}, options); - auto input = at::cat({rand_input, corner_case_inputs}); - env.d("size", at::numel(input)); - const auto graph_string = format(graph_template, env); - run_test(graph_string, input); - break; - } - default: - throw unsupported_dtype(); - } - } -} - -TEST_F(Kernel, InlineProducerIntoReduction) { - // Inline producer (mul) into reduction (sum). - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) - %3 : int = prim::Constant[value=7]() - %4 : Double(device=cpu) = aten::sum(%2, %3) - return (%4))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - // Check the IR we produced. - // We should have only one loop in the end. - const std::string& verification_pattern = - R"IR( - # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 - # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 - # CHECK-NEXT: sum - # CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = (a * b).sum(at::kDouble); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, InlineReductionIntoConsumer) { - // Inline producer (mul %2) into reduction (sum %4) but DO NOT - // inline the reduction into consumer (mul %4). - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : int = prim::Constant[value=6]() - %4 : Float(device=cpu) = aten::sum(%2, %3) - %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4) - return (%5))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - // Check the IR we produced. - // We should have two loops in the end. - const std::string& verification_pattern = - R"IR( - # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 - # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 - # CHECK-NEXT: sum - # CHECK: for (int64_t i_2 = 0ll; i_2 < 5 - # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3 - # CHECK-NEXT: aten_mul - # CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = (a * b).sum(at::kFloat) * (a * b); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, SanitizeNames_CUDA) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0), - %1 : Float(5, 3, strides=[3, 1], device=cuda:0)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%4))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - graph->inputs().at(0)->setDebugName("aten::add:"); - graph->inputs().at(1)->setDebugName("aten::add_"); - TensorExprKernel k(graph); - auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto ref = a * (a * b); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, SanitizeConstants_CUDA) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): - %none : NoneType = prim::Constant() - %size : int = prim::Constant[value=16]() - %sizes : int[] = prim::ListConstruct(%size, %size) - %30 : Device = prim::Constant[value="cuda"]() - %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) - %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we insert a call to - // aten::ones and then const-prop it - ConstantPropagation(graph); - - // We set the name of the constant to include special characters that are - // not allowed. This should be fixed by the sanitizer in TensorExprKernel. - graph->nodes().front()->output()->setDebugName("illegal.name"); - - // Check if we have a constant node with illegal name in the graph. - auto const_node = graph->nodes().front(); - ASSERT_EQ(const_node->kind(), prim::Constant); - ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, ConstantTensors) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %size : int = prim::Constant[value=16]() - %sizes : int[] = prim::ListConstruct(%size, %size) - %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we insert a call to - // aten::ones and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, ConstantTensorsNonContiguous) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %dtype : int = prim::Constant[value=6]() - %c0 : int = prim::Constant[value=0]() - %c256 : int = prim::Constant[value=256]() - %c16 : int = prim::Constant[value=16]() - %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) - %sizes : int[] = prim::ListConstruct(%c16, %c16) - %y_t : Tensor = aten::view(%y_flat, %sizes) - %y : Tensor = aten::t(%y_t) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we generate several aten - // calls to produce non-contiguous constant tensor and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat)) - .view({16, 16}) - .t(); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, RunFast) { -#ifdef TORCH_ENABLE_LLVM - // TODO: Implement call_raw in IREval and remove the ifdef - - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - - k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, RunWithAllocatedOutputs) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - - std::vector args = {o, a, b}; - std::vector stack = fmap(args); - k.runWithAllocatedOutputs(stack); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, CodegenInspection) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %dtype : int = prim::Constant[value=6]() - %c0 : int = prim::Constant[value=0]() - %c256 : int = prim::Constant[value=256]() - %c16 : int = prim::Constant[value=16]() - %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) - %sizes : int[] = prim::ListConstruct(%c16, %c16) - %y_t : Tensor = aten::view(%y_flat, %sizes) - %y : Tensor = aten::t(%y_t) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we generate several aten - // calls to produce non-contiguous constant tensor and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - // Check that we could retrieve generated assembly - auto asm_str = k.getCodeText("asm"); - const std::string& asm_verification_pattern = - R"ASM( - # CHECK: .text - # CHECK: retq)ASM"; - torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str); - - // Check that we could retrieve info about codegen parameters - auto constants = k.getConstantDescriptors(); - auto buf_args = k.getBufferArgs(); - // Expected buf args: [input0, output0, constant0] - ASSERT_EQ(buf_args.size(), 3); - ASSERT_EQ(constants.size(), 1); - ASSERT_TRUE( - !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar()); -#endif -} - -Tensor lowerNanToNum( - const std::vector& inputs, - const std::vector& outputShape, - const std::vector& outputStrides, - const std::optional& outputType, - at::Device device) { - auto input_buf = std::get(inputs[0]); - auto e = Compute( - "custom_nan_to_num", - outputShape, - outputStrides, - [&](const std::vector& axes) { - std::vector indices(axes.begin(), axes.end()); - auto load = input_buf.load(indices); - return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load); - }); - return e; -} - -TEST_F(Kernel, CustomLowering) { - const auto graph_string = R"IR( - graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): - %none : NoneType = prim::Constant() - %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) - return (%y) -)IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - std::unordered_map lowerings = { - {aten::nan_to_num, lowerNanToNum}}; - TensorExprKernel k(graph, lowerings); - - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - - // Check that our custom lowering is actually used - torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str()); - torch::jit::testing::FileCheck().check("isnan")->run(oss.str()); -} - -TEST_F(Kernel, Vectorize) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), - %1 : Float(100, 16, strides=[16, 1], device=cpu)): - %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) - %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 100 * 16; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. -TEST_F(Kernel, DISABLED_FlattenVectorize) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), - %1 : Float(100, 3, strides=[3, 1], device=cpu)): - %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 100 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, Strided1dWithinBounds) { - auto ir = R"IR( - graph(%0 : Float(3, strides=[1], device=cpu), - %1 : Float(3, strides=[2], device=cpu)): - %2 : int = prim::Constant[value=1]() - %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2) - return (%3))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - - auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2)}); - auto expect = a + b; - - std::vector inputs = {a, b}; - - std::vector stack = fmap(inputs); - k.run(stack); - - auto output = stack[0].toTensor(); - - for (size_t i = 0; i < 3; ++i) { - TORCH_CHECK_EQ( - ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); - } -} - -TEST_F(Kernel, InputAsOutput) { - const auto graph_string = R"IR( - graph(%x : Float(5, 3, strides=[3, 1], device=cpu), - %y : Float(5, 3, strides=[1, 5], device=cpu)): - return (%x, %y))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - TensorExprKernel k(graph); - std::vector inputs = {x, y}; - - std::vector stack = fmap(inputs); - k.run(stack); - CHECK(at::allclose(x, stack[0].toTensor())); - CHECK(at::allclose(y, stack[1].toTensor())); -} - -TEST_F(Kernel, ScalarOut) { - auto ir = R"IR( -graph(%x : int, %y : int): - %z : int = aten::mul(%x, %y) - %r : int = aten::mul(%z, %x) - return (%r, %z))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - - // Verify the generated IR. We expect to see a scalar variable (Let) followed - // by a store to a 0-dim buffer. - const std::string& verification_pattern = R"IR( -# CHECK: int64_t -# CHECK-NEXT: [0ll] = -# CHECK-NEXT: int64_t -# CHECK-NEXT: [0ll] = -)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - int64_t x = 2, y = 3, r = 0, z = 0; - - // Verify that TEK::runFast works correctly with scalar outputs - std::vector inputs = {&x, &y}; - std::vector outputs = {&r, &z}; - k.runFast(inputs, outputs); - TORCH_CHECK_EQ(z, x * y); - TORCH_CHECK_EQ(r, z * x); - - // Verify that TEK::run works correctly with scalar outputs - std::vector stack = {x, y}; - k.run(stack); - TORCH_CHECK_EQ(stack[0], x * y * x); - TORCH_CHECK_EQ(stack[1], x * y); -} - -TEST_F(Kernel, ScalarTensorOut) { - auto ir = R"IR( -graph(%x : int, - %xt : Long(3, strides=[1], device=cpu), - %y : int, - %yt : Long(3, strides=[1], device=cpu)): - %z : int = aten::mul(%x, %y) - %r : int = aten::mul(%z, %x) - %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y) - %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt) - return (%r, %rt, %z, %zt))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - int64_t x = 2, y = 3, r = 0, z = 0; - auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2; - auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3; - auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); - auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); - - // Verify that TEK::runFast works correctly with mixed scalar and tensor - // inputs/outputs - std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; - std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; - k.runFast(inputs, outputs); - TORCH_CHECK_EQ(z, x * y); - TORCH_CHECK_EQ(r, z * x); - ASSERT_TRUE(at::equal(zt, xt * yt)); - ASSERT_TRUE(at::equal(rt, zt * xt)); - - // Verify that TEK::run works correctly with mixed scalar and tensor - // inputs/outputs - std::vector stack = {x, xt, y, yt}; - k.run(stack); - TORCH_CHECK_EQ(stack[0], x * y * x); - ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); - TORCH_CHECK_EQ(stack[2], x * y); - ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); -} - -TEST_F(Kernel, FuseLoopsWithVariableBounds) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu), - %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim1, int dim2) { - auto a = - at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto c = - at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b, c}, 1); - - std::vector stack = - fmap(std::vector({a, b, c})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -TEST_F(Kernel, FuseLoopsWithVariableConcatDim) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3, -4, -5}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim1, int dim2, int dim3) { - auto a = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto c = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b, c}, 1); - - std::vector stack = - fmap(std::vector({a, b, c})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - stack.emplace_back(dim3); - stack.emplace_back(3 * dim3); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20, 15); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int, - %SS_6 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3, -4, -5, -6}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t j -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) { - auto a = - at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b}, 1); - - std::vector stack = fmap(std::vector({a, b})); - stack.emplace_back(dim2); - stack.emplace_back(dim3); - stack.emplace_back(dim4); - stack.emplace_back(dim5); - stack.emplace_back(dim4 + dim5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20, 15, 8); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp deleted file mode 100644 index f6ffc84f62c0..000000000000 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ /dev/null @@ -1,1799 +0,0 @@ -#ifdef TORCH_ENABLE_LLVM -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -using LLVMExprEval = ExprEval; - -// Typed tests, can't use gtest params here due to the way we instantiate tests. -#define TEST_LLVM_SCALAR_TYPES(_) \ - _(uint8_t, Byte, 24) \ - _(int8_t, Char, -20) \ - _(int16_t, Short, 3332) \ - _(int, Int, 123456) \ - _(int64_t, Long, 2631563121321) \ - _(float, Float, 0.122) \ - _(double, Double, 0.21312) \ - _(at::Half, Half, 0.128f) - -#define IMM_TEST(Type, Name, Val) \ - TEST(LLVM, Name##ImmTest) { \ - auto a = Name##Imm::make(Val); \ - LLVMExprEval cg(a); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(IMM_TEST) -#undef IMM_TEST - -#define ADD_TEST(Type, Name, Val) \ - TEST(LLVM, Name##AddTest) { \ - auto a = Name##Imm::make(Val); \ - auto b = Name##Imm::make(Val * 2); \ - auto c = Add::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val * 3, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val * 3); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(ADD_TEST) -#undef ADD_TEST - -#define SUB_TEST(Type, Name, Val) \ - TEST(LLVM, Name##SubTest) { \ - auto a = Name##Imm::make(Val * 2); \ - auto b = Name##Imm::make(Val); \ - auto c = Sub::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(SUB_TEST) -#undef SUB_TEST - -#define MUL_TEST(Type, Name, Val) \ - TEST(LLVM, Name##MulTest) { \ - auto a = Name##Imm::make(Val); \ - auto b = Name##Imm::make((Type)4); \ - auto c = Mul::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val * 4, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val * 4); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(MUL_TEST) -#undef MUL_TEST - -#define DIV_TEST(Type, Name, Val) \ - TEST(LLVM, Name##DivTest) { \ - auto a = Name##Imm::make((Type)6); \ - auto b = Name##Imm::make((Type)3); \ - auto c = Div::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), 2, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), 2); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(DIV_TEST) -#undef DIV_TEST - -TEST(LLVM, IntToFloatCastTest) { - auto a = IntImm::make(2); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b, {}); - ASSERT_EQ(cg.value(), 2.0); -} - -TEST(LLVM, FloatToIntCastTest) { - auto a = FloatImm::make(2.0); - auto b = Cast::make(kInt, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, IntToLongCastTest) { - auto a = IntImm::make(12345); - auto b = Cast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 12345); -} - -TEST(LLVM, ByteToCharCastTest) { - auto a = ByteImm::make(250); - auto b = Cast::make(kChar, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), (int8_t)250); -} - -TEST(LLVM, HalfToLongCastTest) { - auto a = HalfImm::make(2.0); - auto b = Cast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, ByteToDoubleCastTest) { - auto a = ByteImm::make(2); - auto b = Cast::make(kDouble, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, FloatToByteCastTest) { - auto a = FloatImm::make(254.0); - auto b = Cast::make(kByte, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 254); -} - -TEST(LLVM, FloatToCharCastTest) { - auto a = FloatImm::make(-2.0); - auto b = Cast::make(kChar, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), -2); -} - -TEST(LLVM, ByteToFloatCastTest) { - auto a = ByteImm::make(254); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 254.0); -} - -TEST(LLVM, CharToFloatCastTest) { - auto a = CharImm::make(-2); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), -2.0); -} - -TEST(LLVM, BitCast) { - /* constexpr int16_t ref16 = 1337; */ - constexpr int32_t ref32 = 1337; - constexpr int64_t ref64 = 1337; - constexpr float reff32 = 1337.0f; - constexpr double reff64 = 1337.0f; - - // this is broken - /*{ - at::Half k_; - at::Half* k = &k_; - *reinterpret_cast(k) = ref16; - auto a = HalfImm::make(k); - auto b = BitCast::make(kShort, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref16); - }*/ - - { - float k = raw_bitcast(ref32); - auto a = FloatImm::make(k); - auto b = BitCast::make(kInt, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref32); - } - - { - double k = raw_bitcast(ref64); - auto a = DoubleImm::make(k); - auto b = BitCast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref64); - } - - { - int64_t k = raw_bitcast(reff64); - auto a = LongImm::make(k); - auto b = BitCast::make(kDouble, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), reff64); - } - - { - int32_t k = raw_bitcast(reff32); - auto a = IntImm::make(k); - auto b = BitCast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), reff32); - } -} - -TEST(LLVM, fastLogFloat) { - const int kTotalSize = 128 * 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - LLVMCodeGen ir_eval(stmt, {a_buf, b_buf}); - ir_eval.call({a_v, b_v}); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::log(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_FLOAT_EQ(test, ref); - } - } -} - -TEST(LLVM, LetTest01) { - BufHandle a("A", {1}, kFloat); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kFloat); - auto block = Block::make({ - Let::make(x, 3.f), - a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))), - }); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f); -} - -TEST(LLVM, LetTest02) { - BufHandle a("A", {1}, kFloat); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - auto block = Block::make( - {Let::make(x, 3.f), - Let::make(y, 6.f), - a.store( - {IntImm::make(0)}, - ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))}); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f); -} - -TEST(LLVM, LetTestMultitype) { - BufHandle a("A", {1}, kDouble); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kByte); - VarHandle y("y", kHalf); - auto block = Block::make( - {Let::make(x, 3), - Let::make(y, 6.f), - a.store( - {0}, - Cast::make( - kDouble, - ExprHandle(2.f) + - (x * ExprHandle(3.f) + y * ExprHandle(4.f))))}); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f); -} - -TEST(LLVM, BufferTest) { - BufHandle a("A", {32}, kFloat); - std::vector v(5); - std::vector args({v.data()}); - auto rv = IntImm::make(0); - LLVMExprEval cg(rv, {a}); - ASSERT_EQ(cg.value(args), 0); -} - -TEST(LLVM, BlockTest) { - BufHandle a("A", {32}, kInt); - std::vector v = {1, 2}; - std::vector args({v.data()}); - - auto block = Block::make({ - a.store({0}, 3), - a.store({1}, 4), - a.store({0}, 4), - }); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 4); - ASSERT_EQ(v[1], 4); -} - -TEST(LLVM, LoadStoreTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - - auto store = b.store({0}, a.load(0)); - LLVMCodeGen cg(store, {a, b}); - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 42); - ASSERT_EQ(b_buffer[0], 42); -} - -TEST(LLVM, IfThenElseTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - std::vector c_buffer = {1}; - - auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0)); - LLVMCodeGen cg(store, {a, b, c}); - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 42); - ASSERT_EQ(b_buffer[0], 42); -} - -// if (x < 10) x = x + 1 -TEST(LLVM, CondNoFalseBlockTest) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); - - for (int32_t x_value : {0, 10, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(cond, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - ASSERT_EQ(x_buffer[0], x_value + 1); - } else { - ASSERT_EQ(x_buffer[0], x_value); - } - } -} - -// if (x < 10) { -// x = x + 1; -// } else { -// x = x - 1; -// } -TEST(LLVM, CondTest) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = - Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - auto block = Block::make({ - cond, - x.store({0}, x.load(0) * 2), - }); - - for (int32_t x_value : {0, 10, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(block, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); - } else { - ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); - } - } -} - -// if (x < 10) { -// if (x > 5) { -// x = x + 1; -// } else { -// x = x - 1; -// } -// } else { -// if (x <= 15) { -// x = x + 2; -// } else { -// x = x - 2; -// } -// } -TEST(LLVM, CondNestedTest) { - BufHandle x("X", {1}, kInt); - auto true_cmp = - CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); - auto true_cond = Cond::make( - true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - auto false_cmp = - CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); - auto false_cond = Cond::make( - false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = Cond::make(cmp, true_cond, false_cond); - - for (int32_t x_value : {0, 8, 15, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(cond, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - if (x_value > 5) { - ASSERT_EQ(x_buffer[0], x_value + 1); - } else { - ASSERT_EQ(x_buffer[0], x_value - 1); - } - } else { - if (x_value <= 15) { - ASSERT_EQ(x_buffer[0], x_value + 2); - } else { - ASSERT_EQ(x_buffer[0], x_value - 2); - } - } - } -} - -TEST(LLVM, DirectVectorization) { - constexpr int M = 3; - constexpr int N = 64; - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {M, N}, kFloat); - BufHandle c("c", {M, N}, kFloat); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - StmtPtr s = For::make( - m, - 0, - M, - Store::make( - c, - {Ramp::make(m * 64, 1, 64)}, - Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) * - Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)}))); - LLVMCodeGen cg(s, {a, b, c}); -} - -TEST(LLVM, VecLoadStoreTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {1, 1, 1, 1}; - std::vector b_buffer = {2, 2, 2, 2}; - - auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)})); - LLVMCodeGen cg(store, {a, b}); - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 1); - ASSERT_EQ(a_buffer[1], 1); - ASSERT_EQ(a_buffer[2], 1); - ASSERT_EQ(a_buffer[3], 1); - ASSERT_EQ(b_buffer[0], 1); - ASSERT_EQ(b_buffer[1], 1); - ASSERT_EQ(b_buffer[2], 1); - ASSERT_EQ(b_buffer[3], 1); -} - -#define FLOAT_INTRINSICS_TEST(Name, Lanes) \ - TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ - BufHandle a("A", {1}, kFloat); \ - BufHandle b("B", {1}, kFloat); \ - float val = 0.5f; \ - std::vector a_buffer(Lanes, val); \ - std::vector b_buffer(Lanes, val); \ - auto store = b.store( \ - {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ - LLVMCodeGen cg(store, {a, b}); \ - std::vector args({a_buffer.data(), b_buffer.data()}); \ - ASSERT_EQ(cg.value(args), 0); \ - for (const auto i : c10::irange(Lanes)) { \ - ASSERT_FLOAT_EQ(a_buffer[i], val); \ - } \ - } // namespace jit -FLOAT_INTRINSICS_TEST(erf, 4) -FLOAT_INTRINSICS_TEST(erfc, 4) -FLOAT_INTRINSICS_TEST(acos, 4) -FLOAT_INTRINSICS_TEST(asin, 4) -FLOAT_INTRINSICS_TEST(atan, 4) -FLOAT_INTRINSICS_TEST(cosh, 4) -FLOAT_INTRINSICS_TEST(sinh, 4) -FLOAT_INTRINSICS_TEST(tanh, 4) -FLOAT_INTRINSICS_TEST(expm1, 4) -FLOAT_INTRINSICS_TEST(lgamma, 4) -FLOAT_INTRINSICS_TEST(erf, 8) -FLOAT_INTRINSICS_TEST(erfc, 8) -FLOAT_INTRINSICS_TEST(acos, 8) -FLOAT_INTRINSICS_TEST(asin, 8) -FLOAT_INTRINSICS_TEST(atan, 8) -FLOAT_INTRINSICS_TEST(cosh, 8) -FLOAT_INTRINSICS_TEST(sinh, 8) -FLOAT_INTRINSICS_TEST(tanh, 8) -FLOAT_INTRINSICS_TEST(expm1, 8) -FLOAT_INTRINSICS_TEST(lgamma, 8) -#undef FLOAT_INTRINSICS_TEST - -#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ - TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ - BufHandle a("A", {1}, kDouble); \ - BufHandle b("B", {1}, kDouble); \ - float val = 0.5f; \ - std::vector a_buffer(Lanes, val); \ - std::vector b_buffer(Lanes, val); \ - auto store = b.store( \ - {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ - LLVMCodeGen cg(store, {a, b}); \ - std::vector args({a_buffer.data(), b_buffer.data()}); \ - ASSERT_EQ(cg.value(args), 0); \ - for (const auto i : c10::irange(Lanes)) { \ - ASSERT_FLOAT_EQ(a_buffer[i], val); \ - } \ - } // namespace jit -DOUBLE_INTRINSICS_TEST(erf, 2) -DOUBLE_INTRINSICS_TEST(erfc, 2) -DOUBLE_INTRINSICS_TEST(acos, 2) -DOUBLE_INTRINSICS_TEST(asin, 2) -DOUBLE_INTRINSICS_TEST(atan, 2) -DOUBLE_INTRINSICS_TEST(cosh, 2) -DOUBLE_INTRINSICS_TEST(sinh, 2) -DOUBLE_INTRINSICS_TEST(tanh, 2) -DOUBLE_INTRINSICS_TEST(expm1, 2) -DOUBLE_INTRINSICS_TEST(lgamma, 2) -DOUBLE_INTRINSICS_TEST(erf, 4) -DOUBLE_INTRINSICS_TEST(erfc, 4) -DOUBLE_INTRINSICS_TEST(acos, 4) -DOUBLE_INTRINSICS_TEST(asin, 4) -DOUBLE_INTRINSICS_TEST(atan, 4) -DOUBLE_INTRINSICS_TEST(cosh, 4) -DOUBLE_INTRINSICS_TEST(sinh, 4) -DOUBLE_INTRINSICS_TEST(tanh, 4) -DOUBLE_INTRINSICS_TEST(expm1, 4) -DOUBLE_INTRINSICS_TEST(lgamma, 4) -#undef DOUBLE_INTRINSICS_TEST - -TEST(LLVM, VectorizerLoadStoreTest) { - BufHandle a("A", {1}, kInt); - - Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - - ASSERT_TRUE(to(to(s)->front()) == nullptr); - - LLVMCodeGen cg(s, {a, c_buf}); - - std::vector a_vec(4, 21); - std::vector c_vec(4, 0); - std::vector args({a_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 21); -} - -TEST(LLVM, VectorizeBitCast) { - BufHandle a("A", {128}, kInt); - - Tensor c = Compute("c", {128}, [&](const VarHandle& i) { - return bitcast(a.load(i)); - }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - ASSERT_TRUE(to(to(s)->front()) == nullptr); - - LLVMCodeGen cg(s, {a, c_buf}); - - std::vector a_vec(128); - std::vector c_vec(128); - for (const auto i : c10::irange(128)) { - a_vec[i] = raw_bitcast(1337.f); - } - std::vector args({a_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 1337.f); -} - -TEST(LLVM, MemcpyTest) { - constexpr int N = 32; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - std::vector a_buffer(N, 42); - std::vector b_buffer(N, 0); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, b.store({i}, a.load(i))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, 42); - assertAllEqual(b_buffer, 42); -} - -TEST(LLVM, BzeroTest) { - constexpr int N = 32; - BufHandle b("B", {N}, kInt); - std::vector b_buffer(N, 11); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, b.store({i}, 0)); - - LLVMCodeGen cg(expr, {b}); - - std::vector args({b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(b_buffer, 0); -} - -TEST(LLVM, ElemwiseAdd) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 42); -} - -TEST(LLVM, ElemwiseAddFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 42.0f); -} - -TEST(LLVM, ElemwiseLog10Float) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - std::vector a_buffer(N, 10.0f); - std::vector b_buffer(N, 2.0f); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N / 4, - b.store( - {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)})))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, 10.0f); - assertAllEqual(b_buffer, 1.0f); -} - -TEST(LLVM, ElemwiseLog1pFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - std::vector a_buffer(N, expf(3.0f) - 1); - std::vector b_buffer(N, 42.0f); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N / 4, - b.store( - {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)})))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, expf(3.0f) - 1); - ExpectAllNear(b_buffer, 3.0f, 1e-5f); -} - -TEST(LLVM, ElemwiseMaxInt) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 41); -} - -TEST(LLVM, ElemwiseMinInt) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 1); -} - -TEST(LLVM, ElemwiseMaxFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 41.0f); -} - -TEST(LLVM, ElemwiseMaxNaNFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, NAN); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(b_buffer, 1.0f); - for (auto const& elt : c_buffer) { - ASSERT_TRUE(std::isnan(elt)); - } -} - -TEST(LLVM, ElemwiseMinFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 1.0f); -} - -TEST(LLVM, ElemwiseMinNaNFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, NAN); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(b_buffer, 1.0f); - for (auto const& elt : c_buffer) { - ASSERT_TRUE(std::isnan(elt)); - } -} - -TEST(LLVM, ElemwiseMod) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 23); - std::vector c_buffer(N, 18); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i)))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 23); - assertAllEqual(c_buffer, 18); -} - -TEST(LLVM, CompareSelectIntEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - for (int i = 0; i < N / 2; i++) { - b_buffer[i] = 0; - c_ref[i] = 0; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectFloatEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1.0f); - std::vector b_buffer(N, 1.0f); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 1); -} - -TEST(LLVM, CompareSelectByteGT) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 0); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 0); - - for (int i = 0; i < N / 2; i++) { - a_buffer[i] = 128; - c_ref[i] = 1; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGT))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(0)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteGE) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 0); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGE))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(0)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteLT) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 128); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - for (int i = 0; i < N / 2; i++) { - a_buffer[i] = 128; - c_ref[i] = 0; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLT))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(128)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteLE) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 128); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLE))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(128)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, StoreFloat) { - BufHandle result("result", {1}, kFloat); - std::vector result_buffer = {0.0f}; - auto expr = result.store({0}, FloatImm::make(3.14f)); - LLVMCodeGen cg(expr, {result}); - std::vector args({result_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(result_buffer[0], 3.14f); -} - -TEST(LLVM, SimpleMath01) { - const int N = 1024; - Tensor tensor = Compute( - "f", {N}, [](const VarHandle& i) { return cast(i * i + 1); }); - LoopNest l({tensor}); - StmtPtr stmt = l.root_stmt(); - BufHandle f_buf(tensor.buf()); - LLVMCodeGen cg(stmt, {f_buf}); - - PaddedBuffer f_v(N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(N, "f_ref"); - for (const auto i : c10::irange(N)) { - f_ref(i) = i * i + 1; - } - ExpectAllNear(f_v, f_ref, 1e-5); -} - -TEST(LLVM, ComputeMul) { - const int N = 1024; - BufHandle a("a", {N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute( - "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - - LLVMCodeGen cg(s, {a, b, c_buf}); - - std::vector a_vec(N, 21.0f); - std::vector b_vec(N, 2.0f); - std::vector c_vec(N, 0.0f); - std::vector args({a_vec.data(), b_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 42.0f); -} - -TEST(LLVM, BroadcastAdd) { - const int M = 32; - const int N = 1024; - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(j); - }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - LLVMCodeGen cg(s, {a, b, c_buf}); - - std::vector av(M * N); - std::iota(av.begin(), av.end(), 0); - std::vector bv(N); - std::iota(bv.begin(), bv.end(), 0); - std::vector cv(M * N, 0); - std::vector args({av.data(), bv.data(), cv.data()}); - ASSERT_EQ(cg.value(args), 0); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]); - } - } -} - -TEST(LLVM, BitwiseOps) { - auto a = IntImm::make(59); - auto b = IntImm::make(11); - auto c = IntImm::make(101); - auto d = IntImm::make(2); - - ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; - LLVMExprEval cg(f); - - ASSERT_EQ(cg.value(), 11); -} - -TEST(LLVM, ArithmeticRightShift) { - auto a = CharImm::make(-4); - auto b = CharImm::make(1); - ExprHandle f = a >> b; - LLVMExprEval cg(f); - ASSERT_EQ(cg.value(), -2); -} - -TEST(LLVM, LogicalRightShift) { - auto a = ByteImm::make(0xfc); - auto b = ByteImm::make(1); - ExprHandle f = a >> b; - LLVMExprEval cg(f); - ASSERT_EQ(cg.value(), 0x7e); -} - -TEST(LLVM, DynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - LLVMCodeGen cg(s, {a, b, c, n}); - std::vector args({aData.data(), bData.data(), cData.data(), &size}); - cg.value(args); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, BindDynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - LLVMCodeGen cg(s, {a, b, c, n}); - cg.call({aData, bData, cData, size}); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, TensorDynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - Tensor c = Compute( - "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - LLVMCodeGen cg(s, {a, b, c, n}); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - cg.call({aData, bData, cData, size}); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, DynamicShape2D) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - LLVMCodeGen cg(s, {a, b, c, m, n}); - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - cg.call({aData, bData, cData, M, N}); - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - }; - testWithSize(1, 8); - testWithSize(16, 32); - testWithSize(37, 11); -} - -TEST(LLVM, EmptyStmt) { - StmtPtr s = alloc(std::vector({})); - - LLVMCodeGen cg(s, {}); - cg.call({}); - // Just don't crash. -} - -TEST(LLVM, EliminatedStmt) { - BufHandle a("a", {1}, kFloat); - - Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; }); - - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - LLVMCodeGen cg(s, {a, c}); - std::vector aData(1, 1.0f); - std::vector cData(0, 0.0f); - cg.call({aData, cData}); -} - -TEST(LLVM, SimpleReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loop({b}); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(LLVM, RFactorReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loop({b}); - - std::vector loops = loop.getLoopStmtsFor(b); - ForPtr loop_m = loops.at(1); - ForPtr loop_n = loops.at(2); - loop.reorderAxis(loop_m, loop_n); - - loops = loop.getLoopStmtsFor(b); - loop_m = loops.at(2); - loop_n = loops.at(1); - auto b_body = loop.getAllWritesToBuf(b.buf())[1]; - ASSERT_TRUE(loop.rfactor(b_body, loop_n)); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(LLVM, RFactorVectorizedReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loopnest({b}); - std::vector loops = loopnest.getLoopStmtsFor(b); - // Reorder n and m loops - loopnest.reorderAxis(loops.at(1), loops.at(2)); - auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1); - auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf()); - ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); - ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); - auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]); - - // Vectorize initializer of rfac_buf - ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0])); - // Vectorize producer of rfac_buf - ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1])); - loopnest.simplify(); - - loopnest.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt()); - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -template -static void testSimpleParallel() { - // Compute a simple operation, and try all loop-axis combination to be - // parallel or sequential. - const int M = 4; - const int N = 6; - Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) { - return cast(m + n); - }); - LoopNest loop_nest({f}); - auto const& loops = loop_nest.getLoopStmtsFor(f); - ForPtr m = loops[0]; - ForPtr n = loops[1]; - if (outer) { - m->set_parallel(); - } - if (inner) { - n->set_parallel(); - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {f}); - - PaddedBuffer f_v(M, N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(M, N, "f_ref"); - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - f_ref(m, n) = m + n; - } - } - ExpectAllNear(f_v, f_ref, 1e-5); -} - -TEST(LLVM, SimpleParallelSS) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelSP) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelPS) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelPP) { - testSimpleParallel(); -} - -TEST(LLVM, CompositeParallel) { - int loop_count = 6; - int test_count = 1 << loop_count; - // Compute a composite operation, and try all loop-axis combination to be - // parallel or sequential. - for (const auto test_cfg : c10::irange(test_count)) { - int M = 5; - int N = 7; - Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; }); - Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; }); - Tensor t3 = - Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) { - return t1.load(m) * t2.load(n); - }); - Tensor t4 = - Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) { - return t3.load(m, n) + m + n; - }); - LoopNest loop_nest({t4}, {t1, t2, t3, t4}); - std::vector loop_list; - { - auto const& loops = loop_nest.getLoopStmtsFor(t1); - loop_list.push_back(loops[0]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t2); - loop_list.push_back(loops[0]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t3); - loop_list.push_back(loops[0]); - loop_list.push_back(loops[1]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t4); - loop_list.push_back(loops[0]); - loop_list.push_back(loops[1]); - } - ASSERT_EQ(loop_list.size(), loop_count); - for (const auto i : c10::irange(loop_count)) { - if (test_cfg & (1 << i)) { - loop_list[i]->set_parallel(); - } - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {t4}); - - PaddedBuffer t4_v(M, N, "t4_v"); - std::vector args({t4_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer t4_ref(M, N, "t4_ref"); - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - t4_ref(m, n) = (m + 1) * (n + 2) + m + n; - } - } - ExpectAllNear(t4_v, t4_ref, 1e-5); - } -} - -TEST(LLVM, VectorizedGEMM) { - int M = 32; - int N = 32; - int K = 48; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - LoopNest loop({CT}); - - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr m = loops[0]; - loop.splitWithMask(m, 16); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr n = loops[2]; - loop.splitWithMask(n, 16); - } - // mo, mi, no, ni, k -> - // mo, no, mi, ni, k - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[1]; - ForPtr no = loops[2]; - loop.reorderAxis(mi, no); - } - // mo, no, mi, ni, k -> - // mo, no, mi, k, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr ni = loops[3]; - ForPtr k = loops[4]; - loop.reorderAxis(ni, k); - } - // mo, no, mi, k, ni -> - // mo, no, k, mi, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[2]; - ForPtr k = loops[3]; - loop.reorderAxis(mi, k); - } - { - auto loops = NodeFinder::find(loop.root_stmt()); - ASSERT_TRUE(LoopNest::vectorize(loops[3])); - ASSERT_TRUE(LoopNest::vectorize(loops.back())); - } - - loop.prepareForCodegen(); - - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - LLVMCodeGen cg(s, {AP, BP, CT}); - - PaddedBuffer a_v(M, K, "a_v"); - PaddedBuffer b_v(K, N, "b_v"); - PaddedBuffer c_v(M, N, "c_v"); - PaddedBuffer c_ref(M, N, "c_ref"); - - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - c_ref(m, n) = 0.f; - for (const auto k : c10::irange(K)) { - c_ref(m, n) += a_v(m, k) * b_v(k, n); - } - } - } - - cg.call({a_v, b_v, c_v}); - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LLVM, CallRaw) { - const int M = 32; - VarHandle N("N", kInt); - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(j); - }); - - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - int32_t N_value = 1024; - std::vector av(M * N_value); - std::iota(av.begin(), av.end(), 0); - std::vector bv(N_value); - std::iota(bv.begin(), bv.end(), 0); - std::vector cv(M * N_value, 0); - std::vector args({av.data(), bv.data(), cv.data(), &N_value}); - - LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N}); - cg.call_raw(args); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N_value)) { - ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); - } - } - - SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N}); - eval.call_raw(args); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N_value)) { - ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); - } - } -} - -TEST(LLVM, CustomTarget) { - constexpr int M = 16; - BufHandle a("a", {M}, kFloat); - BufHandle b("b", {M}, kFloat); - BufHandle c("c", {M}, kFloat); - Tensor d = Compute("d", {M}, [&](const VarHandle& m) { - return a.load(m) * b.load(m) + c.load(m); - }); - LoopNest nest({d}); - nest.prepareForCodegen(); - auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d}) - .triple("i686-elf") - .cpu("i386") - .build(); - std::ostringstream ss; - ss << cg->getCodeText("asm"); - torch::jit::testing::FileCheck() - .check("fadds") - ->check("fmuls") - ->check_not("vfmadd") - ->run(ss.str()); -} - -TEST(LLVM, CodeGenKernelFuncName) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - auto store = b.store({0}, a.load(0)); - - { - LLVMCodeGen cg(store, {a, b}); - // Check that the kernel function name used by LLVMCodeGen - // is not empty. - ASSERT_NE(cg.kernel_func_name(), ""); - } - - { - LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func"); - // Check that the kernel function name used by LLVMCodeGen - // is the one that was given above. - ASSERT_EQ(cg.kernel_func_name(), "new_func"); - } -} - -} // namespace jit -} // namespace torch - -#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp deleted file mode 100644 index a8bda8814dba..000000000000 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ /dev/null @@ -1,6894 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -void checkIR(StmtPtr s, const std::string& pattern) { - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(pattern, oss.str()); -} - -void checkExprIR(ExprPtr e, const std::string& pattern) { - std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; - std::ostringstream oss; - oss << *e << "\n"; - torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); -} - -void checkExprIR(const ExprHandle& e, const std::string& pattern) { - checkExprIR(e.node(), pattern); -} - -TEST(LoopNest, ExprSimple01) { - Tensor tensor = - Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - LoopNest::splitWithTail(loops[0], 2); - LoopNest::splitWithTail(loops[0], 2); -} - -TEST(LoopNest, ExprLower01) { - Tensor tensor = - Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 20); - ASSERT_LT(oss.str().size(), 200); -} - -TEST(LoopNest, ExprSimple02) { - auto func = [](const ExprHandle& x, const ExprHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }; - Tensor tensor = Compute("f", {26, 5}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - LoopNest::splitWithTail(loops[0], 4); - - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 200); - ASSERT_LT(oss.str().size(), 600); - - { - // Compare to a reference loop structure structure. - VarHandle x_outer("i_outer", kInt); - VarHandle x_inner("i_inner", kInt); - VarHandle y("i", kInt); - VarHandle x_tail("i_tail", kInt); - BufHandle f("f", {26, 5}, kFloat); - ExprHandle x_1 = x_outer * 4 + x_inner; - ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; - ForPtr stmt1 = For::make( - x_outer, - 0, - x_outer_end, - For::make( - x_inner, - 0, - 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); - ExprHandle x_2 = x_tail + x_outer_end * 4; - ForPtr stmt2 = For::make( - x_tail, - 0, - (ExprHandle(26) - 0) % 4, - For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); - StmtPtr stmt = Block::make({stmt1, stmt2}); - - std::ostringstream oss_ref; - oss_ref << *stmt; - ASSERT_EQ(oss.str(), oss_ref.str()); - } - - { - PaddedBuffer f_v(26, 5, "f_v"); - PaddedBuffer f_ref(26, 5, "f_res"); - - stmt = FlattenIndexes(stmt); - SimpleIREvaluator ir_eval(stmt, {tensor}); - ir_eval(f_v); - - for (int x = 0; x < 26; x++) { - for (int y = 0; y < 5; y++) { - f_ref(x, y) = 1 + x * x + y * y; - } - } - - ExpectAllNear(f_v, f_ref, 1e-5); - } -} - -BlockPtr getSimplifiedBody(const LoopNest& l) { - StmtPtr stmt = l.root_stmt(); - StmtPtr simplified = IRSimplifier::simplify(stmt); - return to(simplified); -} - -void assertForRange(ForPtr f, int expected_start, int expected_stop) { - ASSERT_NE(f, nullptr); - IntImmPtr start = to(f->start()); - ASSERT_NE(start, nullptr); - ASSERT_EQ(start->value(), expected_start); - IntImmPtr stop = to(f->stop()); - ASSERT_NE(stop, nullptr); - ASSERT_EQ(stop->value(), expected_stop); -} - -void assertForRanges( - BlockPtr body, - const std::vector>& start_stops) { - ASSERT_EQ(body->nstmts(), start_stops.size()); - - auto it = body->begin(); - for (size_t i = 0; i < start_stops.size(); i++, it++) { - ForPtr loop = to(*it); - assertForRange(loop, start_stops[i].first, start_stops[i].second); - } -} - -TEST(LoopNest, ExprSliceHeadWithLoopOptions) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::sliceHead(loops[0], 2, &head, &tail); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {0, 8}}); - - ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); - ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - ASSERT_TRUE(head->loop_options().isDefault()); -} - -TEST(LoopNest, ExprSliceTailWithLoopOptions) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 4, &head, &tail); - - ForPtr tail_head; - ForPtr tail_tail; - tail->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); - - ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); - ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - ASSERT_TRUE(head->loop_options().isDefault()); - ASSERT_TRUE(tail_tail->loop_options().isDefault()); -} - -TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 10, &head, &tail); - - ASSERT_EQ(head, loops[0]); - ASSERT_EQ(tail, nullptr); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 100, &head, &tail); - - ASSERT_EQ(head, loops[0]); - ASSERT_EQ(tail, nullptr); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceHead) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 4, &head, &tail); - - ASSERT_NE(head, nullptr); - ASSERT_NE(head, loops[0]); - ASSERT_NE(tail, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 4}, {4, 10}}); -} - -TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceTail(loops[0], 4, &head, &tail); - // head: [0, 6) - // tail: [6, 10) - - LoopNest::sliceHead(tail, 2); - // tail_head: [6, 8) - // tail_tail: [8, 10) - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); -} - -TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 10, &head, &tail); - - ASSERT_EQ(head, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 100, &head, &tail); - - ASSERT_EQ(head, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceTail) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 4, &head, &tail); - - ASSERT_NE(head, nullptr); - ASSERT_EQ(head, loops[0]); - ASSERT_NE(tail, nullptr); - ASSERT_NE(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {6, 10}}); -} - -TEST(LoopNest, ExprSplitAndSlice) { - // 0: splitWithTail - // 1: sliceTail on inner loop - // 2: sliceHead on outer loop - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {100}, func); - LoopNest l({tensor}); - - ForPtr inner; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // outer: [0, 4) - // inner: [0, 21) - // tail: [84, 100) - LoopNest::splitWithTail(loops[0], 21, &inner, &tail); - LoopNest::sliceTail(inner, 2); - LoopNest::sliceHead(loops[0], 2); - - // for (int x_outer = 0; x_outer < 2; x_outer++) { - // for (int x_inner = 0; x_inner < 19; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // for (int x_inner = 19; x_inner < 21; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // } - // for (int x_outer = 2; x_outer < 4; x_outer++) { - // for (int x_inner = 0; x_inner < 19; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // for (int x_inner = 19; x_inner < 21; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // } - // for (int x_tail = 0; x_tail < 16; x_tail++) { - // f[x_tail + 84] = 1.f + float(x_tail + 84); - // } - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); - - auto biter = body->begin(); - - ForPtr loop = to(*biter++); - assertForRanges(loop->body(), {{0, 19}, {19, 21}}); - - loop = to(*biter); - assertForRanges(loop->body(), {{0, 19}, {19, 21}}); -} - -TEST(LoopNest, ExprSliceAndNormalize) { - // 0: sliceHead - // 1: normalize tail - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceHead(loops[0], 2, &head, &tail); - // head: [0, 2) - // tail: [2, 10) - - LoopNest::normalize(tail); - // normalized_tail: [0, 8) - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {0, 8}}); -} - -template -T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { - ExprEval eval(expr, {var}); - return eval.value(value); -} - -TEST(LoopNest, ExprSliceWithVariableDimension) { - auto testWithDimension = - [](int dimension, - const std::vector>& expected_for_ranges) { - VarHandle dim("dim", kInt); - Tensor tensor = - Compute("f", {dim}, [](const ExprHandle& x) { return x; }); - LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceHead(loops[0], 2, &head, &tail); - - LoopNest::sliceTail(tail, 2); - - BlockPtr body = getSimplifiedBody(l); - ASSERT_EQ(expected_for_ranges.size(), 3); - auto it = body->begin(); - for (auto& start_stop : expected_for_ranges) { - ForPtr loop = to(*it++); - int start = evalExpr(ExprHandle(loop->start()), dim, dimension); - int stop = evalExpr(ExprHandle(loop->stop()), dim, dimension); - ASSERT_EQ(start, start_stop.first); - ASSERT_EQ(stop, start_stop.second); - } - }; - - testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); - testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); - testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); - testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); - testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); - testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); -} - -TEST(LoopNest, ExprSplitWithTail) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {199}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - LoopNest::splitWithTail(loops[0], 17); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - LoopNest::splitWithTail(loops[0], 7); - - StmtPtr stmt = l.root_stmt(); - StmtPtr simplified = IRSimplifier::simplify(stmt); - BlockPtr body = to(simplified); - ASSERT_EQ(body->nstmts(), 3); - auto biter = body->begin(); - - // Verify that the split loops are ordered correctly. - ForPtr loop = to(*biter++); - assertForRange(loop, 0, 7); - - loop = to(*biter++); - assertForRange(loop, 0, 4); - - loop = to(*biter); - assertForRange(loop, 0, 12); -} - -TEST(LoopNest, ExprSplitWithTailNone) { - auto func = [](const ExprHandle& x, const ExprHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }; - Tensor tensor = Compute("f", {24, 5}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithTail(loops[0], 4); - - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 200); - ASSERT_LT(oss.str().size(), 600); - - { - // Compare to a reference loop structure structure. - VarHandle x_outer("i_outer", kInt); - VarHandle x_inner("i_inner", kInt); - VarHandle y("i", kInt); - VarHandle x_tail("i_tail", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - BufHandle f("f", {24, 5}, kFloat); - ExprHandle x_1 = x_outer * 4 + x_inner; - ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; - StmtPtr stmt = alloc(std::vector({For::make( - x_outer, - 0, - x_outer_end, - For::make( - x_inner, - 0, - 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); - - std::ostringstream oss_ref; - oss_ref << *stmt; - ASSERT_EQ(oss.str(), oss_ref.str()); - } - - { - PaddedBuffer f_v(24, 5, "f_v"); - PaddedBuffer f_ref(24, 5, "f_res"); - - SimpleIREvaluator ir_eval(stmt, {tensor}); - ir_eval(f_v); - - for (int x = 0; x < 24; x++) { - for (int y = 0; y < 5; y++) { - f_ref(x, y) = 1 + x * x + y * y; - } - } - - ExpectAllNear(f_v, f_ref, 1e-5); - } -} - -TEST(LoopNest, ExprSplitWithMask01) { - const int M = 26; - const int N = 5; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithMask(loops[1], 4); - - StmtPtr stmt = l.root_stmt(); - - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -// Tests the case where we split a loop cleanly multiple times, we should not -// insert any masks. -TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { - const int M = 64; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - LoopNest::splitWithMask(loops[0], 4); - - StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); - - // Two splits mean 3 loops, but should need no masks in this case. - checkIR(stmt1, R"IR( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: f[)IR"); -} - -TEST(LoopNest, getLoopAt) { - // Input IR: - // for (int i = 0; i < 100; i++) { - // for (int j = 0; j < 100; j++) { - // A[i, j] = sin(i * j); - // for (int k1 = 0; k1 < 200; k1++) { - // B[i, j, k1] = (A[i, j]) / (k1 + 1); - // } - // for (int k2 = 0; k2 < 300; k2++) { - // C[i, j, k2] = (A[i, j]) * (k2 + 1); - // } - // } - // } - BufPtr A = alloc( - "A", - std::vector({alloc(100), alloc(100)}), - kInt); - BufPtr B = alloc( - "B", - std::vector( - {alloc(100), alloc(100), alloc(200)}), - kInt); - BufPtr C = alloc( - "C", - std::vector( - {alloc(100), alloc(100), alloc(300)}), - kInt); - BufHandle a_buf(A); - BufHandle b_buf(B); - BufHandle c_buf(C); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k1("k1", kInt); - VarHandle k2("k2", kInt); - auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); - auto store2 = Store::make( - b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); - auto store3 = Store::make( - c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); - auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); - auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); - auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); - auto for_i = For::make(i, 0, 100, for_j); - LoopNest l(Block::make({for_i}), {B, C}); - auto ret_k2 = l.getLoopAt(for_i, {0, 2}); - TORCH_CHECK(ret_k2 == for_k2); - - std::ostringstream oss; - oss << *ret_k2; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int k2 -# CHECK-NEXT: C[i, j, k2] = - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, TileSimple) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 64, N = 64; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - l.tile(loops[0], loops[1], 4, 8); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: f[ -# CHECK-NOT: for (int i_tail -# CHECK-NOT: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, TileWithTails) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 64, N = 64; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - l.tile(loops[0], loops[1], 5, 9); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: f[ -# CHECK: for (int i_inner -# CHECK: f[ -# CHECK: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, TileInMiddle) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 8, N = 8, L = 8, K = 8; - BufHandle a_buf("a", {M, N, L, K}, kFloat); - BufHandle b_buf("b", {M, N, L, K}, kFloat); - Tensor tensor = Compute( - "f", - {M, N, L, K}, - [&](const ExprHandle& m, - const ExprHandle& n, - const ExprHandle& l, - const ExprHandle& k) { - return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; - }); - - LoopNest nest({tensor}); - std::vector loops = - nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - nest.tile(loops[1], loops[2], 3, 3); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: for (int i_1 -# CHECK: f[ -# CHECK: for (int i_tail_1 -# CHECK: for (int i_inner_1 -# CHECK: for (int i_1 -# CHECK: f[ -# CHECK: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, L, K, "a"); - PaddedBuffer b_v(M, N, L, K, "b"); - PaddedBuffer c_v(M, N, L, K, "c"); - PaddedBuffer c_ref(M, N, L, K, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int l = 0; l < L; l++) { - for (int k = 0; k < K; k++) { - a_v(m, n, l, k) = 2 * (m + l); - b_v(m, n, l, k) = 3 * (n + k); - c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; - } - } - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, SplitWithTailWithLoopOptions) { - const int M = 21; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - ForPtr inner, tail; - - LoopNest l({tensor}); - auto loops = NodeFinder::find(l.root_stmt()); - ASSERT_GT(loops.size(), 0); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::splitWithTail(loops[0], 4, &inner, &tail); - ASSERT_NE(inner, nullptr); - ASSERT_NE(tail, nullptr); - ForPtr outer = loops[0]; - - // Outer loop carries loop axis bindings. - ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); - ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - // Inner loop has none. - ASSERT_TRUE(inner->loop_options().isDefault()); - - // Tail loop has none. - ASSERT_TRUE(tail->loop_options().isDefault()); -} - -TEST(LoopNest, SplitWithMaskWithLoopOptions) { - const int M = 21; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - ForPtr inner; - - LoopNest l({tensor}); - auto loops = NodeFinder::find(l.root_stmt()); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::splitWithMask(loops[0], 4, &inner); - ForPtr outer = loops[0]; - - // Outer loop carries loop axis bindings. - ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); - ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - // Inner loop has none. - ASSERT_TRUE(inner->loop_options().isDefault()); -} - -TEST(LoopNest, ScheduleBroadcastAddBuffer) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - Tensor c = Compute( - "broadcast_add", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - LoopNest l({c}); - StmtPtr stmt = l.root_stmt(); - - PaddedBuffer a_v(M, N, "a_v"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 7 * m * n; - } - } - a_v.Backup(); - - PaddedBuffer b_v(N, K, "b_v"); - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - b_v(n, k) = 11 * n * k; - } - } - b_v.Backup(); - - PaddedBuffer c_v(M, N, K, "c_buf"); - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); - ir_eval(a_v, b_v, c_v); - - a_v.CheckBackup(); - b_v.CheckBackup(); - PaddedBuffer c_ref(M, N, K, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - c_ref(m, n, k) = 7 * m * n + 11 * n * k; - } - } - } - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, ScheduleFunctionCall01) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - Tensor c = Compute( - "broadcast_add", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 100); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N, K); - PaddedBuffer d_v(M, N, K); - PaddedBuffer d_ref(M, N, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - for (int k = 0; k < K; k++) { - d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); - eval(a_v, b_v, d_v); - - ExpectAllNear(d_v, d_ref, 1e-5); -} - -TEST(LoopNest, ScheduleInlineSimple) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer y_1(M, N, K); - PaddedBuffer y_2(M, N, K); - - eval1(a_v, b_v, c_v, d_v, y_1); - eval2(a_v, b_v, c_v, d_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -static std::string remove_space(const std::string& str) { - std::string str_new = str; - str_new.erase( - remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); - return str_new; -} - -void InlineFunc01Helper(const std::vector& inline_order) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - Tensor z = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + y.load(m, n, k); - }); - - LoopNest l({z}, {x, y, z}); - for (const std::string& order : inline_order) { - if (order == "x") { - l.computeInline(x.buf()); - } else if (order == "y") { - l.computeInline(y.buf()); - } else { - throw std::runtime_error("Invalid order: " + order); - } - } - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - - std::ostringstream oss; - oss << *stmt; - std::string str1 = remove_space(oss.str()); - - { - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer z_v(M, N, K); - PaddedBuffer z_ref(M, N, K); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); - eval(a_v, b_v, c_v, d_v, z_v); - ExpectAllNear(z_v, z_ref, 1e-5); - } - - if (inline_order.size() == 2) { - Tensor z2 = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k) + - (c_buf.load(m, n) * d_buf.load(m, k) + - a_buf.load(m, n) * b_buf.load(n, k)); - }); - LoopNest l2({z2}); - l2.prepareForCodegen(); - StmtPtr stmt2 = l2.root_stmt(); - - std::ostringstream oss2; - oss2 << *stmt2; - std::string str2 = remove_space(oss2.str()); - - ASSERT_EQ(str1, str2); - ASSERT_GT(str1.size(), 100); - } -} - -TEST(LoopNest, ScheduleInlineFunc01) { - InlineFunc01Helper({"x", "y"}); - InlineFunc01Helper({"y", "x"}); - InlineFunc01Helper({"x"}); - InlineFunc01Helper({"y"}); - InlineFunc01Helper({}); -} - -// Make sure we cache random vars if we should. -TEST(LoopNest, ScheduleInlineRandom) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Mod::make(Intrinsics::make(kRand, kInt), 5); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + x.load(m, n, k); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: int x = rand(); -# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); -} - -// Make sure we don't cache random vars that are not being inlined. -TEST(LoopNest, ScheduleInlineRandomUnrelated) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return m * n * k; - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + - Intrinsics::make(kRand, kInt); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR"); -} - -// Make sure we generate the right number of random values == the dimensionality -// of the production tensor. -TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute("x", {M}, [&](const VarHandle& m) { - return Mod::make(Intrinsics::make(kRand, kInt), 5); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m) + x.load(m); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: int x = rand(); -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); -} - -// Make sure we don't screw up intrinsics thinking they're rand. -TEST(LoopNest, ScheduleInlineIntrinsics) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x.load(m, n, k)); - }); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M, N, K); - PaddedBuffer y_2(M, N, K); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -// Make sure we can handle rand and non-rand intrinsics. -TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kRand, kFloat); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x.load(m, n, k)); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: float x = rand(); -# CHECK: y[i, i_1, i_2] = sqrt(x);)IR"); -} - -// Split a Compute then inline it into another compute. -TEST(LoopNest, ScheduleSplitAThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Split a Compute then inline another Compute into it. -TEST(LoopNest, ScheduleSplitBThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); - LoopNest::splitWithMask(loops[0], 3); - l.computeInline(a.buf()); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - - std::vector output(6, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Split a Compute twice then inline it. -TEST(LoopNest, ScheduleSplitTwiceThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - ForPtr i_inner; - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4, &i_inner); - LoopNest::splitWithMask(i_inner, 2); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Inline a Compute, then split. -TEST(LoopNest, ScheduleInlineThenSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - l.computeInline(a.buf()); - - std::vector loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.back(), 3); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(6, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Split a Compute, inline it, then split the result. -TEST(LoopNest, ScheduleSplitInlineThenSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - auto loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.back(), 2); - l.computeInline(a.buf()); - - loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.front(), 2); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(16, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 16; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Oversplit a loop that is simplified out after inlining. -TEST(LoopNest, ScheduleSplitInlineSimplify) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { - return ExprHandle(4) * i - ExprHandle(2) * i; - }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Inline a Compute with two consumers. -TEST(LoopNest, ScheduleInlineThreeMixedOnce) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(a.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Inline Compute A into B, then inline B into C. -TEST(LoopNest, ScheduleInlineThreeMixedTwice) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(a.buf()); - l.computeInline(b.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Inline a Compute that is both a producer and consumer. -TEST(LoopNest, ScheduleInlineThreeMixedInner) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(b.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Split 3 Computes, then inline the first two into the last. -TEST(LoopNest, ScheduleInlineThreeMixedSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); - LoopNest::splitWithMask(loops[0], 3); - loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::splitWithMask(loops[0], 2); - - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Check that inlining works for output tensors too -TEST(LoopNest, ScheduleInlineOutputTensors) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return m * n * k; - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + m; - }); - - LoopNest l1({x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; -# CHECK: for (int i_3 = 0; i_3 < 4; i_3++) -# CHECK: for (int i_4 = 0; i_4 < 5; i_4++) -# CHECK: for (int i_5 = 0; i_5 < 6; i_5++) -# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR"); -} - -TEST(LoopNest, ScheduleInlineWithCompoundIndices) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[i*2,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[0, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {static_cast(0), j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - // Inlining should fail since the producer has compound expr as index. - ASSERT_FALSE(l.computeInline(a_buf.node())); - - // The input statement must remain as is. - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t i = 0; - # CHECK-NEXT: A[ - # CHECK: for (int64_t j = 0; - # CHECK-NEXT: B[)IR"); -} - -TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[0ll,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make( - a_buf, - {static_cast(0), i}, - Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {0, j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - ASSERT_TRUE(l.computeInline(a_buf.node())); - - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t j = 0; j < 100; j++) { - # CHECK: B[0ll, j] = j * 500ll + j * 100ll; - # CHECK: })IR"); -} - -TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[(int64_t)0,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[0ll, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make(a_buf, {0, i}, Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {static_cast(0), j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - ASSERT_TRUE(l.computeInline(a_buf.node())); - - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t j = 0; j < 100; j++) { - # CHECK: B[0ll, j] = j * 500ll + j * 100ll; - # CHECK: })IR"); -} - -TEST(LoopNest, ScheduleFuserStyle) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - - Tensor b = - Compute("f", {kTotalSize}, [&](const std::vector& axes) { - return a_buf.load(axes[0]) + 11.0f; - }); - - Tensor c = - Compute("g", {kTotalSize}, [&](const std::vector& axes) { - return b.load(axes[0]) + 1.0f; - }); - - LoopNest l({b, c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - std::vector a_data(kTotalSize, 7.0f); - std::vector b_data(kTotalSize, 0.0f); - std::vector c_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); - - for (int i = 0; i < kTotalSize; i++) { - ASSERT_EQ(b_data[i], 18.0f); - ASSERT_EQ(c_data[i], 19.0f); - } -} - -TEST(LoopNest, ScheduleFuserThreeArg) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat); - - Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) { - return a.load(i) + b.load(i); - }); - Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) { - return e.load(i) + c.load(i); - }); - Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) { - return f.load(i) + d.load(i); - }); - - LoopNest l({g}, {e, f, g}); - l.computeInline(l.getLoopBodyFor(e)); - l.computeInline(l.getLoopBodyFor(f)); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - std::vector a_data(kTotalSize, 1.0f); - std::vector b_data(kTotalSize, 2.0f); - std::vector c_data(kTotalSize, 3.0f); - std::vector d_data(kTotalSize, 4.0f); - std::vector g_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); - - for (int i = 0; i < kTotalSize; i++) { - ASSERT_EQ(g_data[i], 10.0f); - } -} - -TEST(LoopNest, ScheduleDynamicShape2D) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - SimpleIREvaluator cg(s, {a, b, c, m, n}); - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - cg.call({aData, bData, cData, M, N}); - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - }; - testWithSize(1, 8); - testWithSize(16, 32); - testWithSize(37, 11); -} - -TEST(LoopNest, LoopNestComputeAt_1) { - // Verify that compute_at works on the following example: - // - // for (int i_a = 0; i_a < N; i_a++) { - // A[i_a] = i_a * i_a - // } - // for (int i_b = 0; i_b < N; i_b++) { - // B[i_b] = A[i_b] - // } - // - // After the transformation the i_b loop should have an allocation for a temp - // buffer and that buffer should be used in computation of B. No use of A - // should be in that loop after the transformation. Also, computation of A - // should not be inlined into B. Instead, it should be computed into the temp, - // and the temp should be used in B. - VarHandle N("N", kInt); - Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); - Tensor B = - Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); - LoopNest l({B}, {A, B}); - std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {B, N}); - StmtPtr s = cg.stmt(); - - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1] -# CHECK: for (int i = 0; i < N; i++) -# CHECK: temp[ -# CHECK-NOT: A[ -# CHECK: B[i_1] = temp[0] -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector b_data(100, 0); - cg.call({b_data, 100}); - - std::vector b_ref(100, 0); - for (int i = 0; i < 100; i++) { - b_ref[i] = i * i; - } - assertAllEqual(b_data, b_ref); -} - -TEST(LoopNest, LoopNestComputeAt_2) { - // Verify that compute_at works on the following example: - // - // for (int py = 0; py < H+1; py++) { - // for (int px = 0; px < W+1; px++) { - // p[py, px] = py*px - // } - // } - // for (int cy = 0; cy < H; cy++) { - // for (int cx = 0; cx < W; cx++) { - // c[py, px] = p[cy,cx] + p[cy+1,cx] + - // p[cy,cx+1] + p[cy+1,cx+1] - // } - // } - - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - Tensor p = Compute( - "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { - return px * py; - }); - Tensor c = - Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + - p.load(y + 1, x + 1); - }); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); - } - } - LoopNest orig_loopnest({c}, {p, c}); - - { - // First let's try to compute P at axis cy (the outer loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) -# CHECK: for -# CHECK: for -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) -# CHECK-NOT: prod[ -# CHECK: cons[ -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute P at axis cx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) -# CHECK: for -# CHECK: for -# CHECK-NOT: prod[ -# CHECK: cons[ -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } -} - -TEST(LoopNest, LoopNestComputeAt_3) { - // Verify that compute_at works on the following example: - // - // A(x,y) = x*y - // B(x,y) = A(x, y) - // C(x,y) = B(x+1, y) - // D(x,y) = A(x, y+1) + C(x, y) - // - // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. - - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - Tensor A = Compute( - "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { - return ax * ay; - }); - Tensor B = Compute( - "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { - return A.load(by, bx); - }); - Tensor C = - Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { - return B.load(cy, cx + 1); - }); - Tensor D = - Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { - return A.load(dy + 1, dx) + C.load(dy, dx); - }); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); - } - } - - LoopNest orig_loopnest({D}, {A, B, C, D}); - { - // First let's try to compute A at axis dy (the outer loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1, W] -# CHECK: for (int i = 0; i < H + 1; i++) -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) -# CHECK: A[ -# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) -# CHECK: B[ -# CHECK: for (int i_4 = 0; i_4 < H; i_4++) -# CHECK: for (int i_5 = 0; i_5 < W; i_5++) -# CHECK: C[ -# CHECK: for (int i_6 = 0; i_6 < H; i_6++) -# CHECK: for (int i_7 = 0; i_7 < W; i_7++) -# CHECK-NOT: A[)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute A at axis dx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1, 1] -# CHECK: for (int i = 0; i < H + 1; i++) -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) -# CHECK: A[ -# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) -# CHECK: B[ -# CHECK: for (int i_4 = 0; i_4 < H; i_4++) -# CHECK: for (int i_5 = 0; i_5 < W; i_5++) -# CHECK: C[ -# CHECK: for (int i_6 = 0; i_6 < H; i_6++) -# CHECK: for (int i_7 = 0; i_7 < W; i_7++) -# CHECK-NOT: A[)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } -} - -using Axis = const VarHandle&; - -TEST(LoopNest, Reduce2dComputeAt) { - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - - Tensor p = Compute( - "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); - Tensor c = Reduce( - "cons", - {H, W}, - Sum(), - [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, - {2, 2}); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); - } - } - LoopNest orig_loopnest({c}, {p, c}); - checkIR(orig_loopnest.root_stmt(), R"IR( -# CHECK: for (int i = 0; i < H + 1; i++) { -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { -# CHECK: prod[i, i_1] = i_1 * i; -# CHECK: } -# CHECK: } -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) { -# CHECK: cons[i_2, i_3] = int(0); -# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { -# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { -# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - - { - // First let's try to compute P at axis cy (the outer loop) - LoopNest l(orig_loopnest); - auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); - // FIXME: Calling simplify here breaks the IR: - // MALFORMED INPUT: could not find base node in Load - temp[...] - // l.simplify(); - l.eliminateDeadStores(); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] -# CHECK: for (int i = 0; i < H; i++) { -# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { -# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); -# CHECK: } -# CHECK: } -# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); -# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: Free(temp); -)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute P at axis cx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); - l.simplify(); - l.eliminateDeadStores(); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] -# CHECK: for (int i = 0; i < H; i++) { -# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { -# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { -# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); -# CHECK: } -# CHECK: } -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; -# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: Free(temp); -)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - assertAllEqual(c_data, c_ref); - } -} - -TEST(LoopNest, DISABLED_Conv1d_NH) { - // Lots of stuff is broken here. The computeAt swaps the axes for some odd - // reason. Even without that, the index flattener fails due to "dimensions - // mismatch in flatten index". - - int N = 4; - int H = 256; - int R = 3; - int Pad = 1; - BufHandle IP("input", {H}, kFloat); - - Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) { - auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); - cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); - return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); - }); - Tensor B = Reduce( - "B", - {N, H}, - Sum(), - [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, - {R}); - LoopNest l({B}); - checkIR(l.root_stmt(), R"IR( -# CHECK: for (int np = 0; np < 4; np++) { -# CHECK: for (int hp = 0; hp < 258; hp++) { -# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); -# CHECK: } -# CHECK: } -# CHECK: for (int n = 0; n < 4; n++) { -# CHECK: for (int h = 0; h < 256; h++) { -# CHECK: B[n, h] = float(0); -# CHECK: for (int r = 0; r < 3; r++) { -# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - // FIXME: The current IR is totally broken. The body of the inlined loop is: - - // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), - // 0.f, input[idx1 + 0, (idx0 + n) - 1]); - - // Which seems to mix up the axes. The CHECK below is my best guess at what - // the input "should" look like - - checkIR(l.root_stmt(), R"IR( -# CHECK: for (int n = 0; n < 4; n++) { -# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { - temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); -# CHECK: } -# CHECK: } -# CHECK: for (int h = 0; h < 256; h++) { -# CHECK: B[n, h] = float(0); -# CHECK: for (int r = 0; r < 3; r++) { -# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - - l.simplify(); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - SimpleIREvaluator cg(s, {IP, B}); - // auto At = at::ones({N, H}, at::kFloat); - auto At = at::arange(N * H, at::kFloat).reshape({N, H}); - auto Rt = at::conv1d( - At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); - auto Bt = at::empty_like(Rt); - cg.call({At.data_ptr(), Bt.data_ptr()}); - ASSERT_TRUE(at::allclose(Rt, Bt)); -} - -class LoopOrderHelper : public IRVisitor { - std::stringstream ordering; - - public: - std::string getOrder(StmtPtr s) { - ordering.str(""); - s->accept(this); - return ordering.str(); - } - - void visit(const ForPtr& v) final { - ordering << v->var()->name_hint() << ","; - IRVisitor::visit(v); - } -}; - -TEST(LoopNest, LoopNestReorderAxis1) { - Tensor tensor = - Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector stmt1_output(6, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - ASSERT_NE(stmt1, stmt2); - LoopOrderHelper loopOrderHelper; - std::string order1 = loopOrderHelper.getOrder(stmt1); - std::string order2 = loopOrderHelper.getOrder(stmt2); - - ASSERT_EQ(order1, "j,i,"); - ASSERT_EQ(order2, "i,j,"); - - std::vector stmt2_output(6, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg.call({stmt2_output}); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } - - // Reorder them back. - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - StmtPtr stmt3 = l.root_stmt(); - - std::string order3 = loopOrderHelper.getOrder(stmt3); - ASSERT_EQ(order3, order1); - - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt3; - - // Should be identical to the unreordered statement. - ASSERT_EQ(oss1.str(), oss2.str()); -} - -TEST(LoopNest, LoopNestReorderPartialAxes) { - Tensor tensor = Compute( - "f", - {2, 3, 4}, - [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,"); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,"); - - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } - - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[1], loops[2]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,"); - - StmtPtr stmt3 = Stmt::clone(l.root_stmt()); - - std::vector stmt3_output(24, 0); - SimpleIREvaluator cg3(stmt3, {tensor}); - cg3.call({stmt3_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt3_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderInternalAxis) { - Tensor tensor = Compute( - "f", - {1, 2, 3, 4}, - [](const VarHandle& w, - const VarHandle& x, - const VarHandle& y, - const VarHandle& z) { - return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,"); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[2], loops[1]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,"); - - StmtPtr stmt2 = l.root_stmt(); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderEnclosingAxis) { - Tensor tensor = Compute( - "f", - {1, 2, 3, 4}, - [](const VarHandle& w, - const VarHandle& x, - const VarHandle& y, - const VarHandle& z) { - return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[3]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,"); - - StmtPtr stmt2 = l.root_stmt(); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderSameAxis) { - Tensor tensor = - Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[1], loops[1]); - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::ostringstream oss, oss2; - oss << *stmt1; - oss2 << *stmt2; - ASSERT_EQ(oss.str(), oss2.str()); -} - -TEST(LoopNest, LoopNestReorderExtraStatements) { - /* We're going for a structure like this: - * for i in ... - * Stmt 1 - * for j in ... - * Stmt 2 - * for k in ... - * Stmt 3 - * Stmt 4 - */ - - Tensor tensor = Compute( - "f", - {2, 3, 4}, - [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - BufHandle extra("res", {6, 3}, kFloat); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - VarHandle i = VarHandle(loops[0]->var()); - - StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); - StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); - // stmt 3 is the Function body. - StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); - - loops[0]->body()->prepend_stmt(store_1); - loops[1]->body()->prepend_stmt(store_2); - loops[1]->body()->append_stmt(store_3); - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector extra1(6, 0); - std::vector res1(24, 0); - SimpleIREvaluator cg(stmt1, {tensor, extra}); - cg.call({res1, extra1}); - - /* Then we reorder loop y and z, we want it to look like: - * - * for i in ... - * Stmt 1 - * for j in ... - * Stmt 2 - * for j_1 in ... - * for k in ... - * Stmt 3 - * for j_2 in ... - * Stmt 4 - * - * We need extra loops because we don't have dependency info about stmt 3 - * and 4. - * - */ - - LoopNest::reorderAxis(loops[1], loops[2]); - StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - // Check the IR we produced - checkIR(stmt2, R"IR( -# CHECK: for -# CHECK: res[i, 0] = 1 -# CHECK: for -# CHECK: res[i, 1] = 2 -# CHECK: for -# CHECK: for -# CHECK: f[ -# CHECK: for -# CHECK: res[i, 2] = 4 -)IR"); - - std::vector extra2(6, 0); - std::vector res2(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor, extra}); - cg2.call({res2, extra2}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(res1[i], res2[i]); - } - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(extra1[i], extra2[i]); - } - - /* Now reorder x and the y above stmt 3: - * - * - * for x in ... - * Stmt 1 - * for y in ... - * Stmt 2 - * - * for y in ... - * for z in ... - * for x in ... - * Stmt 3 - * - * for x in ... - * for y in ... - * Stmt 4 - * - * - */ - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[2]); - StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - // Check the IR we produced - checkIR(stmt3, R"IR( -# CHECK: for -# CHECK: res[i, 0] = 1 -# CHECK: for -# CHECK: res[i, 1] = 2 -# CHECK: for -# CHECK: for -# CHECK: for -# CHECK: f[ -# CHECK: for -# CHECK: for -# CHECK: res[i_2, 2] = 4 -)IR"); - - std::vector extra3(6, 0); - std::vector res3(24, 0); - SimpleIREvaluator cg3(stmt3, {tensor, extra}); - cg3.call({res3, extra3}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(res1[i], res3[i]); - } - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(extra1[i], extra3[i]); - } -} - -void LoopNestReorderTestHelper( - bool prepend, - bool append, - int index1, - int index2) { - Tensor c = Compute( - "5d", {2, 3, 2, 3, 2}, [](const std::vector&) { return -1; }); - LoopNest l({c}); - - BufHandle extra("extra", {5}, kInt); - - auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - int j = 0; - for (auto l : loops) { - // Add an increment at each layer of the loop which counts the number of - // times the loop executes. - LoadPtr load = - alloc(extra.node(), std::vector({alloc(j)})); - AddPtr add = alloc(load, alloc(1)); - StmtPtr store = alloc( - extra.node(), std::vector({alloc(j)}), add); - if (prepend) { - l->body()->prepend_stmt(store); - } - if (append) { - l->body()->append_stmt(Stmt::clone(store)); - } - - j++; - } - - StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - - std::vector extra1(5, 0); - std::vector res1(2 * 3 * 2 * 3 * 2, 0); - SimpleIREvaluator cg(stmt1, {c, extra}); - cg.call({res1, extra1}); - - std::vector loopExtents = {2, 3, 2, 3, 2}; - - int expected_loops = 0; - if (prepend) { - expected_loops++; - } - if (append) { - expected_loops++; - } - for (int i = 0; i < 5; ++i) { - expected_loops *= loopExtents[i]; - ASSERT_EQ(extra1[i], expected_loops); - } - - loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::reorderAxis(loops[index1], loops[index2]); - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::ostringstream oss, oss2; - oss << *stmt1; - oss2 << *stmt2; - ASSERT_NE(oss.str(), oss2.str()); - - std::vector extra2(5, 0); - std::vector res2(2 * 3 * 2 * 3 * 2, 0); - SimpleIREvaluator cg2(stmt2, {c, extra}); - cg2.call({res2, extra2}); - - expected_loops = 0; - if (prepend) { - expected_loops++; - } - if (append) { - expected_loops++; - } - - for (int i = 0; i < 5; ++i) { - expected_loops *= loopExtents[i]; - ASSERT_EQ(extra2[i], expected_loops); - } - - for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { - ASSERT_EQ(res2[i], res1[i]); - } -} - -TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(true, false, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(false, true, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderLongStringFull) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(true, true, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderInternalLoopNest) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - Tensor z = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + y.load(m, n, k); - }); - - LoopNest l({z}, {x, y, z}); - ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; - ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; - LoopNest::reorderAxis(a, b); - - l.prepareForCodegen(); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - - // Check the IR we produced has the 3 nests in the right order, but k and m - // swapped in the middle. - checkIR(stmt, R"IR( -# CHECK: < 4 -# CHECK: < 5 -# CHECK: < 6 -# CHECK: < 6 -# CHECK: < 5 -# CHECK: < 4 -# CHECK: < 4 -# CHECK: < 5 -# CHECK: < 6)IR"); - - { - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer z_v(M, N, K); - PaddedBuffer z_ref(M, N, K); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); - eval(a_v, b_v, c_v, d_v, z_v); - ExpectAllNear(z_v, z_ref, 1e-5); - } -} - -TEST(LoopNest, OuterLoopVectorization) { - Tensor tensor = - Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - - ASSERT_TRUE( - LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); - - StmtPtr root_stmt = l.root_stmt(); - BlockPtr outer_block = to(root_stmt); - ASSERT_NE(outer_block, nullptr); - while (BlockPtr inner_block = to(outer_block->front())) { - outer_block = inner_block; - } - - // Verify that we have only a single loop level remaining after - // vectorization. - ASSERT_EQ(outer_block->nstmts(), 1); - ForPtr for_loop = to(outer_block->front()); - ASSERT_NE(for_loop, nullptr); - BlockPtr for_body = for_loop->body(); - ASSERT_EQ(for_body->nstmts(), 1); - ASSERT_EQ(to(for_body->front()), nullptr); -} - -TEST(LoopNest, VectorizeLoopNotNormalized) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 1; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 1, 5, for_body); - auto outer_for = For::make(i, 0, 10, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - ASSERT_TRUE(LoopNest::vectorize(inner_for)); - ASSERT_EQ(outer_for->body()->nstmts(), 1); - ASSERT_EQ(to(outer_for->body()->front()), nullptr); -} - -namespace { - -std::string constantUpperBoundLoopIR(int upper_bound_val) { - ExprHandle upper_bound(upper_bound_val); - Tensor A = - Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - std::ostringstream oss; - oss << *unrolled; - return oss.str(); -} - -} // namespace - -TEST(LoopNest, Unroll) { - const std::string actual = constantUpperBoundLoopIR(3); - const std::string& verification_pattern = - R"IR( -# CHECK: A[0] = 0; -# CHECK: A[1] = 2; -# CHECK: A[2] = 4)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, actual); -} - -TEST(LoopNest, UnrollOuter) { - ExprHandle outer_bound(3); - ExprHandle inner_bound(4); - Tensor A = Compute( - "A", - {outer_bound, inner_bound}, - [&](const VarHandle& x, const VarHandle& y) { return x + y; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - checkIR(unrolled, R"IR( -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[0, i] = i; -# CHECK: } -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[1, i] = i + 1; -# CHECK: } -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[2, i] = i + 2; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollInner) { - ExprHandle outer_bound(3); - ExprHandle inner_bound(4); - Tensor A = Compute( - "A", - {outer_bound, inner_bound}, - [&](const VarHandle& x, const VarHandle& y) { return x + y; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll( - static_to(loops[0]->body()->stmts().front()), &unrolled); - checkIR(loops[0], R"IR( -# CHECK: for (int i = 0; i < 3; i++) { -# CHECK: A[i, 0] = i; -# CHECK: A[i, 1] = i + 1; -# CHECK: A[i, 2] = i + 2; -# CHECK: A[i, 3] = i + 3; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollMultipleStatements) { - const int kTotalSize = 3; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle x("x", kInt); - auto f = For::make( - x, - 0, - kTotalSize, - Block::make( - {Store::make(a_buf, {x}, x * 2), - Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); - auto parent_block = Block::make({f}); - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(f, &unrolled); - checkIR(unrolled, R"IR( -# CHECK: A[0] = 0; -# CHECK: B[0] = A[0]; -# CHECK: A[1] = 2; -# CHECK: B[1] = A[1]; -# CHECK: A[2] = 4 -# CHECK: B[2] = A[2];)IR"); -} - -TEST(LoopNest, UnrollNonLiteralConstantBounds) { - // Input IR: - // for (int i = 2 - 1; i < 12 / 3; i++) { - // for (int j = 0; j < 4; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {3, 4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 4, for_body); - auto outer_for = For::make( - i, - IntImm::make(2) - IntImm::make(1), - IntImm::make(12) / IntImm::make(3), - inner_for); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto b = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - checkIR(unrolled, R"IR( -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[1, j] = j; -# CHECK: } -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[2, j] = 2 * j; -# CHECK: } -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[3, j] = 3 * j; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollNonConstantBounds) { - // Input IR: - // for (int i = 0; i < M; i++) { - // for (int j = 0; j < N; j++) { - // A[i, j] = i * j; - // } - // } - VarHandle M("M", kInt); - VarHandle N("N", kInt); - BufHandle a_buf("A", {M, N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, N, for_body); - auto outer_for = For::make(i, 0, M, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - LoopNest::unroll(inner_for, 8); - l.simplify(); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { - # CHECK: A[i, 8 * j_outer] = - # CHECK: A[i, 8 * j_outer + 1] = - # CHECK: A[i, 2 * (4 * j_outer + 1)] = - # CHECK: A[i, 8 * j_outer + 3] = - # CHECK: A[i, 4 * (2 * j_outer + 1)] = - # CHECK: A[i, 8 * j_outer + 5] = - # CHECK: A[i, 8 * j_outer + 6] = - # CHECK: A[i, 8 * j_outer + 7] = - # CHECK: } - # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { - # CHECK: A[i, 8 * (N / 8) + j_tail] = - # CHECK: } - # CHECK: } - )IR"); -} - -TEST(LoopNest, UnrollByFactorsLessThan2) { - // Input IR: - // for (int i = 0; i < M; i++) { - // for (int j = 0; j < N; j++) { - // A[i, j] = i * j; - // } - // } - VarHandle M("M", kInt); - VarHandle N("N", kInt); - BufHandle a_buf("A", {M, N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, N, for_body); - auto outer_for = For::make(i, 0, M, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - // Unrolling by factor = 1 should do nothing. - LoopNest::unroll(inner_for, 1); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); - - // Unrolling by factor = 0 should do nothing. - LoopNest::unroll(inner_for, 0); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); - - // Unrolling by negative factor should do nothing. - LoopNest::unroll(inner_for, -2); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); -} - -TEST(LoopNest, UnrollByFactorEqualToIters) { - // Input IR: - // for (int i = 0; i < 5; i++) { - // A[i] = i * i; - // } - BufHandle a_buf("A", {5}, kInt); - VarHandle i("i", kInt); - auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); - auto for_loop = For::make(i, 0, 5, for_body); - auto block = Block::make({for_loop}); - LoopNest l(block, {a_buf.node()}); - - LoopNest::unroll(for_loop, 5); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) - # CHECK: A[5 * i_outer] - # CHECK: A[5 * i_outer + 1] - # CHECK: A[5 * i_outer + 2] - # CHECK: A[5 * i_outer + 3] - # CHECK: A[5 * i_outer + 4] - )IR"); -} - -TEST(LoopNest, UnrollEmpty) { - const std::string actual = constantUpperBoundLoopIR(0); - const std::string& verification_pattern = R"IR( -# CHECK-NOT: A[ - )IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, actual); -} - -TEST(LoopNest, NoUnroll) { - VarHandle upper_bound("N", kInt); - Tensor A = - Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - ASSERT_THROWS_WITH( - LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); -} - -TEST(LoopNest, UnrollWithLet) { - const int kTotalSize = 3; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle e("e", kInt); - VarHandle x("x", kInt); - auto f = For::make( - x, - 0, - kTotalSize, - Block::make( - {Let::make(e, 7), - Store::make(a_buf, {x}, e), - Store::make(b_buf, {x}, e + 1)})); - auto parent_block = Block::make({f}); - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(f, &unrolled); - std::ostringstream oss; - oss << *unrolled; - const std::string& verification_pattern = - R"IR( -# CHECK: int e = 7; -# CHECK: A[0] = e; -# CHECK: B[0] = e + 1; -# CHECK: A[1] = e; -# CHECK: B[1] = e + 1; -# CHECK: A[2] = e; -# CHECK: B[2] = e + 1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector a_v(kTotalSize, 0); - std::vector b_v(kTotalSize, 0); - SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); - eval(a_v, b_v); - for (int i = 0; i < kTotalSize; ++i) { - ASSERT_EQ(a_v[i], 7); - ASSERT_EQ(b_v[i], 8); - } -} - -TEST(LoopNest, IsNormalized) { - // Input IR: - // for (int i = 50; i < 100; i++) { - // A[i] = B[i]; - // } - BufHandle a_buf("A", {ExprHandle(100)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto for_stmt = - For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); - Block::make({for_stmt}); - ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); - - for_stmt->set_start(alloc(0)); - ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); - - VarHandle N("N", kInt); - for_stmt->set_start(N.node()); - ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); -} - -TEST(LoopNest, NormalizeStartPositive) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - const int kTotalSize = 50; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, 50, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 50; x++) { - # CHECK: A[x + 50] = B[x + 50]; - # CHECK: B[x + 50] = 2 * (x + 50); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartNegative) { - // Input IR: - // for (int x = -50; x < 100; x++) { - // A[x + 50] = B[x + 50]; - // B[x + 50] = x * 2; - // } - const int kTotalSize = 150; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), - Store::make(b_buf, {x + 50}, x * 2)}); - auto for_stmt = For::make(x, -50, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 150; x++) { - # CHECK: A[x] = B[x]; - # CHECK: B[x] = 2 * (x - 50); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartZero) { - // Input IR: - // for (int x = 0; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - // Should not be modified. - - const int kTotalSize = 100; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, 0, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 100; x++) { - # CHECK: A[x] = B[x]; - # CHECK: B[x] = 2 * x; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartVariable) { - // Input IR: - // for (int x = y; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - - const int kTotalSize = 100; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, y, 100, for_body); - auto parent_block = Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 100 - y; x++) { - # CHECK: A[x + y] = B[x + y]; - # CHECK: B[x + y] = 2 * (x + y); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeOnNestedOuterLoop) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // for (int y = 10; y < 100; y++) { - // A[x] = A[x] + B[y] + y * 2; - // } - // } - - BufHandle a_buf("A", {ExprHandle(50)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto inner_for_body = Store::make( - a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); - auto inner_for = For::make(y, 10, 100, inner_for_body); - auto for_stmt = For::make(x, 50, 100, inner_for); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 50; x++) { - # CHECK: for (int y = 10; y < 100; y++) { - # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeOnNestedInnerLoop) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // for (int y = 10; y < 100; y++) { - // A[x] = A[x] + B[y] + y * 2; - // } - // } - - BufHandle a_buf("A", {ExprHandle(50)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto inner_for_body = Store::make( - a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); - auto inner_for = For::make(y, 10, 100, inner_for_body); - auto for_stmt = For::make(x, 50, 100, inner_for); - Block::make({for_stmt}); - - LoopNest::normalize(inner_for); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 50; x < 100; x++) { - # CHECK: for (int y = 0; y < 90; y++) { - # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeAndSplitWithTail) { - // Create a dummy tensor to construct LoopNest. - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - // Input IR: - // for (int x = 5; x < 10; x++) { - // A[x] = x * 2; - // } - const int kTotalSize = 5; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); - auto parent_block = Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - ForPtr x_inner; - ForPtr x_tail; - LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); - - auto x_outer_result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss_outer; - oss_outer << *x_outer_result; - const std::string& expected_outer_ir = - R"IR( - # CHECK: { - # CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); - - auto x_tail_result = IRSimplifier::simplify(x_tail); - std::ostringstream oss_tail; - oss_tail << *x_tail_result; - const std::string& expected_tail_ir = - R"IR( - # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { - # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); - )IR"; - torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); -} - -TEST(LoopNest, NotNormalizeAndSplitWithTail) { - // Create a dummy tensor to construct LoopNest. - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - // Input IR: - // for (int x = 5; x < 15; x++) { - // A[x] = x * 2; - // } - const int kTotalSize = 10; - BufHandle a_buf("A", {kTotalSize}, kInt); - VarHandle x("x", kInt); - auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); - auto parent_block = Block::make({for_stmt}); - - ForPtr x_inner; - ForPtr x_tail; - LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); - - auto x_outer_result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss_outer; - oss_outer << *x_outer_result; - const std::string& expected_outer_ir = - R"IR( - # CHECK: { - # CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); - - auto x_tail_result = IRSimplifier::simplify(x_tail); - std::ostringstream oss_tail; - oss_tail << *x_tail_result; - const std::string& expected_tail_ir = - R"IR( - # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { - # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); - )IR"; - torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); -} - -TEST(LoopNest, FlattenSimpleLoopNest2D) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 5, for_body); - auto outer_for = For::make(i, 0, 10, inner_for); - auto parent_block = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { - # CHECK: A[i_flat / 5, i_flat % 5] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenSimpleLoopNest3D) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // for (int k = 0; k < 7; k++) { - // A[i,j,k] = i + j * k; - // } - // } - // } - BufHandle a_buf("A", {10, 5, 7}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); - auto for1 = For::make(k, 0, 7, for_body); - auto for2 = For::make(j, 0, 5, for1); - auto for3 = For::make(i, 0, 10, for2); - auto parent_block = Block::make({for3}); - - std::vector loops = {for3, for2, for1}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { - # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5, 7); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5, 7); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenLoopNestAfterNormalize) { - // Input IR: - // for (int i = 2; i < 10; i++) { - // for (int j = 3; j < 15; j++) { - // A[i - 2,j - 3] = i * j; - // } - // } - BufHandle a_buf("A", {8, 12}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); - auto inner_for = For::make(j, 3, 15, for_body); - auto outer_for = For::make(i, 2, 10, inner_for); - auto parent_block = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { - # CHECK: A[i_flat / 12, i_flat % 12] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(8, 12); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(8, 12); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { - // Input IR: - // for (int i = 0; i < 15-5; i++) { - // for (int j = 0; j < 20/4; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = - For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); - auto outer_for = - For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto b = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - checkIR(result, R"IR( - # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { - # CHECK: A[i_flat / 5, i_flat % 5] = - )IR"); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenImperfectLoopNest) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // A[i, i] = 0; - // for (int j = 0; j < 15; j++) { - // A[i,j] = i * j; - // } - // } - // Do not flatten. - - BufHandle a_buf("A", {10, 15}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 15, for_body); - auto outer_for = For::make( - i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); - auto par = Block::make({outer_for}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenReductionLoopNest) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // S[i] = 0; - // for (int j = 0; j < 15; j++) { - // S[i] = S[i] + A[i,j]; - // } - // } - // Do not flatten. - - BufHandle a_buf("A", {10, 15}, kInt); - BufHandle s_buf("S", {10}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make( - s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); - auto inner_for = For::make(j, 0, 15, for_body); - auto outer_for = - For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); - auto par = Block::make({outer_for}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenReductionLoopNestFromTensor) { - const int M = 3; - const int N = 7; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle b("b", {m, n}, kFloat); - Tensor c = Reduce("sum", {M}, Sum(), b, {N}); - LoopNest loop({c}); - HashProvider hasher; - auto hash_before = hasher.hash(loop.root_stmt()); - - auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(loop.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenIncorrectLoopsAsInput) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - // for (int x = 0; x < 10; x++) { - // for (int y = 0; y < 5; y++) { - // A[x,y] = A[x,y] + x + y; - // } - // } - // Flatten({For_i, For_y}) => should not succeed - - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - auto par = Block::make({outer_for1, outer_for2}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for1, inner_for2}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, DetectInlineRankMismatch) { - const int kTotalSize = 8; - - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - Tensor a = Compute( - "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); - Tensor reshape = Compute( - "reshape", - {kTotalSize / 2, 2}, - [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); - LoopNest l({reshape}, {a, reshape}); - ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); -} - -TEST(LoopNest, CacheReadsSimple) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 3); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - // just this once: verify the whole thing. - checkIR(result, R"IR( -#CHECK: Allocate(A); // dtype=int, dims=[64, 64] -#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] -#CHECK: for (int i -#CHECK: for (int j -#CHECK: A[ -#CHECK: } -#CHECK: } -#CHECK: for (int i_1 -#CHECK: for (int j_1 -#CHECK: A_local[j_1] = A[ -#CHECK: } -#CHECK: for (int j_2 -#CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; -#CHECK: } -#CHECK: } -#CHECK: for (int i_2 -#CHECK: for (int j_3 -#CHECK: C[ -#CHECK: } -#CHECK: } -#CHECK: Free(A_local); -#CHECK: Free(A); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 3); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsOuter) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; - LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] -#CHECK: A_local[j_1 + 11 * i_1] = -#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsInternal) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] -#CHECK: A_local[k + 11 * j_1] = -#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsInner) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - // note im changing the offset of the first arg of the first call to A. - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr body = l.getLoopBodyFor(B); - LoopNest::cacheAccesses(A.buf(), "A_local", body); - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] -#CHECK: A_local[l + 2 * k] = -#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheWritesSimple) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] -#CHECK: for (int j = 0; j < 64 -#CHECK: A_local[j] = i * j; -#CHECK: for (int j_1 = 0; j_1 < 64 -#CHECK: A[j_1 + 64 * i] = A_local[ -#CHECK: Free(A_local); -#CHECK-NOT: A_local - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, DeadStoreElimination) { - VarHandle y("y", kInt); - VarHandle x("x_tail", kInt); - BufHandle f("f", {26, 5}, kInt); - BufHandle g("g", {26, 5}, kInt); - ExprHandle x_outer_end = 5; - ExprHandle x_2 = x + x_outer_end * 4; - ForPtr stmt1 = For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - Block::make({ - Store::make(f, {x_2, y}, (x_2 + y)), - Store::make(g, {x_2, y}, (x_2 * y)), - }))); - StmtPtr stmt = Block::make({stmt1}); - - // Will eliminate if not used by an output. - LoopNest loop(Stmt::clone(stmt), {f.node()}); - loop.eliminateDeadStores(); - - checkIR(loop.root_stmt(), R"IR( -#CHECK: f[x_tail + 5 * 4, y] -#CHECK-NOT: g[x_tail + 5 * 4, y] - )IR"); - - // But won't eliminate if used by different outputs. - LoopNest loop2(stmt, {f.node(), g.node()}); - loop2.eliminateDeadStores(); - - checkIR(loop2.root_stmt(), R"IR( -#CHECK: f[x_tail + 5 * 4, y] -#CHECK: g[x_tail + 5 * 4, y] - )IR"); -} - -TEST(LoopNest, DeadStoreEliminationWithIntermediates) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - BufHandle f("f", {26 * 5}, kInt); - BufHandle g("g", {26 * 5}, kInt); - BufHandle h("h", {26, 5}, kInt); - ExprHandle x_outer_end = 5; - ExprHandle x_2 = x + x_outer_end * 4; - ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); - ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); - ForPtr stmt3 = For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - Block::make({ - Store::make(h, {x, y}, Load::make(f, {x * y})), - }))); - StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); - - // Will eliminate the write to g, but not f since it used by the producer of - // h. - LoopNest loop(Stmt::clone(stmt), {h.node()}); - loop.eliminateDeadStores(); - - checkIR(loop.root_stmt(), R"IR( - #CHECK: f[x] = x; - #CHECK-NOT: g[z] = - #CHECK: h[x, y] = f[x * y]; - )IR"); - - // Sanity check won't eliminate if g is an output. - LoopNest loop2(stmt, {h.node(), g.node()}); - loop2.eliminateDeadStores(); - - checkIR(loop2.root_stmt(), R"IR( - #CHECK: f[x] = x; - #CHECK: g[z] = z + 1; - #CHECK: h[x, y] = f[x * y]; - )IR"); -} - -TEST(LoopNest, CompoundTensorSimple) { - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - BlockPtr body = Block::make({outer_for1, outer_for2}); - - Tensor A = Tensor(a_buf.node(), body); - - LoopNest l({A}); - l.prepareForCodegen(); - - std::vector a_data(50, 0); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg(s, {A}); - - std::vector a_ref(50, 0); - - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 5; ++j) { - a_ref[i * 5 + j] = (i * j) + i + j; - } - } - cg.call({a_data}); - - assertAllEqual(a_data, a_ref); -} - -TEST(LoopNest, InlineConstantIndex) { - const int N = 10; - BufHandle x_buf("a", {1, N, 1}, kFloat); - Tensor y = Compute( - "f", - {1, N, 1}, - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return x_buf.load(m, n, o); - }); - Tensor z = Compute( - "f", - {1, N, 1}, - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return y.load(m, n, o); - }); - - LoopNest l({z}, {y, z}); - l.simplify(); - ASSERT_TRUE(l.computeInline(y.buf())); -} - -TEST(LoopNest, CompoundTensorUsed) { - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - BlockPtr body = Block::make({outer_for1, outer_for2}); - - Tensor A = Tensor(a_buf.node(), body); - Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j + 1) + A.load(i, j + 2); - }); - - LoopNest l({B}, {A, B}); - ASSERT_FALSE(l.computeInline(A.buf())); - l.prepareForCodegen(); - - std::vector a_data(50, 0); - std::vector b_data(50, 0); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg(s, {B}); - - std::vector b_ref(50, 0); - - auto AT = [](int i, int j) { return i * j + i + j; }; - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 3; ++j) { - b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); - } - } - cg.call({b_data}); - - assertAllEqual(b_data, b_ref); -} - -TEST(LoopNest, InlineFromLoad) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); - auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); - LoopNest l(Block::make({store_a, store_b}), {b.node()}); - - l.computeInline(a.node()); - - // Check that A[j] is replaced with j after inlining - std::ostringstream oss; - oss << *l.root_stmt(); - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: for (int j -# CHECK-NOT: B[j] = A[j] -# CHECK-NEXT: B[j] = j -)IR", - oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsSimple) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {15}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 15 -# CHECK-NEXT: A[i + 5] = C[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsNestedConditions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK: for (int i = 0; i < 10 -# CHECK-NEXT: A[i + 10] = D[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsMultipleStores) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // } - // for (int j = 0; j < 100; j++) { - // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto storeA = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, storeA); - auto storeB = Store::make( - b_buf, - {j}, - IfThenElse::make( - CompareSelect::make(j, 30, kLT), - Load::make(c_buf, {j}), - Load::make(d_buf, {j}))); - auto forJ = For::make(j, 0, 100, storeB); - auto par = Block::make({forI, forJ}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 15 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK: for (int j = 0; j < 30 -# CHECK-NEXT: B[j] = C[j] -# CHECK: for (int j = 0; j < 70 -# CHECK-NEXT: B[j + 30] = D[j + 30] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { - // Input IR: - // for (int i = 0; i < 50; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) - // } - // Only the first conditional, in the write to A, will be optimized. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {100}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto storeA = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - auto storeB = Store::make( - b_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 30, kLT), - Load::make(c_buf, {i}), - Load::make(d_buf, {i}))); - auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK-NEXT: B[i] = C[i] -# CHECK: for (int i = 0; i < 45 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) - // } - // } - // Currently, this case where the condition variable `i` is not the - // inner-most loop variable, is not optimized. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) - // } - // No optimization should be done here because one of the conditions use '>'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - VarHandle N("N", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, N, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsInvalidCondition) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) - // } - // No optimization should be done here because one of the conditions use '>'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kGT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(10 colReduce(int M, int N) { - BufHandle a("a", {M, N}, kFloat); - Tensor t = Reduce( - "b", - {N}, - Sum(), - [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, - {M}); - return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; -} - -static StmtPtr splitTailReorder(Tensor b) { - constexpr int kVectorWidth = 8; - LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; - nest.splitWithTail(loops[0], kVectorWidth); - // Now the loopnests will look like: - // - // for (int i_outer = 0; ... - // for (int i_inner = 0; ... - // b[i_outer * 8 + i_inner] = float(0); - // for (int j = 0; ... - // b[i_outer * 8 + i_inner] = ReduceOp(...); - // - // for (int i_tail = 0; ... - // b[i_tail + ((100 - 0) / 8) * 8] = float(0); - // for (int j = 0; ... - // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); - // - // Since there are 4 writes to b, we will get 4 loopnests from the - // call to `getAllLoopNestsWritingToBuf` below. - // - // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" - // Loopnest #2: {i_outer, i_inner, j}; - // We will have to reorder i_inner and j. - auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); - LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); - nest.prepareForCodegen(); - return nest.root_stmt(); -} - -static StmtPtr splitMaskReorder(Tensor b) { - constexpr int kVectorWidth = 8; - LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; - nest.splitWithMask(loops[0], kVectorWidth); - loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; - LoopNest::reorderAxis(loops[1], loops[2]); - nest.prepareForCodegen(); - return nest.root_stmt(); -} - -static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { - int M = immediateAs(p.dim(0)); - int N = immediateAs(p.dim(1)); - PaddedBuffer a(M, N); - PaddedBuffer b(N); - PaddedBuffer ref(N); - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a(i, j) = 1.0f; - } - } - for (int i = 0; i < N; i++) { - b(i) = 0.0f; - } - for (int i = 0; i < N; i++) { - ref(i) = 76.0f; - } - SimpleIREvaluator(s, {p, t}).call({a, b}); - ExpectAllNear(b, ref, 1e-5); -} - -TEST(LoopNest, ColReduceSplitTailEvenReorder) { - constexpr int M = 76, N = 128; - auto p = colReduce(M, N); - StmtPtr s = splitTailReorder(p.second); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i_outer -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int j -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitTailUnevenReorder) { - constexpr int M = 76, N = 100; - auto p = colReduce(M, N); - StmtPtr s = splitTailReorder(p.second); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i_outer -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int j -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int i_tail -# CHECK-NEXT: b[ -# CHECK-NEXT: for (int j -# CHECK-NEXT: b[ - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitMaskEvenReorder) { - constexpr int M = 76, N = 128; - auto p = colReduce(M, N); - StmtPtr s = splitMaskReorder(p.second); - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { - constexpr int M = 76, N = 100; - auto p = colReduce(M, N); - StmtPtr s = splitMaskReorder(p.second); - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ReorderAxisWithMultipleConds) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // if i > 5 { - // if i < 10 { - // for (int j = 0; j < 100; j++) { - // A[i] = i * j; - // } - // } - // } - // } - BufHandle a_buf("A", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); - auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); - auto outer_cond = - Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); - auto forI = For::make(i, 0, 20, outer_cond); - StmtPtr par = Block::make({forI}); - LoopNest l(par, {a_buf.node()}); - LoopNest::reorderAxis(forI, forJ); - ASSERT_EQ(par, l.root_stmt()); - par = IRSimplifier::simplify(par); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: for (int i -# CHECK-NEXT: if (i>5 -# CHECK-NEXT: if (i<10 -# CHECK-NEXT: A[i] = i * j -# CHECK-NOT: for ( - )IR"; - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, VectorizeUse) { - constexpr int N = 8; - BufHandle a("a", {N}, kFloat); - Tensor b = - Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); - Tensor c = - Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); - LoopNest nest({c}, {b, c}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; - ASSERT_TRUE(LoopNest::vectorize(loops[0])); - loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; - ASSERT_TRUE(LoopNest::vectorize(loops[0])); - nest.prepareForCodegen(); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - StmtPtr s = nest.root_stmt(); - std::ostringstream oss; - oss << *nest.root_stmt(); - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: c[Ramp -)IR", - oss.str()); -} - -const char* int64Loop = R"IR( -# CHECK: for (int64_t i = 0ll; i < 12ll; i++) { -# CHECK: b[i] = (a[i]) + 1ll; -# CHECK: } -)IR"; - -TEST(LoopNest, Int64Direct) { - constexpr int64_t N = 12; - BufHandle a("a", {N}, kLong); - BufHandle b("b", {N}, kLong); - VarHandle n("i", kLong); - StmtPtr s = For::make( - n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(int64Loop, oss.str()); -} - -TEST(LoopNest, Int64Compute) { - constexpr int64_t N = 12; - BufHandle a("a", {N}, kLong); - Tensor b = Compute("b", {N}, [&](const VarHandle& n) { - return a.load(n) + LongImm::make(1l); - }); - LoopNest nest({b}); - nest.prepareForCodegen(); - nest.simplify(); - std::ostringstream oss; - oss << *nest.root_stmt(); - torch::jit::testing::FileCheck().run(int64Loop, oss.str()); -} - -TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI, {forJ}); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopWithoutAnyPivot) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopOverInnerLoops) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { - // Input IR: - // for (int m = 0; m < 50; m++) { - // for (int i = 0; i < 20; i++) { - // A[m,i] = 0; - // for (int j = 0; j < 100; j++) { - // A[m,i] = A[m,i] + i * j; - // } - // B[m,i] = A[m,i]; - // for (int k = 0; k < 50; k++) { - // B[m,i] = B[m,i] + i * k; - // } - // } - // } - BufHandle a_buf("A", {100, 100}, kInt); - BufHandle b_buf("B", {100, 100}, kInt); - VarHandle m("m", kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m, i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, - {m, i}, - Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, - {m, i}, - Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - - { - // Check the case of distributing loop and its parents over all the - // statements in the loop. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: A[m, i] = 0 -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m, i] = -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: B[m, i] = A[m, i] -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m, i] = -# CHECK-NOT: for ( - )IR"; - - auto newForI = to(Stmt::clone(forI)); - auto forM = For::make(m, 0, 50, newForI); - auto par = Block::make({forM}); - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto newLoops = LoopNest::distributeLoopAndParents(newForI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(newLoops.front(), forM); - } - - { - // Check the case of distributing loop and its parents over all the inner - // loops. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: A[m, i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m, i] = -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: B[m, i] = A[m, i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m, i] = -# CHECK-NOT: for ( - )IR"; - - auto newForI = to(Stmt::clone(forI)); - auto forM = For::make(m, 0, 50, newForI); - auto par = Block::make({forM}); - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(newLoops.front(), forM); - } -} - -TEST(LoopNest, fuseLoopsSimple) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsMultiple) { - // Input IR: - // for (int i = 0; i < 100; i++) { - // A[i+100] = 20 + i; - // } - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forI = - For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forI, forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i + 100] = -# CHECK-NEXT: A[i] = -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsNested) { - // Input IR: - // for (int m = 0; m < 20; m++) { - // A[m] = 0; - // for (int j = 0; j < 100; j++) { - // A[m] = A[m] + m * j; - // } - // } - // for (int n = 0; n < 20; n++) { - // B[n] = A[n]; - // for (int k = 0; k < 50; k++) { - // B[n] = B[n] + n * k; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); - auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); - auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); - auto forN = For::make(n, 0, 20, Block::make({initB, forK})); - auto par = Block::make({forM, forN}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: A[m] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m] = -# CHECK: B[m] = A[m] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forM); -} - -TEST(LoopNest, fuseLoopsNested2D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // B[m,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto forI = For::make( - i, - 0, - 20, - For::make( - j, - 0, - 100, - Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); - auto forM = For::make( - m, - 0, - 20, - For::make( - n, - 0, - 50, - Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int n -# CHECK-NEXT: B[i, n] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsNested2DInner) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // for (int n = 0; n < 100; n++) { - // B[i,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle n("n", kInt); - auto forJ = For::make( - j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); - auto forN = For::make( - n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); - auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); - - std::ostringstream oss; - oss << *forI; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK-NEXT: B[i, j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsDifferentStopBounds) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 50; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsDifferentStartBounds) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 50; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsNotContiguous) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // B[0] = 0; - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto initB = Store::make(b_buf, {0}, 0); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, initB, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithDifferentParents) { - // Input IR: - // for (int i = 0; i < 50; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j; - // } - // } - // B[0] = 0; - // for (int k = 50; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {50, 100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); - auto forI = For::make(i, 0, 50, forJ); - auto initB = Store::make(b_buf, {0}, 0); - auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI, initB, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithVariableBounds) { - // Input IR: - // for (int j = 0; j < N; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithExprBounds) { - // Input IR: - // for (int j = 0; j < M + N; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < M + N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { - // Input IR: - // for (int j = M; j < N * 2; j++) { - // A[j] = 10 * j; - // } - // for (int k = M; k < N + N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k+100] = 30 * k - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); - auto par = Block::make({forJ, forK}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: A[j + 100] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // A[m+20,n+100] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int n -# CHECK-NEXT: A[i + 20, n + 100] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithReductions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0 - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + B[i,j]; - // } - // } - // for (int m = 0; m < 20; m++) { - // C[m] = A[m]; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - BufHandle c_buf("C", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto sumA = Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); - auto forJ = For::make(j, 0, 100, sumA); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); - auto forM = - For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = (A[i]) + -# CHECK-NOT: for ( -# CHECK: C[i] = A[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWith2DReductions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 50; j++) { - // A[i,j] = 0 - // for (int k = 0; k < 100; k++) { - // A[i,j] = A[i,j] + B[i,j,k]; - // } - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 40; n++) { - // C[m,n] = A[m,n]; - // } - // } - BufHandle a_buf("A", {20, 50}, kInt); - BufHandle b_buf("B", {20, 50, 100}, kInt); - BufHandle c_buf("C", {20, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto initA = Store::make(a_buf, {i, j}, 0); - auto sumA = Store::make( - a_buf, - {i, j}, - Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); - auto forK = For::make(k, 0, 100, sumA); - auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); - auto forI = For::make(i, 0, 20, forJ); - auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK-NEXT: for (int k -# CHECK-NEXT: A[i, j] = (A[i, j]) + -# CHECK: for (int n -# CHECK-NEXT: C[i, n] = A[i, n] -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithComplexIndices) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,j*20+j+2] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[m,n*20+n+2]; - // } - // } - BufHandle a_buf("A", {20, 400}, kInt); - BufHandle b_buf("B", {20, 400}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = - Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j -# CHECK: for (int n -# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,i*20+j] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[m,m*20+n]; // Both indices of A use m - // } - // } - BufHandle a_buf("A", {20, 500}, kInt); - BufHandle b_buf("B", {20, 500}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithTranspose) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,j] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[n,m]; // Transpose - // } - // } - BufHandle a_buf("A", {20, 20}, kInt); - BufHandle b_buf("B", {20, 20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, j}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies1) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k-1] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies2) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k+50] = 20 * k; - // } - BufHandle a_buf("A", {150}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies3) { - // Input IR: - // for (int m = 0; m < 20; m++) { - // A[m] = 0; - // for (int j = 0; j < 100; j++) { - // A[m] = A[m] + m * j; - // } - // } - // for (int n = 0; n < 20; n++) { - // B[n] = A[n+1]; - // for (int k = 0; k < 50; k++) { - // B[n] = B[n] + n * k; - // } - // } - BufHandle a_buf("A", {25, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); - auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); - auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); - auto forN = For::make(n, 0, 20, Block::make({initB, forK})); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forM, forN}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies4) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // A[m+1,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {30, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto forI = For::make( - i, - 0, - 20, - For::make( - j, - 0, - 100, - Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); - auto forM = For::make( - m, - 0, - 20, - For::make( - n, - 0, - 50, - Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies5) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // for (int n = 0; n < 100; n++) { - // A[i,n+1] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle n("n", kInt); - auto forJ = For::make( - j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); - auto forN = For::make( - n, - 0, - 100, - Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies6) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * A[99-k]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies7) { - // Input IR: - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * A[99-k]; - // } - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forK, forJ}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); -} - -TEST(LoopNest, areLoopsPerfectlyNested) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Specifying the loops in any other order fails. - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); - - // Adding a statement to forK body should be OK. - auto init = Store::make(a_buf, {i, j}, 0); - forK->body()->insert_stmt_before(init, store); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Adding a statement in forJ body should fail this test. - forK->body()->remove_stmt(init); - forJ->body()->insert_stmt_before(init, forK); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Similarly, adding a statement in forI body should fail this test. - forJ->body()->remove_stmt(init); - forI->body()->insert_stmt_before(init, forJ); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); -} - -TEST(LoopNest, reorderNestedLoops2D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); - auto forJ = For::make(j, 0, 30, store); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); - - ASSERT_EQ(reordered[0], forJ); - ASSERT_EQ(reordered[1], forI); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); - ASSERT_EQ(forJ->get_parent(), par); - ASSERT_EQ(store->get_parent(), forI->body()); -} - -TEST(LoopNest, reorderNestedLoops3D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); - - ASSERT_EQ(reordered[0], forK); - ASSERT_EQ(reordered[1], forI); - ASSERT_EQ(reordered[2], forJ); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); - ASSERT_EQ(forK->get_parent(), par); - ASSERT_EQ(store->get_parent(), forJ->body()); -} - -TEST(LoopNest, reorderNestedLoops4D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // for (int l = 0; l < 50; l++) { - // A[i,j,k,l] = i * j * k * l * 500; - // } - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle l("l", kInt); - auto store = Store::make( - a_buf, - {i, j, k, l}, - Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); - auto forL = For::make(l, 0, 50, store); - auto forK = For::make(k, 0, 40, forL); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); - - ASSERT_EQ(reordered[0], forK); - ASSERT_EQ(reordered[1], forI); - ASSERT_EQ(reordered[2], forL); - ASSERT_EQ(reordered[3], forJ); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); - ASSERT_EQ(forK->get_parent(), par); - ASSERT_EQ(store->get_parent(), forJ->body()); -} - -TEST(LoopNest, reorderTrivialPermutation) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); - - ASSERT_EQ(reordered[0], forI); - ASSERT_EQ(reordered[1], forJ); - ASSERT_EQ(reordered[2], forK); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - ASSERT_EQ(forI->get_parent(), par); - ASSERT_EQ(store->get_parent(), forK->body()); -} - -TEST(LoopNest, reorderInvalidPermutations) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), - "invalid permutation size"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 2}), - "invalid permutation size"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), - "invalid permutation for reorder"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), - "invalid permutation for reorder"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), - "invalid permutation for reorder"); -} - -TEST(LoopNest, reorderInvalidLoopNest) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // A[i,j] = 0 - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - - // Specifying the loops in incorrect order fails. - ASSERT_THROWS_WITH( - LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); - - // Adding a statement to forJ loop fails. - auto init = Store::make(a_buf, {i}, 0); - forJ->body()->insert_stmt_before(init, forK); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); - - // Moving that statement to forI loop also fails. - forJ->body()->remove_stmt(init); - forI->body()->insert_stmt_before(init, forJ); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); -} - -TEST(LoopNest, compressBufferSimple) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferMultipleDims) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // B[i,j] = A[i,j] + A[i,j] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); - auto store2 = Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); - auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); - auto forI = For::make(i, 0, 100, forJ); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, 0] = -# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); -} - -TEST(LoopNest, compressBufferMultipleDims2) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // for (int k = 0; k < 300; ++k) { - // A[i,j,k] = sin(i*j*k) - // } - // for (int k = 0; k < 299; ++j) { - // B[i,j,k] = A[i,j,k] + A[i,j,k+1] - // } - // } - // } - BufHandle aBuf("A", {100, 200, 300}, kInt); - BufHandle bBuf("B", {100, 200, 300}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); - auto forK1 = For::make(k, 0, 300, store1); - auto store2 = Store::make( - bBuf, - {i, j, k}, - Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); - auto forK2 = For::make(k, 0, 299, store2); - auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); - auto forI = For::make(i, 0, 100, forJ); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: for (int k -# CHECK-NEXT: A[0, 0, k] = -# CHECK: for (int k -# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 3); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); -} - -TEST(LoopNest, compressBufferDifferentOrderIndices) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[j, i] = sin(i*j) - // } - // for (int j = 0; j < 99; ++j) { - // B[i, j] = A[j, i] + A[j+1, 0] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 99, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[j, 0] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); -} - -TEST(LoopNest, compressBufferVariableBounds) { - // Input IR: - // for (int i = 0; i < M; ++i) { - // for (int j = 0; j < N; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int j = 0; j < N-1; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - N - 1, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferNoCommonParentLoops) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // } - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - auto forI1 = For::make(i, 0, 100, forJ1); - auto forI2 = For::make(i, 0, 100, forJ2); - auto par = Block::make({forI1, forI2}); - LoopNest::compressBuffer(aBuf.node(), par); - - // There should be no change in the buffer or code. - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferIndicesMixed) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i + j, j] = sin(i*j) - // } - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i + j, j] + A[i + j, j+1] - // } - // } - BufHandle aBuf("A", {300, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make( - Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - // There should be no change in the buffer or code. - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i + j, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressMultipleBuffers) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int k = 0; k < 199; ++k) { - // B[i,k] = A[i,k] + A[i, k+1] - // } - // for (int m = 0; m < 50; ++m) { - // C[i,m] = B[i,m] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - BufHandle cBuf("C", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forK = For::make( - k, - 0, - 199, - Store::make( - bBuf, - {i, k}, - Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); - auto forM = - For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); - auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); - auto par = Block::make({forI}); - - // This should compress all buffers A, B, and C as follows: - // A[100, 200] -> A[1, 200] - // B[100, 200] -> B[1, 200] - // C[100, 200] -> C[1, 1] - LoopNest::compressAllBuffers(par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int k -# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) -# CHECK: for (int m -# CHECK-NEXT: C[0, 0] = B[0, m] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); - ASSERT_EQ(bBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); - ASSERT_EQ(cBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); -} - -TEST(LoopNest, sanitizeNames) { - std::vector dim_args; - // Let's pick names that would overlap with default index names if not - // sanitized properly: - dim_args.emplace_back(ExprHandle(alloc("i", kInt))); - dim_args.emplace_back(ExprHandle(alloc("N:2", kInt))); - // Now let's create a many dimensions so that we had to use the same letter - // for different loops - for (int i = 0; i < 10; i++) { - dim_args.emplace_back(ExprHandle(alloc("N", kInt))); - } - - // Now create two Computes with conflicting after sanitization names: - Tensor X = Compute("$X:!", dim_args, [&](const std::vector& v) { - return v[0] + v[1] + v[9] + 1; - }); - Tensor Y = Reduce( - "%X\"+", - {}, - Sum(), - [&](const std::vector& v) { return X.load(v); }, - dim_args); - - // Finally, let's verify what we got after sanitization: - LoopNest l({X, Y}); - StmtPtr s = l.root_stmt(); - LoopNest::sanitizeNames(s); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < i_1; i++) { -# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { -# CHECK-NEXT: for (int k = 0; k < N_9; k++) { -# CHECK-NEXT: for (int l = 0; l < N_8; l++) { -# CHECK-NEXT: for (int m = 0; m < N_7; m++) { -# CHECK-NEXT: for (int n = 0; n < N_6; n++) { -# CHECK-NEXT: for (int o = 0; o < N_5; o++) { -# CHECK-NEXT: for (int p = 0; p < N_4; p++) { -# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { -# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { -# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { -# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { -# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; -# CHECK: v_X___1 = int(0); -# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { -# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { -# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { -# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { -# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { -# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { -# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { -# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { -# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { -# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { -# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { -# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { -# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp deleted file mode 100644 index 5db84eab1f50..000000000000 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ /dev/null @@ -1,3252 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -// Test helper function used to determine if two regions of a buffer have an -// overlap. No Overlap & partial overlap is obvious. Contains means A is -// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal -// ranges are ContainedOrEqual. -TEST(MemDependency, BoundOverlap) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - // Sanity check 3 overlap cases. - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); - - // Partial overlap works in either order. - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); - - // Total Overlap works when one bound encloses the other, and returns which. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); - - // Total overlap works when the bounds are an identical range, returns - // ContainedOrEqual. - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); - - // Total overlap when only one end of the bound matches. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); - - // No overlap when a < b. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); - - // No overlap when a > b. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); - - // No overlap when adjacent. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); - - // Partial overlap when middle bounds match. - ASSERT_EQ( - OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); - ASSERT_EQ( - OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); - - // Total overlap when one bound is single length over one end of the other. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); -} - -TEST(MemDependency, BoundComparison) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ)); - - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE)); - - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT)); - - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE)); - - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT)); - - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE)); -} - -TEST(MemDependency, BoundOverlapSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - VarHandle w("w", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - // Sanity check cases where the start and end is symbolic but the diff is - // constant. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); - ASSERT_EQ( - OverlapKind::PartialOverlap, - boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); - - // We can't infer the sign of y, so cannot tell whether adding y is larger or - // smaller than y/2. - ASSERT_EQ( - OverlapKind::PartialOverlap, - boundOverlap(CB(x, x + y), CB(x, x + y / 2))); - - // No information about this bound, have to take the most conservative option: - // there may be an overlap. - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); - - // Math on opaque terms works. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); - // Even requiring simplification. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); -} - -// Tests the helper function for overlap of multi dimensional indices bounds. -// This uses boundOverlap on each dimension and return the "lowest" kind of -// overlap. -TEST(MemDependency, BoundOverlapMultiDim) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - // Sanity check one dimensional cases. - ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); - ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); - ASSERT_EQ( - OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); - - // Total overlap in 3 dims. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); - - // Total overlap in 2 dims, no overlap in another. - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); - - // Total overlap in 2 dims, partial overlap in another. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); - // This case is most important, so verify the overlap in any dim. (dim 2) - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); - // Dim 1. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); - // Total overlap in 1 dim, partial in 2. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); - // Total overlap, partial overlap, no overlap. - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); - - // Total overlap (B) in 2 dims, total overlap (A) in another. - ASSERT_EQ( - OverlapKind::Contains, - overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); - - // Total overlap (A) in 2 dims, total overlap (B) in another. - ASSERT_EQ( - OverlapKind::Contains, - overlaps( - {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); - - // Total (B), No Overlap, Total (A). - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); -} - -// Test the helper we use to subtract bounds: returns the regions(s) of A which -// remain after removing the region of B. -TEST(MemDependency, BoundSubtract) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - // One element subtract. - ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); - ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); - - // No Overlap. - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); - - // one side overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); - ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); - ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); - - // both sides overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); - - // internal overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); -} - -TEST(MemDependency, BoundSubtractSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - VarHandle w("w", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - // One element subtract. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); - - // Subtract constant range low. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); - // Subtract constant range high. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); - // Subtract constant range total overlap. - ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); - // Subtract constant range internal. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), - {CB(x, x + 2), CB(x + 8, x + 10)})); - - // Size is inferable but not constant, only works with a single var. - ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); - - // Size is not inferable. - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); -} - -// Tests the helper function that does subtraction, but for multi dimensional -// indices bounds. -TEST(MemDependency, BoundSubtractMultiDim) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](std::vector x, std::vector y) { - if (x.size() != y.size()) { - return false; - } - for (auto i = 0U; i < x.size(); ++i) { - if (!indexBoundsEquals(x[i], y[i])) { - return false; - } - } - return true; - }; - - // sanity check one dimension. - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); - - // Multi dim total overlap. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); - - // Multi dim one way partial in dim 1. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), - {{CB(4, 9), CB(0, 2)}})); - - // Multi dim one way partial in dim 2. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), - {{CB(0, 9), CB(11, 20)}})); - - // Partial overlap in 2 dims. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), - {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); - - // Partial overlap in 3 dims. - ASSERT_TRUE( - EQ(subtractIndicesBounds( - {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), - {{CB(0, 1), CB(0, 5), CB(0, 5)}, - {CB(2, 5), CB(0, 1), CB(0, 5)}, - {CB(2, 5), CB(2, 5), CB(0, 1)}})); -} - -// Tests the multi dimensional subtraction code for bounds that cannot be fully -// materialized. -TEST(MemDependency, BoundSubtractMultiDimSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](std::vector x, std::vector y) { - if (x.size() != y.size()) { - return false; - } - for (auto i = 0U; i < x.size(); ++i) { - if (!indexBoundsEquals(x[i], y[i])) { - return false; - } - } - return true; - }; - - // Cannot determine overlaps. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); - - // Various total Overlaps. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); - - // one-way overlap in first dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), - {{CB(x - 4, x), CB(0, y)}})); - // second dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), - {{CB(0, x), CB(0, 4)}})); - - // Internal overlap in first dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), - {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); - // second dim. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), - {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); - - // Overlap in both dimensions. - ASSERT_TRUE( - EQ(subtractIndicesBounds( - {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), - { - {CB(0, 4), CB(0, y)}, - {CB(x - 4, x), CB(0, y)}, - {CB(0, x), CB(0, 9)}, - {CB(0, x), CB(y - 9, y)}, - })); -} - -// Simple check that the analyzer does anything at all... -TEST(MemDependency, MemDependencyCheckerSimple) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, bStore}); - - stmt->accept(&analyzer); - - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); - // sanity check, but anything that depends directly must depend indirectly. - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); -} - -// Check that there is a difference between direct and indirect dependence. -TEST(MemDependency, MemDependencyCheckerMultiStmt) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * B[0] = A[0]; - * C[0] = B[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, bStore, cStore}); - - stmt->accept(&analyzer); - - // C depends on A indirectly. - ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); - ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); - - // C depends on B directly, which depends on A directly. - ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - - // Dependency goes top to bottom only. - ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); -} - -// Verify that we do filter writes that are totally overlapped by later writes. -TEST(MemDependency, MemDependencyCheckerOverlap) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * A[0] = 6; - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr a2Store = Store::make(a, {0}, 6); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, a2Store, bStore}); - - stmt->accept(&analyzer); - - // B store depends on second A store but not first since it is completely - // overlapped. - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); - ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); - - // No dependency between either A store. - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); - ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); -} - -// Verify that bounds match loop iterations, and that dependencies progress -// across loop scopes. -TEST(MemDependency, MemDependencyCheckerLoop) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * for (int x = 0; x < 10; ++x) { - * A[x] = x; - * } - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {x}, x); - StmtPtr loop = For::make(x, 0, 10, aStore); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); - - StmtPtr stmt = Block::make({loop, bStore}); - - stmt->accept(&analyzer); - - // Same A->B dependency. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop but does not depend on any loop iteration. - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); - - auto aStoreAccess = analyzer.accessFor(aStore); - ASSERT_NE(aStoreAccess, nullptr); - - // It should have bounds covering the range of x: 0 <= x < 10. - ASSERT_TRUE(indexBoundsEquals( - aStoreAccess->bounds(), {Bound(alloc(0), alloc(9))})); -} - -// Reductions should promote dependencies as well. -TEST(MemDependency, MemDependencyCheckerLoopReduce) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * A[0] = 0; - * for (int x = 0; x < 10; ++x) { - * A[0] = A[x] + 1; - * } - * B[0] = A[0]; - */ - - StorePtr aInit = Store::make(a, {0}, 0); - ExprHandle reduce = Sum()(a, 1, {x}, {x}); - StorePtr aReduce = Store::make(a, {0}, reduce); - StmtPtr loop = For::make(x, 0, 10, aReduce); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - - StmtPtr stmt = Block::make({aInit, loop, bStore}); - - stmt->accept(&analyzer); - - // B -> A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); - - // B depends indirectly on the initializer of A, since the reduction depends - // on it. - ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); - - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop and depends on other iterations. - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); - - // The loop contents depend on the initializer too. - ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); - - // Find loads within the reduction: - auto reduceLoads = NodeFinder::find(reduce.node()); - // Pull out the access for the load inside the loop. - for (auto load : reduceLoads) { - auto loopLoad = analyzer.accessFor(load); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); - } -} - -// Lowering a reduction doesn't affect dependency analysis. -TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * A[0] = 0; - * for (int x = 0; x < 10; ++x) { - * A[0] = A[x] + 1; - * } - * B[0] = A[0]; - */ - - StorePtr aInit = Store::make(a, {0}, 0); - ExprHandle aLoad = Load::make(a, {x}); - StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); - StmtPtr loop = For::make(x, 0, 10, aReduce); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - - StmtPtr stmt = Block::make({aInit, loop, bStore}); - - stmt->accept(&analyzer); - - // B -> A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); - - // B depends indirectly on the initializer of A, since the reduction depends - // on it. - ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); - - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop and depends on other iterations. - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); - - // The loop contents depend on the initializer too. - ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); - - // Pull out the access for the store inside the loop. - auto loopLoad = analyzer.accessFor(aLoad.node()); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); -} - -// Can determine dependencies of outputs, through to inputs. -TEST(MemDependency, MemDependencyCheckerInputsOutputs) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - // initialize analyzer with inputs and outputs. - analysis::MemDependencyChecker analyzer({a}, {b}); - - // Here's a Relu. - /* - * for (int x = 0; x < 10; ++x) { - * B[x] = Max(A[x], 0); - * } - */ - - ExprHandle aLoad = Load::make(a, {x}); - StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); - StmtPtr loop = For::make(x, 0, 10, bStore); - - StmtPtr stmt = Block::make({loop}); - - stmt->accept(&analyzer); - - // Output depends indirectly on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - // aLoad depends directly on the input A. - ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); - // bStore therefore depends directly on the input A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); - // The output depends directly on the store. - ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); - - // Check AccessInfo based overloads. - auto input = analyzer.input(a.node()); - auto output = analyzer.output(b.node()); - - // Output depends indirectly on input. - ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); - // Not directly. - ASSERT_FALSE(analyzer.dependsDirectly(output, input)); - // Not in reverse order. - ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); - - // output -> bStore -> bLoad -> input. - auto storeAccess = analyzer.accessFor(bStore); - auto loadAccess = analyzer.accessFor(aLoad.node()); - - ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); - ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); -} - -// Can tell if an output does not depend on an input. -TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - // initialize analyzer with inputs and outputs. - analysis::MemDependencyChecker analyzer({a}, {b}); - - // Here's a dumb Relu. - /* - * for (int x = 0; x < 10; ++x) { - * B[x] = Max(x, 0); - * } - */ - - StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); - StmtPtr loop = For::make(x, 0, 10, bStore); - - StmtPtr stmt = Block::make({loop}); - - stmt->accept(&analyzer); - - // Output does not depend indirectly on input. - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); - - // The output still depends directly on the store. - ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); - - // Check AccessInfo based overloads. - auto input = analyzer.input(a.node()); - auto output = analyzer.output(b.node()); - - // Output does not depend indirectly on input. - ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); -} - -// Verify different loop extents produce accesses with different bounds, and -// that later accesses find dependencies that overlap their entire bound range. -TEST(MemDependency, MemDependencyCheckerLoopBounds) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - using namespace analysis; - - MemDependencyChecker analyzer({a}, {c}); - - // This enables using the execution order of the loops to determine if some - // loops are self dependent or not. - analyzer.allowLoopExecutionOrderAnalysis(); - - /* - * for (int x = 1; x < 10; ++x) { - * B[x] = A[x]; - * } - * for (int x = 1; x < 9; ++x) { - * B[x] = B[x] * 2; - * } - * for (int x = 3; x < 4; ++x) { - * C[x] = A[x]; - * } - * for (int x = 0; x < 10; ++x) { - * C[x] = B[x]; - * } - */ - - std::vector stmts( - {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), - For::make( - x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), - For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); - - StmtPtr stmt = Block::make(stmts); - - stmt->accept(&analyzer); - - auto input = analyzer.input(a.node()); - auto output = analyzer.output(c.node()); - - // sanity check Output -> Input. - ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); - - // Check the For loop dependencies: - - // Last write to C depends on both writes to B since they contain the last - // write to at least one element. - ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); - ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); - - // The last write to C does not depend on the other write to C. - ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - /* 0. Input: A[(0, 9)] - dependents: 1 5 - * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 - * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 - * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 - * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 - * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 - * 6. Store: C[(3, 3)] - depends on: 5 - * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 - * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 - * 9. Output: C[(0, 9)] - depends on: 8 - */ - - // Now let's look at the bounds of each access. - // There are 9 accesses in this Stmt, so this is exhaustive, we won't do this - // much. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 10); - VarPtr aVar = a.node()->base_handle(); - VarPtr bVar = b.node()->base_handle(); - VarPtr cVar = c.node()->base_handle(); - - // The first access is the input A. - ASSERT_EQ(history[0]->type(), AccessType::Input); - ASSERT_EQ(history[0]->var(), aVar); - // It has the bounds of the producing Input. - ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); - // sanity check the input we retrieved earlier matches. - ASSERT_EQ(history[0], input); - - // The second access is the load of A in the first loop. - ASSERT_EQ(history[1]->type(), AccessType::Load); - ASSERT_EQ(history[1]->var(), aVar); - // It has the bounds of the loop, i.e. start == 1. - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); - // It reads from A, so it should have a dependency on the last write to this - // range - with is the input. - ASSERT_EQ(history[1]->dependencies().size(), 1); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - // The third access is the store into B in the first loop. - ASSERT_EQ(history[2]->type(), AccessType::Store); - ASSERT_EQ(history[2]->var(), bVar); - // It also has the bounds of the loop, i.e. start == 1. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); - // The previous load is in its RHS, so it depends on it. - ASSERT_EQ(history[2]->dependencies().size(), 1); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The third access is the load from B in the second loop. - ASSERT_EQ(history[3]->type(), AccessType::Load); - ASSERT_EQ(history[3]->var(), bVar); - // It has the bounds of the second loop, i.e. >= 1 < 9. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); - // It reads from B in a smaller range, so should depend on the previous - // store. - ASSERT_EQ(history[3]->dependencies().size(), 1); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The fourth: the store to B in the second loop. - ASSERT_EQ(history[4]->type(), AccessType::Store); - ASSERT_EQ(history[4]->var(), bVar); - // It also has the bounds of the second loop. - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); - // The previous load is in its RHS, so it depends on it as before. - ASSERT_EQ(history[4]->dependencies().size(), 1); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // The fifth access is the load is from the 3rd loop, and skips previous B - // accesses. - ASSERT_EQ(history[5]->type(), AccessType::Load); - ASSERT_EQ(history[5]->var(), aVar); - // It has the bounds of the third loop: >= 3 < 4. - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); - // It depends on the last thing to write to A, which is the A input. - ASSERT_EQ(history[5]->dependencies().size(), 1); - ASSERT_TRUE(history[5]->hasDependency(history[0])); - - // Sixth: the store into the output C. - ASSERT_EQ(history[6]->type(), AccessType::Store); - ASSERT_EQ(history[6]->var(), cVar); - // It also has the bounds of the third loop. - ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); - // The previous load is in its RHS, so it depends on it as always. - ASSERT_EQ(history[6]->dependencies().size(), 1); - ASSERT_TRUE(history[6]->hasDependency(history[5])); - - // The seventh access is the load of B in the fourth loop. - ASSERT_EQ(history[7]->type(), AccessType::Load); - ASSERT_EQ(history[7]->var(), bVar); - // It has the bounds of the final loop, >= 0 < 10 - ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); - // The bounds of this read are larger than the bounds of the previous write, - // so it depends on both previous Stores to B. - ASSERT_EQ(history[7]->dependencies().size(), 2); - ASSERT_TRUE(history[7]->hasDependency(history[2])); - ASSERT_TRUE(history[7]->hasDependency(history[4])); - - // Eight: the final store into the output C. - ASSERT_EQ(history[8]->type(), AccessType::Store); - ASSERT_EQ(history[8]->var(), cVar); - // It also has the bounds of the final loop. - ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); - // The previous load is in its RHS, so it depends on it as always. - ASSERT_EQ(history[8]->dependencies().size(), 1); - ASSERT_TRUE(history[8]->hasDependency(history[7])); - - // The last access represents the output Buf. - ASSERT_EQ(history[9]->type(), AccessType::Output); - ASSERT_EQ(history[9]->var(), cVar); - // It has the bounds of the output Buf. - ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); - // sanity check the input we retrieved earlier matches. - ASSERT_EQ(history[9], output); - // It depends on the last write to C only. - ASSERT_EQ(history[9]->dependencies().size(), 1); - ASSERT_TRUE(history[9]->hasDependency(history[8])); -} - -// Verify that we can still infer bounds when the loop var is offset. -TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer({a}, {b}); - - // This enables using the execution order of the loops to determine if some - // loops are self dependent or not. - analyzer.allowLoopExecutionOrderAnalysis(); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * for (int x = 0; x < 9; x++) { - * A[x] = A[x + 1]; - * } - * for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - * for (int x = 0; x < 10; x++) { - * A[x] = A[9 - x]; - * } - * for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - */ - - StmtPtr stmt = Block::make( - {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), - For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), - For::make( - x, - 0, - 9, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))), - For::make( - x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))), - For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - // Sanity check output depends on Input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - /* 0. Input: A[(0, 9)] - dependents: 1 - * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 - * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 - * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 - * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 - * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 - * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 - * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 - * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 - * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 - * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 - * 11. Output: B[(0, 9)] - depends on: 10 - */ - - // Now let's look at the bounds of each access. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 12); - VarPtr aVar = a.node()->base_handle(); - VarPtr bVar = b.node()->base_handle(); - - // The first access is the input A. - ASSERT_EQ(history[0]->type(), AccessType::Input); - ASSERT_EQ(history[0]->var(), aVar); - // It has the bounds of the producing Input. - ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); - - // The second access is the load A[x-1]. - ASSERT_EQ(history[1]->type(), AccessType::Load); - ASSERT_EQ(history[1]->var(), aVar); - // It has the bounds of the loop modified by the offset of each index, in - // this case -1. - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); - // It depends on the input, but also the store in the same loop, since - // different iterations of the loop depend on each other. - ASSERT_EQ(history[1]->dependencies().size(), 2); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - ASSERT_TRUE(history[1]->hasDependency(history[2])); - - // The third access is the Store to A[x] in the first loop. - ASSERT_EQ(history[2]->type(), AccessType::Store); - ASSERT_EQ(history[2]->var(), aVar); - // It has no offset on x, so should have the same bounds as the loop. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); - - // The fourth access is the load A[x+1] in the second loop. - ASSERT_EQ(history[3]->type(), AccessType::Load); - ASSERT_EQ(history[3]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9) modified by the offset of each - // index, in this case 1. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); - // This load totally overlaps the previous write to A, so it depends only on - // it and not the input. - ASSERT_EQ(history[3]->dependencies().size(), 1); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The fifth access is the store to A[x] in the second loop. - ASSERT_EQ(history[4]->type(), AccessType::Store); - ASSERT_EQ(history[4]->var(), aVar); - // It has no offset on x, so should have the same bounds as the loop. - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); - - // The sixth access is the load to A[8 - x] in the third loop. - ASSERT_EQ(history[5]->type(), AccessType::Load); - ASSERT_EQ(history[5]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9) modified by the offset of each - // index, in this case 8 - x. - // This access has a negative stride, which will be normalized. - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); - // This load totally overlaps the most recent write to A, so it depends only - // on it and not the input or the first write to A. - ASSERT_EQ(history[5]->dependencies().size(), 1); - ASSERT_TRUE(history[5]->hasDependency(history[4])); - - // The seventh access is the store to A[9 - x] in the third loop. - ASSERT_EQ(history[6]->type(), AccessType::Store); - ASSERT_EQ(history[6]->var(), aVar); - // This store has a negative stride on it's indices, but is normalized - // internally. - ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); - - // The eighth access is the load A[9-x] in the second loop. - ASSERT_EQ(history[7]->type(), AccessType::Load); - ASSERT_EQ(history[7]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, - // which essentially traverses the loop backwards. - ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); - // This Load has three write dependencies: - ASSERT_EQ(history[7]->dependencies().size(), 3); - // * The previous store (#6) for elements 1-9 - ASSERT_TRUE(history[7]->hasDependency(history[6])); - // * An earlier store (#4) covering element 0 - ASSERT_TRUE(history[7]->hasDependency(history[4])); - // * A future store inside this loop, since this loop modifies the buffer - // in a non distinct way (due to the load and store having different access - // strides). - ASSERT_TRUE(history[7]->hasDependency(history[8])); - - // The ninth access is the store to A[x] in the fourth loop. - ASSERT_EQ(history[8]->type(), AccessType::Store); - ASSERT_EQ(history[8]->var(), aVar); - // This store has a negative stride on it's indices, but is normalized - // internally. - ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); - - // The tenth and 11th accesses are the copy from A[x] to B[x]. - ASSERT_EQ(history[9]->type(), AccessType::Load); - ASSERT_EQ(history[9]->var(), aVar); - ASSERT_EQ(history[10]->type(), AccessType::Store); - ASSERT_EQ(history[10]->var(), bVar); - - // The last access represents the output Buf. - ASSERT_EQ(history[11]->type(), AccessType::Output); - ASSERT_EQ(history[11]->var(), bVar); - // It has the bounds of the output Buf. - ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); - // It depends on the last write to B only. - ASSERT_EQ(history[11]->dependencies().size(), 1); - ASSERT_TRUE(history[11]->hasDependency(history[10])); - - // ok that's enough of that. -} - -// Check many different cases of loop self dependency - when a load within a -// loop is dependent on a Store later in the same loop but in different -// iteration. This is affected by whether or not we can trust the execution -// order of the loop. -TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - using namespace analysis; - - // This check assumes that the Stmt has a single Store with a single Load on - // the RHS. - auto isSelfDependent = - [](const std::vector>& history) -> bool { - return history.front()->hasDependency(history.back()); - }; - - { - /* for (int y = 0; y < 10; y++) { - * A[y] = (A[y]) + 1; - * } */ - - // Not self dependent since all loop iterations use a different y. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - y, - 0, - 10, - Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))})); - - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int y = 0; y < 10; y++) { - * A[y + 1] = (A[y + 1]) + 1; - * } - */ - - // Not self dependent due to different y (with offset). - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - y, - 0, - 10, - Block::make( - {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))})); - - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - // Is self dependent since all loops use a common constant element of A. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[0] = (B[0]) + x; - * } - */ - - // Is not self dependent because there is no store to the buffer that is - // read. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))})); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[y] = (A[y]) + x; - * } - */ - - // Is self dependent since all loops use a common symbolic element of A. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 1]; - * } - */ - - // In this case it depends if we are considering execution order. - - MemDependencyChecker analyzer; - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - // With analysis of order disabled, this is self dependent since the read - // from X+1 and the write to X+1 could be in reverse order. - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 1]; - * } - */ - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - // If order analysis is enabled, this is not dependent since the read for - // each element occurs before the write to that element. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - MemDependencyChecker analyzer; - - StmtPtr stmt = - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - // In this case, even with order analysis the Load is dependent on the - // Store, since the write to X occurs before the read from X. - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - */ - - // Still works if the execution order is reversed, so long as the read - // comes before the write. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); - stmt->accept(&analyzer); - - // However here was can determine the A store is earlier in the order than - // the load. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[8 - x] = A[9 - x]; - * } - */ - - // But not if it doesn't. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - */ - - // And not if we're not relying on execution order. - - MemDependencyChecker analyzer; - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 3; x < 10; x++) { - * A[x - 2] = A[x - 1]; - * } - */ - - // Forward order but negative indices. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - // However here was can determine the A store is earlier in the order than - // the load. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2]; - * } - */ - - // With an access stride. - - MemDependencyChecker analyzer; - // Execution order doesn't matter since the read and the write are totally - // distinct. - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 1]; - * } - */ - - // Here we can use the common stride of the accesses to determine they are - // distinct. - // Note, this is the only place (loop self dependency) we use this stride - // to avoid unnecessary dependence. - - MemDependencyChecker analyzer; - // Execution order doesn't matter since the read and the write are totally - // distinct. - - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 - 1]; - * } - */ - - // same if the read is behind the write so long as they are distinct. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 2]; - * } - */ - - // But not if the offset is in the stride. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 - 2]; - * } - */ - - // Works with negative offsets too. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 7]; - * } - */ - - // Detects accesses are distinct when offset is large but not a multiple - // of stride. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 4]; - * } - */ - - // Works with offsets which are multiples of the stride. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 6] = A[x * 6 + 5]; - * } - */ - - // detects accesses are distinct with large strides when the offset is - // within. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6]; - * } - */ - - // detects accesses are overlapping when stride is different but a - // multiple. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 4] = A[x * 2]; - * } - */ - - // still works when the read axis is the smaller stride. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6 + 1]; - * } - */ - - // detects accesses are distinct when stride is different but a multiple - // and there is an offset. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6 + 4]; - * } - */ - - // The smaller stride determines whether there is overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2 + 3] = A[x * 6]; - * } - */ - - // The smaller stride determines whether there is overlap, not the larger. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 3 + 1]; - * } - */ - - // If they have strides with no common multiple > 1, they overlap. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 10]; - * } - */ - - // If the offset is greater than the size of the loop, they can't overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[9 - x]; - * } - */ - - // If they have different execution orders they may overlap. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[19 - x * 2]; - * } - */ - - // Or they may not, depending on their start offset and strides. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x / 2] = A[x / 2]; - * } - */ - - // If the stride is not monotonic, they overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x / 2] = A[x / 2] + 1; - * } - */ - - // If the stride is not monotonic, they overlap - even with an offset. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x % 2] = A[x % 2]; - * } - */ - - // Mod too... - - analysis::MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = y; x < z; x++) { - * A[x] = A[x + 1]; - * } - */ - - // Still works with symbolic loop extents. - - { - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - StmtPtr stmt = - For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - } -} - -// Verify that a strided access still works. -// TODO: actually this only works because of the size of the ranges, revisit -// this test after strided overlap is implemented. -TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - BufHandle a("A", {20}, kInt); - BufHandle b("B", {20}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - MemDependencyChecker analyzer({a.node()}, {b.node()}); - StmtPtr stmt = Block::make( - {For::make( - x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), - For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) - - }); - stmt->accept(&analyzer); - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 2 dependencies... the store in each loop. - auto outputAccess = analyzer.output(b.node()); - ASSERT_EQ(outputAccess->dependencies().size(), 2); -} - -/* TODO(nickg) - this test will fail due to the lack of stride math in Bound -TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - BufHandle a("A", {20}, kInt); - BufHandle b("B", {20}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), - For::make( - x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))) - - }); - stmt->accept(&analyzer); - - std::cout << *stmt << "\n"; - for (auto& wi : analyzer.getHistory()) { - wi->print(); - } - } -}*/ - -// analysis on Stmts using Cond. -TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * C[0] = (B[0]) + 1; - * } else { - * C[0] = (B[1]) + 1; - * } - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)), - Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 3); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * C[x] = B[x]; - * } - * } else { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))), - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 3); - - // TODO(nickg): actually since the true and false branch cover the total - // range of the first store this should have 2 dependencies, but we don't - // do that yet. - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Only has true branch. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))), - nullptr)}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * } else { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Only has false branch. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - nullptr, - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (C[0]<5 ? 1 : 0) { - * C[0] = 5; - * } - */ - - // Cond's Condition depends on a previous access. - - MemDependencyChecker analyzer({a}, {c}); - StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); - ExprHandle conditionalLoad = Load::make(c, {0}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, initStore), - Cond::make( - CompareSelect::make( - conditionalLoad, 5, CompareSelectOperation::kLT), - Store::make(c, {0}, 5), - nullptr)}); - - stmt->accept(&analyzer); - - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - - ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); - ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); - } -} - -// Stmts using IfThenElse. -TEST(MemDependency, MemDependencyCheckerIfThenElse) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Add::make(Load::make(b, {0}), 1), - Add::make(Load::make(b, {1}), 1))); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - ifStore}); - - stmt->accept(&analyzer); - - // Output C should have 2 dependencies, each of the two stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // Now we need to check the Store containing the IfThenElse. - auto ifStoreAccess = analyzer.accessFor(ifStore); - - // It should have 2 dependencies. - ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * C[0] = (y < 5 ? (B[0]) + 1 : 42; - */ - - // If the load appears in only one side of an IfThenElse the output may be - // dependent on it. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Add::make(Load::make(b, {0}), 1), - 42)); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - ifStore}); - - stmt->accept(&analyzer); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = (x < 5 ? B[x] : A[x]; - * } - */ - - // In this case C is dependent on both A and B. - - // TODO: in cases like this it would be possible to split the range of B - // into two bounds, one dependent on A and one dependent on B. We'd need to - // examine conditions relative to previously encountered loop variables. I'm - // uncertain if this would be helpful. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Load::make(b, {x}), - Load::make(a, {x}))); - StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); - - stmt->accept(&analyzer); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } -} - -// Cutting a loop with single elem writes -TEST(MemDependency, MemDependencyCheckerCutLoop) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - * B[5] = 100; - */ - - // Cutting a loop with single element writes. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), - Store::make(b, {5}, 100)}); - - stmt->accept(&analyzer); - - // Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 2 dependencies. - auto outputAccess = analyzer.output(b.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - } - - { - /* for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - * for (int x = 4; x < 7; x++) { - * B[x] = B[x] + 3; - * } - * B[5] = 100; - * B[6] = 101; - * B[7] = 102; - */ - - // Cutting a loop with a smaller loop but then totally overlap that second - // loop with one element writes. - - MemDependencyChecker analyzer({a}, {b}); - ForPtr firstLoop = - For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); - StorePtr secondStore = - Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); - ForPtr secondLoop = For::make(x, 4, 7, secondStore); - - StmtPtr stmt = Block::make( - {firstLoop, - secondLoop, - Store::make(b, {4}, 100), - Store::make(b, {5}, 101), - Store::make(b, {6}, 102)}); - - stmt->accept(&analyzer); - - // Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 4 dependencies. - auto outputAccess = analyzer.output(b.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 4); - - // Second loop depends on first loop. - ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); - - // Output does not depend on second loop or store. - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); - } -} - -// Dynamic shapes (load in indices). -TEST(MemDependency, MemDependencyCheckerDynamicShapes) { - BufHandle a("A", {100}, kInt); - BufHandle b("B", {100}, kInt); - BufHandle c("C", {100}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - { - /* for (int x = 0; x < B[0]; x++) { - * C[x] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 - * 1. Input: A[(0, 99)] - dependents: 3 - * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 - * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 - * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Output dependent on A input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - // Also dependent on B input to determine the size of the region written. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The accesses in the loop depend on the load in the stop condition. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // Make a load from B to compare against. - ExprHandle loadFromB = Load::make(b, {0}); - - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); - } - - { - /* for (int x = B[0]; x < B[1]; x++) { - * C[x] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, - Load::make(b, {0}), - Load::make(b, {1}), - Store::make(c, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 3 - * 1. Input: A[(0, 99)] - dependents: 4 - * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 - * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 - * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 - * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 - * 6. Output: C[(0, 99)] - depends on: 5 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 7); - - // The accesses in the loop depend on the load in the start condition. - ASSERT_TRUE(history[5]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[2])); - - // also the stop condition. - ASSERT_TRUE(history[5]->hasDependency(history[3])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // Make loads from B to compare against. - ExprHandle loadFromB0 = Load::make(b, {0}); - ExprHandle loadFromB1 = Load::make(b, {1}); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[B[x]]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 - * 1. Input: A[(0, 99)] - dependents: 3 - * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 - * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 - * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads, the load of A depends on the load of B. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The loads in the indices depend on the relevant input buffer. - ASSERT_TRUE(history[3]->hasDependency(history[1])); - ASSERT_TRUE(history[2]->hasDependency(history[0])); - - // The load from B has the loop bounds. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - - // The load from A has bounds B[0] to B[9]. - ExprHandle loadFromB0 = Load::make(b, {0}); - ExprHandle loadFromB9 = Load::make(b, {9}); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[B[x]] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 3 - * 1. Input: A[(0, 99)] - dependents: 2 - * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 - * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 - * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads, neither load is dependent. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - ASSERT_FALSE(history[3]->hasDependency(history[2])); - ASSERT_FALSE(history[2]->hasDependency(history[3])); - - // The loads each depend on their relevant input. (but accesses are in a - // different order than the last case). - ASSERT_TRUE(history[3]->hasDependency(history[0])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The load from B has the loop bounds. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); - - // And so does the load from A. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[B[A[x]]] = x; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 3 - * 1. Input: A[(0, 99)] - dependents: 2 - * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 - * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 - * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // The outer load depends on the inner. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The loads each depend on their relevant input. (but accesses are in a - // different order than the last case). - ASSERT_TRUE(history[3]->hasDependency(history[0])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The load from A has the loop bounds. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - // The load from B as bounds A[0] to A[9]. - ExprHandle loadFromA0 = Load::make(a, {0}); - ExprHandle loadFromA9 = Load::make(a, {9}); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); - - // The store has bounds of B[A[0]] to B[A[9]]. - ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}); - ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); - } -} - -// Verify multi dimensional bounds work. -TEST(MemDependency, MemDependencyCheckerMultiDim) { - int M = 10, N = 9, K = 12; - BufHandle a("A", {M, N, K}, kInt); - BufHandle b("B", {M, N, K}, kInt); - BufHandle c("C", {M, K}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 9; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, y, z] = A[x, y, z]; - * } - * } - * } - */ - // Full range. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - N, - For::make( - z, - 0, - K, - Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 5; x++) { - * for (int y = 0; y < 5; y++) { - * for (int z = 0; z < 5; z++) { - * B[x, y, z] = A[x, y, z]; - * } - * } - * } - */ - // Partial range. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - For::make( - z, - 0, - 5, - Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 12; y++) { - * B[x, 0, y] = A[x, 0, y]; - * } - * } - */ - - // Partial loops. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - N, - For::make( - y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 100; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); - * } - * } - * } - */ - - // Loops that don't correspond to an index, bufs with different - // dimensionality. - - MemDependencyChecker analyzer({a, c}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - 100, - For::make( - z, - 0, - K, - Store::make( - b, - {x, 0, z}, - Add::make( - Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on both inputs. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); - - // 6 accesses: 2 inputs, 2 loads, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // Simple chain from input to output over the A buf. - // history[0] is the C input, history[3] is the load from C. - ASSERT_TRUE(history[5]->hasDependency(history[4])); - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - // The store also depends on the load from the C input. - ASSERT_TRUE(history[4]->hasDependency(history[3])); - ASSERT_TRUE(history[3]->hasDependency(history[0])); - - // A Buf accesses. - ASSERT_TRUE( - EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); - - // C buf access. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 9; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); - * } - * } - * } - */ - // Multi-dim reductions. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - N, - For::make( - z, - 0, - K, - Store::make( - b, - {x, 0, 0}, - Add::make( - Load::make(b, {x, y, z}), - Load::make(a, {x, y, z}))))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, 2 loads, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 5); - - // Simple chain from input to output. - ASSERT_TRUE(history[4]->hasDependency(history[3])); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[3]->hasDependency(history[1])); - ASSERT_TRUE(history[2]->hasDependency(history[0])); - - // The load from B depends on the store to B. - ASSERT_TRUE(history[1]->hasDependency(history[3])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); - } -} - -// Various tests using the external Compute/Reduce API. -TEST(MemDependency, MemDependencyCheckerComputeAPI) { - using namespace analysis; - - /* for (int m = 0; m < 4; m++) { - * for (int n = 0; n < 5; n++) { - * for (int k = 0; k < 6; k++) { - * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); - * } - * } - * } - * for (int m_1 = 0; m_1 < 4; m_1++) { - * for (int n_1 = 0; n_1 < 5; n_1++) { - * for (int k_1 = 0; k_1 < 6; k_1++) { - * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); - * } - * } - * } - */ - - // Can determine if 2 loops created by Compute are dependent. - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - - MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); - - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); - - // Second loop depends on first loop. - auto c_loop = l.getLoopStmtsFor(c)[0]; - auto d_loop = l.getLoopStmtsFor(d)[0]; - ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); -} - -TEST(MemDependency, MemDependencyCheckerComputeInline) { - using namespace analysis; - - /* for (int m = 0; m < 4; m++) { - * for (int n = 0; n < 5; n++) { - * for (int k = 0; k < 6; k++) { - * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); - * } - * } - * } - */ - - // Check inlining affects the number of accesses returned. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - l.computeInline(c.buf()); - - MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); - - // broadcast_add tensor should not appear in trace at all. - for (auto& wi : analyzer.getHistory()) { - ASSERT_NE(wi->var(), c.buf()->base_handle()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeSplit) { - using namespace analysis; - // Split an axis, so the number of loops != the number of dimensions. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - - LoopNest l({c}); - - MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); - l.root_stmt()->accept(&analyzer_before); - - l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); - - MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - stmt->accept(&analyzer_after); - - // Splitting should not change accesses at all. - auto history_before = analyzer_before.getHistory(); - auto history_after = analyzer_after.getHistory(); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - ASSERT_EQ( - history_before[i]->bounds().size(), history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - ASSERT_EQ( - history_before[i]->dependencies().size(), - history_after[i]->dependencies().size()); - ASSERT_EQ( - history_before[i]->dependents().size(), - history_after[i]->dependents().size()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeReorder) { - using namespace analysis; - // Reorder an axis, so the loop order doesn't match the indexing order. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - - LoopNest l({c}); - - MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); - l.root_stmt()->accept(&analyzer_before); - - auto loops = l.getLoopStmtsFor(c); - l.reorderAxis(loops[0], loops[1]); - - MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - stmt->accept(&analyzer_after); - - // Reordering should not change accesses at all. - auto history_before = analyzer_before.getHistory(); - auto history_after = analyzer_after.getHistory(); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - ASSERT_EQ( - history_before[i]->bounds().size(), history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - ASSERT_EQ( - history_before[i]->dependencies().size(), - history_after[i]->dependencies().size()); - ASSERT_EQ( - history_before[i]->dependents().size(), - history_after[i]->dependents().size()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeReduce) { - using namespace analysis; - /* for (int l2 = 0; l2 < 2; l2++) { - * for (int n1 = 0; n1 < 3; n1++) { - * for (int m1 = 0; m1 < 6; m1++) { - * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); - * } - * } - * } - * for (int l1 = 0; l1 < 2; l1++) { - * sum[l1] = float(0); - * for (int n1_1 = 0; n1_1 < 3; n1_1++) { - * for (int m1_1 = 0; m1_1 < 6; m1_1++) { - * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), - * out_args={l1}, reduce_args={n1, m1}); - * } - * } - * } - */ - - // Can determine dependencies of a Reduction. - - BufHandle a("a", {2, 3, 6}, kFloat); - BufHandle b("b", {2, 3, 6}, kFloat); - - Tensor c = Compute( - "scale", - {2, 3, 6}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6}); - LoopNest l({d}, {c, d}); - - MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()}); - - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node())); - - // Second loop depends on first loop. - auto c_loop = l.getLoopStmtsFor(c)[0]; - auto d_loop = l.getLoopStmtsFor(d)[0]; - ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); - - // Reduction depends on both inputs. - auto reduces = NodeFinder::find(l.root_stmt()); - ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node())); -} - -TEST(MemDependency, MemDependencyCheckerComputeGEMM) { - int M = 1024; - int N = 1024; - int K = 2048; - using namespace analysis; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - LoopNest loop({CT}); - - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr m = loops[0]; - loop.splitWithMask(m, 4); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr n = loops[2]; - loop.splitWithMask(n, 16); - } - // mo, mi, no, ni, k -> - // mo, no, mi, ni, k - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[1]; - ForPtr no = loops[2]; - loop.reorderAxis(mi, no); - } - // mo, no, mi, ni, k -> - // mo, no, mi, k, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr ni = loops[3]; - ForPtr k = loops[4]; - loop.reorderAxis(ni, k); - } - // mo, no, mi, k, ni -> - // mo, no, k, mi, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[2]; - ForPtr k = loops[3]; - loop.reorderAxis(mi, k); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); - } - - MemDependencyChecker analyzer_unlowered( - loop.getInputBufs(), loop.getOutputBufs()); - - MemDependencyChecker analyzer_lowered( - loop.getInputBufs(), loop.getOutputBufs()); - - // Test both unlowered and lowered form. - { - StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); - stmt->accept(&analyzer_unlowered); - - // Outputs depend on inputs. - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node())); - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node())); - - // The last write to gemm should cover the total bound of the output. - std::shared_ptr outputAccess = - analyzer_unlowered.output(CT.buf()); - // A single dependency. - ASSERT_EQ(outputAccess->dependencies().size(), 1); - - // dependencies is a set with 1 element, so can just deref begin(). - std::shared_ptr gemmStore = - outputAccess->dependencies().begin()->second; - // Check its a store. - ASSERT_EQ(gemmStore->type(), AccessType::Store); - - ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); - - // Likewise the first read from each input cover the entire range of the - // input. - auto aInput = analyzer_unlowered.input(AP.node()); - auto bInput = analyzer_unlowered.input(BP.node()); - - // A single dependent each. - ASSERT_EQ(aInput->dependents().size(), 1); - ASSERT_EQ(bInput->dependents().size(), 1); - - // They're both loads. - std::shared_ptr aLoad = aInput->dependents().begin()->second; - std::shared_ptr bLoad = bInput->dependents().begin()->second; - ASSERT_EQ(aLoad->type(), AccessType::Load); - ASSERT_EQ(bLoad->type(), AccessType::Load); - - ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); - ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); - } - - loop.prepareForCodegen(); - SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT}); - - // now check lowered dependency graph. - { - StmtPtr stmt = IRSimplifier::simplify(cg.stmt()); - stmt->accept(&analyzer_lowered); - - // Lowering will change the dimensionality of all bounds due to index - // flattening and will insert Allocates and Frees. - - auto history_before = analyzer_unlowered.getHistory(); - auto history_after = analyzer_lowered.getHistory(); - - ASSERT_EQ(history_before.size() + 2, history_after.size()); - - // Filter out the alloc/free; - auto isAllocFree = [](const auto& info) { - return info->type() == AccessType::Alloc || - info->type() == AccessType::Free; - }; - history_after.erase( - std::remove_if(history_after.begin(), history_after.end(), isAllocFree), - history_after.end()); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - - if (history_before[i]->dependencies().size() != - history_after[i]->dependencies().size()) { - // Must depend on an Alloc. - ASSERT_TRUE(std::any_of( - history_after[i]->dependencies().begin(), - history_after[i]->dependencies().end(), - [](const auto& pair) { - return pair.second->type() == AccessType::Alloc; - })); - - ASSERT_EQ( - history_before[i]->dependencies().size() + 1, - history_after[i]->dependencies().size()); - } - - if (history_before[i]->dependents().size() != - history_after[i]->dependents().size()) { - // Must depend on an Free. - ASSERT_TRUE(std::any_of( - history_after[i]->dependents().begin(), - history_after[i]->dependents().end(), - [](const auto& pair) { - return pair.second->type() == AccessType::Free; - })); - - ASSERT_EQ( - history_before[i]->dependents().size() + 1, - history_after[i]->dependents().size()); - } - - // Inputs and outputs are not flattened, only accesses. - if (history_before[i]->type() == AccessType::Input || - history_before[i]->type() == AccessType::Output) { - ASSERT_EQ( - history_before[i]->bounds().size(), - history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - } else { - ASSERT_EQ(history_after[i]->bounds().size(), 1); - ExprPtr flat_bounds = alloc(1); - - for (auto& b : history_before[i]->bounds()) { - flat_bounds = - alloc(flat_bounds, alloc(b.end, alloc(1))); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); - } - - flat_bounds = IRSimplifier::simplify(flat_bounds); - ExprPtr after_bounds = IRSimplifier::simplify( - alloc(history_after[i]->bounds()[0].end, alloc(1))); - ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); - } - } - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_memplanning.cpp b/test/cpp/tensorexpr/test_memplanning.cpp deleted file mode 100644 index f5ee8747650f..000000000000 --- a/test/cpp/tensorexpr/test_memplanning.cpp +++ /dev/null @@ -1,708 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -extern void checkIR(StmtPtr s, const std::string& pattern); - -TEST(BufLiveRange, SingleRangeLine) { - VarHandle i("i", kInt), j("j", kInt); - BufHandle a("a", {32}, kFloat); - BufHandle b("b", {32, 32}, kFloat); - - // Construct Stmt: - // { - // for (int i = 0; i < 32; i++) { - // a[i] = 0; - // for (int j = 0; j < 32; j++) { - // a[i] = (a[i]) + (b[i, j]); - // } - // } - // } - - StorePtr aInit = Store::make(a, {i}, 0); - ExprHandle reduce = a.load({i}) + b.load({i, j}); - StorePtr aReduce = Store::make(a, {i}, reduce); - StmtPtr loop = - For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); - - StmtPtr stmt = Block::make({loop}); - - auto range = BufLiveRange::liveRange(stmt, a.node()); - ASSERT_TRUE(std::get<0>(range) == 0); - ASSERT_TRUE(std::get<1>(range) == 0); -} - -TEST(BufLiveRange, MulRangeLine) { - VarHandle i("i", kInt); - BufHandle a("a", {32}, kFloat); - BufHandle b("b", {32}, kFloat); - - // Construct Stmt: - // { - // for (int i = 0; i < 32; i++) { - // if (i<10 ? 1 : 0) { - // a[i] = i + i; - // b[i] = i * i; - // } - // } - // for (int i = 0; i < 32; i++) { - // if (i>10 ? 1 : 0) { - // a[i] = i * i; - // b[i] = i + i; - // } - // } - // } - - StorePtr aStore_1 = Store::make(a, {i}, i + i); - StorePtr bStore_1 = Store::make(b, {i}, i * i); - StmtPtr loop_1 = For::make( - i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); - - StorePtr aStore_2 = Store::make(a, {i}, i * i); - StorePtr bStore_2 = Store::make(b, {i}, i + i); - StmtPtr loop_2 = For::make( - i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); - - StmtPtr stmt = Block::make({loop_1, loop_2}); - - auto range_a = BufLiveRange::liveRange(stmt, a.node()); - ASSERT_TRUE(std::get<0>(range_a) == 0); - ASSERT_TRUE(std::get<1>(range_a) == 1); - - auto range_b = BufLiveRange::liveRange(stmt, b.node()); - ASSERT_TRUE(std::get<0>(range_b) == 0); - ASSERT_TRUE(std::get<1>(range_b) == 1); -} - -TEST(MemPlanning, MemReuseWithTypeCast) { - int M = 4; - int N = 4; - int K = 4; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return CompareSelect::make( - CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); - }); - Tensor FT = - Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n); - }); - StmtPtr stmt = - tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are - // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' - // with typecasting. - //{ - // for (int i = 0; i < 4; i++) { - // for (int i_1 = 0; i_1 < 4; i_1++) { - // gemm[i, i_1] = float(0); - // for (int i_2 = 0; i_2 < 4; i_2++) { - // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, - // i_1]), reduce_args={i_2}); - // } - // } - // } - // for (int i_3 = 0; i_3 < 4; i_3++) { - // for (int i_4 = 0; i_4 < 4; i_4++) { - // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); - // } - // } - // for (int i_5 = 0; i_5 < 4; i_5++) { - // for (int i_6 = 0; i_6 < 4; i_6++) { - // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); - // } - // } - // for (int i_7 = 0; i_7 < 4; i_7++) { - // for (int i_8 = 0; i_8 < 4; i_8++) { - // F[i_7, i_8] = E[i_7, i_8]; - // } - // } - //} - - LoopNest l(stmt, {FT.buf()}); - l.prepareForCodegen(); - SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] -# CHECK: Alias(E,gemm); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - PaddedBuffer a_v(M, K, "a"); - PaddedBuffer b_v(K, N, "b"); - PaddedBuffer o1(M, N, "e_before"); - PaddedBuffer o2(M, N, "e_after"); - - for (const auto m : c10::irange(M)) { - for (const auto k : c10::irange(K)) { - a_v(m, k) = at::randn({1}).item().to(); - } - } - - for (const auto k : c10::irange(K)) { - for (const auto n : c10::irange(N)) { - b_v(k, n) = at::randn({1}).item().to(); - } - } - - cg.call({a_v, b_v, o1}); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg_llvm.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] -# CHECK: Alias(E,gemm); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - cg_llvm.call({a_v, b_v, o2}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(o1, o2, 1e-5); -#endif -} - -TEST(MemPlanning, NoMemReuseForLargerType) { - int M = 4; - int N = 4; - int K = 4; - - BufHandle AP("A", {M, K}, kShort); - BufHandle BP("B", {K, N}, kShort); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - auto zero = Cast::make(CT.buf()->dtype(), 0); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); - }); - Tensor FT = - Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n); - }); - StmtPtr stmt = - tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are - // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for - // 'E'. - //{ - // for (int i = 0; i < 4; i++) { - // for (int i_1 = 0; i_1 < 4; i_1++) { - // gemm[i, i_1] = int16_t(0); - // for (int i_2 = 0; i_2 < 4; i_2++) { - // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, - // i_1]), reduce_args={i_2}); - // } - // } - // } - // for (int i_3 = 0; i_3 < 4; i_3++) { - // for (int i_4 = 0; i_4 < 4; i_4++) { - // relu[i_3, i_4] = (gemm[i_3, i_4]) a_v(M, K, "a"); - PaddedBuffer b_v(K, N, "b"); - PaddedBuffer o1(M, N, "e_before"); - PaddedBuffer o2(M, N, "e_after"); - - for (const auto m : c10::irange(M)) { - for (const auto k : c10::irange(K)) { - a_v(m, k) = at::randn({1}).item().to(); - } - } - - for (const auto k : c10::irange(K)) { - for (const auto n : c10::irange(N)) { - b_v(k, n) = at::randn({1}).item().to(); - } - } - - cg.call({a_v, b_v, o1}); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg_llvm.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] -# CHECK: Allocate(E); // dtype=float, dims=[4, 4] -# CHECK: Free(E); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - cg_llvm.call({a_v, b_v, o2}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(o1, o2, 1e-5); -#endif -} - -TEST(MemPlanning, SameBufSizeMemReuse) { - int M = 1024; - int N = 1024; - int K = 2048; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - auto zero = Cast::make(CT.buf()->dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' - // for 'add'. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - Tensor GT = - Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return FT.load(m, n) - ET.load(m, n); - }); - - auto stmt = - Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same - // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - Tensor GT = - Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return FT.load(m, n) - 1; - }); - Tensor HT = - Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return GT.load(m, n) / 2; - }); - - auto stmt = Block::make( - {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and - // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for - // 'mul', and reuse 'gemm' for 'sub'. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = Compute( - "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { - return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); - }); - Tensor FT = Compute( - "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { - return ET.load(fm, fn) * ET.load(fm, fn); - }); - auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of - // buffer 'gemm' is smaller. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1]) -#include -#include -#include -#include -#include - -using namespace torch::jit::tensorexpr; - -using Tensors = std::vector; -using Args = std::vector; -std::unique_ptr compile( - const Args& inputs, - const Tensors& outputs) { - LoopNest nest({outputs}); - nest.prepareForCodegen(); - nest.simplify(); - auto join = inputs; - join.insert(join.end(), outputs.begin(), outputs.end()); - return std::make_unique(nest.root_stmt(), join); -} - -TEST(Ops, Sum) { - constexpr int M = 8; - constexpr int N = 16; - std::vector testDims = {{0}, {1}, {0, 1}}; - std::vector> outputShapes = {{N}, {M}, {}}; - for (unsigned idx = 0; idx < testDims.size(); idx++) { - const auto& dims = testDims[idx]; - const auto& outShape = outputShapes[idx]; - - BufHandle a("a", {M, N}, kFloat); - std::vector outStrides = - c10::fmap(make_contiguous_strides(outShape)); - Tensor b = computeSum( - {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); - auto cg = compile({a}, {b}); - - auto at = at::arange(M * N, at::kFloat).view({M, N}); - auto ref = at::sum(at, dims); - auto bt = at::empty_like(ref); - - cg->call({at.data_ptr(), bt.data_ptr()}); - - ASSERT_TRUE(at::allclose(bt, ref)); - } -} - -TEST(Ops, ChannelsLastSum) { - constexpr int A = 2; - constexpr int B = 3; - constexpr int C = 4; - constexpr int D = 5; - constexpr int E = 6; - std::vector testDims = {{0}, {1}, {0, 1}}; - - std::vector> outputShapes = { - {B, C, D, E}, {A, C, D, E}, {C, D, E}}; - for (unsigned idx = 0; idx < testDims.size(); idx++) { - const auto& dims = testDims[idx]; - const auto& outShape = outputShapes[idx]; - - BufHandle a("a", {A, B, C, D, E}, kFloat); - std::vector outStrides = - c10::fmap(make_channels_last_strides(outShape)); - Tensor b = computeSum( - {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); - auto cg = compile({a}, {b}); - - auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E}); - auto ref = at::sum(at, dims); - auto bt = at::empty_like(ref); - - cg->call({at.data_ptr(), bt.data_ptr()}); - - ASSERT_TRUE(at::allclose(bt, ref)); - } -} diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp deleted file mode 100644 index af6b539ff33e..000000000000 --- a/test/cpp/tensorexpr/test_quantization.cpp +++ /dev/null @@ -1,452 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; -using SimpleIRExprEval = ExprEval; -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -class Quantization : public ::testing::Test { - public: - void SetUp() override { - getTEMustUseLLVMOnCPU() = false; - } -}; - -TEST_F(Quantization, QuantDequantInt8) { - const auto graph_string = R"IR( - graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=12]() - %3 : int = prim::Constant[value=13]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantDequantUInt8) { - const auto graph_string = R"IR( - graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %3 : int = prim::Constant[value=122]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantDequantUInt8_NLC) { - const auto graph_string = R"IR( - graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)): - %2 : int = prim::Constant[value=13]() - %3 : int = prim::Constant[value=122]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(1, 2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - x.unsafeGetTensorImpl()->set_sizes_and_strides( - std::initializer_list{1, 2, 2}, {4, 1, 2}); - auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_add( - at::Tensor x1, - at::Tensor x2, - double scale, - int64_t zero) { - const auto qadd_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::add", "") - .typed(); - return qadd_op.call(x1, x2, scale, zero); -} - -TEST_F(Quantization, QuantAddDequantInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=12]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8); - auto qa = quantized_add(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantAddDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); - auto qa = quantized_add(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantSigmoidDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %qa : QUInt8(2, 2) = aten::sigmoid(%q1) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto qs = at::sigmoid(q1); - auto y_expected = at::dequantize(qs); - - TensorExprKernel k(graph); - std::vector inputs = {x1}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "qs:\n" << qs << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_mul( - at::Tensor x1, - at::Tensor x2, - double scale, - int64_t zero) { - const auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::mul", "") - .typed(); - return op.call(x1, x2, scale, zero); -} - -TEST_F(Quantization, QuantMulDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); - auto qa = quantized_mul(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %4 : NoneType = prim::Constant() - %3 : int[] = prim::Constant[value=[6, 6]]() - %qz : int = prim::Constant[value=13]() - %qs : float = prim::Constant[value=0.1]() - %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2) - %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4) - %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); - auto qu = at::upsample_nearest2d(q, {6, 6}); - auto y_expected = at::dequantize(qu); - - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "q:\n" << q << std::endl; - std::cout << "qu:\n" << qu << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, UpsampleNearst2d) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): - %4 : NoneType = prim::Constant() - %3 : int[] = prim::Constant[value=[4, 4]]() - %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4) - return (%u))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y_expected = at::upsample_nearest2d(x, {4, 4}); - - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_cat( - c10::List const& xs, - int64_t dim, - double scale, - int64_t zero) { - const auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::cat", "") - .typed const&, - int64_t, - std::optional, - std::optional)>(); - return op.redispatch( - DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero); -} - -TEST_F(Quantization, QuantCatDequantUInt8) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): - %qdt : int = prim::Constant[value=13]() - %qxz : int = prim::Constant[value=13]() - %qxs : float = prim::Constant[value=0.1]() - %qyz : int = prim::Constant[value=16]() - %qys : float = prim::Constant[value=0.15]() - %qzz : int = prim::Constant[value=19]() - %qzs : float = prim::Constant[value=0.2]() - %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt) - %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt) - %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt) - %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz) - %catd : int = prim::Constant[value=0]() - %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz) - %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat) - return (%cat))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); - auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8); - auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8); - auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13); - auto expected = at::dequantize(qcat); - - TensorExprKernel k(graph); - std::vector inputs = {x, y, z}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto result = stack[0].toTensor(); - bool check = at::allclose(expected, result); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y:\n" << y << std::endl; - std::cout << "z:\n" << z << std::endl; - std::cout << "qx:\n" << qx << std::endl; - std::cout << "qy:\n" << qy << std::endl; - std::cout << "qz:\n" << qz << std::endl; - std::cout << "qcat:\n" << qcat << std::endl; - std::cout << "expected:\n" << expected << std::endl; - std::cout << "result:\n" << result << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp deleted file mode 100644 index fb83ab85b71e..000000000000 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ /dev/null @@ -1,1928 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(Reductions, ReduceSum0D_1) { - const int M = 10; - - BufHandle b("b", {M}, kFloat); - std::vector in(M); - for (const auto j : c10::irange(M)) { - in[j] = j; - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], in[i]); - } -} - -TEST(Reductions, ReduceSum0D_2) { - BufHandle b("b", {}, kFloat); - std::vector in(1); - in[0] = 77.7; - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], in[0]); -} - -// Sum an array to a single value. -TEST(Reductions, ReduceSum1D) { - BufHandle b("b", {10}, kFloat); - std::vector in(10); - for (const auto j : c10::irange(10)) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {10}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 45); -} -// Sum a 2D tensor to a 1D tensor with dynamic shapes. -TEST(Reductions, ReduceSum2D) { - const int M = 3; - const int N = 7; - - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle b("b", {m, n}, kFloat); - std::vector in(M * N); - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - in[i * N + j] = j; - } - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, n, m}); - - cg.call({in, out, 5, 7}); - - float expected = 0; - for (const auto i : c10::irange(N)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], expected); - } -} - -// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to -// check our work. -TEST(Reductions, ReduceSum3D) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m}); - - std::vector bData(2 * 3 * M, 0); - std::vector cData(2 * 3, 6.0f); - std::vector dData(2, 1.0f); - std::vector eData(2, 1.0f); - - for (int i = 0; i < 2 * 3; ++i) { - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j; - } - } - - cg.call({bData, cData, M}); - float expected = 0; - for (const auto i : c10::irange(M)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - - for (int i = 0; i < 2 * 3; ++i) { - ASSERT_EQ(cData[i], expected); - } - - Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m}); - LoopNest loop2({d}); - loop2.prepareForCodegen(); - StmtPtr s2 = loop2.root_stmt(); - s2 = IRSimplifier::simplify(s2); - - SimpleIREvaluator cg2(s2, {b, d, m}); - cg2.call({bData, dData, M}); - - // We're combining an additional dimension of 3, so the sum is 3x. - expected = expected * 3; - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(dData[i], expected); - } - - // This is the same as just reducing the original result across that axis. - BufHandle c_buf(c.buf()); - Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3}); - LoopNest loop3({e}); - loop3.prepareForCodegen(); - StmtPtr s3 = loop3.root_stmt(); - s3 = IRSimplifier::simplify(s3); - - SimpleIREvaluator cg3(s3, {c, e}); - cg3.call({cData, eData}); - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(eData[i], expected); - } -} - -// Sum a large (10 D) Tensor 5 dimensions in. -TEST(Reductions, ReduceSum10D) { - BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat); - const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; - BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat); - const int OutputSize = 2 * 3 * 2 * 3 * 2; - - std::vector in(InputSize, 1.f); - std::vector out(OutputSize, -1.f); - - Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, c}); - - cg.call({in, out}); - - // NOLINTNEXTLINE(bugprone-integer-division) - float expected = InputSize / OutputSize; - for (const auto i : c10::irange(OutputSize)) { - ASSERT_EQ(out[i], expected); - } -} - -// Reduce via Mul rather than Add using a custom Reducer. -TEST(Reductions, ReduceProduct) { - const int M = 4; - const int N = 4; - - BufHandle b("b", {M, N}, kFloat); - std::vector in(M * N); - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - in[i * N + j] = 2 + j; - } - } - - std::vector out(M, -1.f); - - Reducer product( - ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); - - Tensor c = Reduce("product", {M}, product, b, {N}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - - float expected = 1; - for (const auto i : c10::irange(N)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected *= 2 + i; - } - - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], expected); - } -} - -// Maximum reductions. -TEST(Reductions, ReduceMax) { - BufHandle in_("b", {10}, kFloat); - - std::vector in(10); - std::vector out(1, -1.f); - for (const auto j : c10::irange(10)) { - in[j] = j; - } - - Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10}); - - LoopNest loop({dm1}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - SimpleIREvaluator cg(s, {in_, dm1}); - - cg.call({in, out}); - - ASSERT_EQ(out[0], 9); - - BufHandle in2_("b", {2, 5}, kFloat); - std::vector out2(2, -1.f); - - Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5}); - - LoopNest loop2({m2d}); - loop2.prepareForCodegen(); - s = loop2.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg2(s, {in2_, m2d}); - cg2.call({in, out2}); - - ASSERT_EQ(out2[0], 4); - ASSERT_EQ(out2[1], 9); -} - -// Minimum reduction, with custom initialization. -TEST(Reductions, ReduceMinCustomInitializer) { - VarHandle minInit("minInit", kFloat); - BufHandle in_("b", {10}, kFloat); - - std::vector in(10); - std::vector out(1, -1.f); - for (const auto j : c10::irange(10)) { - in[j] = 10 + j; - } - - Tensor min = Reduce( - "min", - {}, - Minimum(ExprHandle(minInit)), - [&](ParameterList& v) { return in_.load(v); }, - {10}); - - LoopNest loop({min}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, min, minInit}); - - // Works normally (note that out data starts lower than the correct - // minimum). - cg.call({in, out, std::numeric_limits::max()}); - ASSERT_EQ(out[0], 10); - - // With an initializer lower than the min, that's the min. - cg.call({in, out, 5.f}); - ASSERT_EQ(out[0], 5); -} - -// Example implementation of Any/All. -// TODO: this is very awkward without logical And/Or operators. -TEST(Reductions, ReduceAnyAll) { - VarHandle searchValue("searchValue", kInt); - BufHandle b("b", {4, 10}, kInt); - - Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { - return CompareSelect::make(a, 1, 1, b, kEQ); - }); - - Tensor any = Reduce( - "anyEqual", - {4}, - anyEqSV, - [&](const auto& i, const auto& j) { - return CompareSelect::make(b.load(i, j), searchValue, kEQ); - }, - {10}); - - LoopNest loop({any}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, any, searchValue}); - - std::vector in(40, 0); - std::vector out(4, 0); - - // input has 0-39 in 4 rows. - for (const auto i : c10::irange(40)) { - in[i] = i; - } - cg.call({in, out, 1}); - - // only the first row has 1 - ASSERT_EQ(out[0], 1); - ASSERT_EQ(out[1], 0); - ASSERT_EQ(out[2], 0); - ASSERT_EQ(out[3], 0); - - cg.call({in, out, 15}); - - // 15 in the 3rd row - ASSERT_EQ(out[0], 0); - ASSERT_EQ(out[1], 1); - ASSERT_EQ(out[2], 0); - ASSERT_EQ(out[3], 0); - - Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) { - return CompareSelect::make(a, 0, 0, b, kEQ); - }); - - Tensor allGreaterThan = Reduce( - "allGreaterThan", - {4}, - allGTSV, - [&](const auto& i, const auto& j) { - return CompareSelect::make(b.load(i, j), searchValue, kGT); - }, - {10}); - - LoopNest loop2({allGreaterThan}); - loop2.prepareForCodegen(); - s = loop2.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); - - cg2.call({in, out, 11}); - - // 11 is in row 2. - ASSERT_EQ(out[0], 0); - ASSERT_EQ(out[1], 0); - ASSERT_EQ(out[2], 1); - ASSERT_EQ(out[3], 1); - - cg2.call({in, out, -3}); - - // All are positive. - ASSERT_EQ(out[0], 1); - ASSERT_EQ(out[1], 1); - ASSERT_EQ(out[2], 1); - ASSERT_EQ(out[3], 1); -} - -TEST(Reductions, ReduceMatmul2D) { - BufHandle tA("tA", {3, 2}, kFloat); - BufHandle tB("tB", {2, 3}, kFloat); - - std::vector tA_(6); - std::vector tB_(6); - - std::vector out(9, -1.f); - for (const auto i : c10::irange(3)) { - for (const auto j : c10::irange(2)) { - tA_[i * 2 + j] = i * 2 + j; - tB_[j * 3 + i] = i * 2 + j; - } - } - - Tensor mm = Reduce( - "mm", - {3, 3}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return tA.load(m, k) * tB.load(k, n); - }, - {2}); - - LoopNest loop({mm}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {tA, tB, mm}); - cg.call({tA_, tB_, out}); - - std::vector expected( - {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f}); - - for (const auto i : c10::irange(9)) { - ASSERT_EQ(out[i], expected[i]); - } -} - -TEST(Reductions, ReduceRfactorLike) { - BufHandle in("in", {10, 10}, kFloat); - std::vector in_(100); - for (const auto i : c10::irange(100)) { - in_[i] = i; - } - std::vector in_rf_(10, -2.f); - std::vector out(1, -1.f); - - Tensor l1 = Reduce("l1", {10}, Sum(), in, {10}); - BufHandle in_rf(l1.buf()); - - Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10}); - - LoopNest loop({l1, l2}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, l1, l2}); - cg.call({in_, in_rf_, out}); - - ASSERT_EQ(out[0], 99 * 50); -} - -TEST(Reductions, ReduceAsProducer) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle a("a", {2, 3}, kFloat); - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); - Tensor d = - Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) { - return c.load(l, n) * a.load(l, n); - }); - LoopNest loop({d}, {c, d}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {a, b, d, m}); - - std::vector aData(2 * 3, 0); - std::vector bData(2 * 3 * M, 0); - std::vector dData(2 * 3, 6.0f); - - for (int i = 0; i < 2 * 3; ++i) { - aData[i] = 6 - i; - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j; - } - } - - cg.call({aData, bData, dData, M}); - float expected = 0; - for (const auto i : c10::irange(M)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - for (int i = 0; i < 2 * 3; ++i) { - ASSERT_EQ(dData[i], expected * (6 - i)); - } -} - -TEST(Reductions, ReduceAsConsumer) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle a("a", {2, 3, m}, kFloat); - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Compute( - "scale", - {2, 3, m}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {2}, Sum(), c, {3, m}); - LoopNest loop({d}, {c, d}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {a, b, d, m}); - - std::vector aData(2 * 3 * M, 0); - std::vector bData(2 * 3 * M, 0); - std::vector dData(2, 6.0f); - - for (int i = 0; i < 2 * 3; ++i) { - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j + 1; - aData[i * M + j] = 6 - i; - } - } - - cg.call({aData, bData, dData, M}); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float expected[2] = {0, 0}; - for (const auto i : c10::irange(2)) { - for (const auto j : c10::irange(3)) { - for (const auto k : c10::irange(M)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected[i] += (k + 1) * (6 - (i * 3 + j)); - } - } - } - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(dData[i], expected[i]); - } -} - -TEST(Reductions, SplitReduceAxis) { - BufHandle in("in", {16, 8}, kFloat); - - std::vector in_(16 * 8); - for (const auto i : c10::irange(16)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out(16, -1.f); - - Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); - LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::splitWithTail(loops[1], 2); - - l.prepareForCodegen(); - - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, tensor}); - cg.call({in_, out}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out[i], i * 8); - } -} - -TEST(Reductions, SplitNonReduceAxis) { - BufHandle in("in", {16, 8}, kFloat); - - std::vector in_(16 * 8); - for (const auto i : c10::irange(16)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out(16, -1.f); - Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); - LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::splitWithTail(loops[0], 2); - LoopNest::splitWithTail(loops[0], 2); - - l.prepareForCodegen(); - - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, tensor}); - cg.call({in_, out}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out[i], i * 8); - } -} - -TEST(Reductions, ReorderedReductionInitializer) { - /* From the quip: - for k in 0..1: // blockIdx - for m in 0..128: - for n in 0..64: // threadIdx - SumOp(c(k, n), 0, a(k, m, n), {m}) - */ - - BufHandle in("in", {1, 12, 6}, kFloat); - std::vector in_(12 * 6, 1.f); - - Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6}); - LoopNest l_({tensor_}); - - l_.prepareForCodegen(); - StmtPtr s_ = Stmt::clone(l_.root_stmt()); - s_ = IRSimplifier::simplify(s_); - - Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6}); - LoopNest l({tensor}); - - auto loops = l.getLoopStmtsFor(tensor); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - LoopNest::reorderAxis(loops[1], loops[2]); - - StmtPtr s = l.root_stmt(); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - s = IRSimplifier::simplify(s); - - l.prepareForCodegen(); - - s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - std::vector out1(16, -1.f); - SimpleIREvaluator cg(s_, {in, tensor_}); - cg.call({in_, out1}); - - std::vector out2(16, -1.f); - SimpleIREvaluator cg2(s, {in, tensor}); - cg2.call({in_, out2}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out1[i], out2[i]); - } -} - -TEST(Reductions, ReduceRfactor) { - const int M = 10; - const int N = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle b("b", {m, n}, kFloat); - std::vector in(M * N); - for (int j = 0; j < M * N; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 2); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n}); - - cg.call({in, out, M, N}); - ASSERT_EQ(out[0], 4950); -} - -TEST(Reductions, Reduce3DRfactorInner) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("b", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 1); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, Reduce3DRfactorOuter) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("b", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 2); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n, k}); - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReduceRepeatedInternalRfactor) { - BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat); - const int InputSize = 2 * 3 * 4 * 5 * 6; - - std::vector in(InputSize, 1.f); - std::vector out(1, -1.f); - std::vector ref(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6}); - LoopNest orig_loop({c}); - - // Try rfactoring N outer loops - for (const auto rfac_number : c10::irange(1, 5)) { - LoopNest refloop(orig_loop); - LoopNest loop(orig_loop); - refloop.prepareForCodegen(); - SimpleIREvaluator ref_cg( - IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); - ref_cg.call({in, ref}); - - BufPtr tmp_buf = c.buf(); - - for (const auto idx : c10::irange(rfac_number)) { - auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; - ASSERT_TRUE(loop.rfactor( - reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); - } - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, c}); - cg.call({in, out}); - - ASSERT_EQ(ref[0], out[0]); - } -} - -// Split a reduction axis with a tail loop. -TEST(Reductions, ReduceSplitTail) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 8); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis cleanly so there is no tail loop. -TEST(Reductions, ReduceSplitNoTail) { - const int M = 10; - const int N = 10; - const int K = 10; - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 5); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with only a tail loop (the split loop will be size 0 -// and eliminated out). -TEST(Reductions, ReduceOverSplitTail) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 16); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with a mask. -TEST(Reductions, ReduceSplitMask) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 8); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis cleanly not requiring a mask. -TEST(Reductions, ReduceSplitNoMask) { - const int M = 10; - const int N = 10; - const int K = 10; - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 5); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with all logic in the mask. -TEST(Reductions, ReduceOverSplitMask) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 16); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Test an rfactor when there are two ReduceOps in the graph due to a -// splitWithTail. -TEST(Reductions, ReduceSplitRfactor) { - const int M = 2; - const int N = 10; - const int K = 10; - const int SPLIT_FACTOR = 4; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (const auto m : c10::irange(M)) { - for (int j = 0; j < N * K; ++j) { - in[m * N * K + j] = j; - } - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); - - auto c_body = loop.getAllWritesToBuf(c.buf())[2]; - auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); - LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); - all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); - ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = loop.root_stmt(); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - for ([[maybe_unused]] const auto i : c10::irange(M)) { - ASSERT_EQ(out[0], 4950); - } -} - -// Test an rfactor which ends up being eliminated since the total loop size is -// smaller than the split factor. -TEST(Reductions, ReduceOverSplitRfactor) { - const int N = 10; - const int K = 10; - const int SPLIT_FACTOR = 16; - - BufHandle b("b", {N, K}, kFloat); - std::vector in(N * K); - for (int j = 0; j < N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - ForPtr i, t; - LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); - LoopNest::reorderAxis(loops[0], i); - - auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); - - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = loop.root_stmt(); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the IR to verify the rfactored reduce is eliminated. - // TODO: The alloc free should be eliminated here since it is size 0. - /* - const std::string& verification_pattern = - R"IR( -# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] -# CHECK: sum[0] = 0.f; -# CHECK: for (int n = 0; n < 10; n++) { -# CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) { -# CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]); -# CHECK: } -# CHECK: } -# CHECK: Free(tmp_buf);)IR"; - */ - // TODO: rfactor output is not consistent yet, will fix (@nickg). - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Reductions, ReduceInlineReduction) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K}); - Tensor y = Compute( - "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); }); - - PaddedBuffer a_v(M); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - a_v(i) = i * i; - } - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - b_v(i, j, k) = j * j * k; - } - } - } - - LoopNest l1({y}, {x, y}); - // Cannot inline a reduction computation - ASSERT_FALSE(l1.computeInline(x.buf())); -} - -TEST(Reductions, ReduceInlineConsumer) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M, N, K}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n, k) + b_buf.load(m, n, k); - }); - Tensor y = Reduce("y", {M}, Sum(), x, {N, K}); - - PaddedBuffer a_v(M, N, K); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - a_v(i, j, k) = i * i + k; - b_v(i, j, k) = j * j + k; - } - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M); - PaddedBuffer y_2(M); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -TEST(Reductions, ReduceInlineReducerInternal) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M, N, K}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n, k) + b_buf.load(m, n, k); - }); - - Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { - return Add::make(ExprHandle(1.f), Min::make(a, b, false)); - }); - Tensor y = Reduce("y", {M}, minimum, x, {N, K}); - - PaddedBuffer a_v(M, N, K); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - a_v(i, j, k) = i * i + k; - b_v(i, j, k) = j * j + k; - } - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M); - PaddedBuffer y_2(M); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -TEST(Reductions, ReductionCacheAccessesOperatorAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before( - LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[4] -#CHECK: for (int i_2 -#CHECK: d_local[i_2] = 0.f -#CHECK: for (int -#CHECK: for (int -#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[ -#CHECK: } -#CHECK: } -#CHECK: } -#CHECK: for (int i_3 -#CHECK: sum[i_3] = d_local[i_3] -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[1] -#CHECK: sum[i_1] = 0 -#CHECK: d_local[0] = sum[i_1] -#CHECK: for (int j_1 -#CHECK: for (int k_1 -#CHECK: d_local[0] = (d_local[0]) + (scale[ -#CHECK: } -#CHECK: } -#CHECK: sum[i_1] = d_local[0] -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[1] -#CHECK: sum[i_1] = 0 -#CHECK: for (int -#CHECK: d_local[0] = 0 -#CHECK: for (int -#CHECK: d_local[0] = (d_local[0]) + (scale[ -#CHECK: } -#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0]) -#CHECK: } -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheBodyAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(c.buf(), "scale_local", d_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] -#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) { -#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) { -#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1]; -#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]); -#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]); -#CHECK: Free(scale_local); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); - - StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; - l.cacheAccesses(d.buf(), "sum_local", e_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Alias(sum_local,scale); -#CHECK: sum[i_1] = (sum[i_1]) + (scale[ -#CHECK: for (int j_2 = 0; j_2 < 4 -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; -#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionSplitCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - ForPtr inner; - - // Split outer reduction axis. - LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); - - // Split reduction consumer. - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - - l.cacheAccesses(d.buf(), "sum_local", inner); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - // reduction changes but cache does not. - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Alias(sum_local,scale); -#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]); -#CHECK: for (int i_2 = 0; i_2 < 6 -#CHECK: for (int j_2 = 0; j_2 < 4 -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; -#CHECK: for (int j_3 = 0; j_3 < 4 -#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionReorderCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - ForPtr inner; - - // reorder outer reduction axes. - auto loops = l.getLoopStmtsFor(d); - LoopNest::reorderAxis(loops[0], loops[1]); - - // Split reduction consumer. - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - - l.cacheAccesses(d.buf(), "sum_local", inner); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - // neither reduction body not cache changes. - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]); -#CHECK: for (int i_3 = 0; i_3 < 6; -#CHECK: for (int j_2 = 0; j_2 < 4; -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3]; -#CHECK: for (int j_3 = 0; j_3 < 4; -#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionRfactorCacheTempOuter) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("B", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::reorderAxis(loops.at(0), loops.at(1)); - loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - BufPtr rfac_buf; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); - loop.distributeLoop(loops.at(0)); - - auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); - - all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]); - loop.simplify(); - loop.prepareForCodegen(); - StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] -#CHECK: Allocate(tmp); // dtype=float, dims=[n] -#CHECK: for (int i_1 = 0; i_1 < m -#CHECK: for (int j = 0; j < n -#CHECK: tmp[j] = 0 -#CHECK: } -#CHECK: for (int j_1 = 0; j_1 < n -#CHECK: for (int k -#CHECK: tmp[j_1] = (tmp[j_1]) + (B[ -#CHECK: } -#CHECK: } -#CHECK: for (int j_2 = 0; j_2 < n -#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]); -#CHECK: } -#CHECK: Free(tmp); -#CHECK-NOT: tmp - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReductionRfactorCacheTempInner) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("B", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - - LoopNest::reorderAxis(loops.at(0), loops.at(1)); - loops = loop.getLoopStmtsFor(c); - BufPtr rfac_buf; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); - loop.distributeLoop(loops.at(0)); - auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); - - all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]); - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] -#CHECK: Allocate(tmp); // dtype=float, dims=[1] -#CHECK: for (int i_1 = 0; i_1 < m -#CHECK: for (int j = 0; j < n -#CHECK: tmp[0] = 0 -#CHECK: for (int k -#CHECK: tmp[0] = (tmp[0]) + (B[ -#CHECK: } -#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]); -#CHECK: Free(tmp); -#CHECK-NOT: tmp - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReductionVectorize) { - std::vector in_(8 * 8); - for (const auto i : c10::irange(8)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out_before(8, -1.f); - std::vector out_after(8, -1.f); - - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); - LoopNest l_before({tensor}); - LoopNest l(l_before); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); - cg_before.call({in_, out_before}); - - ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); - - StmtPtr s = l.root_stmt(); - s = LoopNest::sanitizeNames(IRSimplifier::simplify(s)); - - std::ostringstream oss; - oss << *s; - const std::string& expected_ir = - R"IR( -#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); -#CHECK: for (int i = 0; i < 8; i++) { -#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i}); -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - // Vectorizing should not change result. - l.prepareForCodegen(); - s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg_after(s, {in, tensor}); - cg_after.call({in_, out_after}); - for (const auto i : c10::irange(8)) { - ASSERT_EQ(out_before[i], out_after[i]); - } -} - -TEST(Reductions, ReductionVectorizeInner) { - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); - LoopNest l({tensor}); - - ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); -} - -TEST(Reductions, ReductionVectorizeRfactor) { - std::vector in_(8 * 8); - for (const auto i : c10::irange(8)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out_before(1, -1.f); - std::vector out_after(1, -1.f); - - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8}); - - LoopNest l_before({tensor}); - LoopNest l(l_before); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); - cg_before.call({in_, out_before}); - - ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); - - // But if we rfactor this so it's not a reduce axis we can vectorize that - // loop. - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::reorderAxis(loops[0], loops[1]); - loops = l.getLoopStmtsFor(tensor); - auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; - BufPtr rfac_buf = nullptr; - ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); - - LoopNest::distributeLoop(loops.at(0)); - auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf); - - ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); - l.simplify(); - - StmtPtr s = LoopNest::sanitizeNames(l.root_stmt()); - - std::ostringstream oss; - oss << *s; - const std::string& expected_ir = - R"IR( -#CHECK: sum = 0.f; -#CHECK: for (int i = 0; i < 8; i++) { -#CHECK: sum_rfac[i] = 0.f; -#CHECK: } -#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) { -#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1}); -#CHECK: } -#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) { -#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2}); -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - // Vectorizing should not change result. - l.prepareForCodegen(); - s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg_after(s, {in, tensor}); - cg_after.call({in_, out_after}); - - ASSERT_EQ(out_before[0], out_after[0]); -} - -TEST(Reductions, InitFunction) { - constexpr int M = 32; - constexpr int N = 16; - BufHandle A("A", {M, N}, kFloat); - BufHandle B("B", {N}, kFloat); - Tensor C = Reduce( - "C", - {N}, - Sum(), - [&](const std::vector& v) { return B.load(v[0]); }, - [&](const std::vector& v) { return A.load(v[1], v[0]); }, - {M}); - LoopNest nest({C}); - nest.prepareForCodegen(); - StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt())); - std::ostringstream oss; - oss << *s << "\n"; - const std::string& expected_ir = - R"IR( -#CHECK: for (int i = 0; i < 16; i++) { -#CHECK: C[i] = B[i]; -#CHECK: for (int j = 0; j < 32; j++) { -#CHECK: C[i] = (C[i]) + (A[i + 16 * j]); -#CHECK: } -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp deleted file mode 100644 index 6cbd04264c32..000000000000 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ /dev/null @@ -1,3702 +0,0 @@ -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/registerizer.h" - -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -// Can replace a simple scalar access with a local variable. -TEST(Registerizer, RegisterizerSimple) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't do replacement of a loop access. -TEST(Registerizer, RegisterizerLoop) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - // No change. - stmt = registerize(stmt); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: A[0] = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: A[x] = -# CHECK-NOT: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't replace even if the load is a fixed scalar, since the store could -// invalidate it. -TEST(Registerizer, RegisterizerLoopFixedLoad) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[0]) + x; - * } - */ - - // No change. - stmt = registerize(stmt); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[0]) + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: A[0] = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: A[x] = -# CHECK-NOT: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// We can registerize accesses that occur entirely within inner scopes, even if -// they depend on the loop var. -TEST(Registerizer, RegisterizerLoopInternal) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * A[x] = (A[x]) + x; - * } - */ - - stmt = registerize(stmt); - - // TODO: the order of terms in addition changes and in general depends on - // some hash value. This results in unpredictable swaps of the operands from - // random changes, which is not great. Ideally, we should ensure some - // specific order (ideally, the original one). - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * A_1 = x + A_1; - * A_1 = x + A_1; - * A[x] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: int A_1 = A[x]; -# CHECK: A_1 = A_1 + x; -# CHECK: A_1 = A_1 + x; -# CHECK: A[x] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An access can be overlapped by another read in the same Expr. In this case -// B[z] and B[y] overlap and prevent registerization of both accesses. -TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (B[y]) + (B[z]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeated) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})) - - }); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = x + (A[1]); - * A[0] = x + (A[1]); - * } - * for (int x = 0; x < 10; x++) { - * A[0] = x + (A[1]); - * A[0] = x + (A[1]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[1]; - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * A_2 = A_1 + x; - * } - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * A_2 = A_1 + x; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[1]; -# CHECK: int A_2 = A[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = A_1 + x; -# CHECK: A_2 = A_1 + x; -# CHECK: } -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = A_1 + x; -# CHECK: A_2 = A_1 + x; -# CHECK: } -# CHECK-NOT: A[1] -# CHECK: A[0] = A_2; -# CHECK-NOT: A[1] -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})) - - }); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = IRSimplifier::simplify(Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), - Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), - Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) - - })); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Will registerize multiple accesses of different items of the same buffer. -TEST(Registerizer, RegisterizerMultiVar) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), - }); - - /* - * A[0] = 0; - * A[1] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * A[1] = (A[1]) - x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * int A_2 = 0; - * for (int x = 0; x < 10; x++) { - * A_2 = x + A_2; - * A_1 = A_1 - x; - * } - * A[1] = A_2; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: int A_2 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A_2 = -# CHECK: A[1] = A_2 -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Will registerize the valid accesses while skipping invalid replacements. -TEST(Registerizer, RegisterizerVariableLoad) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle x2("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make(x, 0, 10, Store::make(b, {x}, x)), - For::make( - x2, - 0, - 10, - Block::make({Store::make( - a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x; - * } - * for (int x_1 = 0; x_1 < 10; x_1++) { - * A[0] = (A[0]) + (B[x_1]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x; - * } - * for (int x_1 = 0; x_1 < 10; x_1++) { - * A_1 = A_1 + (B[x_1]); - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: B[x] = x -# CHECK: for (int x_1 = 0; x_1 < 10; x_1++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize variable accesses so long as the variable does not change. -TEST(Registerizer, RegisterizerSymbolicIndices) { - VarHandle i("i", kInt); - VarHandle N("N", kInt); - BufHandle a("A", {N}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {i}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))}); - - /* - * A[i] = 0; - * for (int x = 0; x < 10; x++) { - * A[i] = (A[i]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[i] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[i] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize accesses dependent on multiple loop vars. -TEST(Registerizer, RegisterizerMultiLoop) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - Block::make({Store::make( - a, - {0}, - Mul::make(Add::make(Load::make(a, {0}), x), y))})))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * A[0] = x * y + (A[0]) * y; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * A_1 = x * y + y * A_1; - * } - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: for (int y = 0; y < 10; y++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize correctly if scalars already exist in the program. -TEST(Registerizer, RegisterizerRepeated) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), - }); - - // Registerize manually to make sure we only replace a single target. - { - registerizer::RegisterizerAnalysis analysis; - stmt->accept(&analysis); - auto candidates = analysis.getCandidates(); - ASSERT_EQ(candidates.size(), 2); - - candidates.pop_back(); - registerizer::RegisterizerReplacer replacer(candidates); - stmt = stmt->accept_mutator(&replacer); - } - - // Re-analyze and replace the second target. - { - registerizer::RegisterizerAnalysis analysis; - stmt->accept(&analysis); - auto candidates = analysis.getCandidates(); - ASSERT_EQ(candidates.size(), 1); - - registerizer::RegisterizerReplacer replacer(candidates); - stmt = stmt->accept_mutator(&replacer); - } - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: int A_1_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A_1_1 = -# CHECK: A[1] = A_1_1; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize the load of A. -TEST(Registerizer, RegisterizerNoLoads) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = x + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + 1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize the load of A but not the store of B. -TEST(Registerizer, RegisterizerNoRepeatedStores) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - // TODO: its unnecessary to reorder the initializer of A[0], but it's not - // actually worse so lets not worry for now. - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: B[x] = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't registerize if there are multiple accesses which may overlap. -TEST(Registerizer, RegisterizerMultiVarOverlap) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})), - }); - stmt = IRSimplifier::simplify(stmt); - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerAllocs) { - BufHandle a("A", {2}, kInt); - BufHandle c("C", {1}, kInt); - VarHandle x("x", kInt); - - BufHandle b("B", {Load::make(c, {0})}, kInt); - - StmtPtr stmt = Block::make( - {Allocate::make(b), - Store::make(a, {0}, Load::make(c, {0})), - Store::make(b, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)), - Store::make(a, {0}, Load::make(c, {0}))})), - Free::make(b)}); - - /* - * Allocate(B, int, {C[0]}); - * A[0] = C[0]; - * B[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[0] = (B[0]) + x; - * A[0] = C[0]; - * } - * Free(B); - */ - - stmt = registerize(stmt); - - /* - * int C_1 = C[0]; - * Allocate(B, int, {C_}); - * int A_1 = C_1; - * int B_1 = 0; - * for (int x = 0; x < 10; x++) { - * B_1 = B_1 + x; - * A_1 = C_1; - * } - * B[0] = B_1; - * A[0] = A_1; - * Free(B); - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int C_1 = C[0]; -# CHECK: Allocate(B -# CHECK: int A_1 = C_1; -# CHECK: int B_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: B_1 = -# CHECK: A_1 = C_ -# CHECK: B[0] = B_1; -# CHECK: A[0] = A_1; -# CHECK: Free(B)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNoInitializer) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNoInitializerLoopVar) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoadThenStore) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {0}, Load::make(b, {0}))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * B[0] = (A[0]) + x; - * A[0] = B[0]; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * int B_1 = B[0]; - * for (int x = 0; x < 10; x++) { - * B_1 = x + A_1; - * A_1 = B_1; - * } - * B[0] = B_1; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: int B_1 = B[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: B[ -# CHECK: B_1 = -# CHECK-NOT: A[ -# CHECK: A_1 = B_ -# CHECK: B[0] = B_ -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerParallelized) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - LoopOptions loopOpts; - loopOpts.set_gpu_block_index(0); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}), - loopOpts)}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - ASSERT_THROWS_WITH( - registerize(stmt), - "Registerization must occur after parallelism flattening"); -} - -// Should be able to registerize this since the scalar would exist before the -// branch. -TEST(Registerizer, RegisterizerConditionAfter) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * C[x] = A_1; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Should be able to registerize this since the scalar exists in the same form -// after the branch and there is no overlap. -TEST(Registerizer, RegisterizerConditionBefore) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x}))}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * A[x] = B[x]; - * C[x] = A[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_ 1 = A[x]; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A_1 = B[x]; - * C[x] = A_1; - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Should be able to registerize this as the combination of the two above rules. -TEST(Registerizer, RegisterizerConditionInside) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * A[x] = C[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * C[x] = A_1; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * B[x] = A_1; - * A_1 = C[x]; - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: B[x] = A_1; -# CHECK: A_1 = C[x]; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An example where an access is cut by an overlapping access inside a -// condition, and both sides are large enough to be registerized but cannot be -// because there is no safe place to put the initializer or finalizer. -TEST(Registerizer, RegisterizerConditionInsideOverlap1) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({ - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {0}, 3), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - }), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * A[x] = C[x]; - */ - - // The A[0] store overlaps, A[x] cutting the region that can be registerized - // into two groups. - // Each group has 2 loads and 2 stores however, so we could registerize it, - // but the first group would need to be finalized inside the condition block, - // the second would need to be initialized inside the condition block. There's - // no safe place to put these that's visible to the other uses in the group - // and so neither registerization is possible. - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Same as the above, but the access group before the condition (and after the -// condition) are large enough to be registerized without needing the access -// from the loop. Registerization occurs but does not include any accesses in -// the condition, and the first group must be finalized before the Cond, the -// second initialized after it. -TEST(Registerizer, RegisterizerConditionInsideOverlap2) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(a, {x}, Load::make(b, {x + 1})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({ - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {0}, 3), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - }), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(b, {x + 1}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * A[x] = B[x + 1]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * B[x + 1] = A[x]; - * A[x] = C[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; // A_1 initializer - * A_1 = B[x + 1]; // - * C[x] = A_1; // - * A[x] = A_1; // A_1 finalizer - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * int A_2 = A[x]; // A_2 initializer - * B[x] = A_2; // - * B[x + 1] = A_2; // - * A_2 = C[x]; // - * A[x] = A_2; // A_2 finalizer - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: A_1 = B[x + 1]; -# CHECK: C[x] = A_1; -# CHECK: A[x] = A_1; -# CHECK: if ( -# CHECK-NOT: A_1 = A_1 + 1; -# CHECK: A[x] = (A[x] -# CHECK: A[0] = -# CHECK: A[x] = (A[x] -# CHECK: } -# CHECK: int A_2 = A[x]; -# CHECK: B[x] = A_2; -# CHECK: B[x + 1] = A_2; -# CHECK: A_2 = C[x]; -# CHECK: A[x] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// When accesses are within conditional blocks they are not visible to the wider -// program, because we don't know if the branch would be taken and if it isn't -// the accesses in it don't need to be valid (think size checks on the index). -// In this case the accesses cannot be registerized. -TEST(Registerizer, RegisterizerConditionHidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * if (x>5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// But... if the same access is found in a non conditional scope, that means -// that that access is valid in the higher scope (or at least if its not it's -// the user's fault). It "unhides" the conditional accesses, allowing -// registerization to occur. -TEST(Registerizer, RegisterizerConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * A[x] = (A[x]) + 1; <-- this is doing the unhiding. - * if (x>5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A_1 = A_1 + 1; - * if (x>5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if (x<5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x>5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a load that occurs in the condition of a Cond. -TEST(Registerizer, RegisterizerCondCondition) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), - nullptr)}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if ((A[x])<5 ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * int C_1 = A_1; - * if (A_1<5 ? 1 : 0) { - * C_1 = C_1 + 1; - * } - * C[x] = C_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: int C_1 = A_1; -# CHECK: if (A_1<5 -# CHECK: C_1 = C_1 + 1; -# CHECK: C[x] = C_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Appearing in the condition of a Cond makes it visible to the enclosing scope, -// and so we can registerize internal usages. -TEST(Registerizer, RegisterizerCondConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); - - /* - * if ((A[x])<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } else { - * A[x] = (A[x]) + 10; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * if (A_1<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } else { - * A_1 = A_1 + 10; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if (A_1<5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } else { -# CHECK: A_1 = A_1 + 10; -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Conditional hiding also works for IfThenElse exprs. -TEST(Registerizer, RegisterizerIfThenElseHidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - {Store::make( - b, - {y}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - Store::make( - b, - {y + 1}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2)))}); - - /* - * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Conditional unhiding also works for IfThenElse exprs. -TEST(Registerizer, RegisterizerIfThenElseUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make({ - Store::make(a, {x}, 0), - Store::make( - b, - {y}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - Store::make( - b, - {y + 1}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - }); - - /* - * A[x] = 0; - * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); -# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Nested IfThenElse exprs can't promote to higher level scopes. -TEST(Registerizer, RegisterizerIfThenElseNested) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - BufHandle d("D", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - IfThenElse::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Load::make(d, {x}), - Load::make(b, {x})), - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kEQ), - Load::make(c, {x}), - Load::make(d, {x}))))}); - - /* - * A[x] = IfThenElse(x<3 ? 1 : 0, - * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), - * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Cannot registerize an access completely contained within an IfThenElse -// branch, since it is not a Stmt and cannot hold variable definitions. We need -// to check that we don't promote the initializer/finalizer to the enclosing -// Block. -TEST(Registerizer, RegisterizerIfThenElseInternal) { - // Making these floats so they don't get simplified to a single access. - BufHandle a("A", {5}, kFloat); - BufHandle b("B", {5}, kFloat); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Add::make(Load::make(b, {x}), Load::make(b, {x})), - Load::make(b, {x})))}); - - /* - * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // If this was a Cond instead of an IfThenElse then we could registerize the - // two accesses to B[x] in the True branch. - - // Actually lets verify that. - - stmt = Block::make({Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))), - Store::make(a, {x}, Load::make(b, {x})))}); - - /* - * if (x<3 ? 1 : 0) { - * A[x] = (B[x]) + (B[x]); - * } else { - * A[x] = B[x]; - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<3 ? 1 : 0) { - * float B_1 = B[x]; - * A[x] = B_1 + B_1; - * } else { - * A[x] = B[x]; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK-NOT: float -# CHECK: if (x<3 -# CHECK: float B_1 = -# CHECK: A[x] = B_1 + B_1 -# CHECK: } else { -# CHECK: A[x] = B[x] -# CHECK: } -# CHECK-NOT: A[x] -# CHECK-NOT: B[x])IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a load that occurs in the condition of an IfThenElse; -TEST(Registerizer, RegisterizerIfThenElseCondition) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(a, {x})), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Load::make(b, {0}), - Load::make(c, {0})))}); - - /* - * A[x] = A[x]; <---- just here so there are enough accesses to combine. - * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * A_1 = A_1; - * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Appearing in the condition of a Cond makes it visible to the enclosing scope, -// and so we can registerize internal usages. -TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - b, - {x}, - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x}), 10)))}); - - /* - * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Cannot promote accesses internal to IfThenElse branches even if the enclosing -// scope if conditional. -TEST(Registerizer, RegisterizerConditionBranchOnly) { - BufHandle a("A", {5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({ - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), x), - Add::make(Load::make(a, {x - 5}), x))), - Store::make( - a, - {x - 5}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), x), - Add::make(Load::make(a, {x - 5}), x)))), - }))}); - stmt = IRSimplifier::simplify(stmt); - - std::ostringstream before; - before << *stmt; - - /* for (int x = 0; x < 10; x++) { - * if (x<5 ? 1 : 0) { - * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } else { - * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } - * } - */ - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// We can registerize an IfThenElse that appears in the condition branch of a -// Cond. This is a weird but valid thing to do. -TEST(Registerizer, RegisterizerCondIfThenElse) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make( - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(b, {x})), - x, - CompareSelectOperation::kEQ), - Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), - nullptr)}); - - /* - * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - // access to A can be registerized, but not B or C - - /* - * int A_1 = A[x]; - * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] -# CHECK: C[x] = (C[x]) + 1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a conditional access in the RHS of a store unhidden by it's -// LHS, and hoist it out of a loop. -TEST(Registerizer, RegisterizerIfThenElseLoop) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = For::make( - y, - 0, - 10, - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(b, {y})))); - - /* - * for (int y = 0; y < 10; y++) { - * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * for (int y = 0; y < 10; y++) { - * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: for ( -# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Cannot registerize if the RHS overlaps the access creating visibility. -TEST(Registerizer, RegisterizerIfThenElseLoopCut) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make({For::make( - y, - 0, - 10, - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(a, {y}))))}); - - /* - * for (int y = 0; y < 10; y++) { - * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Simple case where an access is cut by an overlapping access later in the -// program, we can registerize up until the overlap. -TEST(Registerizer, RegisterizerPartialAfter) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})), - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x] = A[x - 1]; -# CHECK: } -# CHECK-NOT: A)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// We can registerize an access which overlaps a previous access, the -// initializer must be inserted after the previous access. -TEST(Registerizer, RegisterizerPartialBefore) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), - Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: for ( -# CHECK: A[x] = A[x - 1]; -# CHECK: } -# CHECK: int A_1 = 0; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// The combination of the previous two tests, an access is cut by an overlapping -// access in both directions. -TEST(Registerizer, RegisterizerPartialInside) { - BufHandle a("A", {1}, kInt); - VarHandle x1("x1", kInt); - VarHandle x2("x2", kInt); - VarHandle x3("x3", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 2), - For::make( - x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), - For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))), - For::make( - x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))}); - - /* - * A[0] = 2; - * for (int x1 = 0; x1 < 10; x1++) { - * A[0] = (A[0]) + x1; - * } - * for (int x2 = 1; x2 < 10; x2++) { - * A[x2] = A[x2 - 1]; - * } - * for (int x3 = 0; x3 < 10; x3++) { - * A[0] = (A[0]) + x3; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 2; - * for (int x1 = 0; x1 < 10; x1++) { - * A_1 = A_1 + x1; - * } - * A[0] = A_1; - * for (int x2 = 1; x2 < 10; x2++) { - * A[x2] = A[x2 - 1]; - * } - * int A_2 = A[0]; - * for (int x3 = 0; x3 < 10; x3++) { - * A_2 = A_2 + x3; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 2; -# CHECK: for ( -# CHECK: A_1 = A_1 + x1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x2] = -# CHECK: } -# CHECK: int A_2 = A[0]; -# CHECK: for ( -# CHECK: A_2 = A_2 + x3; -# CHECK: } -# CHECK: A[0] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An element could be registerized program wide but is cut by a conditional -// access, we should break this into two scalars and write back to the buffer -// before the condition. -TEST(Registerizer, RegisterizerPartialCondition) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 2), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Load::make(a, {x - 1})), - nullptr), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))}); - - /* - * A[0] = 2; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - * if (x<5 ? 1 : 0) { - * A[x] = A[x - 1]; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 2; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - * if (x<5 ? 1 : 0) { - * A[x] = A[x - 1]; - * } - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + x; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 2; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: if ( -# CHECK: A[x] = -# CHECK: } -# CHECK: int A_2 = A[0]; -# CHECK: for ( -# CHECK: A_2 = A_2 + x; -# CHECK: } -# CHECK: A[0] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Tests case where an access is cut by an internal conditional access which -// itself is registerized. -TEST(Registerizer, RegisterizerPartialConditionInternalCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 1), - Store::make(a, {0}, 3), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), - nullptr), - Store::make(a, {0}, 4), - Store::make(a, {0}, 6)}); - - /* - * A[0] = 1; - * A[0] = 3; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * A[x] = 3; - * } - * A[0] = 4; - * A[0] = 6; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 1; - * A_1 = 3; - * A[0] = A_1; - * if (x<5 ? 1 : 0) { - * int A_2 = 1; - * A_2 = 3; - * A[x] = A_2; - * } - * int A_3 = 4; - * A_3 = 6; - * A[0] = A_3; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 1; -# CHECK: A_1 = 3 -# CHECK: A[0] = A_1; -# CHECK: if ( -# CHECK: int A_2 = 1; -# CHECK: A_2 = 3; -# CHECK: A[x] = A_2; -# CHECK: } -# CHECK: int A_3 = 4; -# CHECK: A_3 = 6; -# CHECK: A[0] = A_3;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// First statement in condition closes outer access, but can be registerized -// with later statements. -TEST(Registerizer, RegisterizerPartialConditionInternalStart) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 1), - Store::make(a, {0}, 3), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), - nullptr), - Store::make(a, {x}, 4), - Store::make(a, {x}, 6)}); - - /* - * A[0] = 1; - * A[0] = 3; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * A[x] = 3; - * } - * A[x] = 4; - * A[x] = 6; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 1; - * A_1 = 3; - * A[0] = A_1; - * int A_2 = A[x]; <--- must read from the input here. - * if (x<5 ? 1 : 0) { - * A_2 = 1; - * A_2 = 3; - * } - * A_2 = 4; - * A_2 = 6; - * A[x] = A_2; - */ - - // TODO: I suppose we could refactor with a conditional initializer? - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 1; -# CHECK: A_1 = 3 -# CHECK: A[0] = A_1; -# CHECK: int A_2 = A[x]; -# CHECK: if ( -# CHECK: A_2 = 1; -# CHECK: A_2 = 3; -# CHECK: } -# CHECK: A_2 = 4; -# CHECK: A_2 = 6; -# CHECK: A[x] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An access cuts two open overlaps and creates four scalar variables. -TEST(Registerizer, RegisterizerPartialOverlapsTwo) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {1}, Load::make(a, {0})), - Store::make(a, {0}, Load::make(a, {1})), - Store::make(a, {0}, Load::make(a, {1})), - For::make(x, 1, 10, Store::make(a, {x}, x)), - Store::make(a, {1}, Load::make(a, {0})), - Store::make(a, {0}, Load::make(a, {1})), - Store::make(a, {0}, Load::make(a, {1}))}); - - /* - * A[1] = A[0]; - * A[0] = A[1]; - * A[0] = A[1]; - * for (int x = 1; x < 10; x++) { - * A[x] = x; - * } - * A[1] = A[0]; - * A[0] = A[1]; - * A[0] = A[1]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * int A_2 = A_1; - * A_1 = A_2; - * A_1 = A_2; - * A[1] = A_2; - * A[0] = A_1; - * for (int x = 1; x < 10; x++) { - * A[x] = x; - * } - * int A_3 = A[0]; - * int A_4 = A_3; - * A_3 = A_4; - * A_3 = A_4; - * A[1] = A_4; - * A[0] = A_3; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: int A_2 = A_1; -# CHECK: A_1 = A_2; -# CHECK: A_1 = A_2; -# CHECK: A[1] = A_2; -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x] = x; -# CHECK: } -# CHECK: int A_3 = A[0]; -# CHECK: int A_4 = A_3; -# CHECK: A_3 = A_4; -# CHECK: A_3 = A_4; -# CHECK: A[1] = A_4; -# CHECK: A[0] = A_3;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Nested blocks will automatically be flattened and do not provent -// registerization of enclosed accesses. -TEST(Registerizer, RegisterizerNestedBlocks) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})}); - - /* - * A[0] = (A[0]) + 1; - * { - * A[0] = (A[0]) + 2; - * } - * { - * A[0] = (A[0]) + 3; - * { - * A[0] = (A[0]) + 4; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * A_1 = A_1 + 2; - * A_1 = A_1 + 3; - * A_1 = A_1 + 4; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: A_1 = A_1 + 2; -# CHECK: A_1 = A_1 + 3; -# CHECK: A_1 = A_1 + 4; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// The access can be registerized internally to a condition, but must ensure -// that both initializer and finalizer are within the same condition. -TEST(Registerizer, RegisterizerNestedConditions) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * if (x==2 ? 1 : 0) { - * - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * if (x==2 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x==2 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// If an access exists outside the scope of the condition then we can lift -// nested conditional usages into the same scalar. -TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {1}, 1), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * A[0] = (A[0]) + 1; - * if (x<5 ? 1 : 0) { - * A[1] = 1; - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * if (x<5 ? 1 : 0) { - * A[1] = 1; - * if (x==2 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x<5 -# CHECK: A[1] = 1; -# CHECK: if (x==2 -# CHECK: A_1 = A_1 + 1; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * if (x<5 ? 1 : 0) { - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - stmt = registerize(stmt); -} - -TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - stmt = registerize(stmt); -} - -// If an access is cut by another access internal to a condition block, it still -// cuts the access. -TEST(Registerizer, RegisterizerNestedConditionsCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {x}, 1), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * A[0] = (A[0]) + 1; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * if (x==2 ? 1 : 0) { - * - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}))}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * for (int x = 0; x < 10; x++) { - * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Three loops and four element regions, three of which should be registerized -// at different levels of the IR. -TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {4}, 0), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kGT), - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kGT), - Block::make({ - Cond::make( - CompareSelect::make(x, 4, CompareSelectOperation::kGT), - Block::make({ - Store::make( - a, {1}, Add::make(Load::make(a, {1}), 1)), - Store::make( - a, {2}, Add::make(Load::make(a, {2}), 1)), - Store::make( - a, {3}, Add::make(Load::make(a, {3}), 1)), - Store::make( - a, {4}, Add::make(Load::make(a, {4}), 1)), - Store::make( - a, {1}, Add::make(Load::make(a, {1}), 1)), - }), - nullptr), - Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)), - }), - nullptr), - nullptr)}); - - /* - * A[4] = 0; - * if (x>2 ? 1 : 0) { - * if (x>3 ? 1 : 0) { - * if (x>4 ? 1 : 0) { - * A[1] = (A[1]) + 1; - * A[2] = (A[2]) + 1; - * A[3] = (A[3]) + 1; - * A[4] = (A[4]) + 1; - * A[1] = (A[1]) + 1; - * } - * A[2] = (A[2]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * if (x>2 ? 1 : 0) { - * if (x>3 ? 1 : 0) { - * int A_3 = A[2]; - * if (x>4 ? 1 : 0) { - * int A_2 = A[1]; - * A_2 = A_2 + 1; - * A_3 = A_3 + 1; - * A[3] = (A[3]) + 1; - * A_1 = A_1 + 1; - * A_2 = A_2 + 1; - * A[1] = A_2; - * } - * A_3 = A_3 + 1; - * A[2] = A_3; - * } - * } - * A[4] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: if (x>2 ? 1 : 0) { -# CHECK: if (x>3 ? 1 : 0) { -# CHECK: int A_3 = A[2]; -# CHECK: if (x>4 ? 1 : 0) { -# CHECK: int A_2 = A[1]; -# CHECK: A_2 = A_2 + 1; -# CHECK: A_3 = A_3 + 1; -# CHECK: A[3] = (A[3]) + 1; -# CHECK: A_1 = A_1 + 1; -# CHECK: A_2 = A_2 + 1; -# CHECK: A[1] = A_2; -# CHECK: } -# CHECK: A_3 = A_3 + 1; -# CHECK: A[2] = A_3; -# CHECK: } -# CHECK: } -# CHECK: A[4] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can replace a simple scalar access with a local variable even when that -// variable is an outer loop var. -TEST(Registerizer, RegisterizerNestedLoopSimple) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({For::make( - y, - 0, - 10, - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))}); - - /* - * for (int y = 0; y < 10; y++) { - * for (int x = 0; x < 10; x++) { - * A[y] = (A[y]) + x; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * for (int y = 0; y < 10; y++) { - * int A_1 = A[y]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[y] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int y -# CHECK: int A_1 = A[y]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[y] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Test the positive case of the hiddenAccess split, where an internal -// conditional access can be hoisted up through a loop to match an existing -// access in a higher scope and the two can be registerized. -TEST(Registerizer, RegisterizerHiddenAccessYes) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kEQ), - For::make( - y, - 0, - 10, - Store::make( - a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A[0] = (A[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A_1 = A_1 + 1; - * } - * } - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: int A_1 = 0; -# CHECK: for (int x -# CHECK: B[x] = 0; -# CHECK: if (x==3 -# CHECK: for (int y -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Test the negative case of the hiddenAccess split, where the hoisted access is -// never unhidden at a higher scope and registerization occurs at the lower -// scope. -TEST(Registerizer, RegisterizerHiddenAccessNo) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kEQ), - For::make( - y, - 0, - 10, - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A[0] = (A[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * int A_1 = A[0]; - * for (int y = 0; y < 10; y++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * } - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: for (int x -# CHECK: B[x] = 0; -# CHECK: if (x==3 -# CHECK: int A_1 = A[0]; -# CHECK: for (int y -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: } -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// In this case the conditional access must be hoisted by two loops, there are -// two accesses here one is unhidden and the other isn't. A[0] can be -// registerized but B[0] cannot. -TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - Block::make({Cond::make( - CompareSelect::make(y, 3, CompareSelectOperation::kEQ), - Block::make( - {Store::make( - a, {0}, Add::make(Load::make(a, {0}), 1)), - Store::make( - b, {0}, Add::make(Load::make(b, {0}), 1))}), - nullptr)})))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * if (y==3 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * B[0] = (B[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * if (y==3 ? 1 : 0) { - * A_1 = A_1 + 1; - * B[0] = (B[0]) + 1; - * } - * } - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: int A_1 = 0; -# CHECK: for (int x -# CHECK: for (int y -# CHECK: if (y==3 -# CHECK: A_1 = A_1 + 1; -# CHECK: B[0] = (B[0]) + 1; -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Accesses are registerized inside two conditions, but the immediate parent is -// not a condition. -TEST(Registerizer, RegisterizerTwoConditionalLoops) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - * if (x>5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * if (x>5 ? 1 : 0) { - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + 1; - * } - * A[0] = A_2; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: if (x>5 -# CHECK: int A_2 = A[0]; -# CHECK: for (int x -# CHECK: A_2 = A_2 + 1; -# CHECK: } -# CHECK: A[0] = A_2; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Accesses are registerized inside two conditions, cut in the middle. -TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr), - For::make(x, 0, 10, Store::make(a, {x}, 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - * for (int x = 0; x < 10; x++) { - * A[x] = 1; - * } - * if (x>5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * for (int x = 0; x < 10; x++) { - * A[x] = 1; - * } - * if (x>5 ? 1 : 0) { - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + 1; - * } - * A[0] = A_2; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: for (int x -# CHECK: A[x] = 1; -# CHECK: if (x>5 -# CHECK: int A_2 = A[0]; -# CHECK: for (int x -# CHECK: A_2 = A_2 + 1; -# CHECK: } -# CHECK: A[0] = A_2; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// references a Let var in a local scope which cannot be hoisted out of the -// loop. -TEST(Registerizer, RegisterizerLoopLetVar) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( - x, - 0, - 10, - Block::make( - {Let::make(y, 30), - Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); - - /* - * for (int x = 0; x < 10; x++) { - * int y = 30; - * A[y] = x + (A[y]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// references a Let var in an outer scope that does not prevent hoisting the -// initializer. -TEST(Registerizer, RegisterizerLoopLetVarOuter) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Let::make(y, 30), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); - - /* - * int y = 30; - * for (int x = 0; x < 10; x++) { - * A[y] = x + (A[y]); - * } - */ - - stmt = registerize(stmt); - - /* - * int y = 30; - * int A_1 = A[y]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[y] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int y = 30; -# CHECK: int A_1 = A[y]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + x; -# CHECK: A[y] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Okay so the registerizer generally goes after index flattening, but just in -// case. Test multi index registerization. -TEST(Registerizer, RegisterizerMultiDim) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))}); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, 1, 2] = (A[0, 1, 2]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0, 1, 2] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0, 1, 2] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't registerize if only some dims match, but will still registerize -// distinct elements. -TEST(Registerizer, RegisterizerMultiDimPartial) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))}); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, 2, 2] = (A[0, 1, 4]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * A[0, 1, 2] = 0; - * int A_1 = A[0, 1, 4]; - * int A_2 = A[0, 2, 2]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * } - * A[0, 2, 2] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: A[0, 1, 2] = 0; -# CHECK: int A_1 = A[0, 1, 4]; -# CHECK: int A_2 = A[0, 2, 2]; -# CHECK: for ( -# CHECK: A_2 = A_1 + x; -# CHECK: A[0, 2, 2] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// If they could overlap across all dimensions we cannot registerize. -TEST(Registerizer, RegisterizerMultiDimOverlap) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = (A[y, 2, 2]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// But, if one dimension is known to be distinct they do not overlap. -TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))}); - - /* - * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. - * } - */ - - stmt = registerize(stmt); - - /* - * A[0, 1, 2] = 0; - * int A_1 = A[y, 2, 4]; - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = A_1 + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: A[0, 1, 2] = 0; -# CHECK: int A_1 = A[y, 2, 4]; -# CHECK: for ( -# CHECK: A[0, x, 2] = A_1 + x; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// A 3D reduction with different input dimensionality. -TEST(Registerizer, RegisterizerMultiDim3DReduction1) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10, 10}, kInt); - BufHandle c("C", {10, 10, 10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - For::make( - z, - 0, - 10, - Store::make( - c, - {x, y, z}, - Add::make( - Load::make(c, {x, y, z}), - Mul::make(Load::make(b, {x, y}), Load::make(a, {x}))))))); - - /* - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 10; z++) { - * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); - * } - * } - * } - */ - - // We can registerize the A and B access since they can be hoisted before - // hitting a dependent loop var. - - stmt = registerize(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * for (int y = 0; y < 10; y++) { - * int B_1 = B[x, y]; - * for (int z = 0; z < 10; z++) { - * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); - * } - * } - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x -# CHECK: int A_1 = A[x]; -# CHECK: for (int y -# CHECK: int B_1 = B[x, y]; -# CHECK: for (int z -# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// A 3D reduction with the same smaller dimensionality using different loop -// vars. -TEST(Registerizer, RegisterizerMultiDim3DReduction2) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = For::make( - x, - 0, - 10, - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - For::make( - y, - 0, - 10, - For::make( - z, - 0, - 10, - Store::make( - c, - {x}, - Add::make( - Load::make(c, {x}), - Mul::make(Load::make(b, {y}), Load::make(a, {x}))))))); - - /* - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 10; z++) { - * C[x] = (C[x]) + (B[y]) * (A[x]); - * } - * } - * } - */ - - // We can registerize all accesses, the A and C access can be hoisted to the - // outer loop since they depend only on it's loop var while the B can only be - // raised to the loop of y. - - stmt = registerize(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * int C_1 = C[x]; - * for (int y = 0; y < 10; y++) { - * int B_1 = B[y]; - * for (int z = 0; z < 10; z++) { - * C_1 = A_1 * B_1 + C_1; - * } - * } - * C[x] = C_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x -# CHECK: int A_1 = A[x]; -# CHECK: int C_1 = C[x]; -# CHECK: for (int y -# CHECK: int B_1 = B[y]; -# CHECK: for (int z -# CHECK: C_1 = A_1 * B_1 + C_1; -# CHECK: } -# CHECK: } -# CHECK: C[x] = C_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp deleted file mode 100644 index 7ca2b74eaa76..000000000000 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ /dev/null @@ -1,5680 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; -using SimpleIRExprEval = ExprEval; - -TEST(Simplify, ConstantFoldSimple) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle f = (a + b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 5); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 5.f); -} - -TEST(Simplify, ConstantFoldTwoLayer) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), -4); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), -4.f); -} - -TEST(Simplify, ConstantFoldShifts) { - ExprHandle a(7); - ExprHandle b(2); - ExprHandle c(3); - ExprHandle f = ((a << b) << b) >> c; - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 14); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 7 << (4 - 3)); -} - -TEST(Simplify, ConstantFoldBitwise) { - ExprHandle a(59); - ExprHandle b(22); - ExprHandle c(101); - ExprHandle f = (a ^ b) & c; - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 37); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), (59 ^ 22) & 101); -} - -TEST(Simplify, ConstantFoldMultiOp) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle e(6.0f); - ExprHandle f(7.0f); - ExprHandle fn = ((a / e) - (c + d)) * (f / b); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - - SimpleIRExprEval eval(newF); - SimpleIRExprEval ref(fn); - - ASSERT_EQ(eval.value(), ref.value()); -} - -TEST(Simplify, ConstantFoldMinMax) { - ExprHandle a(12.0f); - ExprHandle b(15.0f); - ExprHandle c(17.0f); - - // x = max(12, min(15, 17)). - ExprHandle minHandle = Min::make(b, c, true); - ExprHandle fn = Max::make(a, minHandle, false); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 15.f); -} - -TEST(Simplify, ConstantFoldIntrinsics) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle powHandle = Intrinsics::make(kPow, a, b); - ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); - ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); - ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); - ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); - ExprHandle fn = Intrinsics::make(kAbs, rndHandle); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - SimpleIRExprEval ref(fn); - - ASSERT_EQ(eval.value(), ref.value()); -} - -TEST(Simplify, ConstantFoldCastToBool) { - ExprHandle f = Cast::make(kBool, IntImm::make(0)); - ExprHandle newF = IRSimplifier::simplify(f); - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), false); -} - -TEST(Simplify, ConstantFoldWithVar) { - { - VarHandle x("x", kInt); - ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); - - ExprHandle newF = IRSimplifier::simplify(body); - MulPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_NE(to(root->lhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3)); - ASSERT_EQ(eval.value(), 3 * (2 + 4)); - } - - { - VarHandle x("x", kFloat); - ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); - - ExprHandle newF = IRSimplifier::simplify(body); - MulPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_NE(to(root->rhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 3 * (2 + 4)); - } -} - -TEST(Simplify, ConditionalSelectFoldSimple) { - ExprHandle a(3.0f); - ExprHandle b(4.0f); - ExprHandle c(3.0f); - { - ExprHandle f = (a > b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } - { - ExprHandle f = (a < b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a == c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a != c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, ConditionalSelectFoldTwoLayer) { - ExprHandle a(3.0f); - ExprHandle b(2.0f); - ExprHandle c(2.0f); - ExprHandle d(1.0f); - { - ExprHandle f = (a + b < c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } - { - ExprHandle f = (a + b > c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a + d == b + c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a + d != b + c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, ConditionalSelectFoldWithVar) { - VarHandle x("x", kFloat); - ExprHandle f = x < 4.f; - - ExprHandle newF = IRSimplifier::simplify(f); - IntImmPtr folded = newF.AsNode(); - ASSERT_EQ(folded, nullptr); - - { - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 1); - } - { - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(5.f)); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, UnFoldableExpr) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); - - ExprHandle newF = IRSimplifier::simplify(body); - AddPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_EQ(to(root->lhs()), nullptr); - ASSERT_EQ(to(root->rhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - eval.bindVar(y, ExprHandle(2.f)); - ASSERT_EQ(eval.value(), 9 + 10); -} - -TEST(Simplify, HashSimple) { - VarHandle x("x", kFloat); - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle f = a + b * x; - - HashProvider hasher; - - auto hash_x = hasher.hash(x.node()); - auto hash_a = hasher.hash(a.node()); - auto hash_f = hasher.hash(f.node()); - - ASSERT_NE(hash_x, (size_t)0); - ASSERT_NE(hash_a, (size_t)0); - ASSERT_NE(hash_f, (size_t)0); - ASSERT_NE(hash_x, hash_a); - ASSERT_NE(hash_x, hash_f); - ASSERT_NE(hash_a, hash_f); -} - -TEST(Simplify, HashEquivalence) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle f = (x * y) + (x * y); - - AddPtr root = f.AsNode(); - ASSERT_NE(root, nullptr); - - HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); - - // Root not equal to either branch. - ASSERT_NE(hash_f, hash_l); - ASSERT_NE(hash_f, hash_r); - // but branches are equal. - ASSERT_EQ(hash_l, hash_r); - - // Still equivalent if separate. - ExprHandle a(2); - ExprHandle f2 = x + a / y; - ExprHandle b(2); - ExprHandle f3 = x + b / y; - ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); - - // Not equivalent if different vars (even with same name). - VarHandle z("x", kFloat); - ExprHandle f4 = z + b / y; - ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); - - // Intrinsics sanity check. - ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); - ASSERT_NE(hasher.hash(f5.node()), (size_t)0); -} - -TEST(Simplify, HashEquivalenceRand) { - ExprHandle f = - Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); - - AddPtr root = f.AsNode(); - ASSERT_NE(root, nullptr); - - HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); - - // Root not equal to either branch. - ASSERT_NE(hash_f, hash_l); - ASSERT_NE(hash_f, hash_r); - // and branches are NOT equal. - ASSERT_NE(hash_l, hash_r); -} - -TEST(Simplify, HashEquivalenceAfterFolding) { - VarHandle x("x", kFloat); - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(5.0f); - - ExprHandle f1 = ((a + b) * x); - ExprHandle f2 = (c * x); - - HashProvider hasher; - auto hash_l = hasher.hash(f1.node()); - auto hash_r = hasher.hash(f2.node()); - - // Root not equal to either branch, and branches not equal. - ASSERT_NE(hash_l, hash_r); - - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); - - auto hash_l_n = hasher.hash(ff1.node()); - auto hash_r_n = hasher.hash(ff2.node()); - // but branches are now equal. - ASSERT_EQ(hash_l_n, hash_r_n); -} - -TEST(Simplify, HashDifferenceTypes) { - HashProvider hasher; - std::vector immediates; - - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - // NOLINTNEXTLINE(modernize-use-bool-literals) - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - - // Immediates of different types are not equal. - for (unsigned int i = 0; i < immediates.size(); ++i) { - for (unsigned int j = i + 1; j < immediates.size(); ++j) { - ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); - } - } - - // But coerced immediates are if they are the same type: - ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); - ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); - - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); - - ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); -} - -TEST(Simplify, HashLargeExpression) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto memcpy_stmt = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - CompareSelect::make( - Load::make(a, {i}), - Load::make(b, {i}), - CompareSelectOperation::kEQ))); - - BufHandle d("D", {1}, kInt); - BufHandle e("E", {1}, kInt); - auto store_ramp_stmt = Store::make( - e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); - - auto if_stmt = Cond::make( - CompareSelect::make( - Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), - memcpy_stmt, - store_ramp_stmt); - - HashProvider hasher; - auto hash_r = hasher.hash(if_stmt); - // We should not have to do any more work. - ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); - auto hash_t = hasher.hash(memcpy_stmt); - ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); - auto hash_f = hasher.hash(store_ramp_stmt); - - // Root not equal to either branch, and branches not equal. - ASSERT_NE(hash_r, hash_t); - ASSERT_NE(hash_r, hash_f); - ASSERT_NE(hash_t, hash_f); -} - -TEST(Simplify, HashForLoopOptions) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto for_stmt = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - CompareSelect::make( - Load::make(a, {i}), - Load::make(b, {i}), - CompareSelectOperation::kEQ))); - - HashProvider hasher; - auto hash_before = hasher.hash(for_stmt); - hasher.clearCache(); - - for_stmt->set_gpu_block_index(LoopOptions::IDX_X); - auto hash_block_idx = hasher.hash(for_stmt); - hasher.clearCache(); - - ASSERT_NE(hash_before, hash_block_idx); - - for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); - auto hash_reset = hasher.hash(for_stmt); - hasher.clearCache(); - - ASSERT_EQ(hash_before, hash_reset); - for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); - auto hash_thread_idx = hasher.hash(for_stmt); - - ASSERT_NE(hash_before, hash_thread_idx); - ASSERT_NE(hash_block_idx, hash_thread_idx); -} - -/// (2 + x) + 4 => x + 6 -TEST(Simplify, SimplifyAdd) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle m("m", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle n("n", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle n_1("n_1", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); - - ExprHandle simplified = IRSimplifier::simplify(body); - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - VarPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->name_hint(), "x"); - IntImmPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->value(), 6.f); -} - -/// (2 - x) - 4 => -2 - x -TEST(Simplify, SimplifySub) { - VarHandle x("x", kInt); - ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); - - ExprHandle simplified = IRSimplifier::simplify(body); - SubPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - IntImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), -2.f); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// 2 * (1 - x) - 4 => 2 * (-3 - x) -TEST(Simplify, SimplifyMultiLayer) { - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_IMM_WITH_VAL(Int, sub->lhs(), -3); - IS_VAR_WITH_NAME(sub->rhs(), "x"); -} - -/// 2 * (3 * x) - (x * 4) => 2 * x -TEST(Simplify, SimplifyMultiTerm) { - VarHandle x("x", kInt); - ExprHandle body = - (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); - - ExprHandle simplified = IRSimplifier::simplify(body); - MulPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - IntImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), 2); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// 2 * (3 * (long)x) - (x * 4) => 2 * x -TEST(Simplify, SimplifyCasts) { - VarHandle x("x", kLong); - ExprHandle body = - (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); - - ExprHandle simplified = IRSimplifier::simplify(body); - MulPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - LongImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), 2); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// (x + 0) * 1 => x -TEST(Simplify, SimplifyEliminatesNoOps) { - VarHandle x("x", kInt); - ExprHandle body = (x + ExprHandle(0)) * 1; - - ExprHandle simplified = IRSimplifier::simplify(body); - VarPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_EQ(root->name_hint(), "x"); -} - -/// Cannot simplify this. -TEST(Simplify, SimplifyMultiVar) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x * 24 + y * 34; - - ExprHandle simplified = IRSimplifier::simplify(body); - - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - MulPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - VarPtr varX = to(lhs->rhs()); - ASSERT_NE(varX, nullptr); - ASSERT_EQ(varX->name_hint(), "x"); - MulPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - VarPtr varY = to(rhs->rhs()); - ASSERT_NE(varY, nullptr); - ASSERT_EQ(varY->name_hint(), "y"); -} - -// x + 2 + y => x + y + 2 -TEST(Simplify, DISABLED_SimplifyReorderings) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x + 2 + y; - ExprHandle simplified = IRSimplifier::simplify(body); - - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - - IS_NODE_WITH_NAME(Add, root->lhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - IS_IMM_WITH_VAL(Int, root->rhs(), 2); -} - -/// y + x * 0 => y -TEST(Simplify, SimplifyEliminatesVar) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = y + x * ExprHandle(0); - - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); -} - -TEST(Simplify, SimplifyAdds) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) + (x + y) => 2 * (x + y) - ExprHandle body = (x + y) + (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // (x * y) + (x * y) => 2 * (x * y) - ExprHandle body = (x * y) + (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Mul, root->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) + (x - y) => 2 * (x - y) - ExprHandle body = (x - y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + x + x + x) => 4 * x - ExprHandle body = (x + x + x + x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 4); - IS_VAR_WITH_NAME(root->rhs(), "x"); - } - - { - // (x + 0) => x. - ExprHandle body = x + 0; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x + 0.f) => float(x). - ExprHandle body = x + 0.f; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } -} - -TEST(Simplify, SimplifyMuls) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) * (x + y) => (x + y) * (x + y) - // We don't attempt to simplify multiplication of polynomials since the - // result is only very rarely more efficient. - ExprHandle body = (x + y) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x * y * x * y => x * x * y * y - // These get reordered only. - ExprHandle body = x * y * x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); - IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); - IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); - IS_VAR_WITH_NAME(mul1->rhs(), "y"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - IS_VAR_WITH_NAME(mul3->lhs(), "x"); - IS_VAR_WITH_NAME(mul3->rhs(), "x"); - } - - { - // 1 * (x * 1) => x - // Ones cancel cleanly. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // 1.f * (x * 1.f) => x - // Even float ones cancel cleanly, but carry their type. - ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // 1 * (x * 1.f) => x - // One float is enough to cast the expr. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // 1 * (x * 0) => 0 - // Zeroes are eliminated. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // 1 * (x * 0) => 0 - // But not for Float since nan * 0 = nan. - ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); - } - - { - // (x - y) * (x - y) => (x - y) * (x - y) - // As with Add we don't attempt simplification of this. - ExprHandle body = (x - y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + y) * (x - y) => (x + y) * (x - y) - // Don't simplify with different ops on each side. - ExprHandle body = (x + y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with no scalar, poly with non-identity scalar. - // x * (y + 1) => x + x * y - ExprHandle body = x * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with identity scalar, poly with non-identity scalar. - // (x * 1) * (y + 1) => x + x * y - ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with non-identity scalar, poly with non-identity scalar. - // (x * 2) * (y + 1) => 2 * (x + x * y) - ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with non-identity scalar, poly with identity scalar. - // (x * 2) * (y + 0) => 2 * (x * y) - ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with identity scalar, poly with identity scalar. - // (x * 1) * (y + 0) => x * y - ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with no scalar, poly with identity scalar. - // x * (y + 0) => x * y - ExprHandle body = x * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } -} - -// Sub an expr from itself will result in zero. -TEST(Simplify, SimplifySubs) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) - (x + y) => 0 - ExprHandle body = (x + y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x * y) - (x * y) => 0 - ExprHandle body = (x * y) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x - y) - (x - y) => 0 - ExprHandle body = (x - y) - (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x + y) - 2 * (x + y) => -1 * x - y - ExprHandle body = (x + y) - ExprHandle(2) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - y => x - ExprHandle body = (x + y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0) => x. - ExprHandle body = x - 0; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0.f) => x. - // Simple enough to cancel in float. - ExprHandle body = x - ExprHandle(0.f); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // (x - (float)(y - y)) => x. - ExprHandle body = x - Cast::make(kFloat, y - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // (x - y) - y => x - 2 * y - ExprHandle body = (x - y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // 2 * x - x => x - ExprHandle body = (ExprHandle(2) * x) - x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // x - 2 * x = -1 * x - // We don't have a unary negate, but this could be 0 -x I guess? - ExprHandle body = x - (ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // (x + y + 5) * (x - x) => 0 - // Cancelling out one side of Mul cancels both. - ExprHandle body = (x + y + 5) * (x - x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Cancel out opaque modulus. - ExprHandle body = (x % y + 2) - (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // Cancel out opaque modulus with a bit more going on. - ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // Sub where result is negative. - ExprHandle body = x - (x + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), -1); - } - - { - // Sub where result is positive due to negative scalar on RHS. - ExprHandle body = x - (x - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } - - { - // Term - Polynomial sub where RHS must be negated. - ExprHandle body = (x * 2) - (x * 2 + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), -1); - } - - { - // Term - Polynomial sub where the result is a Term. - ExprHandle body = (y * x * 2) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Term - Polynomial sub where the result is a Polynomial. - ExprHandle body = (x * 2) - (x + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_IMM_WITH_VAL(Int, sub->rhs(), 1); - } -} - -TEST(Simplify, SimplifyDiv) { - VarHandle x("x", kInt); - - { - ExprHandle body = ExprHandle(0) / x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - ExprHandle body = x / 1; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } -} - -TEST(Simplify, SimplifyDivWithLoopContext0) { - // Stmt to simplify: - // for (int i = 0; i < 100; i++) { - // A[i] = i / 100; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {100}, kInt); - auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext1) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext2) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i + 25) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext3) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) / (-6); - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = -4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext4) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i - 5) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = 0; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext5) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) / 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NEXT: A[i, j] = j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext6) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (int j = -1; j < 9; j++) { - // A[i, j+1] = (i + 6*j) / 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext7) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) / (-6); - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = -j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext0) { - // Stmt to simplify: - // for (const auto i : c10::irange(100)) { - // A[i] = i % 100; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {100}, kInt); - auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext1) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext2) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i + 25) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i + 1; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext3) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) % (-6); - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext4) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i - 5) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = i - 5; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext5) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) % 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NEXT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext6) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (int j = -1; j < 9; j++) { - // A[i, j+1] = (i + 6*j) % 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext7) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) % (-6); - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyMod) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Constant folding works. - ExprHandle body = ExprHandle(10) % 8; - ExprHandle simplified = IRSimplifier::simplify(body); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // x % x => 0 - ExprHandle body = x % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // 0 % x => 0 - ExprHandle body = ExprHandle(0) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // x % 1 => 0 - ExprHandle body = x % 1; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Doesn't change unknown mods. - // x % y => x % y - ExprHandle body = x % y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); - } - - { - // don't touch if RHS is unknown. - // 4 % x => 4 % x - ExprHandle body = ExprHandle(4) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_IMM_WITH_VAL(Int, mod->lhs(), 4); - IS_VAR_WITH_NAME(mod->rhs(), "x"); - } - - { - // don't touch if LHS is unknown. - // x % 4 => x % 4 - ExprHandle body = x % 4; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 4); - } - - { - // if LHS is a multiple of RHS, mod is zero. - // 2 * x % x => 0 - ExprHandle body = (x * 2) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true even if the multiple is not constant. - // x * y % x => 0 - ExprHandle body = (x * y) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true with multiple unknown values in LHS. - // x * y * z % x => 0 - ExprHandle body = (x * y * z) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true if the denom is compound. - // x * y * z % y * z => 0 - ExprHandle body = (x * y * z) % (y * z); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Sanity check true with scalars that are multiples. - // 12 * x % 4 => 0 - ExprHandle body = (x * 12) % 4; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Sanity check not true if the smaller scalar is on LHS. - // 4 * x % 12 => 4 * x % 12 - ExprHandle body = (x * 4) % 12; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 4); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 12); - } - - { - // Both scalar and symbolic in multiple. - // (6 * x * y) % (3 * x * y) => 0 - ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } -} - -// Test that mixing ops together simplifies as expected. -TEST(Simplify, SimplifyMultiOp) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x * y) + (x - y) => (x + x * y) - y - ExprHandle body = (x * y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - x * y => (x + y) - x * y - ExprHandle body = (x + y) - x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) - (x + y) => -2 * y - ExprHandle body = (x - y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - 0) + (x * 1) - (x + 0) => x - ExprHandle body = (x - 0) + (x * 1) - (x + 0); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) - // Even in Float simple terms cancel out, but the variable ones cannot. - ExprHandle body = - (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); - IS_VAR_WITH_NAME(cast1->src_value(), "x"); - IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); - IS_VAR_WITH_NAME(cast2->src_value(), "x"); - IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); - IS_VAR_WITH_NAME(cast3->src_value(), "x"); - } -} - -// Test that chaining many ops together works as expected. -TEST(Simplify, SimplifyManyOps) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x - ExprHandle body = x + y + x + x + y + y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); - } - - { - // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y - ExprHandle body = x - y + x + x - y - y + x - y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x + y + x - x - y - y + x + y + x = 3 * x - ExprHandle body = x + y + x - x - y - y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 3); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -TEST(Simplify, SimplifyFactorization) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (2 * x) + (2 * y) => 2 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization when scalars have common divider. - // (2 * x) + (4 * y) => 2 * (2 * y + x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Factorization attempt without a common divider. - // (2 * x) + (5 * y) => (5 * y) + (2 * x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Factorization after merging. - // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + - (ExprHandle(8) * x + ExprHandle(6) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 10); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization with common divider but different signs. - // (2 * x) + (-4 * y) => 2 * (x - 2 * y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Factorization with all negative numbers. - // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) - ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); - IS_VAR_WITH_NAME(mul2->rhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); - IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); - IS_VAR_WITH_NAME(mul3->rhs(), "y"); - } - - { - // The following test ensures that there in no infinite recursion during - // factorization when negative numbers are involved. - VarHandle a("a", kInt); - VarHandle b("b", kInt); - VarHandle c("c", kInt); - VarHandle d("d", kInt); - VarHandle e("e", kInt); - VarHandle f("f", kInt); - VarHandle g("g", kInt); - VarHandle h("h", kInt); - - ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + - f * 32 + g * (-1024) + h * (-32); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR( - simplified, - "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); - } -} - -// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) -TEST(Simplify, SimplifyFactorizeUneven) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add1); - IS_NODE_WITH_NAME(Add, add1->lhs(), add2); - - IS_VAR_WITH_NAME(add2->lhs(), "y"); - IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); - IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); - - IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); - IS_VAR_WITH_NAME(xmul->rhs(), "x"); - - IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); - IS_VAR_WITH_NAME(zmul->rhs(), "z"); -} - -// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) -// This is kind of a placeholder test for variable factorization. -TEST(Simplify, SimplifyDeeperTerms) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); - IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); - IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); -} - -// Tests the difference between two less trivial expressions. -// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 -TEST(Simplify, SimplifyDeeperDifference) { - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); -} - -// Test constant folding into the difference between expressions. -// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 -TEST(Simplify, SimplifyFoldComplexDifference) { - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (IntImm::make(2) + - (Cast::make( - kChar, - (m * (ExprHandle(1) * n_1) + (n + 1)) - - (m * (ExprHandle(1) * n_1) + n)))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 3); -} - -TEST(Simplify, SimplifyIfComponents) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make( - ((ExprHandle(5) - ExprHandle(4)) * x) > y, - ExprHandle(2) * x - x, - ExprHandle(2) * y - y); - - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); - - IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kGT); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - - IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); - IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); -} - -TEST(Simplify, SimplifyOpaqueTerms) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // 2 * x/y * y - x/y * y => x/y * y - ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Div, mul->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // x%y - (x%y - 1) => 1 - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } -} - -TEST(Simplify, SimplifySymbolicMinMax) { - { - // Minimum with constant difference between terms. - VarHandle x("x", kInt); - ExprHandle body = Min::make(x + 3, x + 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 3); - } - - { - // Maximum with constant difference between terms. - VarHandle x("x", kInt); - ExprHandle body = Max::make(x + 3, x + 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 7); - } - - { - // Can't simplify multiples because of signedness of variable component. - // TODO: maybe we could for unsigned types? - VarHandle x("x", kInt); - ExprHandle body = Max::make(x * 3, x * 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE(Max, simplified.node()); - } -} - -TEST(Simplify, SimplifyNestedMax) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Max(x + y, x + y) => x + y - ExprHandle body = Max::make(x + y, x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); - } - - { - // Max(x + y, Max(x + y, z)) => Max(x + y, z) - ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(x + y, Max(z, x + y)) => Max(x + y, z) - ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(x + y, z), x + y) => Max(x + y, z) - ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(z, x + y), x + y) => Max(x + y, z) - ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(x, y), x) => Max(Max(x, y), x) - // Nested Max ops with different propagate_nans should not be simplified. - ExprHandle body = Max::make(Max::make(x, y, true), x, false); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y"); - ASSERT_TRUE(max1->propagate_nans()); - IS_VAR_WITH_NAME(max->rhs(), "x"); - ASSERT_FALSE(max->propagate_nans()); - } - - { - // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(x, y, true), Min::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(x, y, true), Min::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) - // When all the ops in the pattern do not have the same propagate_nans, - // it should not be simplified. - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(z, x, false), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); - ASSERT_TRUE(min1->propagate_nans()); - IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); - ASSERT_FALSE(min2->propagate_nans()); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(5, Max(x, 8)) => Max(x, 8) - ExprHandle body = Max::make(5, Max::make(x, 8, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(8, Max(x, 5)) => Max(x, 8) - ExprHandle body = Max::make(8, Max::make(x, 5, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(Max(x, 8), 5) => Max(x, 8) - ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(Max(x, 5), 8) => Max(x, 8) - ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) - // Do not simplify when all the Max ops do not have the same - // propagate_nans. - ExprHandle body = Max::make( - Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), - 8, - false); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); - } - - { - // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } -} - -TEST(Simplify, SimplifyNestedMin) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Min(x + y, x + y) => x + y - ExprHandle body = Min::make(x + y, x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); - } - - { - // Min(x + y, Min(x + y, z)) => Min(x + y, z) - ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(x + y, Min(z, x + y)) => Min(x + y, z) - ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(x + y, z), x + y) => Min(x + y, z) - ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(z, x + y), x + y) => Min(x + y, z) - ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(x, y), x) => Min(Min(x, y), x) - // Nested Min ops with different propagate_nans should not be simplified. - ExprHandle body = Min::make(Min::make(x, y, true), x, false); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y"); - ASSERT_TRUE(min2->propagate_nans()); - IS_VAR_WITH_NAME(min1->rhs(), "x"); - ASSERT_FALSE(min1->propagate_nans()); - } - - { - // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(x, y, true), Max::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(x, y, true), Max::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) - // When all the ops in the pattern do not have the same propagate_nans, - // it should not be simplified. - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(z, x, false), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); - ASSERT_TRUE(max1->propagate_nans()); - IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); - ASSERT_FALSE(max2->propagate_nans()); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(5, Min(x, 8)) => Min(x, 8) - ExprHandle body = Min::make(5, Min::make(x, 8, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(8, Min(x, 5)) => Min(x, 8) - ExprHandle body = Min::make(8, Min::make(x, 5, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(Min(x, 8), 5) => Min(x, 8) - ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(Min(x, 5), 8) => Min(x, 8) - ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) - // Do not simplify when all the Min ops do not have the same - // propagate_nans. - ExprHandle body = Min::make( - Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), - 8, - false); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); - } - - { - // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } -} - -TEST(Simplify, SimplifyWontReorderFloat) { - { - // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) - // This is an expression we can simplify. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 9); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). - // If the vars are floating point, ops are not associative and we can't - // reorder. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). - // We will simplify subexprs if they dont reorder floating point ops. - VarHandle x("x", kDouble); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); - IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); - IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); - } - - { - // Prevent reordering if FP propagated from dtypes. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3.f) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); - IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); - IS_VAR_WITH_NAME(yCast->src_value(), "y"); - } - - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - // x%y - (x%y - 1) => x%y - (x%y - 1). - // We won't reorder opaque ops if they are FP. - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); - IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); - - IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); - IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); - IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); - IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); - } -} - -TEST(Simplify, SimplifyRoundModPattern) { - { - // (x/y)*y + x%y => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Reverse order. - // x%y + (x/y)*y => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x % y) + ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Non opaque denominator. - // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + - (x % (y + ExprHandle(4))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Reverse order. - // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x % (y + ExprHandle(4))) + - ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Opaque denominator. - // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + - (x % (ExprHandle(2) / y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Non opaque numerator - // ((2*x)/y * y) + ((2*x) % y) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Opaque numerator. - // ((x/2) / y * y) + (x/2 % y) => x / 2. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_IMM_WITH_VAL(Int, div->rhs(), 2); - } - - { - // Numerator and denominator. - // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + - ((ExprHandle(2) * x) % (ExprHandle(2) * y)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Reverse order. - // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + - (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Negated Subtraction of Round Mod. - // (x/y) * y - (0 - x%y) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Other terms are preserved. - // (x/y)*y + x%y + (y * x) => x + (y * x). - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) + (x % y) + (y * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Sanity checking we won't do the optimization on floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = ((x / y) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); - IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); - IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); - IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); - IS_VAR_WITH_NAME(roundMul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); - } - - { - // Sanity check we won't do it if the mod term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = ((x / y) * y) + (x % z); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(x / y) * y + x % z"); - } - - { - // Sanity check we won't do it if the div term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = (y * (x / z)) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x % y + (x / z) * y"); - } - - { - // Sanity check we won't do it if the mul term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = ((x / y) * z) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x % y + (x / y) * z"); - } -} - -TEST(Simplify, SimplifyRoundModPatternFactorization) { - { - // Full factorization. - // 2 * (x/y * y) + 2 * (x%y) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Partial Factorization. - // 32 * (x/8) + 4 * (x % 8) => 4 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 4); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Factorization requiring constant folding. - // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + - (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 5); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - VarHandle x("x", kInt); - ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - VarHandle x("x", kInt); - ExprHandle body = (x / 10) * 0 + x % 5; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 5); - } -} - -TEST(Simplify, SimplifyRoundModPatternMultivar) { - { - // Multivar. - // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + - (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Find the right var. - // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Add, add->lhs(), add2); - IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); - IS_VAR_WITH_NAME(xMod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); - IS_VAR_WITH_NAME(add2->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); - IS_VAR_WITH_NAME(zMod->lhs(), "z"); - IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); - } - - { - // Compound. - // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) - // => (z + 512 * y) + x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x + (z + 512 * y)"); - } -} - -TEST(Simplify, SimplifyModRoundModPattern) { - { - // t/7 % 9 * 7 + t % 7 => t%63 - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 - VarHandle t("t", kInt); - ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/x % y * x + t % x => t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (t / x % y) * x + t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // k*t/x % y * x + k*t % x => k*t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (k * t / x % y) * x + k * t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(k * t) % (x * y)"); - } - - { - // t/k/x % y * x + t/k % x => t/k%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / k / x % y) * x + t / k % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_VAR_WITH_NAME(div->rhs(), "k"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Sanity checking we won't do the optimization on floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - VarHandle z("z", kFloat); - ExprHandle body = ((x / y % z) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), mul); - IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - IS_VAR_WITH_NAME(mod->rhs(), "z"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_VAR_WITH_NAME(mod2->lhs(), "x"); - IS_VAR_WITH_NAME(mod2->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyModRoundModPatternFactorization) { - { - // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) - VarHandle t("t", kInt); - ExprHandle body = - ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 - VarHandle t("t", kInt); - ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_IMM_WITH_VAL(Int, div->rhs(), 2); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 - VarHandle t("t", kInt); - ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + - t % (ExprHandle(7) * ExprHandle(3)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 189); - } - - { - // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyModRoundModPatternMultivar) { - { - // t/7 % 9 * 7 + t % 7 + t => t % 63 + t - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "t % 63 + t"); - } - - { - // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); - IS_VAR_WITH_NAME(mod1->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_VAR_WITH_NAME(mod2->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); - } - - { - // k + t/x % y * x + t % x => k + t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = k + (t / x % y) * x + t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "k"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x - // => t%(x*y) + t/k % (x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); - } - - { - // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) - // => io_flat - VarHandle t("io_flat", kInt); - ExprHandle body = - ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + - // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + - // (i0_flat / (9 * 7) % 10) * 7 * 9 + - // (i0_flat / 7 % 9) * 7 + - // i0_flat % 7 => io_flat - VarHandle t("io_flat", kInt); - ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + - (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + - (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { - // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * - // (i0_flat / (m * n)) => io_flat - VarHandle t("io_flat", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + - // (i0_flat / (l * n * m) % k) * m * n * l + - // (i0_flat / (n * m) % l) * m * n + - // (i0_flat / m % n) * m + - // i0_flat % m => io_flat - VarHandle t("io_flat", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle l("l", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + - (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + - (t / m % n) * m + t % m; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } -} - -TEST(Simplify, SimplifyDivisionScalarFactorization) { - { - // Simple factorization of numerator and denominator. - // 8x / 4y => 2x / y. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 8) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // Don't change anything if we can't factorize. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 7) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Don't reorder floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = (x * 8) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); - IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "y"); - IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); - } - - { - // Sanity check we do nothing if there are only scalar parts. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 1) / (y * 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // Can factorize amounts of variables. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x + x + x + x) / (y + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyConstantBranches) { - { - // If the condition is constant true then take the true_value. - // 1 ? x : y => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle t(1); - ExprHandle body = IfThenElse::make(t, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // If the condition is constant false then take the false_value. - // 0 ? x : y => y - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle t(0); - ExprHandle body = IfThenElse::make(t, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); - } - - { - // condition is simplified before checking. - // (x-x) ? x : y => y - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(x - x, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); - } - - { - // If both branches are the same then don't do the condition. - // y ? x : x => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(y, x, x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // If both branches simplify to the same thing it still works. - // y ? (x + x) : (2 * x) => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -TEST(Simplify, SimplifyConstantCond) { - { - // If the condition is constant true then take the true_value. - // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(1); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - CondPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // If the condition is constant false then take the false_value. - // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(0); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "B"); - } - - { - // condition is simplified before checking. - // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "B"); - } - - { - // If both branches are the same then don't do the condition. - // x ? A[0] = x : A[0] = x => A[0] = x - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, x); - StmtPtr false_val = Store::make(a, {0}, x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // If both branches simplify to the same thing it still works. - // x ? (x + x) : (2 * x) => x - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); - StmtPtr false_val = Store::make(a, {0}, x + x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // But not if they dont - // x ? x : (2 * x) => x ? x : (2 * x) - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x); - StmtPtr true_val = Store::make(a, {0}, x); - StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block, nullptr); - } - - { - StmtPtr cond = alloc( - ExprHandle(false).node(), - alloc(std::vector({})), - nullptr); - StmtPtr simplified = IRSimplifier::simplify(cond); - ASSERT_EQ(simplified, nullptr); - } - - { - StmtPtr cond = alloc( - ExprHandle(true).node(), - nullptr, - alloc(std::vector({}))); - StmtPtr simplified = IRSimplifier::simplify(cond); - ASSERT_EQ(simplified, nullptr); - } -} - -TEST(Simplify, SimplifyEliminateEmptyCond) { - // If the branches are empty in different ways, eliminate. - { - VarHandle x("x", kInt); - ExprHandle condition(x); - StmtPtr true_val = alloc(std::vector({})); - - StmtPtr body = alloc(condition.node(), true_val, nullptr); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_NE(block, nullptr); - ASSERT_EQ(block->nstmts(), 0); - } - - { - VarHandle x("x", kInt); - ExprHandle condition(x); - StmtPtr false_val = alloc(std::vector({})); - - StmtPtr body = alloc(condition.node(), nullptr, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_NE(block, nullptr); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyConstantComparisons) { - auto ComparisonTest = - [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { - ExprHandle body = CompareSelect::make(a, b, op); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), result); - }; - - // Equals. - ComparisonTest(2, 2, kEQ, 1); - ComparisonTest(1, 2, kEQ, 0); - ComparisonTest(2, 1, kEQ, 0); - - // Greater than. - ComparisonTest(2, 2, kGT, 0); - ComparisonTest(1, 2, kGT, 0); - ComparisonTest(2, 1, kGT, 1); - - // Greater or Equal. - ComparisonTest(2, 2, kGE, 1); - ComparisonTest(1, 2, kGE, 0); - ComparisonTest(2, 1, kGE, 1); - - // Less Than. - ComparisonTest(2, 2, kLT, 0); - ComparisonTest(1, 2, kLT, 1); - ComparisonTest(2, 1, kLT, 0); - - // Less or Equal. - ComparisonTest(2, 2, kLE, 1); - ComparisonTest(1, 2, kLE, 1); - ComparisonTest(2, 1, kLE, 0); - - // Not equal. - ComparisonTest(2, 2, kNE, 0); - ComparisonTest(1, 2, kNE, 1); - ComparisonTest(2, 1, kNE, 1); - - // With specified results: - ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 42); -} - -TEST(Simplify, SimplifySymbolicComparisons) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; - auto TookFalseBranch = [](ExprHandle a) { - IS_IMM_WITH_VAL(Int, a.node(), 0); - }; - - // EQ - - // x == x => 1 - ExprHandle body = CompareSelect::make(x, x, kEQ); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x == x+1 => 0 - body = CompareSelect::make(x, x + 1, kEQ); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x == x * 2 cannot simplify since we don't know x is nonzero. - body = CompareSelect::make(x, x * 2, kEQ); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // x == x * 1 => 1 - body = CompareSelect::make(x, x * 1, kEQ); - TookTrueBranch(IRSimplifier::simplify(body)); - - { - // x == y => x == y - body = CompareSelect::make(x, y, kEQ); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kEQ); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - } - - { - // x == 5 => x == 5 - body = CompareSelect::make(x, 5, kEQ); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kEQ); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); - } - - // GT - - // x+1 > x => 1 - body = CompareSelect::make(x + 1, x, kGT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x > x + 1 => 0 - body = CompareSelect::make(x, x + 1, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x > x - 1 => 1 - body = CompareSelect::make(x, x - 1, kGT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x - 1 > x => 0 - body = CompareSelect::make(x - 1, x, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x > x => 0 - body = CompareSelect::make(x, x, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x * 2 > x => x * 2 > x - // since we don't know the sign of x. - body = CompareSelect::make(x * 2, x, kGT); - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // GE - - // x+1 >= x => 1 - body = CompareSelect::make(x + 1, x, kGE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x >= x + 1 => 0 - body = CompareSelect::make(x, x + 1, kGE); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x >= x => 1 - body = CompareSelect::make(x, x, kGE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x * 2 >= x => x * 2 >= x - // since we don't know the sign of x. - body = CompareSelect::make(x * 2, x, kGE); - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // LT - - // x+1 < x => 0 - body = CompareSelect::make(x + 1, x, kLT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x < x + 1 => 1 - body = CompareSelect::make(x, x + 1, kLT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x < x => 0 - body = CompareSelect::make(x, x, kLT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // LE - - // x+1 <= x => 0 - body = CompareSelect::make(x + 1, x, kLE); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x <= x + 1 => 1 - body = CompareSelect::make(x, x + 1, kLE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x <= x => 1 - body = CompareSelect::make(x, x, kLE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // NE - - // x+1 != x => 1 - body = CompareSelect::make(x + 1, x, kNE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x != x + 1 => 1 - body = CompareSelect::make(x, x + 1, kNE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x != x => 0 - body = CompareSelect::make(x, x, kNE); - TookFalseBranch(IRSimplifier::simplify(body)); -} - -TEST(Simplify, SimplifyEliminateZeroLengthFor) { - { - // Will eliminate zero loop For. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // still works if start is not zero. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // works if both terms are variable. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // works if one term simplifies down. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE(For, simplified); - } -} - -TEST(Simplify, SimplifyOneLoopFor) { - { - // Will remove the loop if the body is run once. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // still works if start is not zero. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 2); - } - - { - // works if both terms are variable. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_VAR_WITH_NAME(store->flat_index(), "x"); - } - - { - // works if one term simplifies down. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = - For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE(For, simplified); - } -} - -TEST(Simplify, SimplifyForWontLoseLoopOptions) { - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - LoopOptions options; - options.set_gpu_block_index(LoopOptions::IDX_W); - auto body = - For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, for_); - LoopOptions options2 = for_->loop_options(); - ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); - } -} - -TEST(Simplify, SimplifyMultilevelFor) { - { - // Multiple layers of For will be simplified out. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 1, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Will maintain an outer loop if the inner loop is eliminated. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 2, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - ForPtr for__ = static_to(simplified); - IS_NODE_WITH_NAME(For, for__, for_); - IS_VAR_WITH_NAME(for_->var(), "j"); - IS_IMM_WITH_VAL(Int, for_->start(), 0); - IS_IMM_WITH_VAL(Int, for_->stop(), 2); - BlockPtr block = to(for_->body()); - ASSERT_NE(block, nullptr); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Will maintain inner loop if outer loops is eliminated. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 1, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(For, block->front(), for_); - IS_VAR_WITH_NAME(for_->var(), "i"); - IS_IMM_WITH_VAL(Int, for_->start(), 0); - IS_IMM_WITH_VAL(Int, for_->stop(), 2); - IS_NODE_WITH_NAME(Store, for_->body()->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_VAR_WITH_NAME(store->flat_index(), "i"); - } -} - -TEST(Simplify, SimplifyForCleansUp) { - { - BufHandle a("a", {1, 12, 1}, kFloat); - VarHandle x("x", kInt); - Tensor b = Compute( - "x", - {1, 12, 1}, - [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { - return i + m + n; - }); - LoopNest l({b}); - l.prepareForCodegen(); - - StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); - StmtPtr simplified = IRSimplifier::simplify(body); - - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(For, block->front(), for_); - // for is over "m". - IS_VAR_WITH_NAME(for_->var(), "j"); - // x[m] = m; - IS_NODE_WITH_NAME(Store, for_->body()->front(), store); - IS_VAR_WITH_NAME(store->flat_index(), "j"); - IS_VAR_WITH_NAME(store->value(), "j"); - } -} - -TEST(Simplify, SimplifyEliminateEmptyFor) { - { - // Flatten many layers around an empty block to an empty block. - StmtPtr last = alloc(std::vector({})); - for ([[maybe_unused]] const auto i : c10::irange(11)) { - VarHandle loopVar("loopVar", kInt); - last = For::make(loopVar, 0, 10, last); - } - - StmtPtr simplified = IRSimplifier::simplify(last); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyFlattenBlock) { - { - // Flatten multiple blocks down to one. - // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store1, store2})); - BlockPtr block2 = alloc(std::vector({block1})); - - BlockPtr enclosing = alloc(std::vector({block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten multiple sub blocks containing statements. - // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store1})); - BlockPtr block2 = alloc(std::vector({store2})); - - BlockPtr enclosing = alloc(std::vector({block1, block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten sub blocks with different depths. - // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store2})); - BlockPtr block2 = alloc(std::vector({block1})); - - BlockPtr enclosing = alloc(std::vector({store1, block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten many layers around an empty block to an empty block. - StmtPtr last = alloc(std::vector({})); - for ([[maybe_unused]] const auto i : c10::irange(11)) { - last = alloc(std::vector({last})); - } - - StmtPtr simplified = IRSimplifier::simplify(last); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { - { - // Simple positive case. - BufHandle b("x", {0}, kInt); - - AllocatePtr alloc_ = Allocate::make(b); - FreePtr free_ = Free::make(b); - - BlockPtr block1 = alloc(std::vector({alloc_, free_})); - ASSERT_EQ(block1->nstmts(), 2); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 0); - } - - { - // Simple negative case. - BufHandle b("x", {2}, kInt); - - AllocatePtr alloc_ = Allocate::make(b); - FreePtr free_ = Free::make(b); - - BlockPtr block1 = alloc(std::vector({alloc_, free_})); - ASSERT_EQ(block1->nstmts(), 2); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - } - - { - // Finds right Alloc/Free. - BufHandle b1("x", {0}, kInt); - BufHandle b2("y", {2}, kInt); - - AllocatePtr alloc1 = Allocate::make(b1); - AllocatePtr alloc2 = Allocate::make(b2); - FreePtr free2_ = Free::make(b2); - FreePtr free1_ = Free::make(b1); - - BlockPtr block1 = - alloc(std::vector({alloc1, alloc2, free2_, free1_})); - ASSERT_EQ(block1->nstmts(), 4); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); - IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y"); - IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); - ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); - } - - { - // Dynamic shape. - VarHandle z("z", kInt); - BufHandle b1("x", {0}, kInt); - BufHandle b2("y", {z}, kInt); - - AllocatePtr alloc1 = Allocate::make(b1); - AllocatePtr alloc2 = Allocate::make(b2); - FreePtr free2_ = Free::make(b2); - FreePtr free1_ = Free::make(b1); - - BlockPtr block1 = - alloc(std::vector({alloc1, alloc2, free2_, free1_})); - ASSERT_EQ(block1->nstmts(), 4); - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - } -} - -TEST(Simplify, DontSimplifyRand) { - { - // rand() + rand() = rand() + rand() NOT 2 * rand(). - ExprHandle body = - Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_RAND(add->lhs()); - IS_RAND(add->rhs()); - } - - { - // rand() - rand() = rand() - rand() NOT 0. - ExprHandle body = - Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_RAND(sub->lhs()); - IS_RAND(sub->rhs()); - } - - { - // rand() * rand() = rand() * rand(). - ExprHandle body = - Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_RAND(mul->lhs()); - IS_RAND(mul->rhs()); - } -} - -TEST(Simplify, SimplifyReorderForCond) { - BufHandle a("A", {4}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - { - // for ( if ( ... ) ) => if ( for ( ... ) ). - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {i}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Can't reorder if condition is dependent on the loop var. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(i, 2, CompareSelectOperation::kEQ), - Store::make(c, {i}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Can't reorder if condition is dependent on a var that is modified inside - // the loop. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(c, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Condition based on buffer not referenced in body. Can reorder here. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(b, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Condition based on buffer read only in body. Can reorder here. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Condition depends on Let in the loop. Cannot reorder. - auto body = For::make( - i, - 0, - 4, - Block::make( - {Let::make(j, 3), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)})); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Let, loop->body()->front(), let); - IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); - } - - { - // Multi level Ifs where all conditions are distinct. Move BOTH Cond - // statements outside the loop. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kEQ), - Store::make(c, {0}, Load::make(a, {i})), - nullptr), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); - IS_NODE_WITH_NAME(For, true_block2->front(), loop); - } - - { - // Multi level Ifs where the inner condition does depend on a loop var, - // reorder only the first Cond. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Cond::make( - CompareSelect::make(i, 3, CompareSelectOperation::kEQ), - Store::make(c, {0}, Load::make(a, {i})), - nullptr), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - IS_NODE_WITH_NAME(Block, loop->body(), loop_body); - IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); - } - - { - // Don't reorder if there's an else block of the Cond. - // We could, but is it much better? - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - Store::make(c, {0}, 0))); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Condition uses distinct region of Tensor. - // We could reorder here with better analysis, but we don't. Included for - // completeness. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(c, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {1}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } -} - -TEST(Simplify, SimplifyFuseConditions) { - BufHandle a("A", {2}, kInt); - BufHandle b("B", {2}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - { - // Can fuse since the conditions are identical. - // if (A) { X }; if (A) { Y }; => if (A) { X; Y } - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can't fuse, conditions are not identical in lhs (i != j). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - { - // Can't fuse, conditions are not identical in rhs (10 != 11). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 11, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can't fuse, conditions are not identical in operation (LT vs GT). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kGT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can't fuse, CompareSelect results are different. - // Actually we totally could if we normalized CompareSelect results, but - // TODO for later. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can fuse with false stmt only. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(a, {0}, i)), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(a, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 2); - ASSERT_EQ(cond->true_stmt(), nullptr); - } - - { - // Can fuse with both true and false stmt. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - Store::make(b, {0}, i)), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - Store::make(b, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 2); - } - - { - // Can fuse with mismatched true / false stmt existing - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(b, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 1); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 1); - } - - { - // Can fuse partial block contents, ie when there are non fused stmts before - // and after. - // before: - // if (j < 10) { A[0] = j; } - // if (i < 10) { A[0] = i; } - // if (i < 10) { A[1] = i; } - // if (i < 11) { A[1] = j; } - // - // after: - // - // if (j < 10) { A[0] = j; } - // if (i < 10) { - // A[0] = i; - // A[1] = i; - // } - // if (i < 11) { A[1] = j; } - - auto body = Block::make({ - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 11, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - it++; - IS_NODE_WITH_NAME(Cond, *it, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can fuse longer sequences of identical conditions. - auto body = Block::make({ - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 4); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can't fuse through a non condition. - auto body = Block::make({ - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Store::make(b, {1}, i + j), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt2->nstmts(), 2); - ASSERT_EQ(cond2->false_stmt(), nullptr); - - auto it = block->begin(); - it++; - IS_NODE_WITH_NAME(Store, *it, middle); - } - - { - // Can fuse if the conditions simplify to the same thing. - auto body = Block::make( - {Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(87) % ExprHandle(11), - CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(300) / ExprHandle(30), - CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can fuse non-CompareSelects. - // if (i) { X } if (i) { Y } => if (i) { X; Y } - auto body = Block::make( - {Cond::make(i, Store::make(a, {0}, i), nullptr), - Cond::make(i, Store::make(a, {1}, i), nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Sanity check won't fuse different non-CompareSelects. - auto body = Block::make( - {Cond::make(i, Store::make(a, {0}, i), nullptr), - Cond::make(j, Store::make(a, {1}, i), nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - } - - { - // Sanity check constant condition elimination still occurs when merging is - // possible. - auto body = Block::make( - {Cond::make(1, Store::make(a, {0}, i), nullptr), - Cond::make(1, Store::make(a, {1}, i), nullptr)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Store, block->front(), store1); - IS_NODE_WITH_NAME(Store, block->back(), store2); - } - - { - // Sanity check for-cond reordering occurs after fusing. - auto body = For::make( - i, - 0, - 4, - Block::make( - {Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, Load::make(b, {0})), - nullptr), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {2}, Load::make(b, {0})), - nullptr)})); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } -} - -TEST(Simplify, SimplifySyncThreads) { - BufHandle a("A", {4}, kInt); - VarHandle i("i", kInt); - - { - // Merge two inner SyncThreads. - auto body = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {0}, 1), - alloc(), - alloc(), - Store::make(a, {1}, 0)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } - - { - // Eliminate outer SyncThreads. - auto body = Block::make( - {alloc(), Store::make(a, {1}, 0), alloc()}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - auto it = block->begin(); - IS_NODE(Store, *it); - } - - { - // Merge many inner SyncThreads. - auto body = Block::make( - {Store::make(a, {0}, 1), - alloc(), - alloc(), - alloc(), - alloc(), - alloc(), - Store::make(a, {1}, 0)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } - - { - // Merge multiple outer SyncThreads. - auto body = Block::make( - {alloc(), - alloc(), - Store::make(a, {1}, 0), - alloc(), - alloc(), - alloc(), - alloc()}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - auto it = block->begin(); - IS_NODE(Store, *it); - } - - { - // Merge multiple sections; - auto body = Block::make( - {Store::make(a, {0}, 1), - alloc(), - alloc(), - Store::make(a, {1}, 0), - Store::make(a, {2}, 0), - alloc(), - alloc(), - alloc(), - Store::make(a, {3}, 0)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 6); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } -} - -TEST(Simplify, SimplifyRampSubBroadcast) { - int num_lanes = 4; - ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); - ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); - ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); - RampPtr newRamp = simplified.AsNode(); - IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); - ASSERT_EQ(base->value(), 5); - IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); - ASSERT_EQ(stride->value(), 6); - ASSERT_EQ(newRamp->lanes(), num_lanes); -} - -TEST(Simplify, SimplifyBroadcastTermExpander) { - int num_lanes = 8; - ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); - ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); - ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); - // NB: We need a term in the middle which isn't simplified to trigger the - // relevant path in TermExpander::mutate. The two bc1 terms are brought - // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. - ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); - BufHandle buf("buf", {num_lanes}, kInt); - // The result isn't fully simplified currently and thus would be brittle to - // match. Observe its value instead. - auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); - SimpleIREvaluator eval(store, {buf}); - std::vector output(num_lanes); - eval(output); - for (const auto i : c10::irange(num_lanes)) { - ASSERT_EQ(output[i], 2); - } -} - -TEST(Simplify, CompareSelectLoopBounds) { - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - VarHandle m("m", kInt); - VarHandle var_N("var_N", kInt); - VarHandle var_M("var_M", kInt); - - auto test_case_fn = [](const VarHandle& n, - const BufHandle& b, - const ExprHandle& start, - const ExprHandle& stop, - const int& cmp_val, - const CompareSelectOperation& cmp_op, - const std::string& check_string) { - StmtPtr s = For::make( - n, - start, - stop, - b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - std::string target_string = "# CHECK: "; - target_string += check_string; - torch::jit::testing::FileCheck().run(target_string, oss.str()); - }; - - auto test_case_nest_loops_fn = [](const VarHandle& n, - const VarHandle& m, - const BufHandle& b, - const ExprHandle& n_start, - const ExprHandle& n_stop, - const ExprHandle& m_start, - const ExprHandle& m_stop, - const CompareSelectOperation& cmp_op, - const std::string& check_string) { - StmtPtr s = For::make( - m, - m_start, - m_stop, - b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); - StmtPtr root_s = For::make(n, n_start, n_stop, s); - root_s = IRSimplifier::simplify(root_s); - std::ostringstream oss; - oss << *root_s; - std::string target_string = "# CHECK: "; - target_string += check_string; - torch::jit::testing::FileCheck().run(target_string, oss.str()); - }; - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, 2)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, 2)) { - // b[1] = 0.f; - // } - test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kNE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kNE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - test_case_nest_loops_fn( - n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 31, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 31, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - test_case_nest_loops_fn( - n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n < m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kLT, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kLT, - "b[n, m] = n m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kGT, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kGT, - "b[n, m] = n>m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n > m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 1.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kGT, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kGT, - "b[n, m] = n>m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n >= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 31, - kGE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 31, - kGE, - "b[n, m] = n>=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n >= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 1.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kGE, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kGE, - "b[n, m] = n>=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n <= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kLE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kLE, - "b[n, m] = n<=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = (n <= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kLE, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kLE, - "b[n, m] = n<=m ? 0.f : 1.f;"); -} - -TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - StmtPtr s = For::make( - n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[n] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, IfThenCondAlwaysInLoopBounds) { - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - StmtPtr s = - For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[n] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { - // This test mimics the unpadded region of a conv2d. We want to remove any - // conditional that is provably satisfied (or unsatisfied) by the entire loop - // range. - // Before: - // for (const auto i : c10::irange(1, 7)) { - // for (const auto j : c10::irange(1, 7)) { - // b[i, j] = IfThenElse( - // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); - // After: - // for (const auto i : c10::irange(1, 7)) { - // for (const auto j : c10::irange(1, 7)) { - // b[i, j] = 1.f; - constexpr int N = 8; - BufHandle b("b", {N, N}, kFloat); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto csel = CompareSelect::make(i, 1, kLT); - csel = CompareSelect::make(j, 1, 1, csel, kLT); - csel = CompareSelect::make(i, N - 1, 1, csel, kGE); - csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); - s = For::make(j, 1, N - 1, s); - s = For::make(i, 1, N - 1, s); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[i, j] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, DISABLED_SimplifyLoopBounds) { - // This test mimics the padded region of a conv2d. We want to adjust the - // loop bounds such that the condition will be always met. Note that this - // could be solved by peeling, and applying the range-based conditional - // simplification in the previous tests. - // Before: - // for (const auto i : c10::irange(3)) { - // for (const auto j : c10::irange(3)) { - // b[i, j] = (b[i, j]) + (IfThenElse( - // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); - // After: - // for (const auto i : c10::irange(1, 3)) { - // for (const auto j : c10::irange(1, 3)) { - // b[i, j] = (b[i, j]) + 1.f; - constexpr int N = 8; - constexpr int K = 3; - BufHandle a("a", {N, N}, kFloat); - BufHandle b("b", {N, N}, kFloat); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto csel = CompareSelect::make(i, 1, kLT); - csel = CompareSelect::make(j, 1, 1, csel, kLT); - csel = CompareSelect::make(i, N - 1, 1, csel, kGE); - csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - StmtPtr s = b.store( - {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); - s = For::make(j, 0, K, s); - s = For::make(i, 0, K, s); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: for (const auto i : c10::irange(1, 3)) { -# CHECK: for (const auto j : c10::irange(1, 3)) { -# CHECK-NOT: IfThenElse -)IR", - oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp deleted file mode 100644 index 56535de914e4..000000000000 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ /dev/null @@ -1,402 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -struct WithCPUFuser { - WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { - overrideCanFuseOnCPU(val); - } - - ~WithCPUFuser() { - overrideCanFuseOnCPU(cpuFuserEnabled); - } - - bool cpuFuserEnabled; -}; - -TEST(TEFuserPass, FuserPass_1) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(128, strides=[1], device=cpu), - %1 : Float(128, strides=[1], device=cpu)): - %12 : int = prim::Constant[value=1]() - %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) - %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) - %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) - %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) - %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should not be able to fuse across the in-place operation here. - testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("aten::add_") - ->check("prim::TensorExprGroup_") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_2) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(128, strides=[1], device=cpu), - %1 : Float(128, strides=[1], device=cpu)): - %12 : int = prim::Constant[value=1]() - %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) - %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) - %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) - %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) - return (%d))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should not be able to fuse across the in-place operation here. - testing::FileCheck() - .check("aten::add_") - ->check("prim::TensorExprGroup_0") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_3) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(128, strides=[1], device=cpu), - %y : Float(128, strides=[1], device=cpu)): - %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%r))IR"; - { - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // We should not create a fusion group since its size would be too small - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should create a fusion group since its size is above the threshold - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); - } -} - -TEST(TEFuserPass, FuserPass_0DimInput) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(device=cpu), - %y : Float(device=cpu)): - %one : int = prim::Constant[value=1]() - %a : Float(device=cpu) = aten::mul(%x, %y) - %b : Float(device=cpu) = aten::add(%x, %a, %one) - return (%b))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should fuse 0-dim tensors too - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_UnfusibleDevice) { - WithCPUFuser cf(false); - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(10, strides=[1], device=cpu)): - %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%a))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // Test that we're not starting fusion groups from nodes with unfusible device - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_UnknownShapes) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Tensor, - %y : Tensor): - %a : Tensor = aten::mul(%x, %y) - %b : Tensor = aten::mul(%x, %a) - return (%b))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // Test that we're not generating fusion groups when shapes are not known - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_Multidevice) { - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should be able to fuse this - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cuda:0), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should not fuse this aten::cat since its inputs are from different - // devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(10, strides=[1], device=cuda:0)): - %dim : int = prim::Constant[value=0]() - %xy_list : Tensor[] = prim::ListConstruct(%x, %y) - %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) - %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) - return (%r))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // Test that we check device before merging one node (cat) into another - // (mul) - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(10, strides=[1], device=cuda:0)): - %z2 : Tensor = aten::mul(%z, %z) - %dim : int = prim::Constant[value=0]() - %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // Test that we check device before merging one node (mul) into another - // (cat) - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cuda:0)): - %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%r))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should not fuse this graph since its inputs are from different devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cuda:0), - %y : Float(20, strides=[1], device=cuda:1), - %z : Float(20, strides=[1], device=cpu)): - %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) - %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) - %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) - return (%x2, %y2, %z2))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // We should not fuse these two computations since they use different - // devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } -} - -TEST(TEFuserPass, FuserPass_MergeGroups) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%a : Float(128, strides=[1], device=cpu), - %b : Float(128, strides=[1], device=cpu)): - %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) - %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) - return (%x, %y))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // The %x and %y computations are completely independent and yet we should put - // them into a single fusion group rather than having two separate ones. - testing::FileCheck() - .check("= prim::TensorExprGroup_") - ->check_not("= prim::TensorExprGroup_") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Bool(8, strides=[1], device=cpu), - %y : Bool(8, strides=[1], device=cpu)): - %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) - %b : Tensor = aten::__or__(%a, %y) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_Where) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(8, strides=[1], device=cpu), - %y : Float(8, strides=[1], device=cpu), - %z : Float(8, strides=[1], device=cpu)): - %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) - %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_WhereList) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(8, strides=[1], device=cpu), - %y : Float(8, strides=[1], device=cpu), - %z : Float(8, strides=[1], device=cpu)): - %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) - %b : Tensor[] = aten::where(%cond) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, DynamicShapeFusion) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), - %1 : Float(10, 5, strides=[5, 1], device=cpu)): - %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) - %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) - return (%3))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs( - g, - /* min_group_size = */ 2, - /* add_composed_op = */ true, - /* fuse_to_dynamic_shapes = */ true); - Code code(g, ""); - - testing::FileCheck() - .check("prim::TensorExprDynamicGroup_") - ->check("prim::TensorExprDynamicGuard") - ->check("prim::TensorExprGroup_") - ->run(*g); - - auto run_and_compare = [&](const std::vector& inputs) { - TORCH_INTERNAL_ASSERT(inputs.size() == 2); - - auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); - - InterpreterState interp(code); - Stack stack(inputs.begin(), inputs.end()); - interp.run(stack); - at::Tensor out = pop(stack).toTensor(); - ASSERT_TRUE(at::allclose(out, ref)); - }; - - std::vector inputs = {at::rand({10, 5}), at::rand({10, 5})}; - run_and_compare(inputs); - - std::vector inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; - run_and_compare(inputs2); - - std::vector inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; - run_and_compare(inputs3); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp deleted file mode 100644 index 6758503f4de7..000000000000 --- a/test/cpp/tensorexpr/test_type.cpp +++ /dev/null @@ -1,202 +0,0 @@ -#include - -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -TEST(Type, Test01) { - { - Dtype dt1 = kInt; - ASSERT_EQ(dt1, kInt); - } - { - Dtype dt2_a(kInt, 8); - Dtype dt2_b(kInt, 4); - Dtype dt2_c(ScalarType::Int, 8); - ASSERT_EQ(dt2_a, dt2_c); - ASSERT_NE(dt2_a, dt2_b); - } - { - ASSERT_EQ(kInt, ToDtype()); - ASSERT_EQ(kFloat, ToDtype()); - ASSERT_EQ(kByte, ToDtype()); - ASSERT_EQ(kChar, ToDtype()); - ASSERT_EQ(kShort, ToDtype()); - ASSERT_EQ(kLong, ToDtype()); - ASSERT_EQ(kHalf, ToDtype()); - ASSERT_EQ(kDouble, ToDtype()); - ASSERT_EQ(kBool, ToDtype()); - } - { - Dtype int32x8(kInt, 8); - Dtype float32x8(kFloat, 8); - ASSERT_NE(int32x8, float32x8); - ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); - ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); - ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); - ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); - } -} - -TEST(Type, BitCasting) { - { - VarHandle x("x", kFloat); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kInt); - } - { - VarHandle x("x", kInt); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kFloat); - } - { - VarHandle x("x", kShort); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kHalf); - } - { - VarHandle x("x", kHalf); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kShort); - } - - constexpr int32_t ref32 = 1337; - constexpr int64_t ref64 = 1337; - constexpr float reff32 = 1337.0f; - constexpr double reff64 = 1337.0f; - using SimpleIRExprEval = ExprEval; - // this is broken - /*{ - constexpr int16_t ref16 = 1337; - at::Half k_; - at::Half* k = &k_; - *reinterpret_cast(k) = ref16; - auto a = HalfImm::make(*k); - auto b = BitCast::make(kShort, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref16); - }*/ - - { - float k = raw_bitcast(ref32); - auto a = FloatImm::make(k); - auto b = BitCast::make(kInt, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref32); - } - - { - double k = raw_bitcast(ref64); - auto a = DoubleImm::make(k); - auto b = BitCast::make(kLong, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref64); - } - - { - int64_t k = raw_bitcast(reff64); - auto a = LongImm::make(k); - auto b = BitCast::make(kDouble, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), reff64); - } - - { - int32_t k = raw_bitcast(reff32); - auto a = IntImm::make(k); - auto b = BitCast::make(kFloat, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), reff32); - } - - // This segfaults :( - /*{ - VarHandle x("x", kDouble); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kFloat); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kLong); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kShort); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kInt); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - }*/ -} - -TEST(Type, Propagation) { - // Same types: - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = FloatImm::make(2.f) + - (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); - ASSERT_EQ(body.dtype(), kFloat); - } - // Int to bigger int: - { - VarHandle x("x", kShort); - VarHandle y("y", kLong); - ExprHandle body = - ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); - ASSERT_EQ(body.dtype(), kLong); - } - // Float to bigger float: - { - VarHandle x("x", kHalf); - VarHandle y("y", kDouble); - ExprHandle body = - HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); - ASSERT_EQ(body.dtype(), kDouble); - } - // Int to Float: - { - VarHandle x("x", kFloat); - VarHandle y("y", kInt); - ExprHandle body = - IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); - ASSERT_EQ(body.dtype(), kFloat); - } - // Smaller float, bigger Int: - { - VarHandle x("x", kHalf); - VarHandle y("y", kLong); - ExprHandle body = - HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); - ASSERT_EQ(body.dtype(), kHalf); - } - // Bigger float, smaller Int: - { - VarHandle x("x", kChar); - VarHandle y("y", kDouble); - ExprHandle body = - CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); - ASSERT_EQ(body.dtype(), kDouble); - } - // Sign change char/byte upgrades to short: - { - VarHandle x("x", kChar); - VarHandle y("y", kByte); - ExprHandle body = - CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); - ASSERT_EQ(body.dtype(), kShort); - } -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_type_specializations.cpp b/test/cpp/tensorexpr/test_type_specializations.cpp deleted file mode 100644 index d9756627fa74..000000000000 --- a/test/cpp/tensorexpr/test_type_specializations.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include - -// Test that tensor type specializations are available in -// the custom passes - -namespace torch { -namespace jit { - -namespace { - -bool hasTensorTypeSpecializations(torch::jit::Block* block) { - for (Value* v : block->inputs()) { - if (hasTensorTypeSpecialization(v)) - return true; - } - for (Node* n : block->nodes()) { - for (torch::jit::Block* b : n->blocks()) { - if (hasTensorTypeSpecializations(b)) - return true; - } - for (Value* v : n->outputs()) { - if (hasTensorTypeSpecialization(v)) - return true; - } - } - return false; -} - -static bool hasSpecializations = false; -void detectTTSpecializationPass(std::shared_ptr& graph) { - GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph); - hasSpecializations = hasTensorTypeSpecializations(graph->block()); -} - -} // namespace - -TEST(SpecializationsInCustomPasses, Basic) { - RegisterPass p(detectTTSpecializationPass); - hasSpecializations = false; - std::shared_ptr graph = std::make_shared(); - parseIR( - R"IR( -graph(%a.1 : Tensor, - %b.1 : Tensor): - %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8 - %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8 - return (%d.1) - )IR", - &*graph); - - IValue ival = IValue(torch::randn({22}, at::kCPU)); - std::vector stack = {ival, ival}; - auto run = [&](std::shared_ptr& graph, std::vector stack) { - GraphExecutor executor(graph, ""); - executor.run(stack); - return stack; - }; - run(graph, stack); - - // Profiling mode will not be run with simple executor - if (!getExecutorMode()) { - EXPECT_TRUE(hasSpecializations); - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h deleted file mode 100644 index 065e513c1a64..000000000000 --- a/test/cpp/tensorexpr/test_utils.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -#define IS_NODE(T, node) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - } - -#define IS_NODE_WITH_NAME(T, node, name) \ - auto name = to(node); \ - ASSERT_NE(nullptr, name); - -#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ - NodePtr name = nullptr; \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ - name = to(node_->src_value()); \ - } \ - ASSERT_NE(nullptr, name); - -#define IS_IMM_WITH_VAL(T, node, val) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->value(), val); \ - } - -#define IS_VAR_WITH_NAME(node, name) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->name_hint(), name); \ - } - -#define IS_BINOP_W_VARS(T, node, name, v1, v2) \ - NodePtr name = nullptr; \ - { \ - name = to(node); \ - ASSERT_NE(nullptr, name); \ - IS_VAR_WITH_NAME(name->lhs(), v1); \ - IS_VAR_WITH_NAME(name->rhs(), v2); \ - } - -#define IS_BINOP_W_CONST(T, node, name, v, c) \ - NodePtr name = nullptr; \ - { \ - name = to(node); \ - ASSERT_NE(nullptr, name); \ - IS_VAR_WITH_NAME(name->lhs(), v); \ - IS_IMM_WITH_VAL(Int, name->rhs(), c); \ - } - -#define IS_RAND(node) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->op_type(), kRand); \ - } - -void checkIR(StmtPtr s, const std::string& pattern); -void checkExprIR(ExprPtr e, const std::string& pattern); -void checkExprIR(const ExprHandle& e, const std::string& pattern); - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp deleted file mode 100644 index 3f4c32af463b..000000000000 --- a/test/cpp/tensorexpr/tutorial.cpp +++ /dev/null @@ -1,542 +0,0 @@ -// *** Tensor Expressions *** -// -// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to -// work with them, and outlines how they are used in the overall TorchScript -// compilation pipeline. This doc is permanently a "work in progress" since NNC -// is under active development and things change fast. -// -// This Tutorial's code is compiled in the standard pytorch build, and the -// executable can be found in `build/bin/tutorial_tensorexpr`. -// -// *** What is NNC *** -// -// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT -// and it performs on-the-fly code generation for kernels, which are often a -// combination of multiple aten (torch) operators. -// -// When the JIT interpreter executes a torchscript model, it automatically -// extracts subgraphs from the torchscript IR graph for which specialized code -// can be JIT generated. This usually improves performance as the 'combined' -// kernel created from the subgraph could avoid unnecessary memory traffic that -// is unavoidable when the subgraph is interpreted as-is, operator by operator. -// This optimization is often referred to as 'fusion'. Relatedly, the process of -// finding and extracting subgraphs suitable for NNC code generation is done by -// a JIT pass called 'fuser'. -// -// *** What is TE *** -// -// TE stands for Tensor Expressions. TE is a commonly used approach for -// compiling kernels performing tensor (~matrix) computation. The idea behind it -// is that operators are represented as a mathematical formula describing what -// computation they do (as TEs) and then the TE engine can perform mathematical -// simplification and other optimizations using those formulas and eventually -// generate executable code that would produce the same results as the original -// sequence of operators, but more efficiently. -// -// NNC's design and implementation of TE was heavily inspired by Halide and TVM -// projects. -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace torch::jit::tensorexpr; - -#ifdef TORCH_ENABLE_LLVM - -// Helper function to print a snippet from a big multi-line string -static void printLinesToFrom(const std::string& input_str, int from, int to); - -#endif - -int main(int argc, char* argv[]) { - std::cout << "*** Structure of tensor expressions and statements ***" - << std::endl; - { - // A tensor expression is a tree of expressions. Each expression has a type, - // and that type defines what sub-expressions the current expression has. - // For instance, an expression of type 'Mul' would have a type 'kMul' and - // two subexpressions: LHS and RHS. Each of these two sub-expressions could - // also be a 'Mul' or some other expression. - // - // Let's construct a simple TE: - ExprPtr lhs = alloc(5); - ExprPtr rhs = alloc("x", kInt); - ExprPtr mul = alloc(lhs, rhs); - std::cout << "Tensor expression: " << *mul << std::endl; - // Prints: Tensor expression: 5 * x - - // Here we created an expression representing a 5*x computation, where x is - // an int variable. - - // Another, probably a more convenient, way to construct tensor expressions - // is to use so called expression handles (as opposed to raw expressions - // like we did in the previous example). Expression handles overload common - // operations and allow us to express the same semantics in a more natural - // way: - ExprHandle l = 5; - ExprHandle r = Var::make("x", kInt); - ExprHandle m = l * r; - std::cout << "Tensor expression: " << *m.node() << std::endl; - // Prints: Tensor expression: 5 * x - - // Converting from handles to raw expressions and back is easy: - ExprHandle handle = Var::make("x", kInt); - ExprPtr raw_expr_from_handle = handle.node(); - ExprPtr raw_expr = alloc("x", kInt); - ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); - - // We could construct arbitrarily complex expressions using mathematical - // and logical operations, casts between various data types, and a bunch of - // intrinsics. - ExprHandle a = Var::make("a", kInt); - ExprHandle b = Var::make("b", kFloat); - ExprHandle c = Var::make("c", kFloat); - ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); - std::cout << "Tensor expression: " << *x.node() << std::endl; - // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) - - // An ultimate purpose of tensor expressions is to optimize tensor - // computations, and in order to represent accesses to tensors data, there - // is a special kind of expression - a load. - // To construct a load we need two pieces: the base and the indices. The - // base of a load is a Buf expression, which could be thought of as a - // placeholder similar to Var, but with dimensions info. - // - // Let's construct a simple load: - BufHandle A("A", {64, 32}, kInt); - VarPtr i_var = alloc("i", kInt), j_var = alloc("j", kInt); - ExprHandle i(i_var), j(j_var); - ExprHandle load = Load::make(A.dtype(), A, {i, j}); - std::cout << "Tensor expression: " << *load.node() << std::endl; - // Prints: Tensor expression: A[i, j] - - // Tensor Expressions constitute Tensor Statements, which are used to - // represent computation of a given operator or a group of operators from a - // fusion group. - // - // There are three main kinds of tensor statements: - // - block - // - store - // - loop - // - // A Store represents a store to a single element of a tensor (or to a - // group of elements if it's a vectorized store). Store statements, - // similarly to Load expressions, have a base and indices, but on top of - // that they also include a value - an expression representing what needs - // to be stored at the given memory location. Let's create a Store stmt: - StmtPtr store_a = Store::make(A, {i, j}, i + j); - std::cout << "Store statement: " << *store_a << std::endl; - // Prints: Store statement: A[i, j] = i + j; - - // An operator fills the entire tensor, not just a single element, and to - // represent this we need to use For stmt: let's wrap our store stmt with - // two nested loops to represent that variables i and j need to iterate - // over some ranges. - ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); - ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); - - std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; - // Prints: - // Nested for loops: - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // A[i, j] = i + j; - // } - // } - - // A Block statement is used when we need a sequence of other statements. - // E.g. if a fusion group contains several operators, we initially define - // separate loopnest for each of them and put them all into a common block: - BufHandle B("B", {64, 32}, kInt); - StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); - ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); - ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); - - BlockPtr block = Block::make({loop_i_a, loop_i_b}); - std::cout << "Compound Block statement: " << std::endl - << *block << std::endl; - // Prints: - // Compound Block statement: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // A[i, j] = i + j; - // } - // } - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // B[i, j] = A[i, j]; - // } - // } - // } - - // Manually constructing nested loops and blocks to represent a computation - // might be laborious, and instead we can use a 'Compute' API. This API - // requires us to specify dimensions and a lambda to compute a single - // element of the resulting tensor and returns a `Tensor` structure. This - // structure is simply a pair of a buffer that was created to represent the - // result of the computation (BufPtr) and a statement representing the - // computation itself (StmtPtr). - Tensor C = - Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - std::cout << "Stmt produced by 'Compute' API: " << std::endl - << *C.stmt() << std::endl; - // Prints: - // Stmt produced by 'Compute' API: - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * j; - // } - // } - - // To construct statements to represent computations with reductions, we - // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple - // of extra arguments defining how to perform the reduction. Let's define a - // simple 2D sum of C using that: - Tensor D = Reduce( - "D", - {}, - Sum(), - [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, - {64, 32}); - std::cout << "Stmt produced by 'Reduce' API: " << std::endl - << *D.stmt() << std::endl; - } - - std::cout << "*** Loopnests transformations ***" << std::endl; - { - // When a statement for the computation is generated, we might want to - // apply some optimizations to it. These transformations allow us to end up - // with a statement producing the same results, but more efficiently. - // - // Let's look at a couple of transformations that are used in NNC. We will - // begin with constructing a Block statement like we did before. - - Tensor C = - Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return i * (j + 1); - }); - BufHandle c_buf(C.buf()); - Tensor D = - Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return c_buf.load(i, j) - i; - }); - StmtPtr block = Block::make({C.stmt(), D.stmt()}); - std::cout << "Stmt produced by 'Compute' API: " << std::endl - << *block << std::endl; - // Prints: - // Stmt produced by 'Compute' API: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * (j + 1); - // } - // } - // for (const auto i_1 : c10::irange(64)) { - // for (const auto j_1 : c10::irange(32)) { - // D[i_1, j_1] = (C[i_1, j_1]) - i_1; - // } - // } - // } - - // One transformation we can apply to this computation is inlining: i.e. - // taking the expression that defines values of C and substituting a load - // from C with it. - // To do that, we first need to create a special object called LoopNest - - // all transformations are methods of this class. To create a loopnest we - // need to provide a list of output buffers and the root statement: - LoopNest nest(block, {D.buf()}); - - // We can always retrieve the Stmt back from LoopNest: - std::cout << "LoopNest root stmt: " << std::endl - << *nest.root_stmt() << std::endl; - // Prints: - // LoopNest root stmt: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * (j + 1); - // } - // } - // for (const auto i_1 : c10::irange(64)) { - // for (const auto j_1 : c10::irange(32)) { - // D[i_1, j_1] = (C[i_1, j_1]) - i_1; - // } - // } - // } - - // Now we can apply the inlining transformation: - nest.computeInline(C.buf()); - std::cout << "Stmt after inlining:" << std::endl - << *nest.root_stmt() << std::endl; - // Prints: - // Stmt after inlining: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // D[i, j] = i * (j + 1) - i; - // } - // } - // } - - // We can also apply algebraic simplification to a statement: - StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); - std::cout << "Stmt after simplification:" << std::endl - << *simplified << std::endl; - // Prints: - // Stmt after simplification: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // D[i, j] = i * j; - // } - // } - // } - - // Many loopnest transformations are stateless and can be applied without - // creating a LoopNest object. In fact, we plan to make all transformations - // stateless. - // splitWithTail is one such transformation: it splits an iteration space - // of a given loop into two with a given factor. - ForPtr outer_loop = to(to(simplified)->stmts().front()); - LoopNest::splitWithTail(outer_loop, 13); - // Call simplifier once more to fold some arithmetic. - simplified = IRSimplifier::simplify(simplified); - std::cout << "Stmt after splitWithTail:" << std::endl - << *simplified << std::endl; - // Prints: - // Stmt after splitWithTail: - // { - // for (const auto i_outer : c10::irange(4)) { - // for (const auto i_inner : c10::irange(13)) { - // for (const auto j : c10::irange(32)) { - // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); - // } - // } - // } - // for (const auto i_tail : c10::irange(12)) { - // for (const auto j : c10::irange(32)) { - // D[i_tail + 52, j] = i_tail * j + 52 * j; - // } - // } - // } - - // NNC supports a wide range of loop nest transformations, which we are not - // listing here. Please refer to documentation in - // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h - // for more details. - } - - std::cout << "*** Codegen ***" << std::endl; - { - // An ultimate goal of tensor expressions is to be provide a mechanism to - // execute a given computation in the fastest possible way. So far we've - // looked at how we could describe what computation we're interested in, but - // we haven't looked at how to actually execute it. - // - // All we've been dealing with was just symbols with no actual data - // associated, in this section we would look at how we can bridge that gap. - - // Let's start by constructing a simple computation for us to work with: - BufHandle A("A", {64, 32}, kInt); - BufHandle B("B", {64, 32}, kInt); - Tensor X = - Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + B.load(i, j); - }); - - // And let's lower it to a loop nest, as we did in the previous section. We - // can pass Tensor object directly: - LoopNest loopnest({X}); - std::cout << *loopnest.root_stmt() << std::endl; - // Prints: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // X[i, j] = (A[i, j]) + (B[i, j]); - // } - // } - - // Now imagine that we have two actual tensors 64x32 that we want sum - // together, how do we pass those tensors to the computation and how do we - // carry it out? - // - // Codegen object is aimed at providing exactly that functionality. Codegen - // is an abstract class and concrete codegens are derived from it. - // Currently, we have three codegens: - // 1) Simple Evaluator, - // 2) LLVM Codegen for CPU, - // 3) CUDA Codegen. - // In this example we will be using Simple Evaluator, since it's available - // everywhere. - - // To create a codegen, we need to provide the statement - it specifies the - // computation we want to perform - and a list of placeholders and tensors - // used in the computation. The latter part is crucial since that's the only - // way the codegen could use to correlate symbols in the statement to actual - // data arrays that we will be passing when we will actually be performing - // the computation. - // - // Let's create a Simple IR Evaluator codegen for our computation: - SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); - - // We are using the simplest codegen and in it almost no work is done at the - // construction step. Real codegens such as CUDA and LLVM perform - // compilation during that stage so that when we're about to run the - // computation everything is ready. - - // Let's now create some inputs and run our computation with them: - std::vector data_A(64 * 32, 3); // This will be the input A - std::vector data_B(64 * 32, 5); // This will be the input B - std::vector data_X(64 * 32, 0); // This will be used for the result - - // Now let's invoke our codegen to perform the computation on our data. We - // need to provide as many arguments as how many placeholders and tensors we - // passed at the codegen construction time. A position in these lists would - // define how real data arrays from the latter call (these arguments are - // referred to as 'CallArg's in our codebase) correspond to symbols - // (placeholders and tensors) used in the tensor expressions we constructed - // (these are referred to as 'BufferArg'). - // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A - // contains data for the placeholder A, data_B - for the placeholder B, and - // data_X would be used for contents of tensor X. - ir_eval(data_A, data_B, data_X); - - // Let's print one of the elements from each array to verify that the - // computation did happen: - std::cout << "A[10] = " << data_A[10] << std::endl - << "B[10] = " << data_B[10] << std::endl - << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; - // Prints: - // A[10] = 3 - // B[10] = 5 - // X[10] = A[10] + B[10] = 8 - } - - std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; - { - // This section requires a LLVM-enabled PyTorch build, so we have to use a - // guard: -#ifdef TORCH_ENABLE_LLVM - - // Often we would like to convert a TorchScript IR to TE rather than - // construct TE IR from scratch. NNC provides an API to perform such - // lowering: it takes a TorchScript graph and returns an object that can be - // used to invoke the generated kernel. - // This API is currently used by the TorchScript JIT fuser and can also be - // used ahead of time to pre-compile parts of a model. - // - // To get familiar with this API let's first start with defining a simple - // TorchScript graph: - const auto graph_string = R"IR( - graph(%A : Float(5, 3, strides=[3, 1], device=cpu), - %B : Float(5, 3, strides=[3, 1], device=cpu)): - %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) - %one : int = prim::Constant[value=1]() - %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) - %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) - return (%AAB_plus_B))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - // This graph defines a simple computation of A*A*B + B where A and B are - // input 5x3 tensors. - - // To lower this TorchScript graph to TE, we just need to create a - // TensorExprKernel object. In its constructor it constructs the - // corresponding TE IR and compiles it for the given backend (in this - // example for CPU using LLVM compiler). - TensorExprKernel kernel(graph); - - // We can retrieve the generated TE stmt from the kernel object: - StmtPtr kernel_stmt = kernel.getCodeGenStmt(); - std::cout << "TE Stmt constructed from TorchScript: " << std::endl - << *kernel_stmt << std::endl; - // Prints: - // TE Stmt constructed from TorchScript: - // { - // for (const auto v : c10::irange(5)) { - // for (const auto _tail_tail : c10::irange(3)) { - // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * - // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + - // (tB[_tail_tail + 3 * v]); - // } - // } - // } - - // We can also examine generated LLVM IR and assembly code: - std::cout << "Generated LLVM IR: " << std::endl; - auto ir_str = kernel.getCodeText("ir"); - printLinesToFrom(ir_str, 15, 20); - // Prints: - // Generated LLVM IR: - // %9 = bitcast float* %2 to <8 x float>* - // %10 = load <8 x float>, <8 x float>* %9 ... - // %11 = bitcast float* %5 to <8 x float>* - // %12 = load <8 x float>, <8 x float>* %11 ... - // %13 = fmul <8 x float> %10, %12 - // %14 = fmul <8 x float> %10, %13 - - std::cout << "Generated assembly: " << std::endl; - auto asm_str = kernel.getCodeText("asm"); - printLinesToFrom(asm_str, 10, 15); - // Prints: - // Generated assembly: - // vmulps %ymm1, %ymm0, %ymm2 - // vfmadd213ps %ymm1, %ymm0, %ymm2 - // vmovups %ymm2, (%rax) - // vmovss 32(%rcx), %xmm0 - // vmovss 32(%rdx), %xmm1 - // vmulss %xmm1, %xmm0, %xmm2 - - // We can also execute the generated kernel: - auto A = - at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * - 2.0; - auto B = - at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * - 3.0; - std::vector inputs = {A, B}; - std::vector stack = torch::fmap(inputs); - kernel.run(stack); - auto R = stack[0].toTensor(); - - // Let's print one of the elements from the result tensor to verify that the - // computation did happen and was correct: - std::cout << "R[2][2] = " << R[2][2] << std::endl; - // Prints: - // R[2][2] = 15 - // [ CPUFloatType{} ] -#endif - } - return 0; -} - -void printLinesToFrom(const std::string& input_str, int from, int to) { - std::istringstream f(input_str); - std::string s; - int idx = 0; - while (getline(f, s)) { - if (idx > from) { - std::cout << s << "\n"; - } - if (idx++ > to) { - break; - } - } -} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 63e9eb77dd34..e3dfc581179a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -291,10 +291,43 @@ void boxed_fill_infinity( stack[0] = from(res); } +Tensor my_pad(Tensor t) { + std::vector padding = {1, 2, 2, 1}; + std::string mode = "constant"; + double value = 0.0; + return pad(t, padding, mode, value); +} + +void boxed_my_pad( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto res = my_pad(to(stack[0])); + stack[0] = from(res); +} + +Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) { + return narrow(t, dim, start, length); +} + +void boxed_my_narrow( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto res = my_narrow( + to(stack[0]), + to(stack[1]), + to(stack[2]), + to(stack[3])); + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); m.def("my_empty_like(Tensor t) -> Tensor"); m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)"); + m.def("my_pad(Tensor t) -> Tensor"); + m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -303,6 +336,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("fill_infinity", &boxed_fill_infinity); } +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) { + m.impl("my_pad", &boxed_my_pad); + m.impl("my_narrow", &boxed_my_narrow); +} Tensor my_zero_(Tensor t) { return zero_(t); @@ -320,3 +357,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("my_zero_", &boxed_my_zero_); } + +bool test_default_constructor(bool defined) { + Tensor out; + if (defined) { + AtenTensorHandle defined_ath; + int64_t sizes[] = {2, 3}; + int64_t strides[] = {3, 1}; + aoti_torch_empty_strided( + 2, + sizes, + strides, + aoti_torch_dtype_float32(), + aoti_torch_device_type_cpu(), + 0, + &defined_ath); + out = Tensor(defined_ath); + } + return out.defined(); +} + +void boxed_test_default_constructor( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + bool res = test_default_constructor(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("test_default_constructor(bool undefined) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_default_constructor", &boxed_test_default_constructor); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 1694bfa1b396..817732371060 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -164,3 +164,42 @@ def fill_infinity(t) -> Tensor: Returns: The modified tensor (same as input) """ return torch.ops.libtorch_agnostic.fill_infinity.default(t) + + +def test_default_constructor(defined) -> bool: + """ + Tests the default constructor for torch::stable::Tensor. + + Args: + defined: bool - if True, tests defined tensor; if False, tests undefined tensor + + Returns: bool - result of calling .defined() on the tensor + """ + return torch.ops.libtorch_agnostic.test_default_constructor.default(defined) + + +def my_pad(t) -> Tensor: + """ + Pads the input tensor with hardcoded padding parameters. + + Args: + t: Input tensor + + Returns: Padded tensor with padding [1, 2, 2, 1], mode "constant", value 0.0 + """ + return torch.ops.libtorch_agnostic.my_pad.default(t) + + +def my_narrow(t, dim, start, length) -> Tensor: + """ + Returns a new tensor that is a narrowed version of the input tensor. + + Args: + t: Input tensor + dim: Dimension along which to narrow + start: Starting position + length: Length of the narrowed section + + Returns: Narrowed tensor + """ + return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index bd409a0eb5a6..ae3c2767627f 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -218,6 +218,40 @@ def test_fill_infinity(self, device): expected = torch.full_like(t, math.inf) self.assertEqual(out, expected) + @onlyCPU + def test_default_constructor(self): + import libtorch_agnostic + + defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor( + True + ) + self.assertTrue(defined_tensor_is_defined) + + undefined_tensor_is_defined = ( + libtorch_agnostic.ops.test_default_constructor(False) + ) + self.assertFalse(undefined_tensor_is_defined) + + def test_my_pad(self, device): + import libtorch_agnostic + + t = torch.rand(2, 3, device=device) + out = libtorch_agnostic.ops.my_pad(t) + expected = torch.nn.functional.pad(t, [1, 2, 2, 1], "constant", 0.0) + self.assertEqual(out, expected) + + def test_my_narrow(self, device): + import libtorch_agnostic + + t = torch.randn(2, 5, device=device) + + dim0 = 0 + start0 = 0 + length0 = 1 + out0 = libtorch_agnostic.ops.my_narrow(t, dim0, start0, length0) + expected0 = torch.narrow(t, dim0, start0, length0) + self.assertEqual(out0, expected0) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py index 07d31e73d76b..386e34cdb56f 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py @@ -85,6 +85,7 @@ def main(): cmdclass={ "clean": BuildClean, # type: ignore[misc] }, + include_package_data=False, ) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 478eb498ac5d..b64d4107ee0c 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -299,12 +299,20 @@ def _check_count(copy_count, resize_count): def _reinplace_all_gather_with_optional_checks(self, fwd_fullgraph): def _run_with_checks(graph, orig_fn): - self.assertGreater( - _count_op_in_graph( - graph, torch.ops._c10d_functional.all_gather_into_tensor.default - ), - 0, - ) + if self.world_size > 1: + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, + ) + elif self.world_size == 1: + self.assertEqual( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, + ) orig_fn(graph) @@ -315,12 +323,22 @@ def _run_with_checks(graph, orig_fn): 0, ) - self.assertGreater( - _count_op_in_graph( - graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default - ), - 0, - ) + if self.world_size > 1: + self.assertGreater( + _count_op_in_graph( + graph, + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + ), + 0, + ) + else: + self.assertEqual( + _count_op_in_graph( + graph, + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + ), + 0, + ) if fwd_fullgraph: return mock.patch.object( @@ -549,7 +567,7 @@ def test_compiled(): Developer debug context: call_method TensorVariable() backward () {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0123.html""", # noqa: B950 + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0123.html""", # noqa: B950 ) else: self.assertGreater(len(counters["graph_break"]), 1) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_logging.py b/test/distributed/_composable/fsdp/test_fully_shard_logging.py index 2ee46febfb24..c9450a2b8f47 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_logging.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_logging.py @@ -6,11 +6,9 @@ import torch.distributed as dist from torch._dynamo.test_case import run_tests from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.logging_utils import LoggingTestCase -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index cf8b86cc8e06..6ff022f46d19 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -1467,5 +1467,70 @@ def forward(self, imgs: torch.Tensor) -> torch.Tensor: check_sharded_parity(self, ref_model, model) +class TestFullyShardWorldSize1(FSDPTest): + @property + def world_size(self) -> int: + return 1 + + @skip_if_lt_x_gpu(1) + def test_train_parity_single_worldsize1(self): + """ + Tests train parity with DDP for a single FSDP group when sharding + parameters on dim-0. + """ + self.run_subtests( + { + "lin_shapes": [ + [(16, 15), (15, 8)], + [(7, 15), (15, 3)], + [(16, 17), (17, 8)], + ], + "use_shard_placement_fn": [False], + }, + self._test_train_parity_single_group, + ) + + def _test_train_parity_single_group( + self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool + ): + torch.manual_seed(42) + model = nn.Sequential( + nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1]) + ) + ref_model = copy.deepcopy(model).to(device_type) + replicate(ref_model, device_ids=[self.rank]) + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) + + def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + return Shard(param.shape.index(max(param.shape))) + + shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None + fully_shard(model, shard_placement_fn=shard_placement_fn) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + torch.manual_seed(42 + self.rank + 1) + inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),) + + for iter_idx in range(10): + losses: list[torch.Tensor] = [] + + ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + losses.append(ref_model(*inp).sum()) + losses[-1].backward() + ref_optim.step() + + optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + comm_mode = CommDebugMode() + with comm_mode: + losses.append(model(*inp).sum()) + losses[-1].backward() + + # Before there was 1 all-gather and 1 reduce-scatter + # Now therre is 1 reduce-scatter + self.assertEqual(comm_mode.get_total_counts(), 1) + optim.step() + + self.assertEqual(losses[0], losses[1]) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 3ab0b6269b2d..bcaf06ea947a 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -277,19 +277,19 @@ def test_tp_with_fsdp_offloading(self): loss = model(inp).sum() fwd_comm_counts = fwd_comm_mode.get_comm_counts() - self.assertEqual(len(fwd_comm_counts), 2) + self.assertEqual(len(fwd_comm_counts), 1) self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) - self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps) + self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], 0) ref_loss = ref_model(inp).sum() self.assertEqual(loss, ref_loss) with CommDebugMode() as bwd_comm_mode: loss.backward() bwd_comm_counts = bwd_comm_mode.get_comm_counts() - self.assertEqual(len(bwd_comm_counts), 3) + self.assertEqual(len(bwd_comm_counts), 2) # First MLP's input gradient does not need to be all-reduced self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) - self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps) + self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0) self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) ref_loss.backward() diff --git a/test/distributed/_composable/test_replicate_with_fsdp.py b/test/distributed/_composable/test_replicate_with_fsdp.py index ff61e2c05f27..099f84b9e848 100644 --- a/test/distributed/_composable/test_replicate_with_fsdp.py +++ b/test/distributed/_composable/test_replicate_with_fsdp.py @@ -256,7 +256,7 @@ def test_train_replicate_fsdp(self): @skip_if_lt_x_gpu(2) def test_train_parity_2d_mlp(self): """ - Verifies that when a device mesh is passed in, the model has the same behavior as the original model when training + Verifies when a device mesh is passed in, the model has the same behavior as the original model when training """ self._init_pg() global_mesh = self.init_replicate_tp_mesh() diff --git a/test/distributed/checkpoint/e2e/test_fsdp_ep.py b/test/distributed/checkpoint/e2e/test_fsdp_ep.py index 7489317035b9..51d4b3e99537 100644 --- a/test/distributed/checkpoint/e2e/test_fsdp_ep.py +++ b/test/distributed/checkpoint/e2e/test_fsdp_ep.py @@ -73,8 +73,8 @@ def test_e2e(self): self.device_type, (2, 4), mesh_dim_names=("dp", "tp") ) # TODO: we are using an internal API atm. Change to a public API once it is ready. - mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",)) - del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep] + mesh_fsdp_ep = _mesh_resources.create_sub_mesh(mesh_fsdp_tp, ("dp",), [(0,)]) + del _mesh_resources.child_to_root_mapping[mesh_fsdp_ep] mesh_fsdp = init_device_mesh(self.device_type, (8,)) for i, l in enumerate(model.second.ep_layers): diff --git a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py index ba07c62728d7..ad74c34c4e2e 100644 --- a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py +++ b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py @@ -8,6 +8,7 @@ import torch.distributed.checkpoint as dist_cp from torch import distributed as dist from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + _calculate_max_contiguous_elements, consolidate_safetensors_files, ) from torch.distributed.checkpoint._hf_utils import _metadata_fn @@ -153,6 +154,76 @@ def test_consolidate_to_two_files(self): ) dist.barrier() + def test_calculate_max_contiguous_elements_validations(self) -> None: + """Test validation logic in _calculate_max_contiguous_elements function.""" + + # Test empty lists validation + with self.assertRaisesRegex(ValueError, "Input lists cannot be empty"): + _calculate_max_contiguous_elements([], [2, 3], [4, 5]) + + # Test mismatched list lengths validation + with self.assertRaisesRegex( + ValueError, "All input lists must have the same length" + ): + _calculate_max_contiguous_elements([1], [2, 3], [4, 5]) + + # Test indices out of bounds validation + with self.assertRaisesRegex( + ValueError, "Index .* at dimension .* is out of bounds for sub-tensor shape" + ): + _calculate_max_contiguous_elements( + [2, 1], [2, 3], [4, 5] + ) # indices[0] >= sub_tensor_shape[0] + + # Test sub-tensor dimensions exceeding tensor dimensions validation + with self.assertRaisesRegex( + ValueError, + "Sub-tensor dimension .* at position .* exceeds tensor dimension", + ): + _calculate_max_contiguous_elements( + [1, 2], [2, 6], [4, 5] + ) # sub_tensor_shape[1] > tensor_shape[1] + + def test_calculate_max_contiguous_elements_valid_cases(self) -> None: + """Test valid cases for _calculate_max_contiguous_elements function.""" + + # Test 1D case - simple remaining elements + result = _calculate_max_contiguous_elements([2], [5], [10]) + self.assertEqual(result, 3) # 5 - 2 = 3 elements remaining + + # Test 2D case - at start of row, can write complete rows + result = _calculate_max_contiguous_elements([1, 0], [3, 4], [6, 4]) + self.assertEqual(result, 8) # 2 rows * 4 columns = 8 elements + + # Test 2D case - middle of row, only remaining in current row + result = _calculate_max_contiguous_elements([1, 2], [3, 4], [6, 8]) + self.assertEqual(result, 2) # 4 - 2 = 2 elements remaining in row + + # Test 3D case - at start of 2D slice, can write complete slices + result = _calculate_max_contiguous_elements([1, 0, 0], [3, 2, 4], [5, 2, 4]) + self.assertEqual(result, 16) # 2 slices * 2 rows * 4 columns = 16 elements + + # Test edge case - at last position + result = _calculate_max_contiguous_elements([2, 3], [3, 4], [6, 8]) + self.assertEqual(result, 1) # Only 1 element remaining + + # Test case where sub-tensor spans full width + result = _calculate_max_contiguous_elements([0, 0], [2, 5], [4, 5]) + self.assertEqual(result, 10) # 2 rows * 5 columns = 10 elements + + # Test column-wise sharded case - sub-tensor doesn't span full width + # Even at start of row, can only write width of one row due to column sharding + result = _calculate_max_contiguous_elements([1, 0], [3, 2], [4, 8]) + self.assertEqual( + result, 2 + ) # Only 2 elements (width of sub-tensor) can be written contiguously + + # Test another column-wise sharded case - middle of tensor + result = _calculate_max_contiguous_elements([0, 0], [2, 3], [6, 10]) + self.assertEqual( + result, 3 + ) # Only 3 elements (width of sub-tensor) can be written contiguously + if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 637dd228944f..81558db13a69 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -162,8 +162,16 @@ def test_write_data_with_sharding(self) -> None: ) def test_read_data_hf(self) -> None: - # Create test tensors tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + mock_safe_open = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value.get_slice.return_value = tensor_0 + mock_safe_open.return_value = mock_context + + sys.modules["safetensors"] = MagicMock() + sys.modules["safetensors"].safe_open = mock_safe_open + with tempfile.TemporaryDirectory() as path: # Create the reader reader = HuggingFaceStorageReader(path=path) @@ -200,8 +208,6 @@ def test_read_data_hf(self) -> None: fqn="tensor_0", offset=torch.Size([0]), index=None ): _HFStorageInfo( file_path, - len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN, - tensor_0.numel() * tensor_0.element_size(), tensor_0.shape, tensor_0.dtype, ), @@ -260,6 +266,9 @@ def test_read_data_hf(self) -> None: # Verify results - the target tensors should now contain the values from our test tensor self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0)) + mock_safe_open.assert_called_once_with(filename=file_path, framework="pt") + mock_context.__enter__.return_value.get_slice.assert_called_with("tensor_0") + def test_write_metadata_hf(self) -> None: mock_module = MagicMock() sys.modules["huggingface_hub"] = mock_module @@ -313,35 +322,50 @@ def test_write_metadata_hf(self) -> None: self.assertEqual(metadata, expected_metadata) def test_read_metadata_hf(self): + mock_safe_open = MagicMock() + mock_context = MagicMock() + + mock_safe_open.return_value = mock_context + + mock_context.__enter__.return_value.keys.return_value = ["tensor_0"] + mock_context.__enter__.return_value.metadata.return_value = {} + + mock_slice = MagicMock() + mock_slice.get_shape.return_value = [5, 10] + mock_slice.get_dtype.return_value = "F32" + mock_context.__enter__.return_value.get_slice.return_value = mock_slice + + mock_safetensors = MagicMock() + mock_safetensors.safe_open = mock_safe_open + + mock_safetensors.torch._getdtype = MagicMock(return_value=torch.float32) + + sys.modules["safetensors"] = mock_safetensors + sys.modules["safetensors.torch"] = mock_safetensors.torch + with tempfile.TemporaryDirectory() as path: reader = HuggingFaceStorageReader(path=path) key = "tensor_0" file_name = "test.safetensors" - with open(os.path.join(path, file_name), "wb") as f: - # write metadata the same way it would be in safetensors file - metadata_contents = json.dumps( - { - "tensor_0": { - "dtype": "F32", - "shape": [5, 10], - "data_offsets": [0, 200], - } - } - ) - metadata_bytes = metadata_contents.encode("utf-8") + file_path = os.path.join(path, file_name) - f.write( - len(metadata_bytes).to_bytes( - NUM_BYTES_FOR_HEADER_LEN, byteorder="little" - ) - ) - f.write(metadata_bytes) + # Create an empty file so fs.ls can find it + with open(file_path, "wb") as _: + pass + + # Mock the fs.ls method to return our test file + original_ls = reader.fs.ls + reader.fs.ls = MagicMock(return_value=[file_path]) - tensor = torch.rand(5, 10) - f.write(tensor.numpy().tobytes()) + try: + metadata = reader.read_metadata() + finally: + # Restore the original ls method + reader.fs.ls = original_ls - metadata = reader.read_metadata() + # Verify that safe_open was called with our file path + mock_safe_open.assert_called_once_with(file_path, framework="pt") self.assertEqual( metadata.state_dict_metadata, @@ -365,8 +389,6 @@ def test_read_metadata_hf(self): fqn=key, offset=torch.Size([0, 0]), index=None ): _HFStorageInfo( os.path.join(path, file_name), - len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN, - 200, torch.Size([5, 10]), torch.float32, ) diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py index 86a952e0701d..8134472f52d5 100644 --- a/test/distributed/checkpoint/test_state_dict_stager.py +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -1,12 +1,23 @@ # Owner(s): ["oncall: distributed"] import dataclasses +import os +import tempfile +from datetime import timedelta import torch import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) from torch.distributed._tensor import DTensor -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.checkpoint._state_dict_stager import StateDictStager +from torch.distributed.checkpoint.staging import _ReplicationStager +from torch.distributed.tensor import DeviceMesh, distribute_tensor from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -818,5 +829,523 @@ def test_dtensor(self): self.assertEqual(cpu_state_dict["dtensor"].size(), dtensor.size()) +class TestReplicationStager(DTensorTestBase): + """ + Test suite for _ReplicationStager functionality. + Tests replication of state_dict across training ranks using CPU tensors only. + """ + + @property + def backend(self) -> str: + return "cpu:gloo,cuda:nccl" + + def _create_simple_state_dict(self, rank: int) -> dict: + """ + Create a simple state_dict with CPU tensors, deterministically unique per rank. + + Args: + rank: The rank number to create unique tensors for + + Returns: + dict: A state dictionary with CPU tensors + """ + # Create unique tensors for each rank + torch.manual_seed(42 + rank) # Different seed per rank + + return { + "layer1.weight": torch.randn(64, 128, device="cpu"), + "layer1.bias": torch.randn(64, device="cpu"), + "layer2.weight": torch.randn(32, 64, device="cpu"), + "layer2.bias": torch.randn(32, device="cpu"), + "nested": { + "param": torch.randn(16, 16, device="cpu"), + "buffer": torch.randn(8, device="cpu"), + }, + "scalar": torch.tensor(float(rank), device="cpu"), + } + + def _verify_simple_state_dict_replication( + self, replicated_dict: dict, rank: int, partner_rank: int + ): + """ + Verify that replication worked correctly. + + Args: + replicated_dict: The replicated state_dict received from partner + rank: Current rank + partner_rank: Partner rank we should have received from + """ + # Create expected state_dict (what partner rank would have created) + expected_dict = self._create_simple_state_dict(partner_rank) + + def compare_tensors(actual, expected, path=""): + if isinstance(actual, dict) and isinstance(expected, dict): + self.assertEqual( + actual.keys(), expected.keys(), f"Keys mismatch at {path}" + ) + for key in actual: + compare_tensors( + actual[key], expected[key], f"{path}.{key}" if path else key + ) + elif isinstance(actual, torch.Tensor) and isinstance( + expected, torch.Tensor + ): + self.assertEqual( + actual.device.type, "cpu", f"Tensor at {path} should be on CPU" + ) + self.assertEqual( + actual.shape, expected.shape, f"Shape mismatch at {path}" + ) + self.assertEqual( + actual.dtype, expected.dtype, f"Dtype mismatch at {path}" + ) + self.assertTrue( + torch.equal(actual, expected), f"Values mismatch at {path}" + ) + else: + self.assertEqual(actual, expected, f"Value mismatch at {path}") + + compare_tensors(replicated_dict, expected_dict) + + def _create_dtensor_state_dict(self, rank: int, device_mesh: DeviceMesh) -> dict: + """ + Create state_dict with DTensor and regular tensors for deterministic testing + due to DTensor Shard, Replicate placements. + + Args: + rank: Current rank + device_mesh: DeviceMesh for DTensor creation + + Returns: + dict: State dictionary with DTensors + """ + # Create a large global tensor with deterministic values + # Each position contains a unique value that encodes both position and rank info + global_size = 128 + global_tensor = torch.arange(0, global_size * 16, dtype=torch.float32).reshape( + global_size, 16 + ) + + # Create DTensor with Shard(0) - each rank gets different rows + sharded_dtensor = distribute_tensor(global_tensor, device_mesh, [Shard(0)]) + + # Create DTensor with Replicate() - all ranks have the same data + replicated_global = torch.full( + (8, 8), float(global_size * 100), dtype=torch.float32, device="cpu" + ) + replicated_dtensor = distribute_tensor( + replicated_global, device_mesh, [Replicate()] + ) + + return { + "sharded_param": sharded_dtensor, + "replicated_param": replicated_dtensor, + "rank_scalar": torch.tensor(float(rank), device="cpu"), + } + + def _verify_dtensor_replication( + self, replicated_dict: dict, rank: int, partner_rank: int + ): + """ + Verify DTensor replication accuracy by checking local shards and global reconstruction. + + Args: + replicated_dict: Replicated state_dict received from partner + rank: Current rank + partner_rank: Partner rank we should have received from + """ + # Verify sharded DTensor + if "sharded_param" in replicated_dict: + replicated_sharded = replicated_dict["sharded_param"] + self.assertIsInstance(replicated_sharded, DTensor, "Should receive DTensor") + + # Get local shard from replicated DTensor + replicated_local = replicated_sharded.to_local() + + # Create expected local shard (what partner rank would have) + expected_global = torch.arange(0, 128 * 16, dtype=torch.float32).reshape( + 128, 16 + ) + + # Calculate expected shard for this rank's position + world_size = dist.get_world_size() + shard_size = 128 // world_size + start_idx = partner_rank * shard_size + end_idx = (partner_rank + 1) * shard_size + expected_local = expected_global[start_idx:end_idx] + + self.assertTrue( + torch.equal(replicated_local, expected_local), + "Sharded DTensor value mismatch", + ) + + # Verify DTensor metadata is preserved + self.assertEqual( + replicated_sharded._spec.placements[0].__class__.__name__, + "Shard", + "DTensor should maintain Shard placement", + ) + + # Verify replicated DTensor + if "replicated_param" in replicated_dict: + replicated_replicated = replicated_dict["replicated_param"] + self.assertIsInstance( + replicated_replicated, DTensor, "Should receive DTensor" + ) + + # Get local data from replicated DTensor + replicated_local = replicated_replicated.to_local() + + # Expected value should be global_size * 100 + expected_value = float(128 * 100) + expected_tensor = torch.full( + (8, 8), expected_value, dtype=torch.float32, device="cpu" + ) + + self.assertTrue( + torch.equal(replicated_local, expected_tensor), + "Replicated DTensor value mismatch", + ) + + # Verify DTensor metadata is preserved + self.assertEqual( + replicated_replicated._spec.placements[0].__class__.__name__, + "Replicate", + "DTensor should maintain Replicate placement", + ) + + # Verify regular tensors + if "rank_scalar" in replicated_dict: + self.assertEqual( + replicated_dict["rank_scalar"].item(), + float(partner_rank), + f"Rank scalar should be {partner_rank}, got {replicated_dict['rank_scalar'].item()}", + ) + + def _create_sharded_tensor_state_dict(self, rank: int, world_size: int) -> dict: + """ + Create state_dict with ShardedTensor for deterministic testing. + + Args: + rank: Current rank + world_size: Total world size + + Returns: + dict: State dictionary with ShardedTensor + """ + # Create deterministic local shard for this rank + global_size = 64 + shard_size = global_size // world_size + start_idx = rank * shard_size + end_idx = (rank + 1) * shard_size + + # Create local tensor with deterministic values + local_tensor = torch.arange( + start_idx * 8, end_idx * 8, dtype=torch.float32, device="cpu" + ).reshape(shard_size, 8) + + # Create ShardedTensor using init_from_local_shards + sharded_tensor = init_from_local_shards( + [ + ShardedTensorShard( + tensor=local_tensor, + metadata=ShardMetadata( + shard_offsets=[start_idx, 0], + shard_sizes=[shard_size, 8], + placement=f"rank:{rank}/cpu", + ), + ) + ], + global_size, + 8, + ) + + return { + "sharded_tensor": sharded_tensor, + "rank_scalar": torch.tensor(float(rank), device="cpu"), + } + + def _verify_sharded_tensor_replication( + self, replicated_dict: dict, rank: int, partner_rank: int + ): + """ + Verify ShardedTensor replication accuracy by checking local shards and metadata. + + Args: + replicated_dict: Replicated state_dict received from partner + rank: Current rank + partner_rank: Partner rank we should have received from + """ + # Verify sharded tensor + if "sharded_tensor" in replicated_dict: + replicated_sharded = replicated_dict["sharded_tensor"] + self.assertIsInstance( + replicated_sharded, ShardedTensor, "Should receive ShardedTensor" + ) + + # Get local shard from replicated ShardedTensor + local_shards = replicated_sharded.local_shards() + self.assertEqual( + len(local_shards), 1, "Should have exactly one local shard" + ) + + local_shard = local_shards[0] + replicated_local = local_shard.tensor + + # Create expected local shard (what partner rank would have) + world_size = dist.get_world_size() + global_size = 64 + shard_size = global_size // world_size + start_idx = partner_rank * shard_size + end_idx = (partner_rank + 1) * shard_size + + expected_local = torch.arange( + start_idx * 8, end_idx * 8, dtype=torch.float32, device="cpu" + ).reshape(shard_size, 8) + + self.assertTrue( + torch.equal(replicated_local, expected_local), + "Sharded tensor value mismatch", + ) + + # Verify shard metadata is preserved + expected_metadata = ShardMetadata( + shard_offsets=[start_idx, 0], + shard_sizes=[shard_size, 8], + placement=f"rank:{partner_rank}/cpu", + ) + self.assertEqual( + local_shard.metadata.shard_offsets, + expected_metadata.shard_offsets, + "Shard offsets should match", + ) + self.assertEqual( + local_shard.metadata.shard_sizes, + expected_metadata.shard_sizes, + "Shard sizes should match", + ) + + # Verify regular tensors + if "rank_scalar" in replicated_dict: + self.assertEqual( + replicated_dict["rank_scalar"].item(), + float(partner_rank), + f"Rank scalar should be {partner_rank}, got {replicated_dict['rank_scalar'].item()}", + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_basic(self): + """Test basic replication functionality with world_size=16""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Create unique DTensor state_dict for this rank + state_dict = self._create_simple_state_dict(current_rank) + + # Initialize replication stager + stager = _ReplicationStager( + pg=dist.new_group(backend=dist.Backend.GLOO), + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + ) + + # Perform replication + replicated_dict = stager.stage(state_dict) + + # Calculate expected partner rank + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify DTensor replication + self._verify_simple_state_dict_replication( + replicated_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_dtensors(self): + """Test replication with DTensor and mixed tensor types""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Create CPU-based DeviceMesh for DTensor + device_mesh = DeviceMesh("cpu", list(range(world_size))) + + # Create DTensor state_dict which includes different tensor types + state_dict = self._create_dtensor_state_dict(current_rank, device_mesh) + + # Initialize replication stager + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + ) + + # Perform replication + result = stager.stage(state_dict) + + # Wait for completion + from concurrent.futures import Future + + if isinstance(result, Future): + replicated_dict = result.result() + else: + replicated_dict = result + + # Calculate expected partner + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify all DTensor types are correctly replicated + self._verify_dtensor_replication(replicated_dict, current_rank, partner_rank) + + # Clean up + stager.close() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_sharded_tensors(self): + """Test replication with ShardedTensor and mixed tensor types""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Create ShardedTensor state_dict for this rank + state_dict = self._create_sharded_tensor_state_dict(current_rank, world_size) + + # Initialize replication stager + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + ) + + # Perform replication + result = stager.stage(state_dict) + + # Wait for completion + from concurrent.futures import Future + + if isinstance(result, Future): + replicated_dict = result.result() + else: + replicated_dict = result + + # Calculate expected partner + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify all ShardedTensor types are correctly replicated + self._verify_sharded_tensor_replication( + replicated_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_persistence(self): + """Test persistence functionality in _ReplicationStager""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Test 1: Default storage directory (auto-generated tempdir) + with tempfile.TemporaryDirectory() as _: + # Create state_dict for this rank + state_dict = self._create_simple_state_dict(current_rank) + + # Initialize stager with default storage_dir (None) + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + storage_dir=None, # Let it create its own tempdir + ) + + # Perform replication to trigger persistence + stager.stage(state_dict) + + # Calculate expected partner rank + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify file was created with correct naming convention + expected_path = stager._get_persisted_path(current_rank, partner_rank) + + self.assertTrue( + os.path.exists(expected_path), + f"Persisted file should exist at {expected_path}", + ) + + # Verify the storage directory was created + self.assertTrue( + os.path.isdir(stager._storage_dir), "Storage directory should exist" + ) + self.assertTrue( + stager._storage_dir.startswith(tempfile.gettempdir()), + "Default storage directory should be in system temp directory", + ) + + # Load and verify the persisted state_dict matches the received one + loaded_state_dict = torch.load(expected_path) + self._verify_simple_state_dict_replication( + loaded_state_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + # Test 2: Custom storage directory + with tempfile.TemporaryDirectory() as custom_storage_dir: + # Create custom subdirectory + custom_subdir = os.path.join(custom_storage_dir, "custom_replication_test") + + # Create state_dict for this rank + state_dict = self._create_simple_state_dict(current_rank) + + # Initialize stager with custom storage_dir + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + storage_dir=custom_subdir, + ) + + # Perform replication to trigger persistence + stager.stage(state_dict) + + # Verify custom storage directory was created and used + self.assertEqual( + stager._storage_dir, + custom_subdir, + "Should use custom storage directory", + ) + self.assertTrue( + os.path.isdir(custom_subdir), + "Custom storage directory should be created", + ) + + # Verify file was created in custom directory + expected_path = stager._get_persisted_path(current_rank, partner_rank) + + self.assertTrue( + os.path.exists(expected_path), + f"Persisted file should exist in custom directory at {expected_path}", + ) + + # Load and verify the persisted state_dict + loaded_state_dict = torch.load(expected_path) + self._verify_simple_state_dict_replication( + loaded_state_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index ac34246ee643..c80602c5d50f 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -31,10 +31,10 @@ sys.exit(0) -_DISTRIBUTED_STATE_DICT_IMPLS = { +_DISTRIBUTED_STATE_DICT_IMPLS = ( StateDictType.LOCAL_STATE_DICT, StateDictType.SHARDED_STATE_DICT, -} +) class TestDistributedCheckpoint(FSDPTest): diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index a2ac5baaebf7..ae91911bc6a0 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) d_hid = 512 -batch_size = 256 +batch_size = 64 torch.manual_seed(0) device_type = "cuda" @@ -60,38 +60,139 @@ def backend_str(cls) -> str: def device(self) -> torch.device: return torch.device(device_type, self.rank) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize("ScheduleClass", [_ScheduleForwardOnly]) - def test_forward_only(self, ScheduleClass): - mod = MultiMLP(d_hid, n_layers=self.world_size) - mod.to(self.device) + def _setup_models_and_data(self, n_layers=None, model_class=MultiMLP): + """Setup models, input data, target data, and loss function.""" + if n_layers is None: + n_layers = self.world_size - mod_ref = copy.deepcopy(mod) + full_mod = model_class(d_hid, n_layers=n_layers) + full_mod.to(self.device) + ref_mod = copy.deepcopy(full_mod) x = torch.randn(batch_size, d_hid, device=self.device) - x_clone = x.clone() + with torch.no_grad(): + y = ref_mod(x) + target = y + torch.randn(batch_size, d_hid, device=self.device) - num_microbatches = 2 * self.world_size - x_mb = x.chunk(num_microbatches)[0] + loss_fn = torch.nn.MSELoss(reduction="sum") + return full_mod, ref_mod, x, target, loss_fn - # Create a pipeline - split_spec = mod.split_spec if hasattr(mod, "split_spec") else None - pipe = pipeline( - mod, - mb_args=(x_mb,), - split_spec=split_spec, - ) + def _create_single_stage_pipeline(self, mod, x, chunks, use_tracer=True): + """Create a single-stage pipeline using either tracer or manual stage creation.""" + if use_tracer: + x_mb = x.chunk(chunks)[0] + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline(mod, mb_args=(x_mb,), split_spec=split_spec) + stage = pipe.build_stage(self.rank, self.device) + stage_module = pipe.get_stage_module(self.rank) + return stage, stage_module, [stage_module] + else: + # Manual stage creation + submod_name = f"layers.{self.rank}" + stage_module = mod.get_submodule(submod_name) + stage = PipelineStage(stage_module, self.rank, self.world_size, self.device) + return stage, stage_module, [stage_module] + + def _create_multi_stage_pipeline( + self, mod, stages_per_rank, n_stages, stage_indices=None + ): + """Create multiple pipeline stages for interleaved schedules.""" + if stage_indices is None: + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] - stage = pipe.build_stage( - self.rank, - self.device, - ) + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [mod.get_submodule(submod_name) for submod_name in submod_names] + stages = [ + PipelineStage(stage_module, stage_idx, n_stages, self.device) + for stage_module, stage_idx in zip(stage_modules, stage_indices) + ] + return stages, stage_modules, submod_names - # Attach to a schedule + def _run_reference_model( + self, ref_mod, x, target, loss_fn, num_iterations=2, **kwargs + ): + """Run reference model for specified iterations and return final output and loss.""" + ref_out = None + ref_loss = None + + for _ in range(num_iterations): + ref_mod.zero_grad() + ref_out = ref_mod(x, **kwargs) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + return ref_out, ref_loss + + def _check_gradients( + self, stage_modules, ref_mod, submod_names=None, rtol=1e-5, atol=4e-5 + ): + """Check that gradients match between pipeline stages and reference model using flexible comparison.""" + + def grad_check(grad1, grad2, param_name, rtol, atol, tolerance=0.05): + if grad1 is None and grad2 is None: + return + if grad1 is None or grad2 is None: + raise AssertionError( + f"One gradient is None for {param_name}: {grad1} vs {grad2}" + ) + torch.testing.assert_close(grad1, grad2, rtol=rtol, atol=atol) + + if submod_names is None: + # Single stage case - need to detect tracer vs manual pipeline + stage_modules = [stage_modules] + + # Try to detect if this is a tracer-based pipeline by checking if parameter exists in ref_mod + sample_param_name = next(iter(stage_modules[0].named_parameters()))[0] + try: + # Try to get parameter directly from reference model (tracer-based) + ref_mod.get_parameter(sample_param_name) + is_tracer_based = True + except AttributeError: + # Parameter doesn't exist at root level, must be manual pipeline + is_tracer_based = False + + if is_tracer_based: + # Tracer-based pipeline: parameter names are full paths from root model + for name, p in stage_modules[0].named_parameters(): + ref_p = ref_mod.get_parameter(name) + grad_check(p.grad, ref_p.grad, name, rtol, atol) + else: + # Manual pipeline: parameter names are local to the submodule + submod_name = f"layers.{self.rank}" + ref_submod = ref_mod.get_submodule(submod_name) + for name, p in stage_modules[0].named_parameters(): + ref_p = ref_submod.get_parameter(name) + grad_check(p.grad, ref_p.grad, f"{submod_name}.{name}", rtol, atol) + else: + # Multi-stage case - always use submodule approach + for stage_module, submod_name in zip(stage_modules, submod_names): + ref_submod = ref_mod.get_submodule(submod_name) + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + grad_check(p.grad, ref_p.grad, f"{submod_name}.{name}", rtol, atol) + + def _zero_gradients(self, stage_modules): + """Zero gradients for all stage modules.""" + if not isinstance(stage_modules, list): + stage_modules = [stage_modules] + for stage_module in stage_modules: + stage_module.zero_grad() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [_ScheduleForwardOnly]) + def test_forward_only(self, ScheduleClass): + mod, mod_ref, x, _, _ = self._setup_models_and_data() + x_clone = x.clone() + + num_microbatches = 2 * self.world_size + stage, _, _ = self._create_single_stage_pipeline(mod, x, num_microbatches) schedule = ScheduleClass(stage, num_microbatches, scale_grads=False) - # Run + # Run forward-only schedule + out = None num_iters = 20 for _ in range(num_iters): if self.rank == 0: @@ -103,11 +204,10 @@ def test_forward_only(self, ScheduleClass): else: schedule.step() - # Validate pipelined output is the same as reference model + # Validate pipelined output matches reference model if self.rank == self.world_size - 1: for _ in range(num_iters): x_clone = mod_ref(x_clone) - torch.testing.assert_close(x_clone, out) @requires_nccl() @@ -123,6 +223,7 @@ def test_forward_only(self, ScheduleClass): ], ) def test_eval_inference_mode(self, ScheduleClass): + num_microbatches = 4 if ScheduleClass in [ ScheduleInterleaved1F1B, ScheduleLoopedBFS, @@ -131,158 +232,64 @@ def test_eval_inference_mode(self, ScheduleClass): # Multi-stage schedules stages_per_rank = 2 n_stages = stages_per_rank * self.world_size - mod = MultiMLP(d_hid, n_layers=n_stages) - mod.to(self.device) - - x = torch.randn(batch_size, d_hid, device=self.device) - target = torch.randn(batch_size, d_hid, device=self.device) - loss_fn = torch.nn.MSELoss(reduction="sum") - - chunks = 4 - stage_indices = [ - self.rank + i * self.world_size for i in range(stages_per_rank) - ] - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - mod.get_submodule(submod_name) for submod_name in submod_names - ] - stages = [ - PipelineStage( - stage_module, - stage_idx, - n_stages, - self.device, - ) - for stage_module, stage_idx in zip(stage_modules, stage_indices) - ] - - # Test with eval() method for inference - schedule = ScheduleClass(stages, chunks, loss_fn=loss_fn, scale_grads=False) + mod, _, x, target, loss_fn = self._setup_models_and_data(n_layers=n_stages) - # Clear gradients - for stage_module in stage_modules: - stage_module.zero_grad() - - if self.rank == 0: - # Support with and without no_grad() - with torch.no_grad(): - schedule.eval(x) - elif self.rank == self.world_size - 1: - losses = [] - schedule.eval(target=target, losses=losses) - else: - schedule.eval() - - # Check that gradients were NOT computed during eval - grad_computed_eval = False - for stage_module in stage_modules: - for param in stage_module.parameters(): - if param.grad is not None: - grad_computed_eval = True - break - if grad_computed_eval: - break - - # Verify that gradients were not computed during eval - self.assertFalse( - grad_computed_eval, - "Gradients should not be computed during eval()", + # Create multi-stage pipeline + stages, stage_modules, _ = self._create_multi_stage_pipeline( + mod, stages_per_rank, n_stages + ) + schedule = ScheduleClass( + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) - - # Verify that losses are still computed during eval - if self.rank == self.world_size - 1: - self.assertTrue( - len(losses) > 0, "Losses should be computed during eval()" - ) else: # Single-stage schedules - mod = MultiMLP(d_hid, n_layers=self.world_size) - mod.to(self.device) - - x = torch.randn(batch_size, d_hid, device=self.device) - target = torch.randn(batch_size, d_hid, device=self.device) - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, _, x, target, loss_fn = self._setup_models_and_data() - chunks = 4 - x_mb = x.chunk(chunks)[0] - - # Create a pipeline - split_spec = mod.split_spec if hasattr(mod, "split_spec") else None - pipe = pipeline( - mod, - mb_args=(x_mb,), - split_spec=split_spec, + # Create single-stage pipeline + stage, stage_module, _ = self._create_single_stage_pipeline( + mod, x, num_microbatches ) - - stage = pipe.build_stage( - self.rank, - self.device, + stage_modules = [stage_module] + schedule = ScheduleClass( + stage, num_microbatches, loss_fn=loss_fn, scale_grads=False ) - # Test with eval() method for inference - schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) + # Clear gradients and run eval + self._zero_gradients(stage_modules) + losses = [] - # Get stage module for gradient checking - stage_module = pipe.get_stage_module(self.rank) - stage_module.zero_grad() + if self.rank == 0: + # Support with and without no_grad() + with torch.no_grad(): + schedule.eval(x) + elif self.rank == self.world_size - 1: + schedule.eval(target=target, losses=losses) + else: + schedule.eval() - if self.rank == 0: - # Support with and without no_grad() - with torch.no_grad(): - schedule.eval(x) - elif self.rank == self.world_size - 1: - losses = [] - schedule.eval(target=target, losses=losses) - else: - schedule.eval() - - # Check that gradients were NOT computed during eval - grad_computed_eval = False - for param in stage_module.parameters(): - if param.grad is not None: - grad_computed_eval = True - break - - # Verify that gradients were not computed during eval - self.assertFalse( - grad_computed_eval, - "Gradients should not be computed during eval()", - ) + # Check that gradients were NOT computed during eval + grad_computed_eval = any( + param.grad is not None + for stage_module in stage_modules + for param in stage_module.parameters() + ) - # Verify that losses are still computed during eval - if self.rank == self.world_size - 1: - self.assertTrue( - len(losses) > 0, "Losses should be computed during eval()" - ) + # Verify that gradients were not computed during eval + self.assertFalse( + grad_computed_eval, "Gradients should not be computed during eval()" + ) + + # Verify that losses are still computed during eval + if self.rank == self.world_size - 1: + self.assertTrue(len(losses) > 0, "Losses should be computed during eval()") @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_multi_iter(self, ScheduleClass): - mod = MultiMLP(d_hid, n_layers=self.world_size) - mod.to(self.device) - - x = torch.randn(batch_size, d_hid, device=self.device) - target = torch.randn(batch_size, d_hid, device=self.device) - loss_fn = torch.nn.MSELoss(reduction="sum") - + mod, _, x, target, loss_fn = self._setup_models_and_data() chunks = 4 - x_mb = x.chunk(chunks)[0] - - # Create a pipeline - split_spec = mod.split_spec if hasattr(mod, "split_spec") else None - pipe = pipeline( - mod, - mb_args=(x_mb,), - split_spec=split_spec, - ) - - stage = pipe.build_stage( - self.rank, - self.device, - ) - - # Attach to a schedule + stage, _, _ = self._create_single_stage_pipeline(mod, x, chunks) schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) # Run @@ -333,10 +340,11 @@ def test_kwargs_with_tracer(self, ScheduleClass): schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) # Run + out = None + losses = [] if self.rank == 0: schedule.step(x, y=y) elif self.rank == group_size - 1: - losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() @@ -355,52 +363,26 @@ def test_kwargs_with_tracer(self, ScheduleClass): @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_grad_with_tracer(self, ScheduleClass): - mod = MultiMLP(d_hid, n_layers=self.world_size) - mod.to(self.device) - - ref_mod = copy.deepcopy(mod) - x = torch.randn(batch_size, d_hid, device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data() # Run reference - for _ in range(2): - ref_mod.zero_grad() - ref_out = ref_mod(x) - ref_loss = loss_fn(ref_out, target) - ref_loss.backward() + ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn) - # Create a pipeline + # Create pipeline and schedule chunks = 2 * self.world_size - x_mb = x.chunk(chunks)[0] - split_spec = mod.split_spec if hasattr(mod, "split_spec") else None - pipe = pipeline( - mod, - mb_args=(x_mb,), - split_spec=split_spec, + stage, stage_module, stage_modules = self._create_single_stage_pipeline( + mod, x, chunks ) - - stage = pipe.build_stage( - self.rank, - self.device, - ) - - # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) - # Run - stage_module = pipe.get_stage_module(self.rank) + # Run pipeline + out = None + losses = [] for _ in range(2): - # Zero gradients - stage_module.zero_grad() + self._zero_gradients(stage_module) if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: - losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() @@ -409,81 +391,53 @@ def test_grad_with_tracer(self, ScheduleClass): # Last rank checks result if self.rank == self.world_size - 1: - # Check output torch.testing.assert_close(out, ref_out) - # Check loss - # Since the reduction used in the loss function above is "sum", we use - # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - for name, p in stage_module.named_parameters(): - ref_p = ref_mod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) - except AssertionError: - print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") - raise + # Check gradients using helper method + self._check_gradients(stage_module, ref_mod) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @parametrize("shape_inference", [True, False]) def test_grad_with_manual(self, ScheduleClass, shape_inference): - full_mod = MultiMLP(d_hid, n_layers=self.world_size) - full_mod.to(self.device) - - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data() # Run reference - for _ in range(2): - ref_mod.zero_grad() - ref_out = ref_mod(x) - ref_loss = loss_fn(ref_out, target) - ref_loss.backward() + ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn) - # Get a submodule, e.g. `layers.0` or `layers.1` - submod_name = f"layers.{self.rank}" - stage_module = full_mod.get_submodule(submod_name) + # Create manual pipeline stage chunks = 2 * self.world_size + stage, stage_module, _ = self._create_single_stage_pipeline( + mod, x, chunks, use_tracer=False + ) - if shape_inference: - input_args = None - output_args = None - else: + # Handle shape inference + if not shape_inference: input_args = (x.chunk(chunks)[0],) with torch.no_grad(): output_args = stage_module(*input_args) + stage = PipelineStage( + stage_module, + self.rank, + self.world_size, + self.device, + input_args=input_args, + output_args=output_args, + ) - # Create a pipeline stage to wrap that submodule - stage = PipelineStage( - stage_module, - self.rank, - self.world_size, - self.device, - input_args=input_args, - output_args=output_args, - ) - - # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) - # Run + # Run pipeline + out = None + losses = [] for _ in range(2): - # Zero gradients - stage_module.zero_grad() + self._zero_gradients(stage_module) if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: - losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() @@ -492,23 +446,12 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): # Last rank checks result if self.rank == self.world_size - 1: - # Check output torch.testing.assert_close(out, ref_out) - # Check loss - # Since the reduction used in the loss function above is "sum", we use - # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - ref_submod = ref_mod.get_submodule(submod_name) - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) - except AssertionError: - print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") - raise + # Check gradients using helper method + self._check_gradients(stage_module, ref_mod) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -524,95 +467,63 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size - full_mod = MultiMLP(d_hid, n_layers=n_stages) - full_mod.to(self.device) - - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data( + n_layers=n_stages + ) # Run reference - for _ in range(2): - ref_mod.zero_grad() - ref_out = ref_mod(x) - ref_loss = loss_fn(ref_out, target) - ref_loss.backward() + ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn) + + # Create multi-stage pipeline + stages, stage_modules, submod_names = self._create_multi_stage_pipeline( + mod, stages_per_rank, n_stages + ) + print(f"Rank {self.rank} stages: {[stage.stage_index for stage in stages]}") - # Get a submodule, e.g. `layers.0` or `layers.1` - stage_indices = [ - self.rank + i * self.world_size for i in range(stages_per_rank) - ] - print(f"Rank {self.rank} stages: {stage_indices}") - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - full_mod.get_submodule(submod_name) for submod_name in submod_names - ] - # Create a pipeline stage to wrap that submodule num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") else 2 * self.world_size ) - stages = [ - PipelineStage( - stage_module, - stage_idx, - n_stages, - self.device, - ) - for stage_module, stage_idx in zip(stage_modules, stage_indices) - ] - # Attach to a schedule + # Create schedule schedule = ScheduleClass( stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) + + # Handle new runtime testing if use_new_runtime: old_schedule = schedule tmp_schedule = _PipelineScheduleRuntime( - stages, - num_microbatches, - loss_fn=loss_fn, - scale_grads=False, + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) tmp_schedule._load_actions(old_schedule.pipeline_order) - # test that csv round-trip works for compute_comms schedule + + # Test CSV round-trip for compute_comms schedule schedule = _PipelineScheduleRuntime( - stages, - num_microbatches, - loss_fn=loss_fn, - scale_grads=False, + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) with tempfile.NamedTemporaryFile() as f: tmp_schedule._dump_csv(f.name) f.seek(0) schedule._load_csv(f.name, format="compute_comms") + one_more_schedule = _PipelineScheduleRuntime( - stages, - num_microbatches, - loss_fn=loss_fn, - scale_grads=False, + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) one_more_schedule._load_actions( schedule.pipeline_order_with_comms, format="compute_comms" ) + + # Verify schedule consistency self.assertEqual( len(schedule.pipeline_order_with_comms), - len( - one_more_schedule.pipeline_order_with_comms, - ), + len(one_more_schedule.pipeline_order_with_comms), ) for rank in schedule.pipeline_order_with_comms: self.assertEqual( len(schedule.pipeline_order_with_comms[rank]), - len( - one_more_schedule.pipeline_order_with_comms[rank], - ), + len(one_more_schedule.pipeline_order_with_comms[rank]), ) for a, b in zip( schedule.pipeline_order_with_comms[rank], @@ -620,19 +531,19 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): ): self.assertEqual(a, b) - # Run + # Run pipeline with tensor leak checking + out = None + losses = [] with check_leaked_tensors() as garbage_tensors: for _ in range(2): - # Zero gradients - for stage_module in stage_modules: - stage_module.zero_grad() + self._zero_gradients(stage_modules) if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: - losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() + self.assertEqual( len(garbage_tensors), 0, @@ -640,28 +551,17 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): ) dist.barrier() - # Last rank checks result + # Verify results if self.rank == self.world_size - 1: - # Check output torch.testing.assert_close(out, ref_out) - # Check loss - # Since the reduction used in the loss function above is "sum", we use - # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - for stage_module, submod_name in zip(stage_modules, submod_names): - # Get corresponding submodule from reference model - ref_submod = ref_mod.get_submodule(submod_name) - # Check gradients per parameter - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=1e-3) - except AssertionError: - print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") - raise + # Check gradients - use relaxed tolerances for interleaved schedules + # since gradients are small + self._check_gradients( + stage_modules, ref_mod, submod_names, rtol=5e-3, atol=5e-3 + ) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -671,54 +571,29 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): if ScheduleClass is ScheduleInterleavedZeroBubble: n_stages = 4 num_microbatches = 2 * n_stages - rank_stages = { - 0: [0, 2], - 1: [1, 3], - } + rank_stages = {0: [0, 2], 1: [1, 3]} else: n_stages = ScheduleClass.n_stages num_microbatches = ScheduleClass.num_microbatches rank_stages = ScheduleClass.rank_stages num_steps = 4 - full_mod = MultiMLP(d_hid, n_layers=n_stages) - full_mod.to(self.device) - - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) - # x = torch.randn(batch_size, d_hid, device=self.device, requires_grad=True) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data( + n_layers=n_stages + ) - # Create a pipeline stage to wrap that submodule + # Create multi-stage pipeline with custom stage indices stage_indices = rank_stages[self.rank] print(f"Rank {self.rank} stages: {stage_indices}") - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - full_mod.get_submodule(submod_name) for submod_name in submod_names - ] - stages = [ - PipelineStage( - stage_module, - stage_idx, - n_stages, - self.device, - ) - for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) - ] + stages, stage_modules, submod_names = self._create_multi_stage_pipeline( + mod, len(stage_indices), n_stages, stage_indices + ) - # We set scale_grads=False since we use a loss function that sums instead of mean-reduces - # (note: normally we recommend using mean-reduce loss functions, but we preserve at least one test case - # using sum scaling for completeness) schedule = ScheduleClass( stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) - # Run reference + # Run reference model ref_x = x.detach().clone().requires_grad_(x.requires_grad) torch.testing.assert_close(x, ref_x) for _ in range(num_steps): @@ -726,105 +601,60 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): ref_loss = loss_fn(ref_out, target) ref_loss.backward() + # Run pipeline with tensor leak checking + losses = [] with check_leaked_tensors() as garbage_tensors: - # Run pipelined stages for _ in range(num_steps): if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: - losses = [] schedule.step(target=target, losses=losses) else: schedule.step() + self.assertEqual( len(garbage_tensors), 0, "Found leaked tensors, check logs above for debug info", ) - # Every rank checks parameters compared with the reference model - for stage_module, submod_name in zip(stage_modules, submod_names): - # Get corresponding submodule from reference model - ref_submod = ref_mod.get_submodule(submod_name) - # Check gradients per parameter - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) - except AssertionError: - print( - f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}" - ) - raise + # Check gradients using helper method + self._check_gradients(stage_modules, ref_mod, submod_names) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize( - "ScheduleClass", - [ - ScheduleWithReorderedB, - ], - ) + @parametrize("ScheduleClass", [ScheduleWithReorderedB]) def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): n_stages = 2 - num_microbatches = 2 stages_per_rank = 1 - full_mod = MultiMLP(d_hid, n_layers=n_stages) - full_mod.to(self.device) - - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data( + n_layers=n_stages + ) # Run reference - for _ in range(2): - ref_mod.zero_grad() - ref_out = ref_mod(x) - ref_loss = loss_fn(ref_out, target) - ref_loss.backward() + ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn) + + # Create pipeline stages + stages, stage_modules, submod_names = self._create_multi_stage_pipeline( + mod, stages_per_rank, n_stages + ) + print(f"Rank {self.rank} stages: {[stage.stage_index for stage in stages]}") - # Get a submodule, e.g. `layers.0` or `layers.1` - stage_indices = [ - self.rank + i * self.world_size for i in range(stages_per_rank) - ] - print(f"Rank {self.rank} stages: {stage_indices}") - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - full_mod.get_submodule(submod_name) for submod_name in submod_names - ] - # Create a pipeline stage to wrap that submodule num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") else 8 ) - stages = [ - PipelineStage( - stage_module, - stage_idx, - n_stages, - self.device, - ) - for stage_module, stage_idx in zip(stage_modules, stage_indices) - ] - # Attach to a schedule schedule = ScheduleClass( stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) assert isinstance(schedule, _PipelineScheduleRuntime) - # Run + # Run pipeline with tensor leak checking with check_leaked_tensors() as garbage_tensors: for _ in range(2): - # Zero gradients - for stage_module in stage_modules: - stage_module.zero_grad() + self._zero_gradients(stage_modules) if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: @@ -832,6 +662,7 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): out = schedule.step(target=target, losses=losses) else: schedule.step() + self.assertEqual( len(garbage_tensors), 0, @@ -839,28 +670,14 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): ) dist.barrier() - # Last rank checks result + # Verify results if self.rank == self.world_size - 1: - # Check output torch.testing.assert_close(out, ref_out) - # Check loss - # Since the reduction used in the loss function above is "sum", we use - # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - for stage_module, submod_name in zip(stage_modules, submod_names): - # Get corresponding submodule from reference model - ref_submod = ref_mod.get_submodule(submod_name) - # Check gradients per parameter - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) - except AssertionError: - print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") - raise + # Check gradients using helper method + self._check_gradients(stage_modules, ref_mod, submod_names) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -871,101 +688,57 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): if schedule_class is ScheduleZBVZeroBubble: n_stages = 4 - rank_stages = { - 0: [0, 3], - 1: [1, 2], - } + rank_stages = {0: [0, 3], 1: [1, 2]} else: n_stages = schedule_class.n_stages rank_stages = schedule_class.rank_stages - full_mod = MultiMLP(d_hid, n_layers=n_stages) - full_mod.to(self.device) - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data( + n_layers=n_stages + ) # Run reference - for _ in range(2): - ref_mod.zero_grad() - ref_out = ref_mod(x) - ref_loss = loss_fn(ref_out, target) - ref_loss.backward() + ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn) - # Create a pipeline stage to wrap that submodule + # Create multi-stage pipeline with custom stage indices num_microbatches = 1 stage_indices = rank_stages[self.rank] print(f"Rank {self.rank} stages: {stage_indices}") - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - full_mod.get_submodule(submod_name) for submod_name in submod_names - ] - stages = [ - PipelineStage( - stage_module, - stage_idx, - n_stages, - self.device, - ) - for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) - ] + stages, stage_modules, submod_names = self._create_multi_stage_pipeline( + mod, len(stage_indices), n_stages, stage_indices + ) schedule = schedule_class( - stages, - num_microbatches, - loss_fn=loss_fn, - scale_grads=False, + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) + if use_new_runtime: old_schedule = schedule schedule = _PipelineScheduleRuntime( - stages, - num_microbatches, - loss_fn=loss_fn, + stages, num_microbatches, loss_fn=loss_fn ) schedule._load_actions(old_schedule.pipeline_order) - # Run - # TODO how to better specify .step() when first and last stage are on rank 0... + # Run pipeline - special case where first and last stage are on rank 0 + out = None + losses = [] for _ in range(2): - # Zero gradients - for stage_module in stage_modules: - stage_module.zero_grad() + self._zero_gradients(stage_modules) if self.rank == 0: - losses = [] out = schedule.step(x, target=target, losses=losses) else: schedule.step() dist.barrier() - # Last rank checks result + # Verify results (rank 0 has both first and last stages) if self.rank == 0: - # Check output torch.testing.assert_close(out, ref_out) - # Check loss - # Since the reduction used in the loss function above is "sum", we use - # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - for stage_module, submod_name in zip(stage_modules, submod_names): - # Get corresponding submodule from reference model - ref_submod = ref_mod.get_submodule(submod_name) - # Check gradients per parameter - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) - except AssertionError: - print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") - raise + # Check gradients using helper method + self._check_gradients(stage_modules, ref_mod, submod_names) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -973,42 +746,18 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size - full_mod = MultiMLPWithDw(d_hid, n_layers=n_stages) - full_mod.to(self.device) - - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - ref_loss_fn = torch.nn.MSELoss(reduction="sum") - full_loss_fn = torch.nn.MSELoss(reduction="sum") - + full_mod, ref_mod, x, target, loss_fn = self._setup_models_and_data( + n_layers=n_stages, model_class=MultiMLPWithDw + ) full_mod.toggle() - # Get a submodule, e.g. `layers.0` or `layers.1` - stage_indices = [ - self.rank + i * self.world_size for i in range(stages_per_rank) - ] - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - full_mod.get_submodule(submod_name) for submod_name in submod_names - ] - # Run reference - for _ in range(2): - ref_stage_modules = [ - ref_mod.get_submodule(submod_name) for submod_name in submod_names - ] - for stage_module in ref_stage_modules: - stage_module.zero_grad() + ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn) - ref_mod.zero_grad() - ref_out = ref_mod(x) - ref_loss = ref_loss_fn(ref_out, target) - ref_loss.backward() + # Create multi-stage pipeline with custom dw_builder + stages, stage_modules, submod_names = self._create_multi_stage_pipeline( + full_mod, stages_per_rank, n_stages + ) class CustomState: def __init__(self, stage_module, stage_idx, rank): @@ -1019,7 +768,6 @@ def __init__(self, stage_module, stage_idx, rank): def dw_builder(self): def dw_runner(): - # This inner function would be called by PipelineStage during `backward_weight_one_chunk` self.i += 1 print( f"[Rank {self.rank}] dw_count={self.i} stage={self.stage_idx}" @@ -1028,12 +776,14 @@ def dw_runner(): return dw_runner + # Create custom states and rebuild stages with dw_builder cs = {} + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] for stage_module, stage_idx in zip(stage_modules, stage_indices): cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank) - # Create a pipeline stage to wrap that submodule - chunks = 2 stages = [ PipelineStage( stage_module, @@ -1045,43 +795,30 @@ def dw_runner(): for stage_module, stage_idx in zip(stage_modules, stage_indices) ] - # Attach to a schedule - schedule = ScheduleClass( - stages, chunks, loss_fn=full_loss_fn, scale_grads=False - ) + schedule = ScheduleClass(stages, 2, loss_fn=loss_fn, scale_grads=False) + # Run pipeline + out = None + losses = [] for _ in range(2): - # Zero gradients - for stage_module in stage_modules: - stage_module.zero_grad() + self._zero_gradients(stage_modules) if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: - losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() - # Last rank checks result + + # Verify results if self.rank == self.world_size - 1: - # Check output torch.testing.assert_close(out, ref_out) - - # Check loss - # Since the reduction used in the loss function above is "sum", we use - # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - for stage_module, submod_name in zip(stage_modules, submod_names): - # Get corresponding submodule from reference model - ref_submod = ref_mod.get_submodule(submod_name) - # Check gradients per parameter - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + # Check gradients using helper method + self._check_gradients(stage_modules, ref_mod, submod_names) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -1092,53 +829,21 @@ def dw_runner(): def test_zero_bubble_with_model_kwargs(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size - full_mod = MultiMLPKwargs(d_hid, n_layers=n_stages) - full_mod.to(self.device) - - ref_mod = copy.deepcopy(full_mod) - x = torch.randn(batch_size, d_hid, device=self.device) + mod, ref_mod, x, target, loss_fn = self._setup_models_and_data( + n_layers=n_stages, model_class=MultiMLPKwargs + ) unused_kwarg = torch.tensor([1.0], device=self.device) - with torch.no_grad(): - y = ref_mod(x) - # Add a small perturbation - target = y + torch.randn(batch_size, d_hid, device=self.device) - - loss_fn = torch.nn.MSELoss(reduction="sum") - - # Get a submodule, e.g. `layers.0` or `layers.1` - stage_indices = [ - self.rank + i * self.world_size for i in range(stages_per_rank) - ] - submod_names = [f"layers.{i}" for i in stage_indices] - stage_modules = [ - full_mod.get_submodule(submod_name) for submod_name in submod_names - ] - # Run reference - for _ in range(2): - ref_stage_modules = [ - ref_mod.get_submodule(submod_name) for submod_name in submod_names - ] - for stage_module in ref_stage_modules: - stage_module.zero_grad() - - ref_mod.zero_grad() - ref_out = ref_mod(x, unused_kwarg=unused_kwarg) - ref_loss = loss_fn(ref_out, target) - ref_loss.backward() + # Run reference with kwargs + ref_out, ref_loss = self._run_reference_model( + ref_mod, x, target, loss_fn, unused_kwarg=unused_kwarg + ) - # Create a pipeline stage to wrap that submodule - stages = [ - PipelineStage( - stage_module, - stage_idx, - n_stages, - self.device, - ) - for stage_module, stage_idx in zip(stage_modules, stage_indices) - ] + # Create multi-stage pipeline + stages, stage_modules, submod_names = self._create_multi_stage_pipeline( + mod, stages_per_rank, n_stages + ) - # Attach to a schedule num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") @@ -1148,10 +853,11 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) + # Run pipeline with kwargs + out = None + losses = [] for _ in range(2): - # Zero gradients - for stage_module in stage_modules: - stage_module.zero_grad() + self._zero_gradients(stage_modules) if self.rank == 0: schedule.step( x, @@ -1160,35 +866,22 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): .expand(num_microbatches, -1), ) elif self.rank == self.world_size - 1: - losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() - # Last rank checks result + + # Verify results if self.rank == self.world_size - 1: - # Check output torch.testing.assert_close(out, ref_out) - - # Check loss pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) - # Every rank checks gradients - for stage_module, submod_name in zip(stage_modules, submod_names): - # Get corresponding submodule from reference model - ref_submod = ref_mod.get_submodule(submod_name) - # Check gradients per parameter - for name, p in stage_module.named_parameters(): - ref_p = ref_submod.get_parameter(name) - try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=5e-3) - except AssertionError: - print( - f"Gradient test failed for {name}: {p.grad=} vs {ref_p.grad=}" - ) - raise + # Check gradients using helper method + self._check_gradients( + stage_modules, ref_mod, submod_names, rtol=1e-5, atol=5e-3 + ) instantiate_parametrized_tests(ScheduleTest) diff --git a/test/distributed/tensor/test_common_rules.py b/test/distributed/tensor/test_common_rules.py index b320f80fe03c..3450f8faa2b5 100644 --- a/test/distributed/tensor/test_common_rules.py +++ b/test/distributed/tensor/test_common_rules.py @@ -8,20 +8,17 @@ from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorContinuousTestBase, ) aten = torch.ops.aten -class CommonRulesTest(DTensorTestBase): - @property - def world_size(self) -> int: - # hard code world size to 4 as we need to test - # at least with 2d mesh - return 4 +class CommonRulesTest(DTensorContinuousTestBase): + # hard code world size to 4 as we need to test + # at least with 2d mesh + world_size = 4 def _gen_tensor_meta(self, shape): empty_tensor = torch.empty(shape) @@ -31,10 +28,9 @@ def _gen_tensor_meta(self, shape): empty_tensor.dtype, ) - @with_comms def test_einop_basic_propagation(self): # plain einsum, mm - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) mm_call = aten.mm.default # propagate col-wise sharding @@ -85,9 +81,8 @@ def test_einop_basic_propagation(self): self.assertIsNotNone(output_spec) self.assertTrue(output_spec.placements[0].is_partial()) - @with_comms def test_einop_pointwise_propagation(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) add_call = aten.add.Tensor # addition @@ -137,13 +132,12 @@ def test_einop_pointwise_propagation(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1, -1]) - @with_comms def test_einop_merge_sharding(self): # 2d mesh einop merge sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) mm_call = aten.mm.default @@ -163,12 +157,11 @@ def test_einop_merge_sharding(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, 1]) - @with_comms def test_einop_linearity(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) mm_call = aten.mm.default @@ -231,11 +224,10 @@ def test_einop_linearity(self): # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) - @with_comms def test_einop_multi_sharding_on_mesh_dim(self): # einop prop with multi sharding on same mesh dim mesh_shape = torch.arange(self.world_size) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) mm_call = aten.mm.default mat1, mat2 = [0, -1], [0, -1] @@ -260,12 +252,11 @@ def test_einop_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) - @with_comms def test_einop_errors(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) add_call = aten.add.Tensor mat1, mat2 = [0, -1], [1, -1] @@ -281,9 +272,8 @@ def test_einop_errors(self): with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {})) - @with_comms def test_pointwise_rules_broadcasting(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) where_call = aten.where.self inp1, inp2, inp3 = [0], [], [-1, -1] @@ -307,9 +297,8 @@ def test_pointwise_rules_broadcasting(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) - @with_comms def test_pointwise_rules_suggestion(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) lerp_call = aten.lerp.Scalar # propagate point-wise sharding @@ -335,13 +324,12 @@ def test_pointwise_rules_suggestion(self): self.assertEqual(len(schema_suggestion.args_schema), 3) self.assertEqual(schema_suggestion.args_schema[2], -1) - @with_comms def test_pointwise_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) add_call = aten.add.Tensor @@ -381,13 +369,12 @@ def test_pointwise_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) - @with_comms def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) add_call = aten.add_.Tensor diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 86f1e9d8fb47..54ec52ee32d4 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -166,8 +166,6 @@ def forward(self, b_buffer, x): return (view_as_1,)""", # noqa: B950 ) - # During tracing, sharding propagation cache is skipped, so an extra dry run for - # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), """\ @@ -175,8 +173,8 @@ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None - add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None - view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None + add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None + view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None return (view_1,)""", # noqa: B950 ) @@ -298,29 +296,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) - @skipIfHpu - def test_dtensor_dynamic_cat(self): - # RESET COUNTS - - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) - - # test passing in tuple of DTensors as - def fn(x, y): - return ( - torch.cat((x, y), dim=0) - .redistribute(device_mesh=x.device_mesh, placements=[Replicate()]) - .to_local()[0] - ) - - x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) - y = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) - torch._dynamo.mark_dynamic(x, 0) - ref = fn(x, y) - - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) - res = opt_fn(x, y) - self.assertEqual(res, ref) - def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index 3f724d9a85bf..e5dcdfe11c8c 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -160,7 +160,6 @@ def wrapped(fn): xfail("frexp"), xfail("full"), xfail("full_like"), - xfail("gather"), xfail("geometric"), xfail("geqrf"), xfail("grid_sampler_2d"), diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 93ce80f18ee1..2419720256de 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -271,14 +271,22 @@ def test_layer_norm_fwd(self): norm_shape_idx_list = list(range(x.ndim)) shard_dims = [-1, 0, 1, 2] elementwise_affine_list = [False, True] + + # Test RMSNorm as well if CUDA + norm_types = [torch.nn.LayerNorm] + if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): + norm_types.append(torch.nn.RMSNorm) + test_config_list = list( - itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + itertools.product( + norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list + ) ) # normalized shape is a torch.Size object - for shard_dim, norm_idx, elementwise_affine in test_config_list: + for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list: normalized_shape = x.shape[norm_idx:] - layer_norm = torch.nn.LayerNorm( + layer_norm = norm_type( normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, @@ -287,6 +295,7 @@ def test_layer_norm_fwd(self): def _replicate_fn(name, module, device_mesh): for name, param in module.named_parameters(): + # RMSNorm only has weight, LayerNorm has both weight and bias if name in ["weight", "bias"]: param_dist = torch.nn.Parameter( distribute_tensor(param, device_mesh, [Replicate()]) @@ -307,7 +316,7 @@ def _replicate_fn(name, module, device_mesh): self.assertLessEqual( comm_mode.get_total_counts(), 1, # TODO: This should be 0! - f"comm count={comm_mode.get_total_counts()}, " + f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -329,12 +338,20 @@ def test_layer_norm_bwd(self): norm_shape_idx_list = list(range(3)) shard_dims = [0, 1, 2] elementwise_affine_list = [False, True] + + # Test both LayerNorm and RMSNorm (if CUDA) + norm_types = [torch.nn.LayerNorm] + if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): + norm_types.append(torch.nn.RMSNorm) + test_config_list = list( - itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + itertools.product( + norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list + ) ) # normalized shape is a torch.Size object - for shard_dim, norm_idx, elementwise_affine in test_config_list: + for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list: x = torch.rand( batch, sentence_length, @@ -343,7 +360,7 @@ def test_layer_norm_bwd(self): requires_grad=True, ) normalized_shape = x.shape[norm_idx:] - layer_norm = torch.nn.LayerNorm( + layer_norm = norm_type( normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, @@ -364,9 +381,11 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( layer_norm_local.weight, layer_norm_dist.weight.full_tensor() ) - self.assertEqual( - layer_norm_local.bias, layer_norm_dist.bias.full_tensor() - ) + # RMSNorm doesn't have bias + if hasattr(layer_norm_local, "bias"): + self.assertEqual( + layer_norm_local.bias, layer_norm_dist.bias.full_tensor() + ) x_local = x.detach().clone().requires_grad_(True) x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) @@ -384,7 +403,7 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["forward"].values()), expected_fwd_comm, - f"comm count={comm_mode.get_total_counts()}, " + f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -398,7 +417,7 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["backward"].values()), expected_bwd_comm, - f"comm count={comm_mode.get_total_counts()}, " + f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -412,18 +431,22 @@ def _replicate_fn(name, module, device_mesh): is_tensor_partial(layer_norm_dist.weight.grad._spec), needs_reduction, ) - self.assertEqual( - is_tensor_partial(layer_norm_dist.bias.grad._spec), - needs_reduction, - ) + # RMSNorm doesn't have bias + if hasattr(layer_norm_dist, "bias"): + self.assertEqual( + is_tensor_partial(layer_norm_dist.bias.grad._spec), + needs_reduction, + ) self.assertEqual( layer_norm_local.weight.grad, layer_norm_dist.weight.grad.full_tensor(), ) - self.assertEqual( - layer_norm_local.bias.grad, - layer_norm_dist.bias.grad.full_tensor(), - ) + # RMSNorm doesn't have bias + if hasattr(layer_norm_local, "bias"): + self.assertEqual( + layer_norm_local.bias.grad, + layer_norm_dist.bias.grad.full_tensor(), + ) self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) @@ -432,8 +455,14 @@ def test_layer_norm_bwd_req_grad(self): device_mesh = self.build_device_mesh() batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32 + # Test both LayerNorm and RMSNorm (if CUDA) + norm_types = [torch.nn.LayerNorm] + if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): + norm_types.append(torch.nn.RMSNorm) + # build our subtest configurations and filter out invalid ones class SubTest(NamedTuple): + norm_type: type multidim_norm: bool elementwise_affine: bool emb_req_grad: bool @@ -443,19 +472,24 @@ class SubTest(NamedTuple): subtest_fails = {} valid_filter = ( # noqa: E731 lambda cfg: ( - not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:]) + not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[3:]) ) ) subtest_cfgs = list( filter( valid_filter, - [SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))], + [ + SubTest(norm_type, *cfg) + for norm_type in norm_types + for cfg in itertools.product(*(((False, True),) * 5)) + ], ) ) for subtest_cfg in subtest_cfgs: try: ( + norm_type, multidim_norm, elementwise_affine, emb_req_grad, @@ -473,7 +507,7 @@ def __init__(self): self.preln_embeddings = torch.nn.Embedding( vocab_size, embedding_dim ) - self.layer_norm = torch.nn.LayerNorm( + self.layer_norm = norm_type( normalized_shape, elementwise_affine=elementwise_affine ) self.postln_linear = torch.nn.Linear( @@ -572,104 +606,6 @@ def forward(self, tokens): f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" ) - @with_comms - def test_rms_norm_bwd(self): - device_mesh = self.build_device_mesh() - - # NLP example from pytorch docs - batch, sentence_length, embedding_dim = 20, 5, 10 - norm_shape_idx_list = list(range(3)) - shard_dims = [0] # non-first dimensional sharding is not supported - elementwise_affine_list = [False, True] - test_config_list = list( - itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) - ) - - # normalized shape is a torch.Size object - for shard_dim, norm_idx, elementwise_affine in test_config_list: - x = torch.rand( - batch, - sentence_length, - embedding_dim, - device=self.device_type, - requires_grad=True, - ) - normalized_shape = x.shape[norm_idx:] - rms_norm = torch.nn.RMSNorm( - normalized_shape, - elementwise_affine=elementwise_affine, - device=self.device_type, - ) - rms_norm_local = copy.deepcopy(rms_norm).to(self.device_type) - - def _replicate_fn(name, module, device_mesh): - for name, param in module.named_parameters(): - if name == "weight": - param_dist = torch.nn.Parameter( - distribute_tensor(param, device_mesh, [Replicate()]) - ) - module.register_parameter(name, param_dist) - - rms_norm_dist = distribute_module(rms_norm, device_mesh, _replicate_fn) - - if elementwise_affine: - self.assertEqual( - rms_norm_local.weight, rms_norm_dist.weight.full_tensor() - ) - - x_local = x.detach().clone().requires_grad_(True) - x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) - self.assertEqual(x_local, x_dist.full_tensor()) - - y_local = rms_norm_local(x_local) - # make sure that backward rms norm does not introduce extra collectives - comm_mode = CommDebugMode() - with comm_mode: - y_dist = rms_norm_dist(x_dist) - y_dist.sum().backward() - - # TODO: forward pass is sharding strategy is generated from composite, hence 1 more collective than layer_norm - # see: https://github.com/pytorch/pytorch/pull/158716#issuecomment-3096012679 - expected_fwd_comm = 0 if shard_dim < norm_idx else 2 - - self.assertEqual( - sum(comm_mode.comm_module_counts["Global"]["forward"].values()), - expected_fwd_comm, - f"comm count={comm_mode.get_total_counts()}, " - f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", - ) - - self.assertEqual(y_local, y_dist.full_tensor()) - - # backward step - y_local.sum().backward() - - expected_bwd_comm = 0 if shard_dim < norm_idx else 1 - - self.assertEqual( - sum(comm_mode.comm_module_counts["Global"]["backward"].values()), - expected_bwd_comm, - f"comm count={comm_mode.get_total_counts()}, " - f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", - ) - - if elementwise_affine: - # if input is sharded on any outer dimension, the gradient of weight - # should be Partial - dim_map = x_dist._spec.dim_map - outer_dims = range(norm_idx) - needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) - self.assertEqual( - is_tensor_partial(rms_norm_dist.weight.grad._spec), - needs_reduction, - ) - self.assertEqual( - rms_norm_local.weight.grad, - rms_norm_dist.weight.grad.full_tensor(), - ) - - self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) - @with_comms def test_topk(self): device_mesh = self.build_device_mesh() diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 5e98934249e9..180286bd2e1d 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -87,6 +87,38 @@ def test_init_ops(self): self._run_init_op(torch.randn_like, dtype=dtype) self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype) + @with_comms + @skip_if_lt_x_gpu(4) + def test_init_with_user_generator(self): + device_mesh = self.build_device_mesh() + torch.manual_seed(42) + rng = torch.Generator(device="cuda").manual_seed(42) + t1 = torch.distributed.tensor.empty( + (8, 3), device_mesh=device_mesh, placements=[Shard(0)] + ) + t2 = torch.distributed.tensor.empty( + (8, 3), device_mesh=device_mesh, placements=[Shard(0)] + ) + for i in range(2): + # run a second time, to make sure that `rng`'s offset-state is advancing on the second usage + torch.nn.init.uniform_(t1, 0.0, 1.0) + torch.nn.init.uniform_(t2, 0.0, 1.0, rng) + self.assertEqual(t1.full_tensor(), t2.full_tensor(), f"Failed at {i=}") + + # ensure that we do not cache the 'seed' of `rng` from the first time we see it in DTensor + # TODO: we have a semantics decision to make + # There is a discontinuity between how the default RNG and a user-supplied RNG behaves with DTensor: + # (a) if the user calls `torch.manual_seed` after already using the default RNG with DTensor, + # they may be surprised that it has no effect on DTensor. They must instead call this private API + # (`torch.distributed.tensor._random._rng_tracker._manual_seed`) + # (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that + # their RNG object never advances its state after using it with DTensor. + # torch.distributed.tensor._random._rng_tracker._manual_seed(55) + # rng.manual_seed(55) + # torch.nn.init.uniform_(t1, 0.0, 1.0) + # torch.nn.init.uniform_(t2, 0.0, 1.0, rng) + # self.assertEqual(t1.full_tensor(), t2.full_tensor()) + @with_comms @skip_if_lt_x_gpu(4) def test_meta_tensor_init(self): diff --git a/test/distributed/tensor/test_view_ops.py b/test/distributed/tensor/test_view_ops.py index 92de79bc188b..39f5b98d4eab 100644 --- a/test/distributed/tensor/test_view_ops.py +++ b/test/distributed/tensor/test_view_ops.py @@ -10,6 +10,7 @@ from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, + DTensor, init_device_mesh, Replicate, Shard, @@ -25,7 +26,7 @@ view_groups, ) from torch.distributed.tensor.debug import CommDebugMode -from torch.distributed.tensor.placement_types import Placement +from torch.distributed.tensor.placement_types import _StridedShard, Placement from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -168,8 +169,34 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): *(device_mesh.ndim * [sharding_choices]) ) - for in_shard in all_sharding_choices: - in_dt = distribute_tensor(args[0], device_mesh, in_shard) + outer_mesh = device_mesh["outer"] + inner_mesh = device_mesh["inner"] + inner_mesh_size = inner_mesh.size() + strided_sharding_choices = [ + (_StridedShard(i, split_factor=inner_mesh_size), Shard(i)) + for i, s in enumerate(in_shape) + if s > 1 and i not in no_shard_dims + ] + + for in_shard in itertools.chain(all_sharding_choices, strided_sharding_choices): + if isinstance(in_shard[0], _StridedShard): + if op != Tensor.view: + continue + # cannot produce DTensor using ``distribute_tensor()`` + # with ``_StridedShard``. Need to distribute the input + # over inner mesh dim first, then distribute the + # _local_tensor over the outer mesh dim. + in_dt = distribute_tensor(args[0], inner_mesh, (in_shard[1],)) + in_dt = distribute_tensor( + in_dt._local_tensor, outer_mesh, (Shard(in_shard[0].dim),) + ) + in_dt = DTensor.from_local( + in_dt._local_tensor, + device_mesh, + in_shard, + ) + else: + in_dt = distribute_tensor(args[0], device_mesh, in_shard) comm_mode = CommDebugMode() with comm_mode: @@ -216,8 +243,9 @@ def test_illegal_views(self): @with_comms def test_view_ops(self): - self.device_mesh = DeviceMesh( - self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) + mesh_shape = (dist.get_world_size() // 2, 2) + self.device_mesh = init_device_mesh( + self.device_type, mesh_shape=mesh_shape, mesh_dim_names=("outer", "inner") ) self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) @@ -442,7 +470,6 @@ def test_view_ops(self): (randn(42, 24, 36), 1), (InputDim(0), Singleton(), InputDim(1), InputDim(2)), ) - self.dimmap_test( Tensor.view, (randn(6, 12, 24), 72, 24), diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 80f6705f4ef6..bafc781b591c 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -1,5 +1,6 @@ # Owner(s): ["module: c10d"] import gc +import re import threading import unittest from datetime import timedelta @@ -77,10 +78,6 @@ def device(self) -> torch.device: return torch.device(f"cuda:{self.rank}") def _init_process_group(self) -> None: - # Allow testing aoti after torch.compile - torch._inductor.config.triton.store_cubin = True - torch._inductor.config.debug = True - torch.cuda.set_device(self.device) store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( @@ -713,6 +710,12 @@ def test_collectives(self) -> None: self.assertEqual(pg.dels, 4) +def find_buffer_assignments(code): + pattern = r"buf(\d+) = empty_strided_" + matches = re.finditer(pattern, code) + return tuple(f"buf{match.group(1)}" for match in matches) + + class CompileTestCPU(TestCase): def setUp(self): super().setUp() @@ -741,23 +744,23 @@ def func(arg: torch.Tensor) -> torch.Tensor: return ar0 arg = torch.rand(4, 4, device="cpu") - torch._inductor.config.cpp_wrapper = cpp_wrapper - compiled = torch.compile(func) + with torch._inductor.config.patch({"cpp_wrapper": cpp_wrapper}): + compiled = torch.compile(func) - _, (code,) = run_and_get_code(compiled, arg) - include_ops = ( - [ - "aoti_torch_cpu__c10d_functional_all_reduce_", - "aoti_torch_cpu__c10d_functional_wait_tensor", - ] - if cpp_wrapper - else [ - "torch.ops._c10d_functional.all_reduce_.default", - "torch.ops._c10d_functional.wait_tensor.default", - ] - ) - for op in include_ops: - self.assertIn(op, code) + _, (code,) = run_and_get_code(compiled, arg) + include_ops = ( + [ + "aoti_torch_cpu__c10d_functional_all_reduce_", + "aoti_torch_cpu__c10d_functional_wait_tensor", + ] + if cpp_wrapper + else [ + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.wait_tensor.default", + ] + ) + for op in include_ops: + self.assertIn(op, code) # Test aoti AOTIRunnerUtil.run(func, (arg,)) @@ -771,9 +774,6 @@ def test_inductor_all_reduce_cpu(self): class CompileTest(TestCase): def setUp(self): super().setUp() - # Allow testing aoti after torch.compile - torch._inductor.config.triton.store_cubin = True - torch._inductor.config.debug = True self.rank = 0 self.world_size = 2 @@ -807,22 +807,33 @@ def func(arg: torch.Tensor) -> torch.Tensor: compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) + buf0, buf1 = find_buffer_assignments(code) ( FileCheck() - .check("buf0 = empty") - .check("buf7 = empty") + .check(f"{buf0} = empty") + .check(f"{buf1} = empty") # Expect in-place with inductor allocated buf - .check("torch.ops._c10d_functional.all_reduce_.default(buf0") - .check("torch.ops._c10d_functional.wait_tensor.default(buf0") - # Expect no in-place with graph input (buf5 is a clone) - .check("torch.ops._c10d_functional.all_reduce_.default(buf7") - .check("torch.ops._c10d_functional.wait_tensor.default(buf7") + .check(f"torch.ops._c10d_functional.all_reduce_.default({buf0}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf0}") + # Expect no in-place with graph input + .check(f"torch.ops._c10d_functional.all_reduce_.default({buf1}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf1}") # Expect no extra copy on return - .check("return (buf0, buf7, )") + .check(f"return ({buf0}, {buf1}, )") .run(code) ) + # Check the return tensor from wait_tensor is not used anywhere assert "= torch.ops._c10d_functional.wait_tensor.default" not in code + with torch._inductor.config.patch({"cpp_wrapper": True}): + code = run_and_get_triton_code(compiled, arg) + # Check the return tensors from all_reduce and wait_tensor are not used anywhere by + # checking if they are explicitly deleted by calling aoti_torch_delete_tensor_object + FileCheck().check_not( + # all_reduce must have been rewritten into all_reduce_ + "aoti_torch_cpu__c10d_functional_all_reduce(buf" + ).check_count("aoti_torch_delete_tensor_object(buf", 4).run(code) + # Test aoti AOTIRunnerUtil.run(func, (arg,)) torch.cuda.synchronize() @@ -843,26 +854,27 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: args = [torch.rand(4, 4, device="cuda") for _ in range(2)] compiled = torch.compile(func) code = run_and_get_triton_code(compiled, args) + buf0, buf1, buf2, buf3 = find_buffer_assignments(code) ( FileCheck() - .check("buf0 = empty") - .check("buf5 = empty") - .check("buf1 = empty") - .check("buf6 = empty") + .check(f"{buf0} = empty") + .check(f"{buf1} = empty") + .check(f"{buf2} = empty") + .check(f"{buf3} = empty") # Expect in-place with inductor allocated buf .check( - "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]" + f"torch.ops._c10d_functional.all_reduce_coalesced_.default([{buf0}, {buf2}]" ) - # Expect no in-place with graph input (buf5, buf6 are clones) + # Expect no in-place with graph input ({buf1}, {buf3} are clones) .check( - "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]" + f"torch.ops._c10d_functional.all_reduce_coalesced_.default([{buf1}, {buf3}]" ) - .check("torch.ops._c10d_functional.wait_tensor.default(buf0") - .check("torch.ops._c10d_functional.wait_tensor.default(buf1") - .check("torch.ops._c10d_functional.wait_tensor.default(buf5") - .check("torch.ops._c10d_functional.wait_tensor.default(buf6") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf0}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf2}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf1}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf3}") # Expect no extra copy on return - .check("return (buf0, buf1, buf5, buf6, )") + .check(f"return ({buf0}, {buf2}, {buf1}, {buf3}, )") .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code @@ -884,14 +896,15 @@ def func(arg: torch.Tensor) -> torch.Tensor: compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) + (buf0,) = find_buffer_assignments(code) ( FileCheck() - .check("buf0 = empty") + .check(f"{buf0} = empty") # We always call .contiguous() on the input to all_reduce_, # so input will not be a view anymore. - .check("torch.ops._c10d_functional.all_reduce_.default(buf0") - .check("torch.ops._c10d_functional.wait_tensor.default(buf0") - .check("return (buf0") + .check(f"torch.ops._c10d_functional.all_reduce_.default({buf0}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf0}") + .check(f"return ({buf0}") .run(code) ) @@ -938,20 +951,21 @@ def func(arg: torch.Tensor) -> torch.Tensor: arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) + buf0, buf1 = find_buffer_assignments(code) ( FileCheck() # Expect allocation - .check("buf0 = empty") - .check("torch.ops._c10d_functional.all_reduce_.default(buf0") - .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + .check(f"{buf0} = empty") + .check(f"torch.ops._c10d_functional.all_reduce_.default({buf0}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf0}") # Expect allocation - .check("buf7 = empty") - .check("extern_kernels.mm(arg0_1, buf0, out=buf7") - # Expect buf0 to be reused - .check("buf8 = buf0; del buf0 # reuse") - .check("extern_kernels.mm(arg0_1, buf7, out=buf8") + .check(f"{buf1} = empty") + .check(f"extern_kernels.mm(arg0_1, {buf0}, out={buf1}") + # Expect {buf0} to be reused + .check(f"buf8 = {buf0}; del {buf0} # reuse") + .check(f"extern_kernels.mm(arg0_1, {buf1}, out=buf8") # Expect no extra copy on return - .check("return (buf7, buf8, )") + .check(f"return ({buf1}, buf8, )") .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code @@ -1166,19 +1180,20 @@ def func(arg: torch.Tensor) -> torch.Tensor: compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) + buf0, buf1 = find_buffer_assignments(code) ( FileCheck() - .check("buf0 = empty") - .check("buf1 = buf0") - .check("buf8 = empty") + .check(f"{buf0} = empty") + .check(f"buf1 = {buf0}") + .check(f"{buf1} = empty") # Expect in-place with inductor allocated buf .check("torch.ops._c10d_functional.broadcast_.default(buf1") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") # Expect no in-place with graph input (buf5 is a clone) - .check("torch.ops._c10d_functional.broadcast_.default(buf8") - .check("torch.ops._c10d_functional.wait_tensor.default(buf8") + .check(f"torch.ops._c10d_functional.broadcast_.default({buf1}") + .check(f"torch.ops._c10d_functional.wait_tensor.default({buf1}") # Expect no extra copy on return - .check("return (buf1, buf8, )") + .check(f"return (buf1, {buf1}, )") .run(code) ) diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 5a9f22f5aeaa..95bc8b534523 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -1289,7 +1289,7 @@ def test_allgather_basics_cuda(self): @requires_gloo() def test_allgather_noncontiguous_input(self): # Take a column of 2D tensor, such that memory is not dense - self._test_allgather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0]) + self._test_allgather_basics(lambda t: t.expand(2, 2).tril().contiguous()[:, 0]) @requires_gloo() def test_allgather_inference_mode(self): diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index fd9e7594828d..a1e8d30fef6c 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3172,35 +3172,42 @@ def test_nccl_user_buffer_registration(self): @requires_multicast_support() def test_nccl_window_registration(self): store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - backend="nccl", rank=self.rank, world_size=self.world_size, store=store - ) device = torch.device(f"cuda:{self.rank}") - torch.cuda.set_device(self.rank) - pg = c10d.distributed_c10d._get_default_group() - backend = pg._get_backend(torch.device(device)) - - # Use NCCL memory allocator - # enable symmetric memory usage in NCCL - pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True) - - # allocate memory with ncclMemAlloc - # note: symmetric kernels are not available for dtypes like torch.int64 - with torch.cuda.use_mem_pool(pool): - tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32) + with torch.cuda.device(device): + # Eager init the nccl comm so that we don't implicitly create one during register_mem_pool + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store, + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + backend = pg._get_backend(torch.device(device)) + + # Use NCCL memory allocator + # enable symmetric memory usage in NCCL + pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True) + + # allocate memory with ncclMemAlloc + # note: symmetric kernels are not available for dtypes like torch.int64 + with torch.cuda.use_mem_pool(pool): + tensor = torch.arange( + 1024 * 1024 * 2, device=device, dtype=torch.float32 + ) - # register buffers to NCCL - backend.register_mem_pool(pool) + # register buffers to NCCL + backend.register_mem_pool(pool) - # allreduce now should use NVIDIA Switches - pg.allreduce(tensor).wait() - torch.cuda.synchronize(device=device) + # allreduce now should use NVIDIA Switches + pg.allreduce(tensor).wait() + torch.cuda.synchronize(device=device) - # de-register buffers from NCCL - backend.deregister_mem_pool(pool) + # de-register buffers from NCCL + backend.deregister_mem_pool(pool) - # clean up memory - del tensor, pool + # clean up memory + del tensor, pool with open(os.environ["NCCL_DEBUG_FILE"]) as f: nccl_debug_file_content = f.read() diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index 63ff2fa2bbfe..c05d5edae233 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -179,8 +179,11 @@ def func(a): .check("extern_kernels.mm") .check("triton_poi_fused_relu") .check("torch.ops._c10d_functional.all_reduce_.default") - .check("torch.ops._c10d_functional.wait_tensor.default") + .check_same("buf0") + # mm not use buf prior to wait_tensor .check("extern_kernels.mm") + .check_not("buf0") + .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) ) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 04aaad9990f9..5672171d0be4 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol +from torch._C._distributed_c10d import Backend as C10dBackend from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.distributed_c10d import ( @@ -30,7 +31,7 @@ DTensorTestBase, with_comms, ) -from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeStore from torch.utils._typing_utils import not_none @@ -578,6 +579,115 @@ def test_raises_mesh_shape_mesh_dim_names_mismatch(self): mesh_dim_names=["dp", "tp"], ) + def _test_backend_override_argument_dict_with_idx_and_backend(self): + opts = FakeProcessGroup.Options() + opts.fake_option = 42 + + mesh = init_device_mesh( + self.device_type, + (2, 2, 2), + mesh_dim_names=("dp", "tp", "cp"), + backend_override={0: "fake", 2: ("fake", opts)}, + ) + + def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options: + return ( + mesh.get_group(dim_idx) + ._get_backend(torch.device(f"{self.device_type}:{self.rank}")) + .options + ) + + # Fake pg only have BackendType as BackendType::CUSTOM. + self.assertEqual(mesh.get_group(0)._get_backend_name(), "custom") + self.assertNotEqual(mesh.get_group(1)._get_backend_name(), "custom") + self.assertEqual(mesh.get_group(2)._get_backend_name(), "custom") + + self.assertIsNone(get_opts(mesh, 0)) + self.assertEqual(get_opts(mesh, 2).fake_option, 42) + + dp_tp_mesh = mesh["dp", "tp"]._flatten() + dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override="fake") + tp_cp_mesh = mesh["tp", "cp"]._flatten(backend_override=("fake", opts)) + + self.assertNotEqual(dp_tp_mesh.get_group(0)._get_backend_name(), "custom") + self.assertEqual(dp_cp_mesh.get_group(0)._get_backend_name(), "custom") + self.assertEqual(tp_cp_mesh.get_group(0)._get_backend_name(), "custom") + + self.assertIsNone(get_opts(dp_cp_mesh, 0)) + self.assertEqual(get_opts(tp_cp_mesh, 0).fake_option, 42) + + @with_comms + def test_backend_override_argument_dict_with_idx_and_backend_lazy(self): + self._test_backend_override_argument_dict_with_idx_and_backend() + + @with_comms(eager_init=True) + def test_backend_override_argument_dict_with_idx_and_backend_eager(self): + self._test_backend_override_argument_dict_with_idx_and_backend() + + @with_comms(backend="fake") + def test_backend_override_argument_dict_with_name_and_options(self): + opts = FakeProcessGroup.Options() + opts.fake_option = 42 + + mesh = init_device_mesh( + self.device_type, + (2, 2, 2), + mesh_dim_names=("dp", "tp", "cp"), + backend_override={"tp": opts}, + ) + + def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options: + return ( + mesh.get_group(dim_idx) + ._get_backend(torch.device(f"{self.device_type}:{self.rank}")) + .options + ) + + self.assertIsNone(get_opts(mesh, 0)) + self.assertEqual(get_opts(mesh, 1).fake_option, 42) + self.assertIsNone(get_opts(mesh, 2)) + + dp_tp_mesh = mesh["dp", "tp"]._flatten() + dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override=opts) + + self.assertIsNone(get_opts(dp_tp_mesh, 0)) + self.assertEqual(get_opts(dp_cp_mesh, 0).fake_option, 42) + + @with_comms + def test_backend_override_argument_errors(self): + with self.assertRaisesRegex( + RuntimeError, + "Found redundant dim index 0 and name dp in backend_override", + ): + init_device_mesh( + self.device_type, + (2, 4), + mesh_dim_names=("dp", "tp"), + backend_override={"dp": "foo", 0: "bar"}, + ) + + with self.assertRaisesRegex( + RuntimeError, + r"Found invalid keys in backend_override: got \['cp'\]", + ): + init_device_mesh( + self.device_type, + (2, 4), + mesh_dim_names=("dp", "tp"), + backend_override={"cp": "foo"}, + ) + + with self.assertRaisesRegex( + RuntimeError, + r"Found invalid keys in backend_override: got \[42\]", + ): + init_device_mesh( + self.device_type, + (2, 4), + mesh_dim_names=("dp", "tp"), + backend_override={42: "bar"}, + ) + class TestDeviceMeshGetItem(DTensorTestBase): @property diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 856e1c5f7b3c..f7cf7764df56 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1524,39 +1524,49 @@ def _reorder_communication_preserving_peak_memory( @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") def test_all_gather_bucket(self): - def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) - # cast the inputs - ag_0_cast = ag_0.to(torch.bfloat16) ag_1_cast = ag_1.to(torch.bfloat16) - # allgather group_name = ( torch.distributed.distributed_c10d._get_default_group().group_name ) + ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_2, group_size, group_name + ) + ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) + + ag_0 = ag_2_out + ag_0 + ag_0_cast = ag_0.to(torch.bfloat16) + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( ag_0_cast, group_size, group_name ) ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out) ag_0_out = ag_0_out * 2 - ag_1_cast = ag_1_cast * 2 ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( ag_1_cast, group_size, group_name ) - # wait op ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out) - return y, ag_0_out, ag_1_out + ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_3, group_size, group_name + ) + ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) + return y, ag_0_out, ag_1_out, ag_2_out, ag_3_out x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1] + ag_2 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_3 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1, ag_2, ag_3] + correct = func(*inputs, **self.get_world_trs()) with torch._inductor.config.patch( { @@ -1568,9 +1578,14 @@ def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) # NOTE: The first return value should be the output of the first wait_tensor. # We want to make sure no unnecessary copy is made. - (FileCheck().check("all_gather_into_tensor_out").run(code)) + ( + FileCheck() + .check("= torch.ops._c10d_functional.all_gather_into_tensor") + .check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(") + .check("= torch.ops._c10d_functional.all_gather_into_tensor") + .run(code) + ) out = compiled(*inputs, **self.get_world_trs()) - correct = func(*inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @@ -1745,10 +1760,15 @@ def _reorder_communication_preserving_peak_memory( _reorder_communication_preserving_peak_memory, ], "allow_buffer_reuse": False, + "test_configs.track_memory_lifecycle": "error", } ): - compiled = torch.compile(func) + compiled = torch.compile(func, fullgraph=True) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + + # make sure memory tracking is codegen. the ops will then do runtime checking with assertion. + FileCheck().check("check_memory_step").check("tracked_empty_strided").run(code) + # NOTE: The first return value should be the output of the first wait_tensor. # We want to make sure no unnecessary copy is made. ( diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index c4565a96496c..15dca00d0121 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -1,9 +1,7 @@ # Owner(s): ["oncall: distributed"] - # To run: # python test/distributed/test_nvshmem_triton.py - import triton.language as tl import torch @@ -14,14 +12,15 @@ from torch.testing._internal.common_distributed import MultiProcContinousTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + parametrize, run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, ) -from torch.testing._internal.inductor_utils import requires_triton +from torch.testing._internal.inductor_utils import IS_H100, requires_triton -# Decorator +# Decorators def requires_nvshmem(): return skip_but_pass_in_sandcastle_if( not symm_mem.is_nvshmem_available(), @@ -29,6 +28,13 @@ def requires_nvshmem(): ) +def requires_h100(): + return skip_but_pass_in_sandcastle_if( + not IS_H100, + "NVSHMEM requires H100. Skipping test on non-H100 GPU.", + ) + + # So that tests are written in device-agnostic way device_type = "cuda" device_module = torch.get_device_module(device_type) @@ -36,47 +42,47 @@ def requires_nvshmem(): # Shared Triton JIT kernels @triton.jit -def put_kernel( - dst_ptr, - src_ptr, - numel, - peer, +def nvshmem_put_kernel( + dest, + src, + nelems, + pe, ): - nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + nvshmem.put(dest, src, nelems, pe) @triton.jit -def get_kernel( - dst_ptr, - src_ptr, - numel, - peer, +def nvshmem_get_kernel( + dest, + src, + nelems, + pe, ): - nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + nvshmem.get(dest, src, nelems, pe) @triton.jit -def put_signal_kernel( +def nvshmem_putmem_signal_block_kernel( dst_ptr, src_ptr, - numel, + size_bytes, sig_ptr, signal_val, sig_op, peer, ): nvshmem.putmem_signal_block( - dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + dst_ptr, src_ptr, size_bytes, sig_ptr, signal_val, sig_op, peer ) @triton.jit -def signal_wait_until_kernel(sig_ptr, cmp_op, cmp_val): +def nvshmem_signal_wait_until_kernel(sig_ptr, cmp_op, cmp_val): nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val) @triton.jit -def signal_op_kernel( +def nvshmem_signal_op_kernel( sig_addr, signal, sig_op, @@ -86,75 +92,65 @@ def signal_op_kernel( @triton.jit -def wait_until_kernel( - ivar_ptr, +def nvshmem_wait_until_kernel( + ivar, cmp_op, cmp_val, ): - nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) + nvshmem.wait_until(ivar, cmp_op, cmp_val) @triton.jit -def put_and_signal_kernel( - dst_ptr, - src_ptr, - numel, - sig_ptr, - signal_val, - sig_op, - peer, -): - nvshmem.putmem_signal_block( - dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer - ) +def nvshmem_fence_kernel(): + nvshmem.fence() @triton.jit -def put_with_fence_kernel( - dst_ptr1, - dst_ptr2, - src_ptr1, - src_ptr2, - flag_ptr, - flag_src_ptr, - numel, +def nvshmem_put_with_fence_kernel( + dst1, + src1, + dst2, + src2, + flag_dst, + flag_src, + nelems, peer, ): # First put - nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer) + nvshmem.put(dst1, src1, nelems, peer) # Ensure the first put is ordered before the next. nvshmem.fence() # Second put - nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer) + nvshmem.put(dst2, src2, nelems, peer) # Order the second put before flag update. nvshmem.fence() # Write the flag (single int64) to signal completion. - nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer) + nvshmem.put(flag_dst, flag_src, 1, peer) @triton.jit -def put_with_quiet_kernel( - dst_ptr, - src_ptr, - flag_dst_ptr, - flag_src_ptr, - numel, +def nvshmem_put_with_quiet_kernel( + dst, + src, + flag_dst, + flag_src, + nelems, peer, ): # Put data - nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + nvshmem.put(dst, src, nelems, peer) # Call quiet to ensure put is complete nvshmem.quiet() # Only after quiet, set the completion flag # This ensures the data put is complete before flag is set - nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer) + nvshmem.put(flag_dst, flag_src, 1, peer) @triton.jit -def barrier_test_kernel( - dst_ptr, - src_ptr, - numel, +def nvshmem_barrier_test_kernel( + dst, + src, + nelems, ): # Testing barrier_all() requires coordinated operations across PEs within # the same kernel execution. Unlike other kernels that just wrap NVSHMEM @@ -162,73 +158,90 @@ def barrier_test_kernel( # device-side barrier synchronization. my_pe = nvshmem.my_pe() n_pes = nvshmem.n_pes() + # Rank 0 broadcasts its value to all other ranks if my_pe == 0: # Write initial value - p_src = src_ptr.to(tl.pointer_type(tl.int32)) + p_src = src.to(tl.pointer_type(tl.int32)) tl.store(p_src, 42) # Put to all other ranks i = 1 while i < n_pes: - nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + nvshmem.put(dst, src, nelems, i) i += 1 + # Synchronize all PEs nvshmem.barrier_all() + # Non-zero ranks increment the received value if my_pe != 0: - p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) + p_dst = dst.to(tl.pointer_type(tl.int32)) received = tl.load(p_dst) tl.store(p_dst, received + 1) @triton.jit -def sync_test_kernel( - dst_ptr, - src_ptr, - numel, +def nvshmem_barrier_all_kernel(): + nvshmem.barrier_all() + + +@triton.jit +def nvshmem_sync_test_kernel( + local_data, + remote_data, + nelems, ): my_pe = nvshmem.my_pe() n_pes = nvshmem.n_pes() - # Rank 0 broadcasts its value to all other ranks - if my_pe == 0: - # Write initial value - p_src = src_ptr.to(tl.pointer_type(tl.int32)) - tl.store(p_src, 42) - # Put to all other ranks - i = 1 - while i < n_pes: - nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) - i += 1 - # Synchronize all PEs (this is more lightweight than barrier_all() b/c it only ensures local store visibility - # and doesn't wait for remote ops to complete) + # Each PE writes a unique value to its local memory + p_local = local_data.to(tl.pointer_type(tl.int32)) + unique_value = my_pe + 100 # PE 0 writes 100, PE 1 writes 101, etc. + tl.store(p_local, unique_value) + + # sync_all() ensures local stores are visible to other PEs + # but doesn't guarantee completion of any remote operations nvshmem.sync_all() - # Non-zero ranks increment the received value - if my_pe != 0: - p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) - received = tl.load(p_dst) - tl.store(p_dst, received + 1) + + # Now each PE reads from the next PE's memory to verify visibility + # PE 0 reads from PE 1, PE 1 reads from PE 2, ..., PE n-1 reads from PE 0 + next_pe = (my_pe + 1) % n_pes + nvshmem.get(remote_data, local_data, nelems, next_pe) + + # The get should now see the value that the next PE wrote locally + # because sync_all() made those local stores visible @triton.jit -def alltoall_kernel( +def nvshmem_alltoall_kernel( team_handle, - dest_ptr, - src_ptr, - nelems, + dst, + src, + nelems_per_pe, ): - nvshmem.alltoall(team_handle, dest_ptr, src_ptr, nelems) + nvshmem.alltoall(team_handle, dst, src, nelems_per_pe) @triton.jit -def broadcast_kernel( +def nvshmem_broadcast_kernel( team_handle, - dest_ptr, - src_ptr, + dst, + src, nelems, pe_root, ): - nvshmem.broadcast(team_handle, dest_ptr, src_ptr, nelems, pe_root) + nvshmem.broadcast(team_handle, dst, src, nelems, pe_root) + + +@triton.jit +def nvshmem_reduce_kernel( + team_handle, + dest_tensor, + source_tensor, + nreduce, + operation: tl.constexpr, +): + nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation) @instantiate_parametrized_tests @@ -248,6 +261,7 @@ def device(self) -> torch.device: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_put(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -259,36 +273,52 @@ def test_triton_put(self) -> None: symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - msg_size_bytes = 8 - dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize + # Configuration + nelems = 5 # number of elements to transfer + dtype = torch.int64 + val = 42 + rank # Each rank has different data - val = 5 - inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) - out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + # Create symmetric tensors + src = symm_mem.empty(nelems, dtype=dtype, device=self.device) + dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) - peer = (self.world_size - 1) - rank + # Fill source tensor with rank-specific pattern + for i in range(nelems): + src[i] = ( + val * 10 + i + ) # Rank 0: [420, 421, 422, 423, 424], Rank 1: [430, 431, ...] + + # Rendezvous + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + + # Synchronize before operation + dist.barrier() + + peer = 1 - rank if rank == 0: - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - put_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - numel=numel, - peer=peer, + # Rank 0 puts its data to Rank 1 + nvshmem_put_kernel[(1,)]( + dst, + src, + nelems, + peer, extern_libs=nvshmem_lib, ) + # Synchronize after operation dist.barrier() + if rank == 1: + # Verify that rank 1 received rank 0's data + expected = [420 + i for i in range(nelems)] torch.testing.assert_close( - out, val * torch.ones(numel, dtype=dtype, device=self.device) + dst, torch.tensor(expected, device=self.device, dtype=dtype) ) @skipIfRocm @requires_triton() + @requires_h100() def test_triton_get(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -297,27 +327,29 @@ def test_triton_get(self) -> None: group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - msg_size_bytes = 8 + + # Configuration + numel = 8 dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize val = 7 + + # Create symmetric tensors inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_( val if rank == 0 else -1 ) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + symm_mem.rendezvous(inp, group=group_name) + symm_mem.rendezvous(out, group=group_name) + dist.barrier() - peer = (self.world_size - 1) - rank + peer = 1 - rank if rank == 1: - # Rank 1 gets data from rank 0 - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - get_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - numel=numel, - peer=peer, + # Rank 1 gets data from rank 0 using tensor-aware API + nvshmem_get_kernel[(1,)]( + out, + inp, + numel, + peer, extern_libs=nvshmem_lib, ) if rank == 1: @@ -327,6 +359,7 @@ def test_triton_get(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_get_ring(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -336,29 +369,29 @@ def test_triton_get_ring(self) -> None: symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank world_size = dist.get_world_size() - msg_size_bytes = 8 + + # Configuration + numel = 8 dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize # Each rank fills its input buffer with its own rank value inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + symm_mem.rendezvous(inp, group=group_name) + symm_mem.rendezvous(out, group=group_name) + dist.barrier() # Ring topology: each rank gets data from the rank to its left # rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc. peer = (rank - 1) % world_size - # All ranks execute the get operation - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - get_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - numel=numel, - peer=peer, + # All ranks execute the get operation using tensor-aware API + nvshmem_get_kernel[(1,)]( + out, + inp, + numel, + peer, extern_libs=nvshmem_lib, ) @@ -369,6 +402,7 @@ def test_triton_get_ring(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_put_signal_set(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -394,7 +428,7 @@ def test_triton_put_signal_set(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = (self.world_size - 1) - rank + peer = 1 - rank NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set SIGNAL_VAL = 1 # Signal completion value NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until @@ -404,10 +438,10 @@ def test_triton_put_signal_set(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - put_signal_kernel[(1, 1, 1)]( + nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, sig_ptr=sig_ptr, signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_SET, @@ -418,7 +452,7 @@ def test_triton_put_signal_set(self) -> None: if rank == 1: # Wait until signal flag is set by Rank 0 sig_ptr_local = out_hdl.signal_pad_ptrs[rank] - signal_wait_until_kernel[(1,)]( + nvshmem_signal_wait_until_kernel[(1,)]( sig_ptr_local, cmp_op=NVSHMEM_CMP_EQ, cmp_val=SIGNAL_VAL, @@ -434,6 +468,7 @@ def test_triton_put_signal_set(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_put_signal_add(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -459,7 +494,7 @@ def test_triton_put_signal_add(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = (self.world_size - 1) - rank + peer = 1 - rank NVSHMEM_SIGNAL_ADD = 5 # atomic add operation SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD NVSHMEM_CMP_EQ = 0 @@ -469,10 +504,10 @@ def test_triton_put_signal_add(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - put_signal_kernel[(1, 1, 1)]( + nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, sig_ptr=sig_ptr, signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_ADD, @@ -482,7 +517,7 @@ def test_triton_put_signal_add(self) -> None: if rank == 1: sig_ptr_local = out_hdl.signal_pad_ptrs[rank] - signal_wait_until_kernel[(1, 1, 1)]( + nvshmem_signal_wait_until_kernel[(1, 1, 1)]( sig_ptr_local, cmp_op=NVSHMEM_CMP_EQ, cmp_val=SIGNAL_VAL, @@ -497,6 +532,7 @@ def test_triton_put_signal_add(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_wait_until(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -506,72 +542,52 @@ def test_triton_wait_until(self) -> None: symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = (self.world_size - 1) - rank - NVSHMEM_CMP_EQ = 0 # from nvshmem.h - - # Allocate symmetric buffers - msg_size_bytes = 8 - dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize - val = 13 - flag_val = 21 + peer = 1 - rank + NVSHMEM_CMP_EQ = 0 # equal comparison + FLAG_INITIAL_VALUE = 0 + FLAG_FINAL_VALUE = 42 - inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) - out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + # Use a single int64 symmetric tensor as our synchronization flag. + flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_( + FLAG_INITIAL_VALUE + ) + symm_mem.rendezvous(flag, group=group_name) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + nvshmem_barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) if rank == 0: - # Rank 0 waits for the flag to be set by Rank 1, then checks the data - ivar_ptr = out_hdl.signal_pad_ptrs[rank] - - wait_until_kernel[(1, 1, 1)]( - ivar_ptr, + # Rank 0 (the waiter) + nvshmem_wait_until_kernel[(1,)]( + flag, cmp_op=NVSHMEM_CMP_EQ, - cmp_val=flag_val, + cmp_val=FLAG_FINAL_VALUE, extern_libs=nvshmem_lib, ) + # Verification torch.testing.assert_close( - out, - val * torch.ones(numel, dtype=dtype, device=self.device), + flag, + torch.tensor([FLAG_FINAL_VALUE], dtype=torch.int64, device=self.device), ) if rank == 1: - # Rank 1 puts data into Rank 0's output buffer - dst_ptr = out_hdl.buffer_ptrs[peer] - src_ptr = inp_hdl.buffer_ptrs[rank] - - put_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - numel=numel, - peer=peer, - extern_libs=nvshmem_lib, + # Rank 1 (the signaler) + val_to_put = torch.tensor( + [FLAG_FINAL_VALUE], dtype=torch.int64, device=self.device ) - # Fence to order data put before flag put - @triton.jit - def fence_kernel(): - nvshmem.fence() - - fence_kernel[(1, 1, 1)](extern_libs=nvshmem_lib) - - # Put the flag value (do not use signal_op here) - flag_src = torch.tensor([flag_val], dtype=torch.int64, device=self.device) - flag_dst_ptr = out_hdl.signal_pad_ptrs[peer] - - put_kernel[(1, 1, 1)]( - flag_dst_ptr, - flag_src.data_ptr(), - numel=1, - peer=peer, + # Launch a kernel to put the value to Rank 0's flag tensor. + nvshmem_put_kernel[(1,)]( + flag, # Destination symmetric tensor on the remote PE + val_to_put, # Source data tensor (local) + 1, # Number of elements + peer, # The target PE (Rank 0) extern_libs=nvshmem_lib, ) @skipIfRocm @requires_triton() + @requires_h100() def test_triton_signal_wait_until(self) -> None: self._init_device() # Enable NVSHMEM for Triton @@ -579,7 +595,7 @@ def test_triton_signal_wait_until(self) -> None: group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = (self.world_size - 1) - rank + peer = 1 - rank # NVSHMEM constants from documentation NVSHMEM_CMP_EQ = 0 # equal comparison @@ -589,6 +605,7 @@ def test_triton_signal_wait_until(self) -> None: msg_size_bytes = 8 dtype = torch.int8 numel = msg_size_bytes // dtype.itemsize + val_to_put = 123 # arbitrary test value COMPLETION_FLAG_VAL = 1 @@ -607,11 +624,11 @@ def test_triton_signal_wait_until(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - put_and_signal_kernel[(1, 1, 1)]( + nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel, - sig_ptr, + size_bytes=msg_size_bytes, + sig_ptr=sig_ptr, signal_val=COMPLETION_FLAG_VAL, sig_op=NVSHMEM_SIGNAL_SET, peer=peer, @@ -620,7 +637,7 @@ def test_triton_signal_wait_until(self) -> None: elif rank == 1: # Consumer (rank 1): Waits on the signal variable using `signal_wait_until`. sig_ptr = out_hdl.signal_pad_ptrs[rank] - signal_wait_until_kernel[(1, 1, 1)]( + nvshmem_signal_wait_until_kernel[(1, 1, 1)]( sig_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=COMPLETION_FLAG_VAL, @@ -639,6 +656,7 @@ def test_triton_signal_wait_until(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_fence(self) -> None: """ Rank 0 performs two put operations into Rank 1's buffers with a fence @@ -648,18 +666,17 @@ def test_triton_fence(self) -> None: its arrival implies that both preceding puts have been delivered in order. """ - torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = (self.world_size - 1) - rank + peer = 1 - rank # Message configuration - msg_size_bytes = 8 dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize + numel = 8 + val1 = 10 val2 = 20 flag_val = 1 @@ -668,42 +685,35 @@ def test_triton_fence(self) -> None: inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2) out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp1_hdl = symm_mem.rendezvous(inp1, group=group_name) - inp2_hdl = symm_mem.rendezvous(inp2, group=group_name) - out1_hdl = symm_mem.rendezvous(out1, group=group_name) - out2_hdl = symm_mem.rendezvous(out2, group=group_name) - - # Flag buffer resides in the signal pad of out2. - flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + symm_mem.rendezvous(inp1, group=group_name) + symm_mem.rendezvous(inp2, group=group_name) + symm_mem.rendezvous(out1, group=group_name) + symm_mem.rendezvous(out2, group=group_name) + + # Use regular symmetric memory tensor for flag + flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0) + symm_mem.rendezvous(flag, group=group_name) flag_update_val = torch.tensor( [flag_val], dtype=torch.int64, device=self.device ) NVSHMEM_CMP_EQ = 0 # compare equal if rank == 0: - dst_ptr1 = out1_hdl.buffer_ptrs[rank] - dst_ptr2 = out2_hdl.buffer_ptrs[rank] - src_ptr1 = inp1_hdl.buffer_ptrs[rank] - src_ptr2 = inp2_hdl.buffer_ptrs[rank] - flag_ptr = out2_hdl.signal_pad_ptrs[rank] - flag_src_ptr = flag_update_val.data_ptr() - - put_with_fence_kernel[(1, 1, 1)]( - dst_ptr1, - dst_ptr2, - src_ptr1, - src_ptr2, - flag_ptr, - flag_src_ptr, - numel, + nvshmem_put_with_fence_kernel[(1,)]( + out1, + inp1, + out2, + inp2, + flag, + flag_update_val, + nelems=numel, peer=peer, extern_libs=nvshmem_lib, ) elif rank == 1: - # Wait until flag is set by Rank 0. - ivar_ptr = out2_hdl.signal_pad_ptrs[rank] - wait_until_kernel[(1, 1, 1)]( - ivar_ptr, + # Wait until flag is set by Rank 0 + nvshmem_wait_until_kernel[(1,)]( + flag, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, extern_libs=nvshmem_lib, @@ -722,63 +732,60 @@ def test_triton_fence(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_quiet(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() - # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - msg_size_bytes = 8 + peer = 1 - rank + dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize - # Data buffers + numel = 8 val = 15 + flag_val = 42 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) - # Use signal pad as completion flag - flag_val = 42 - peer = (self.world_size - 1) - rank + flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0) + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + + symm_mem.rendezvous(inp, group=group_name) + symm_mem.rendezvous(out, group=group_name) + symm_mem.rendezvous(flag, group=group_name) + NVSHMEM_CMP_EQ = 0 - if rank == 0: - # Rank 0 waits for flag from Rank 1 - ivar_ptr = out_hdl.signal_pad_ptrs[rank] - wait_until_kernel[(1, 1, 1)]( - ivar_ptr, + dist.barrier() + if rank == 1: + nvshmem_put_with_quiet_kernel[(1,)]( + out, + inp, + flag, + flag_update_val, + nelems=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + elif rank == 0: + nvshmem_wait_until_kernel[(1,)]( + flag, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, extern_libs=nvshmem_lib, ) - # After flag is set, data should be complete due to quiet torch.testing.assert_close( out, val * torch.ones(numel, dtype=dtype, device=self.device) ) - if rank == 1: - # Rank 1 puts data and flag with quiet in between - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - flag_dst_ptr = out_hdl.signal_pad_ptrs[rank] - # Create a tensor for the flag value - flag_update_val = torch.tensor( - [flag_val], dtype=torch.int64, device=self.device - ) - flag_src_ptr = flag_update_val.data_ptr() - put_with_quiet_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - flag_dst_ptr, - flag_src_ptr, - numel=numel, - peer=peer, - extern_libs=nvshmem_lib, - ) + dist.barrier() @skipIfRocm @requires_triton() + @requires_h100() def test_triton_barrier(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -788,77 +795,85 @@ def test_triton_barrier(self) -> None: rank = self.rank numel = 1 dtype = torch.int32 - # Create symmetric buffers + src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) - # Launch kernel with cooperative grid - barrier_test_kernel[(1,)]( - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], - numel=numel, + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + + nvshmem_barrier_test_kernel[(1,)]( + dst, + src, + nelems=numel, extern_libs=nvshmem_lib, launch_cooperative_grid=True, num_ctas=1, ) - # Verify results - # Rank 0 should have 42, and then the rest should have incremented + 1 to 43 + dist.barrier() + if rank == 0: - # Rank 0 should have its original value (42) in src torch.testing.assert_close( src, torch.tensor([42], device=self.device, dtype=dtype) ) else: - # Other ranks should have received 42 and incremented to 43 torch.testing.assert_close( dst, torch.tensor([43], device=self.device, dtype=dtype) ) @skipIfRocm @requires_triton() + @requires_h100() def test_triton_sync(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() + nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank numel = 1 dtype = torch.int32 + # Create symmetric buffers - src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) - dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + local_data = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + remote_data = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(local_data, group=group_name) + symm_mem.rendezvous(remote_data, group=group_name) + # Launch kernel with cooperative grid - sync_test_kernel[(1,)]( - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], - numel=numel, + nvshmem_sync_test_kernel[(1,)]( + local_data, + remote_data, + nelems=numel, extern_libs=nvshmem_lib, launch_cooperative_grid=True, num_ctas=1, ) + # Verify results - if rank == 0: - # Rank 0 should have its original value (42) in src - torch.testing.assert_close( - src, torch.tensor([42], device=self.device, dtype=dtype) - ) - else: - # Other ranks should have received 42 and incremented to 43 - torch.testing.assert_close( - dst, torch.tensor([43], device=self.device, dtype=dtype) - ) + # Each PE should have written rank + 100 to its local_data + expected_local = rank + 100 + torch.testing.assert_close( + local_data, torch.tensor([expected_local], device=self.device, dtype=dtype) + ) + + # Each PE should have read (next_rank + 100) into its remote_data + # PE 0 reads from PE 1, PE 1 reads from PE 2, ..., PE n-1 reads from PE 0 + next_rank = (rank + 1) % self.world_size + expected_remote = next_rank + 100 + torch.testing.assert_close( + remote_data, + torch.tensor([expected_remote], device=self.device, dtype=dtype), + ) @skipIfRocm @requires_triton() + @requires_h100() def test_triton_alltoall(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) world_size = dist.get_world_size() rank = self.rank @@ -876,16 +891,16 @@ def test_triton_alltoall(self) -> None: src[i * nelems_per_pe : (i + 1) * nelems_per_pe] = value # Destination buffer dst = symm_mem.empty(src_size, dtype=dtype, device=self.device).fill_(-1) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) # Synchronize before alltoall dist.barrier() team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0 - # Launch the kernel - alltoall_kernel[(1,)]( + # Launch the kernel using new tensor-aware API + nvshmem_alltoall_kernel[(1,)]( team_handle, - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], + dst, + src, nelems_per_pe, extern_libs=nvshmem_lib, launch_cooperative_grid=True, @@ -902,19 +917,25 @@ def test_triton_alltoall(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_broadcast(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank + # Configuration nelems = 4 # number of elements dtype = torch.int64 + # Source buffer - only root will have meaningful data pe_root = 0 # PE 0 will be the root src = symm_mem.empty(nelems, dtype=dtype, device=self.device) + # Destination buffer + dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) + if rank == pe_root: # Root fills with specific pattern for i in range(nelems): @@ -922,31 +943,265 @@ def test_triton_broadcast(self) -> None: else: # Non-root PEs have dummy data src.fill_(-1) - # Destination buffer - dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + # Synchronize before broadcast dist.barrier() + # Execute broadcast team_handle = 0 # NVSHMEM_TEAM_WORLD - broadcast_kernel[(1,)]( + nvshmem_broadcast_kernel[(1,)]( team_handle, - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], + dst, + src, nelems, pe_root, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) + # Synchronize after broadcast dist.barrier() + # Verify results - all ranks should have the root's data expected = [100 + i for i in range(nelems)] torch.testing.assert_close( dst, torch.tensor(expected, device=self.device, dtype=dtype) ) + @skipIfRocm + @requires_triton() + @requires_h100() + @parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + ) + def test_triton_sum_reduce(self, dtype) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.distributed_c10d._get_default_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Configuration + nreduce = 3 # number of separate reductions + # Source buffer - each rank contributes different values + src = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + for i in range(nreduce): + src[i] = (rank + 1) * (i + 1) # Rank 0: [1,2,3], Rank 1: [2,4,6], etc. + # Destination buffer + dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + # Calculate expected results + expected = [] + for i in range(nreduce): + # Sum across all ranks: sum((rank+1)*(i+1) for rank in range(world_size)) + total = sum((r + 1) * (i + 1) for r in range(world_size)) + expected.append(total) + + # Synchronize before reduction + dist.barrier() + + # Execute sum reduction across all ranks + team_handle = 0 # NVSHMEM_TEAM_WORLD + nvshmem_reduce_kernel[(1,)]( + team_handle, + dst, + src, + nreduce, + operation="sum", + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + + # Synchronize after reduction + dist.barrier() + + # Verify results + torch.testing.assert_close( + dst, torch.tensor(expected, device=self.device, dtype=dtype) + ) + + @skipIfRocm + @requires_triton() + @requires_h100() + @parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + ) + def test_triton_minmax_reduce(self, dtype) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.distributed_c10d._get_default_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Configuration + nreduce = 2 # number of values to reduce + # Source buffers for min and max + src_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + src_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + # Each rank contributes different values + # For min: rank 0: [10, 20], rank 1: [15, 5], etc. + # For max: same values + for i in range(nreduce): + if i == 0: + src_min[i] = 10 + rank * 5 # 10, 15, 20, ... + src_max[i] = 10 + rank * 5 + else: + src_min[i] = 20 - rank * 15 # 20, 5, -10, ... + src_max[i] = 20 - rank * 15 + # Destination buffers + dst_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + dst_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + symm_mem.rendezvous(src_min, group=group_name) + symm_mem.rendezvous(src_max, group=group_name) + symm_mem.rendezvous(dst_min, group=group_name) + symm_mem.rendezvous(dst_max, group=group_name) + # Calculate expected results + all_values = [] + for i in range(nreduce): + values = [] + for r in range(world_size): + if i == 0: + values.append(10 + r * 5) + else: + values.append(20 - r * 15) + all_values.append(values) + expected_min = [min(vals) for vals in all_values] + expected_max = [max(vals) for vals in all_values] + dist.barrier() + # Execute MIN reduction + team_handle = 0 + nvshmem_reduce_kernel[(1,)]( + team_handle, + dst_min, + src_min, + nreduce, + operation="min", + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Execute MAX reduction + nvshmem_reduce_kernel[(1,)]( + team_handle, + dst_max, + src_max, + nreduce, + operation="max", + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + dist.barrier() + # Verify results + torch.testing.assert_close( + dst_min, torch.tensor(expected_min, device=self.device, dtype=dtype) + ) + torch.testing.assert_close( + dst_max, torch.tensor(expected_max, device=self.device, dtype=dtype) + ) + + @skipIfRocm + @requires_triton() + @requires_h100() + @parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + ) + def test_triton_prod_reduce(self, dtype) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.distributed_c10d._get_default_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Configuration + nreduce = 3 # number of separate reductions + # Source buffer - each rank contributes different values + # Use very small values to avoid overflow, especially for small integer types + src = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + for i in range(nreduce): + # Use values that won't overflow even for int8: all values 1 or 2 + if i == 0: + # For first element: rank 0,2,4... gets 1, rank 1,3,5... gets 2 + src[i] = 1 if rank % 2 == 0 else 2 + elif i == 1: + # For second element: all get 1 (no multiplication effect) + src[i] = 1 + else: + # For third element: rank 0,1 get 1, rank 2,3 get 2, etc. (groups of 2) + src[i] = 1 if (rank // 2) % 2 == 0 else 2 + # Destination buffer + dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + # Calculate expected results + vals = torch.empty(nreduce, world_size, dtype=dtype) + vals[0, ::2] = 1 + vals[0, 1::2] = 2 + vals[1] = 1 + vals2 = vals[2].view(-1, 2, 2) + vals2[:, 0] = 1 + vals2[:, 1] = 2 + expected = vals.prod(-1).tolist() + + # Synchronize before reduction + dist.barrier() + + # Execute product reduction across all ranks + team_handle = 0 # NVSHMEM_TEAM_WORLD + nvshmem_reduce_kernel[(1,)]( + team_handle, + dst, + src, + nreduce, + operation="prod", + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + + # Synchronize after reduction + dist.barrier() + + # Verify results + torch.testing.assert_close( + dst, torch.tensor(expected, device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/cpython/3_13/test_itertools.diff b/test/dynamo/cpython/3_13/test_itertools.diff index 44a17e58becc..21763d689ac6 100644 --- a/test/dynamo/cpython/3_13/test_itertools.diff +++ b/test/dynamo/cpython/3_13/test_itertools.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py -index 7d5ba727389..637fbe545cd 100644 +index 7d5ba727389..d15d83a2184 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -1,3 +1,25 @@ @@ -50,7 +50,50 @@ index 7d5ba727389..637fbe545cd 100644 def pickletest(self, protocol, it, stop=4, take=1, compare=None): """Test that an iterator is the same after pickling, also when part-consumed""" -@@ -888,7 +910,7 @@ class TestBasicOps(unittest.TestCase): +@@ -454,14 +476,8 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1) + +- @pickle_deprecated + def test_permutations(self): +- self.assertRaises(TypeError, permutations) # too few arguments +- self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments +- self.assertRaises(TypeError, permutations, None) # pool is not iterable +- self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative + self.assertEqual(list(permutations('abc', 32)), []) # r > n +- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None + self.assertEqual(list(permutations(range(3), 2)), + [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) + +@@ -498,7 +514,7 @@ class TestBasicOps(unittest.TestCase): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) + +- for n in range(7): ++ for n in range(5): + values = [5*x-12 for x in range(n)] + for r in range(n+2): + result = list(permutations(values, r)) +@@ -515,9 +531,6 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(result, list(permutations(values, None))) # test r as None + self.assertEqual(result, list(permutations(values))) # test default r + +- for proto in range(pickle.HIGHEST_PROTOCOL + 1): +- self.pickletest(proto, permutations(values, r)) # test pickling +- + @support.bigaddrspacetest + def test_permutations_overflow(self): + with self.assertRaises((OverflowError, MemoryError)): +@@ -756,7 +769,7 @@ class TestBasicOps(unittest.TestCase): + def test_cycle(self): + self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) + self.assertEqual(list(cycle('')), []) +- self.assertRaises(TypeError, cycle) ++ # self.assertRaises(TypeError, cycle) + self.assertRaises(TypeError, cycle, 5) + self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) + +@@ -888,7 +901,7 @@ class TestBasicOps(unittest.TestCase): # Check normal pickled for proto in range(pickle.HIGHEST_PROTOCOL + 1): dup = [] @@ -59,7 +102,7 @@ index 7d5ba727389..637fbe545cd 100644 for elem in g: self.assertEqual(k, elem[0]) dup.append(elem) -@@ -896,8 +918,8 @@ class TestBasicOps(unittest.TestCase): +@@ -896,8 +909,8 @@ class TestBasicOps(unittest.TestCase): # Check nested case dup = [] @@ -70,7 +113,7 @@ index 7d5ba727389..637fbe545cd 100644 for elem in ig: self.assertEqual(k, elem[0]) self.assertEqual(ik, elem[2]) -@@ -907,8 +929,8 @@ class TestBasicOps(unittest.TestCase): +@@ -907,8 +920,8 @@ class TestBasicOps(unittest.TestCase): # Check nested and pickled for proto in range(pickle.HIGHEST_PROTOCOL + 1): dup = [] @@ -81,7 +124,7 @@ index 7d5ba727389..637fbe545cd 100644 for elem in ig: self.assertEqual(k, elem[0]) self.assertEqual(ik, elem[2]) -@@ -917,7 +939,7 @@ class TestBasicOps(unittest.TestCase): +@@ -917,7 +930,7 @@ class TestBasicOps(unittest.TestCase): # Check case where inner iterator is not used @@ -90,7 +133,7 @@ index 7d5ba727389..637fbe545cd 100644 expectedkeys = set([r[0] for r in s]) self.assertEqual(set(keys), expectedkeys) self.assertEqual(len(keys), len(expectedkeys)) -@@ -925,7 +947,7 @@ class TestBasicOps(unittest.TestCase): +@@ -925,7 +938,7 @@ class TestBasicOps(unittest.TestCase): # Check case where inner iterator is used after advancing the groupby # iterator s = list(zip('AABBBAAAA', range(9))) @@ -99,7 +142,7 @@ index 7d5ba727389..637fbe545cd 100644 _, g1 = next(it) _, g2 = next(it) _, g3 = next(it) -@@ -936,7 +958,7 @@ class TestBasicOps(unittest.TestCase): +@@ -936,7 +949,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(g3), []) for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -108,23 +151,118 @@ index 7d5ba727389..637fbe545cd 100644 _, g = next(it) next(it) next(it) -@@ -1038,6 +1060,7 @@ class TestBasicOps(unittest.TestCase): - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - self.pickletest(proto, filterfalse(isEven, range(6))) - -+ @skipIfTorchDynamo("infinite loop in torch dynamo") - def test_zip(self): - # XXX This is rather silly now that builtin zip() calls zip()... - ans = [(x,y) for x, y in zip('abc',count())] -@@ -1082,6 +1105,7 @@ class TestBasicOps(unittest.TestCase): - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - self.pickletest(proto, zip('abc', count())) - -+ @skipIfTorchDynamo("infinite loop in torch dynamo") - def test_ziplongest(self): - for args in [ - ['abc', range(6)], -@@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase): +@@ -1002,27 +1015,29 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2]) + self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2]) + self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6]) +- self.assertRaises(TypeError, filter) +- self.assertRaises(TypeError, filter, lambda x:x) +- self.assertRaises(TypeError, filter, lambda x:x, range(6), 7) +- self.assertRaises(TypeError, filter, isEven, 3) +- self.assertRaises(TypeError, next, filter(range(6), range(6))) ++ # these tests raise dynamo exceptions, not TypeError ++ # self.assertRaises(TypeError, filter) ++ # self.assertRaises(TypeError, filter, lambda x:x) ++ # self.assertRaises(TypeError, filter, lambda x:x, range(6), 7) ++ # self.assertRaises(TypeError, filter, isEven, 3) ++ # dynamo raises Unsupported in this case ++ # self.assertRaises(TypeError, next, filter(range(6), range(6))) + + # check copy, deepcopy, pickle +- ans = [0,2,4] +- +- c = filter(isEven, range(6)) +- self.assertEqual(list(copy.copy(c)), ans) +- c = filter(isEven, range(6)) +- self.assertEqual(list(copy.deepcopy(c)), ans) +- for proto in range(pickle.HIGHEST_PROTOCOL + 1): +- c = filter(isEven, range(6)) +- self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans) +- next(c) +- self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:]) +- for proto in range(pickle.HIGHEST_PROTOCOL + 1): +- c = filter(isEven, range(6)) +- self.pickletest(proto, c) ++ # ans = [0,2,4] ++ ++ # c = filter(isEven, range(6)) ++ # self.assertEqual(list(copy.copy(c)), ans) ++ # c = filter(isEven, range(6)) ++ # self.assertEqual(list(copy.deepcopy(c)), ans) ++ # for proto in range(pickle.HIGHEST_PROTOCOL + 1): ++ # c = filter(isEven, range(6)) ++ # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans) ++ # next(c) ++ # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:]) ++ # for proto in range(pickle.HIGHEST_PROTOCOL + 1): ++ # c = filter(isEven, range(6)) ++ # self.pickletest(proto, c) + + @pickle_deprecated + def test_filterfalse(self): +@@ -1047,8 +1062,8 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3))) + self.assertEqual(list(zip('abcdef')), lzip('abcdef')) + self.assertEqual(list(zip()), lzip()) +- self.assertRaises(TypeError, zip, 3) +- self.assertRaises(TypeError, zip, range(3), 3) ++ # self.assertRaises(TypeError, zip, 3) ++ # self.assertRaises(TypeError, zip, range(3), 3) + self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')], + lzip('abc', 'def')) + self.assertEqual([pair for pair in zip('abc', 'def')], +@@ -1105,19 +1120,19 @@ class TestBasicOps(unittest.TestCase): + + self.assertEqual(list(zip_longest('abc', 'defg', **{})), + list(zip(list('abc')+[None], 'defg'))) # empty keyword dict +- self.assertRaises(TypeError, zip_longest, 3) +- self.assertRaises(TypeError, zip_longest, range(3), 3) +- +- for stmt in [ +- "zip_longest('abc', fv=1)", +- "zip_longest('abc', fillvalue=1, bogus_keyword=None)", +- ]: +- try: +- eval(stmt, globals(), locals()) +- except TypeError: +- pass +- else: +- self.fail('Did not raise Type in: ' + stmt) ++ # self.assertRaises(TypeError, zip_longest, 3) ++ # self.assertRaises(TypeError, zip_longest, range(3), 3) ++ ++ # for stmt in [ ++ # "zip_longest('abc', fv=1)", ++ # "zip_longest('abc', fillvalue=1, bogus_keyword=None)", ++ # ]: ++ # try: ++ # eval(stmt, globals(), locals()) ++ # except TypeError: ++ # pass ++ # else: ++ # self.fail('Did not raise Type in: ' + stmt) + + self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')], + list(zip('abc', 'def'))) +@@ -1296,7 +1311,6 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(list(product(*(args*r))), + list(product(*args, **dict(repeat=r)))) + self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) +- self.assertRaises(TypeError, product, range(6), None) + + def product1(*args, **kwds): + pools = list(map(tuple, args)) * kwds.get('repeat', 1) +@@ -1336,7 +1350,8 @@ class TestBasicOps(unittest.TestCase): + argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), + set('abcdefg'), range(11), tuple(range(13))] + for i in range(100): +- args = [random.choice(argtypes) for j in range(random.randrange(5))] ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ args = [random.choice(argtypes) for j in range(random.randrange(5))] + expected_len = prod(map(len, args)) + self.assertEqual(len(list(product(*args))), expected_len) + self.assertEqual(list(product(*args)), list(product1(*args))) +@@ -1767,6 +1782,7 @@ class TestBasicOps(unittest.TestCase): script_helper.assert_python_ok("-c", script) # Issue 13454: Crash when deleting backward iterator from tee() @@ -132,7 +270,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_tee_del_backward(self): forward, backward = tee(repeat(None, 20000000)) try: -@@ -1920,7 +1945,7 @@ class TestBasicOps(unittest.TestCase): +@@ -1920,7 +1936,7 @@ class TestBasicOps(unittest.TestCase): tp.foobar = 1 @@ -141,7 +279,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_accumulate(self): self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) -@@ -2032,7 +2057,7 @@ class TestExamples(unittest.TestCase): +@@ -2032,7 +2048,7 @@ class TestExamples(unittest.TestCase): self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4]) @@ -150,7 +288,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_batched_recipe(self): def batched_recipe(iterable, n): -@@ -2081,6 +2106,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): +@@ -2081,6 +2097,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): for i, element in zip(range(i + 1, stop), iterable): pass @@ -158,7 +296,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_islice_recipe(self): self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB')) self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD')) -@@ -2265,7 +2291,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): +@@ -2265,7 +2282,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): raise @@ -167,7 +305,7 @@ index 7d5ba727389..637fbe545cd 100644 def makecycle(self, iterator, container): container.append(iterator) -@@ -2465,7 +2491,7 @@ def L(seqn): +@@ -2465,7 +2482,7 @@ def L(seqn): return chain(map(lambda x:x, R(Ig(G(seqn))))) @@ -176,7 +314,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_accumulate(self): s = [1,2,3,4,5] -@@ -2644,7 +2670,7 @@ class TestVariousIteratorArgs(unittest.TestCase): +@@ -2644,7 +2661,7 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, tee, N(s)) self.assertRaises(ZeroDivisionError, list, tee(E(s))[0]) @@ -185,7 +323,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_repeat(self): self.assertEqual(operator.length_hint(repeat(None, 50)), 50) -@@ -2657,7 +2683,7 @@ class LengthTransparency(unittest.TestCase): +@@ -2657,7 +2674,7 @@ class LengthTransparency(unittest.TestCase): self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0) self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0) @@ -194,7 +332,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_sf_793826(self): # Fix Armin Rigo's successful efforts to wreak havoc -@@ -2718,6 +2744,7 @@ class RegressionTests(unittest.TestCase): +@@ -2718,6 +2735,7 @@ class RegressionTests(unittest.TestCase): @support.skip_if_pgo_task @support.requires_resource('cpu') @@ -202,7 +340,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_long_chain_of_empty_iterables(self): # Make sure itertools.chain doesn't run into recursion limits when # dealing with long chains of empty iterables. Even with a high -@@ -2750,7 +2777,7 @@ class RegressionTests(unittest.TestCase): +@@ -2750,7 +2768,7 @@ class RegressionTests(unittest.TestCase): next(g, None) # shouldn't crash @@ -211,7 +349,7 @@ index 7d5ba727389..637fbe545cd 100644 def test_keywords_in_subclass(self): # count is not subclassable... testcases = [ -@@ -2805,49 +2832,5 @@ class SubclassWithKwargsTest(unittest.TestCase): +@@ -2805,49 +2823,5 @@ class SubclassWithKwargsTest(unittest.TestCase): self.assertEqual(u.newarg, 3) diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py index 637fbe545cd2..d15d83a2184d 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -476,14 +476,8 @@ def test_combinations_with_replacement_tuple_reuse(self): self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1) - @pickle_deprecated def test_permutations(self): - self.assertRaises(TypeError, permutations) # too few arguments - self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments - self.assertRaises(TypeError, permutations, None) # pool is not iterable - self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative self.assertEqual(list(permutations('abc', 32)), []) # r > n - self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -520,7 +514,7 @@ def permutations2(iterable, r=None): if len(set(indices)) == r: yield tuple(pool[i] for i in indices) - for n in range(7): + for n in range(5): values = [5*x-12 for x in range(n)] for r in range(n+2): result = list(permutations(values, r)) @@ -537,9 +531,6 @@ def permutations2(iterable, r=None): self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - self.pickletest(proto, permutations(values, r)) # test pickling - @support.bigaddrspacetest def test_permutations_overflow(self): with self.assertRaises((OverflowError, MemoryError)): @@ -778,7 +769,7 @@ def test_count_with_step_threading(self): def test_cycle(self): self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) self.assertEqual(list(cycle('')), []) - self.assertRaises(TypeError, cycle) + # self.assertRaises(TypeError, cycle) self.assertRaises(TypeError, cycle, 5) self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) @@ -1024,27 +1015,29 @@ def test_filter(self): self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2]) self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2]) self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6]) - self.assertRaises(TypeError, filter) - self.assertRaises(TypeError, filter, lambda x:x) - self.assertRaises(TypeError, filter, lambda x:x, range(6), 7) - self.assertRaises(TypeError, filter, isEven, 3) - self.assertRaises(TypeError, next, filter(range(6), range(6))) + # these tests raise dynamo exceptions, not TypeError + # self.assertRaises(TypeError, filter) + # self.assertRaises(TypeError, filter, lambda x:x) + # self.assertRaises(TypeError, filter, lambda x:x, range(6), 7) + # self.assertRaises(TypeError, filter, isEven, 3) + # dynamo raises Unsupported in this case + # self.assertRaises(TypeError, next, filter(range(6), range(6))) # check copy, deepcopy, pickle - ans = [0,2,4] - - c = filter(isEven, range(6)) - self.assertEqual(list(copy.copy(c)), ans) - c = filter(isEven, range(6)) - self.assertEqual(list(copy.deepcopy(c)), ans) - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - c = filter(isEven, range(6)) - self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans) - next(c) - self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:]) - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - c = filter(isEven, range(6)) - self.pickletest(proto, c) + # ans = [0,2,4] + + # c = filter(isEven, range(6)) + # self.assertEqual(list(copy.copy(c)), ans) + # c = filter(isEven, range(6)) + # self.assertEqual(list(copy.deepcopy(c)), ans) + # for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # c = filter(isEven, range(6)) + # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans) + # next(c) + # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:]) + # for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # c = filter(isEven, range(6)) + # self.pickletest(proto, c) @pickle_deprecated def test_filterfalse(self): @@ -1060,7 +1053,6 @@ def test_filterfalse(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, filterfalse(isEven, range(6))) - @skipIfTorchDynamo("infinite loop in torch dynamo") def test_zip(self): # XXX This is rather silly now that builtin zip() calls zip()... ans = [(x,y) for x, y in zip('abc',count())] @@ -1070,8 +1062,8 @@ def test_zip(self): self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3))) self.assertEqual(list(zip('abcdef')), lzip('abcdef')) self.assertEqual(list(zip()), lzip()) - self.assertRaises(TypeError, zip, 3) - self.assertRaises(TypeError, zip, range(3), 3) + # self.assertRaises(TypeError, zip, 3) + # self.assertRaises(TypeError, zip, range(3), 3) self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')], lzip('abc', 'def')) self.assertEqual([pair for pair in zip('abc', 'def')], @@ -1105,7 +1097,6 @@ def test_zip_tuple_reuse(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, zip('abc', count())) - @skipIfTorchDynamo("infinite loop in torch dynamo") def test_ziplongest(self): for args in [ ['abc', range(6)], @@ -1129,19 +1120,19 @@ def test_ziplongest(self): self.assertEqual(list(zip_longest('abc', 'defg', **{})), list(zip(list('abc')+[None], 'defg'))) # empty keyword dict - self.assertRaises(TypeError, zip_longest, 3) - self.assertRaises(TypeError, zip_longest, range(3), 3) - - for stmt in [ - "zip_longest('abc', fv=1)", - "zip_longest('abc', fillvalue=1, bogus_keyword=None)", - ]: - try: - eval(stmt, globals(), locals()) - except TypeError: - pass - else: - self.fail('Did not raise Type in: ' + stmt) + # self.assertRaises(TypeError, zip_longest, 3) + # self.assertRaises(TypeError, zip_longest, range(3), 3) + + # for stmt in [ + # "zip_longest('abc', fv=1)", + # "zip_longest('abc', fillvalue=1, bogus_keyword=None)", + # ]: + # try: + # eval(stmt, globals(), locals()) + # except TypeError: + # pass + # else: + # self.fail('Did not raise Type in: ' + stmt) self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')], list(zip('abc', 'def'))) @@ -1320,7 +1311,6 @@ def test_product(self): self.assertEqual(list(product(*(args*r))), list(product(*args, **dict(repeat=r)))) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) - self.assertRaises(TypeError, product, range(6), None) def product1(*args, **kwds): pools = list(map(tuple, args)) * kwds.get('repeat', 1) @@ -1360,7 +1350,8 @@ def product2(*iterables, repeat=1): argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), set('abcdefg'), range(11), tuple(range(13))] for i in range(100): - args = [random.choice(argtypes) for j in range(random.randrange(5))] + with torch._dynamo.set_fullgraph(fullgraph=False): + args = [random.choice(argtypes) for j in range(random.randrange(5))] expected_len = prod(map(len, args)) self.assertEqual(len(list(product(*args))), expected_len) self.assertEqual(list(product(*args)), list(product1(*args))) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index d64334533f9b..6b7662cbe646 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,7 @@ from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor from torch.utils.checkpoint import ( checkpoint, @@ -28,7 +28,6 @@ ) -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -243,7 +242,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_function_via_global_checkpoint(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -262,7 +261,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_function_with_kwargs(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -282,7 +281,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_sequential_layers(self, device): def gn(x): x = x.cos() @@ -307,7 +306,7 @@ def fn(x): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) - @requires_cuda + @requires_cuda_and_triton def test_tags_multiple_checkpoints(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -329,7 +328,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_module(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -357,7 +356,7 @@ def fn(x): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) - @requires_cuda + @requires_cuda_and_triton def test_tags_decomps(self, device): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): @@ -392,7 +391,7 @@ def fn(x): ) self._validate(fn, backend, x) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch(fallback_random=True) def test_tags_recomputed_rand(self, device): def gn(x, y): @@ -416,7 +415,7 @@ def fn(x, y): backend = "inductor" self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch(fallback_random=True) def test_tags_rand(self, device): def gn(x, y): @@ -443,7 +442,7 @@ def fn(x, y): backend = "inductor" self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch(fallback_random=True) def test_tags_dropout(self, device): # Figure out a way to test the number of inductor_random calls @@ -551,7 +550,7 @@ def _factory_fn(): Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""", ) - @requires_cuda + @requires_cuda_and_triton def test_fallback(self, device): def gn(x, y): torch._dynamo.graph_break() @@ -579,7 +578,7 @@ def fn(x, y): self.assertEqual(cnt.op_count, 2) self.assertEqual(len(cnt.graphs), 2) - @requires_cuda + @requires_cuda_and_triton def test_kwargs(self, device): def gn(x, y, z=None): a = torch.matmul(x, y) @@ -613,7 +612,7 @@ def fn(x, y, z): body_function = getattr(cnt.graphs[0], wrap_node.args[0].name) self.assertEqual(op_count(body_function), 2) - @requires_cuda + @requires_cuda_and_triton def test_symints_location(self, device): def gn(x, y): return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) @@ -643,7 +642,7 @@ def fn(x, y): wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) self.assertEqual(len(wrap_node.args), 3) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_must_recompute(self, device): def context_fn_must_recompute_mm(): @@ -710,7 +709,7 @@ def fn(x): ), ) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): def selective_checkpointing_context_fn(): @@ -757,7 +756,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_tensor_subclass(self, device): def selective_checkpointing_context_fn(): @@ -807,7 +806,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_custom_rule(self, device): def _get_custom_policy(meta): @@ -872,7 +871,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_partial_ctx_fn(self, device): def selective_checkpointing_context_fn(no_recompute_list): @@ -918,7 +917,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_outplace_op(self, device): def selective_checkpointing_context_fn(): @@ -963,7 +962,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_list_ops(self, device): def selective_checkpointing_context_fn(): @@ -1011,7 +1010,7 @@ def fn(x, y): "In-place op support in selective checkpointing + torch.compile " "requires TorchDispatchMode + torch.compile work to complete" ) - @requires_cuda + @requires_cuda_and_triton def test_compile_selective_checkpoint_inplace_op(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -1057,7 +1056,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) def test_compile_selective_checkpoint_random_op(self, device): @@ -1117,7 +1116,7 @@ def fn(x): self._validate(fn, backend, x, skip_check=not preserve_rng_state) self._compare_orig_and_checkpointed_fns(gn, fn, x) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_invalid_context(self): def gn(x, y): @@ -1155,7 +1154,7 @@ def fn(x, y): ): self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_compile_selective_checkpoint_parametrization(self): def sac_policy(): @@ -1249,7 +1248,7 @@ def reset_parameters(self): self.assertEqual(input.grad, input_compiled.grad) @skipIfRocm - @requires_cuda + @requires_cuda_and_triton def test_autocast_flash_attention(self, device): def fn(primals_1, primals_2, primals_3): return torch.ops.aten._scaled_dot_product_efficient_attention.default( @@ -1273,7 +1272,7 @@ def gn(*args): res = opt_gn(*args) self.assertEqual(ref, res) - @requires_cuda + @requires_cuda_and_triton def test_error_msg(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -1297,7 +1296,7 @@ def fn(x): ): opt_fn(x) - @requires_cuda + @requires_cuda_and_triton def test_list_inputs(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -1322,7 +1321,7 @@ def fn(x, ys): res = opt_fn(x, [y, z]) self.assertEqual(ref, res) - @requires_cuda + @requires_cuda_and_triton def test_pattern_matcher(self, device): # Check that the sdpa op is recomputed in the backward graph # tests percolate_tags @@ -1402,7 +1401,7 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): ) @requires_distributed() - @requires_cuda + @requires_cuda_and_triton def test_distributed_utils_checkpoint_wrapper(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as dist_checkpoint_wrapper, @@ -1428,7 +1427,7 @@ def forward(self, x): self.assertEqual(ref, res) @requires_distributed() - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_dynamo_does_not_trace_getattr_as_top_frame(self): # inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do. diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 0d4a1f01f9a3..7e6895ccde5c 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -37,7 +37,7 @@ skipIfWindows, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_triton -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -447,8 +447,8 @@ def test_non_bundled_to_bundled_config_change(self): def fn(x, y): return (x * 2, y @ y) - a = torch.rand(25, device="cuda") - b = torch.rand(5, 5, device="cuda") + a = torch.rand(25, device=GPU_TYPE) + b = torch.rand(5, 5, device=GPU_TYPE) compiled_fn = torch.compile(fn, backend="inductor") self.assertEqual(fn(a, b), compiled_fn(a, b)) @@ -690,7 +690,7 @@ def fn(a, b): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) - @requires_cuda + @requires_cuda_and_triton @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -746,7 +746,7 @@ def backward(ctx, grad_output): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - @requires_cuda + @requires_cuda_and_triton @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -788,8 +788,7 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) - @requires_cuda - @requires_triton() + @requires_cuda_and_triton @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -822,7 +821,7 @@ def backward(ctx, grad_output): def fn(a): return MyAutogradFunction.apply(a) - a = torch.randn(5, device="cuda", requires_grad=True) + a = torch.randn(5, device=GPU_TYPE, requires_grad=True) a2 = a.clone().detach_().requires_grad_(True) compiled_fn = torch.compile(fn, backend="inductor") result = compiled_fn(a) @@ -842,6 +841,214 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"autograd_cache_allow_custom_autograd_functions": True}) + def test_custom_autograd_function_with_custom_triton_kernel_cache_invalidation( + self, + ): + @triton.jit + def my_jit(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + y = torch.ops.test.my_triton_op(x) + ctx.save_for_backward(y) + ctx.foo = x.cos() + return y + + @staticmethod + def backward(ctx, grad_output): + result = ctx.saved_tensors[0] + return grad_output * result + ctx.foo * grad_output + + def fn(a): + return MyAutogradFunction.apply(a) + + a = torch.randn(5, device=GPU_TYPE, requires_grad=True) + a2 = a.clone().detach_().requires_grad_(True) + a3 = a.clone().detach_().requires_grad_(True) + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + result.sum().backward() + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Clear dynamo and run again. Should be a cache hit. + counters.clear() + self._clear_dynamo_and_codecache() + result = compiled_fn(a2) + self.assertEqual(fn(a2), result) + result.sum().backward() + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + + # Now modify the source code of my_jit by redefining it + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) # Changed from +1 to +2 + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + # Clear dynamo and run again. Should be a cache miss due to modified source code. + counters.clear() + self._clear_dynamo_and_codecache() + compiled_fn = torch.compile(fn, backend="inductor") + + result = compiled_fn(a3) + # Assert that after changing the source code, the cache no longer hits + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(fn(a3), result) + + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_triton_op_cache_invalidation(self): + from torch._library import capture_triton + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + capture_triton(my_jit)[1,](y) + return y + + def fn(a): + return torch.ops.test.my_triton_op(a) + + a = torch.randn(5, device=GPU_TYPE) + a2 = a.clone().detach_() + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + self._clear_dynamo_and_codecache() + + # Redefine the triton op + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a2) + + # Second run should still miss + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + self.assertEqual(fn(a2), result) + + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @unittest.expectedFailure # Currently ops that call other ops does not properly invalidate cache + def test_triton_op_cache_multiple_ops_invalidation(self): + @triton.jit + def my_jit(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @triton.jit + def my_jit2(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + torch._library.capture_triton(my_jit2)[1,](y) + return y + + @torch._library.triton_op("test::my_triton_op2", mutates_args=()) + def my_triton_op2(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch.ops.test.my_triton_op(y) + return y + + def fn(a): + return torch.ops.test.my_triton_op2(a) + + a = torch.randn(5, device=GPU_TYPE) + a2 = a.clone().detach_() + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + self._clear_dynamo_and_codecache() + + # Redefine the triton op + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + torch._library.capture_triton(my_jit2)[1,](y) + return y + + @torch._library.triton_op("test::my_triton_op2", mutates_args=()) + def my_triton_op2(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch.ops.test.my_triton_op(y) + return y + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a2) + + # Second run should still miss + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + self.assertEqual(fn(a2), result) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch({"fx_graph_cache": True}) @functorch_config.patch({"enable_autograd_cache": True}) @@ -1260,7 +1467,7 @@ def f(): result = f() self.assertEqual(result[0].device, torch.device("cuda:1")) - @requires_cuda + @requires_cuda_and_triton @inductor_config.patch("fx_graph_cache", True) @inductor_config.patch("fx_graph_remote_cache", False) @functorch_config.patch({"enable_autograd_cache": True}) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 6f460b402404..de5afce14598 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -8,10 +8,13 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils -from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda +from torch.testing._internal.triton_utils import ( + HAS_CUDA_AND_TRITON, + requires_cuda_and_triton, +) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: import triton from torch.testing._internal.triton_utils import add_kernel @@ -1473,7 +1476,7 @@ def fn(): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_basic(self): class Add(torch.autograd.Function): @staticmethod @@ -1504,7 +1507,7 @@ def f(x, y): loss.backward() self.assertEqual(x + y, z) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_multiple_out(self): class Add(torch.autograd.Function): @staticmethod diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 9d61bbf31acb..be1470c08e79 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -16,10 +16,7 @@ onlyHPU, ) from torch.testing._internal.common_utils import skipIfHpu -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.triton_utils import requires_cuda_and_triton class Seq(torch.nn.Module): @@ -133,7 +130,7 @@ def test_aot_eager_decomp_partition(self, device): def test_aot_ts(self, device): self._check_backend_works("aot_ts", device) - @requires_cuda + @requires_cuda_and_triton def test_aot_cudagraphs(self, device): self._check_backend_works("cudagraphs", device) diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 18cdf78c61f2..607b502351aa 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import unittest import unittest.mock as mock import torch @@ -13,10 +12,6 @@ ) from torch._higher_order_ops.schema import find_hop_schema from torch.testing._internal.common_utils import instantiate_parametrized_tests -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") def normalize_graph(gm): diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index 8112a2e89e95..e51636462631 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -8,7 +8,7 @@ from torch._dynamo.test_case import run_tests, TestCase from torch._guards import CompileId from torch.testing._internal.common_utils import TEST_WITH_ROCM -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda_and_triton class CallbackTests(TestCase): @@ -61,7 +61,7 @@ def test_counter_assertion(self) -> None: @unittest.skipIf( TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs" ) - @unittest.skipIf(not HAS_CUDA, "requires triton") + @requires_cuda_and_triton @torch._inductor.config.patch(force_disable_caches=True) def test_triggers(self) -> None: torch._dynamo.reset() diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index a5a350c0d1ad..161f9674cd4a 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] -import unittest from contextlib import contextmanager from importlib import import_module @@ -11,19 +10,18 @@ from torch._inductor.compiler_bisector import CompilerBisector from torch._inductor.test_case import TestCase from torch.library import _scoped_library, Library -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda_and_triton aten = torch.ops.aten -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 -@requires_cuda +@requires_cuda_and_triton class TestCompilerBisector(TestCase): test_ns = "_test_bisector" diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index ea39f6fbd9e1..eae4d06d9890 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import os -import unittest from unittest.mock import patch import torch @@ -10,11 +9,8 @@ from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string from torch._dynamo.test_case import TestCase from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.inductor_utils import HAS_CUDA -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") - f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 3b29e5e96119..9bf982c5b90e 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -10,6 +10,7 @@ import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import skipIfWindows def my_custom_function(x): @@ -892,6 +893,9 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) + @skipIfWindows( + msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." + ) def test_disable_recursive_false(self): def fn2(x): return x + 1 diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 063e6863b870..e91e7ef52097 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -62,7 +62,7 @@ def fn(): Developer debug context: aten.nonzero.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0036.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html from user code: File "test_error_messages.py", line N, in fn @@ -84,7 +84,7 @@ def fn(): Developer debug context: aten.linalg_lstsq.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0037.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0037.html from user code: File "test_error_messages.py", line N, in fn @@ -107,7 +107,7 @@ def fn(x): Developer debug context: call_method TensorVariable() item () {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0124.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html from user code: File "test_error_messages.py", line N, in fn @@ -131,7 +131,7 @@ def fn(x): Developer debug context: aten.equal.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0033.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0033.html from user code: File "test_error_messages.py", line N, in fn @@ -159,7 +159,7 @@ def fn(lst): Developer debug context: TensorVariable() - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0207.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0207.html from user code: File "test_error_messages.py", line N, in fn @@ -185,7 +185,7 @@ def fn(it): Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ [] {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0156.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html from user code: File "test_error_messages.py", line N, in fn @@ -214,7 +214,7 @@ def fn(x, items): Developer debug context: call_method UserDefinedObjectVariable(dict_items) __iter__ [] {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0156.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html from user code: File "test_error_messages.py", line N, in fn @@ -238,7 +238,7 @@ def fn(it): Developer debug context: call_function UserDefinedObjectVariable(zip) [] {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0147.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0147.html from user code: File "test_error_messages.py", line N, in fn @@ -262,7 +262,7 @@ def fn(obj): Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0142.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html from user code: File "test_error_messages.py", line N, in fn @@ -293,7 +293,7 @@ def fn(x): return x + 1 - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0219.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0219.html""", ) def test_unsupported_builtin(self): @@ -312,7 +312,7 @@ def fn(): Developer debug context: builtin print [] False - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0059.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0059.html from user code: File "test_error_messages.py", line N, in fn @@ -338,7 +338,7 @@ def post_munge(s): Developer debug context: module: unittest.case, qualname: skip, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html from user code: File "test_error_messages.py", line N, in fn @@ -360,7 +360,7 @@ def fn(): Developer debug context: module: torch._dynamo.decorators, qualname: disable, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html from user code: File "test_error_messages.py", line N, in fn @@ -389,7 +389,7 @@ def post_munge(s): Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0008.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0008.html from user code: File "test_error_messages.py", line N, in fn @@ -411,7 +411,7 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -432,7 +432,7 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{'msg': ConstantVariable(str: 'test graph break')}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -454,7 +454,7 @@ def fn(): Developer debug context: module: _warnings, qualname: warn, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html from user code: File "test_error_messages.py", line N, in fn @@ -483,7 +483,7 @@ def fn(x): Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) @scoped_load_inline @@ -530,7 +530,7 @@ def f(x): Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) cpp_source = """ @@ -582,7 +582,7 @@ def fn(x, y): Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0038.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html from user code: File "test_error_messages.py", line N, in fn @@ -604,7 +604,7 @@ def fn(): Developer debug context: raised exception RuntimeError([ConstantVariable(str: 'test')]) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0088.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html from user code: File "test_error_messages.py", line N, in fn @@ -630,7 +630,7 @@ def fn(mod): Developer debug context: Foo - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0119.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0119.html from user code: File "test_error_messages.py", line N, in fn @@ -659,7 +659,7 @@ def fn(mod, x): Developer debug context: nn.Module subclass: Foo, name: attr, attribute type: module - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0161.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0161.html from user code: File "test_error_messages.py", line N, in fn @@ -689,7 +689,7 @@ def fn(): Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)] - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0066.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html from user code: File "test_error_messages.py", line N, in fn @@ -705,7 +705,7 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html""", ) def test_load_build_class(self): @@ -726,7 +726,7 @@ class Foo: Developer debug context: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0075.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0075.html from user code: File "test_error_messages.py", line N, in fn @@ -759,7 +759,7 @@ def post_munge(s): Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. Developer debug context: GET_AITER with args (, Instruction(GET_AITER) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0082.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0082.html from user code: File "test_error_messages.py", line N, in fn @@ -790,7 +790,7 @@ def post_munge(s): Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0092.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html from user code: File "test_error_messages.py", line N, in fn @@ -826,7 +826,7 @@ def post_munge(s): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_reconstruction_failure_gb torch.compile(fn, backend="eager")() @@ -846,7 +846,7 @@ def post_munge(s): Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0092.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html from user code: File "test_error_messages.py", line N, in fn @@ -875,7 +875,7 @@ def fn(x): Developer debug context: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0087.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0087.html from user code: File "test_error_messages.py", line N, in fn @@ -899,7 +899,7 @@ def fn(x): Developer debug context: attempted to jump with TensorVariable() - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html from user code: File "test_error_messages.py", line N, in fn @@ -966,7 +966,7 @@ def fn(x): Developer debug context: value: ConstantVariable(bool: False) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0034.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0034.html from user code: File "test_error_messages.py", line N, in fn @@ -1010,7 +1010,7 @@ def gn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -1063,7 +1063,7 @@ def gn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -1099,7 +1099,7 @@ def hn(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_nested_compile_user_frames torch.compile(fn, backend="eager")(torch.randn(3)) @@ -1213,7 +1213,7 @@ def f3(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_graph_break_traceback_collapsed_resume_frames f1(torch.randn(3)) @@ -1303,7 +1303,7 @@ def post_munge(s): Developer debug context: .f at 0xmem_addr> - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0098.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html from user code: File "test_error_messages.py", line N, in outer @@ -1325,7 +1325,7 @@ def g(x): Developer debug context: .g at 0xmem_addr> - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0098.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html from user code: File "test_error_messages.py", line N, in outer @@ -1351,7 +1351,7 @@ def forward(self, x): Developer debug context: source: LocalSource(local_name='fn', is_input=True, dynamism=None, is_derefed_cell_contents=False) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0148.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0148.html from user code: File "test_error_messages.py", line N, in outer diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index a7cb02132bd5..ad56417ed568 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -43,7 +43,7 @@ def fn001(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_exc.py", line N, in fn001 @@ -183,7 +183,7 @@ def fn001(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_exc.py", line N, in test_graph_break_log torch.compile(fn001, backend="eager")(torch.randn(1)) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 4afb6acc5d87..31505b9445d4 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -268,6 +268,48 @@ def test_itertools_product(a, b): v = v + x * i return v + def test_itertools_product_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(*args, **kwargs): + return torch.tensor(list(itertools.product(*args, **kwargs))) + + self.assertRaises(Unsupported, fn, [1, 2, 3], fake_arg=1) + + @make_test + def test_itertools_product_various_iterators(a, b): + itertools.product( + [a, b], + zip([1, 2], [3, 4]), + map(lambda x: x, [1, 2]), + filter(lambda x: True, [1, 2]), + ) + return a + + def test_itertools_permutations_basic(self): + def fn(): + return torch.tensor(list(itertools.permutations([1, 2, 3], 2))) + + actual = torch.compile(fn, backend="eager", fullgraph=True)() + expected = fn() + self.assertEqual(actual, expected) + + def test_itertools_permutations_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(*args, **kwargs): + return torch.tensor(list(itertools.permutations(*args, **kwargs))) + + self.assertRaises(Unsupported, fn) + self.assertRaises(Unsupported, fn, [1, 2, 3], 1, 2) + self.assertRaises(Unsupported, fn, [1, 2, 3], fake_arg=1) + + @make_test + def test_itertools_permutations_various_iterators(a, b): + itertools.permutations([a, b]) + itertools.permutations(zip([1, 2], [3, 4])) + itertools.permutations(map(lambda x: x, [1, 2])) + itertools.permutations(filter(lambda x: True, [1, 2])) + return a + @make_test def test_itertools_chain(a, b): v = a @@ -4094,6 +4136,7 @@ def func(): self.assertEqual(cnts.frame_count, 3) self.assertEqual(cnts.op_count, 6) + @torch._dynamo.config.patch(assume_dunder_attributes_remain_unchanged=False) def test_meth_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg @@ -5030,6 +5073,29 @@ def __getattribute__(self, name): with self.assertRaises(Unsupported): a.call_function(None, [], {}) + def test_inspect_method_source(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + + def check(self, x): + return x * 2 + + def forward(self, x): + return x * 2 + + mod = Mod() + + def fn(x): + inspect.signature(mod.check).parameters.items() + return mod(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + instantiate_parametrized_tests(FunctionTests) instantiate_parametrized_tests(DefaultsTests) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 0164b6f9c680..47e9ee3cb888 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -3,20 +3,73 @@ import logging import subprocess import sys -import tempfile import unittest import torch import torch._logging.structured import torch.distributed as dist +from torch._inductor.codecache import WritableTempFile from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE +from torch.utils._triton import has_triton if torch.distributed.is_available(): from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +if has_triton(): + import triton + import triton.language as tl + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + @triton.jit + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 1024}, + num_warps=4, + num_stages=2, + pre_hook=init_to_zero("output_ptr"), + ) + ], + pre_hook=init_to_zero("output_ptr"), + post_hook=init_to_zero("output_ptr"), + key=["n_elements"], + ) + @triton.jit + def add_kernel_autotune( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu + class FxGraphRunnableArtifactFilter(logging.Filter): def filter(self, record): @@ -79,7 +132,7 @@ def _exec_and_verify_payload(self): self.assertTrue(payload, "Expected fx_graph_runnable payload but got nothing") self.assertIn("def forward", payload) # sanity-check for actual FX code - with tempfile.NamedTemporaryFile("w", suffix=".py") as tmp: + with WritableTempFile("w", suffix=".py") as tmp: tmp.write(payload) tmp.flush() res = subprocess.run( @@ -100,6 +153,41 @@ def f(x): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() + @unittest.skipUnless(has_triton(), "Triton not available") + def test_user_defined_triton_kernel_autotune(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = output.numel() + + def grid( + meta, + ): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + add_kernel_autotune[grid](x, y, output, n_elements) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + + @unittest.skipUnless(has_triton(), "Triton not available") + @requires_gpu + def test_user_defined_triton_kernel(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = x.numel() + add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 94fd37ec8ac1..27401f36e02f 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1,5 +1,7 @@ # Owner(s): ["module: dynamo"] +import abc import functools +import inspect import unittest import weakref @@ -880,8 +882,9 @@ def hook(guard_wrapper, f_locals, builder): counter += 1 class Bar: - x = 4 - y = torch.randn(4) + def __init__(self): + self.x = 4 + self.y = torch.randn(4) bar = Bar() @@ -930,7 +933,7 @@ def hook(guard_wrapper, f_locals, builder): # Check types of foo.x foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source) - self.assertTrue(foo_x_mgr.is_guarded_value_dict()) + self.assertTrue(issubclass(foo_x_mgr.get_type_of_guarded_value(), dict)) # Check types of foo.x["a"] foo_x_a_source = DictGetItemSource(foo_x_source, "a") @@ -945,12 +948,14 @@ def hook(guard_wrapper, f_locals, builder): # Check types of foo.z foo_z_source = AttrSource(foo_source, "z") foo_z_mgr = builder.get_guard_manager_from_source(foo_z_source) - self.assertTrue(foo_z_mgr.is_guarded_value_empty_dict()) + self.assertTrue(issubclass(foo_z_mgr.get_type_of_guarded_value(), dict)) # Check types of mod mod_source = LocalSource("mod") mod_mgr = builder.get_guard_manager_from_source(mod_source) - self.assertTrue(mod_mgr.is_guarded_value_nn_module()) + self.assertTrue( + issubclass(mod_mgr.get_type_of_guarded_value(), torch.nn.Module) + ) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) with install_guard_manager_testing_hook(hook): @@ -1005,6 +1010,12 @@ def hook(guard_wrapper, f_locals, builder): from torch._dynamo.source import AttrSource, LocalSource foo_source = LocalSource("foo") + foo_mgr = builder.get_guard_manager_from_source(foo_source) + for accessor in foo_mgr.get_accessors(): + if isinstance(accessor, GetAttrGuardAccessor): + self.assertTrue( + accessor.get_attr_name() in ("a", "b", "c", "d", "e") + ) # Check types of foo.a foo_a_source = AttrSource(foo_source, "a") @@ -1141,21 +1152,32 @@ def hook(guard_wrapper, f_locals, builder): def test_nn_module_tag_safe(self): class Foo(torch.nn.Module): + c = 2 + def __init__(self): super().__init__() self.a = 4 + def check(self, x): + return True + def forward(self, x): - return x + self.a + inspect.signature(self.check).parameters.items() + return x + self.a + self.c foo = Foo() - class Baz(torch.nn.Module): + class Env(metaclass=abc.ABCMeta): # noqa: B024 + pass + + class Baz(torch.nn.Module, Env): def __init__(self): super().__init__() self.foo = foo def forward(self, x): + if "Foo" in str(type(self).__mro__): + x = torch.sin(x) return self.foo(x) baz = Baz() @@ -1170,7 +1192,6 @@ def fn(x): from utils import install_guard_manager_testing_hook def hook(guard_wrapper, f_locals, builder): - from torch._C._dynamo.guards import GetGenericDictGuardAccessor from torch._dynamo.source import LocalSource baz_source = LocalSource("baz") @@ -1180,27 +1201,6 @@ def hook(guard_wrapper, f_locals, builder): self.assertTrue(baz_mgr.is_tag_safe()) self.assertTrue(baz_mgr.is_tag_safe_root()) - # Check tagness of baz.__dict__ - self.assertTrue(len(baz_mgr.get_accessors()) == 1) - dunder_dict_accessor = baz_mgr.get_accessors()[0] - self.assertTrue( - isinstance(dunder_dict_accessor, GetGenericDictGuardAccessor) - ) - - dunder_dict_mgr = baz_mgr.get_child_managers()[0] - self.assertTrue(dunder_dict_mgr.is_tag_safe()) - self.assertFalse(dunder_dict_mgr.is_tag_safe_root()) - - # Check tagness of baz.__dict__["_modules"] - modules_mgr = dunder_dict_mgr.get_child_managers()[0] - self.assertTrue(modules_mgr.is_tag_safe()) - self.assertFalse(modules_mgr.is_tag_safe_root()) - - # Check tagness of baz.__dict__["_modules"]["foo"] - modules_foo_mgr = modules_mgr.get_child_managers()[0] - self.assertTrue(modules_foo_mgr.is_tag_safe()) - self.assertFalse(modules_foo_mgr.is_tag_safe_root()) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) with install_guard_manager_testing_hook(hook): opt_fn(torch.randn(4, 4)) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 10808c922b3f..969460364630 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -1325,6 +1325,27 @@ def getattr_new(*args, **kwargs): finally: builtins_dict["getattr"] = getattr_original + def test_skipped_objects(self): + def foo(): + pass + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.code = foo.__code__ + self.foo = foo + self.p = torch.nn.Parameter(torch.randn(3, 2)) + + def forward(self, x): + z = x + 1 + for p in self.parameters(): + z += p + return z + + m = Module() + ref, loaded = self._test_serialization("TENSOR_MATCH", m, torch.randn(3, 2)) + self._test_check_fn(ref, loaded, {"self": m, "x": torch.randn(3, 2)}, True) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index b9c1ff3a61fe..5844a13fcad0 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -38,11 +38,8 @@ xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db -from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.triton_utils import requires_cuda_and_triton def count_ops(gm, args, freq, op): @@ -6845,7 +6842,7 @@ def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): for arg, cloned_arg in zip(args, cloned_args): self.assertEqual(arg.grad, cloned_arg.grad) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_function(self): def gn(x, y): @@ -6864,7 +6861,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_function_with_kwargs(self): def gn(x, y): @@ -6887,7 +6884,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_dropout(self): def gn(x, y): @@ -6913,7 +6910,7 @@ def fn(x, y): fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_dropout_inductor(self): def gn(x, y): @@ -6932,7 +6929,7 @@ def fn(x, y): fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_fallback(self): def gn(x, y): @@ -6963,7 +6960,7 @@ def fn(x, y): self.assertEqual(cnt.op_count, 2) self.assertEqual(len(backend.graphs), 2) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_module(self): class MockModule(torch.nn.Module): @@ -7216,7 +7213,7 @@ def false_branch(x): class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase): - @requires_cuda + @requires_cuda_and_triton @parametrize("backend", ("aot_eager", "inductor")) @ops( list(filter(lambda op: op.name not in xfail_hops_compile, hop_db)), diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 2ecf73d0b9b5..a5a6ee54aa74 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,21 +21,29 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, + IS_WINDOWS, munge_exc, skipIfTorchDynamo, + skipIfWindows, TEST_XPU, xfailIf, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU +from torch.testing._internal.inductor_utils import ( + HAS_CUDA_AND_TRITON, + HAS_XPU_AND_TRITON, +) from torch.testing._internal.logging_utils import ( LoggingTestCase, make_logging_test, make_settings_test, ) +from torch.testing._internal.triton_utils import requires_cuda_and_triton + +requires_gpu = unittest.skipUnless( + HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON, "requires cuda or xpu with triton" +) -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") -requires_gpu = unittest.skipUnless(HAS_CUDA or HAS_XPU, "requires cuda or xpu") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -131,7 +139,7 @@ def test_fusion(self, records): self.assertGreater(len(records), 0) self.assertLess(len(records), 8) - @requires_cuda + @requires_cuda_and_triton @make_logging_test(cudagraphs=True) def test_cudagraphs(self, records): fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) @@ -244,7 +252,7 @@ def throw(x): exitstack.close() @requires_distributed() - @requires_cuda + @requires_cuda_and_triton @make_logging_test(ddp_graphs=True) def test_ddp_graphs(self, records): class ToyModel(torch.nn.Module): @@ -522,7 +530,7 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) - lines = stderr.decode().split("\n") + lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -538,6 +546,7 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() + @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" @@ -717,10 +726,10 @@ def f(x, y, z): self.assertExpectedInline( munge_shape_guards(record.getMessage()), """\ -+- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # -+- __SHAPE_GUARD__: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) -+- __SHAPE_GUARD__: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # -+- __SHAPE_GUARD__: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 ++- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # ++- __SHAPE_GUARD__: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) ++- __SHAPE_GUARD__: ((2*L['y'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # ++- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 ) @make_logging_test(guards=True) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a4da77b4c98d..624f0603678a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -16,11 +16,13 @@ import math import operator import os +import pickle import random import sys import tempfile import threading import traceback +import types import typing import unittest import unittest.mock as mock @@ -54,6 +56,7 @@ ) from torch._dynamo.utils import call_size, counters, ifdynstaticdefault from torch._dynamo.variables import builder +from torch._inductor.codecache import WritableTempFile from torch._inductor.utils import fresh_cache, run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -8519,6 +8522,50 @@ def global_context_capture_fn(frame_summary): self.assertEqual(seen_frames[0].name, "fn") self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") + def test_fullgraph_capture(self): + def foo(x): + return x + x.shape[0] + + compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo) + compiled_foo(torch.randn(3, 2)) + compiled_foo(torch.randn(4)) + artifacts = compiled_foo.get_artifacts() + + guarded_codes = artifacts.dynamo_artifacts.guarded_codes + backend_ids = list(artifacts.backend_inputs.keys()) + gms = [b.graph_module for b in artifacts.backend_inputs.values()] + + def _convert_to_ep_demo(code, backend_id, gm, args): + # Inject compiled function as the original gm + new_globals = copy.copy(globals()) + new_globals[backend_id] = gm + # Minimal boilerplate to setup a callable. + SerializedCode = type(code.dynamo_code) + dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code) + guards_state = pickle.loads(code.guards_state) + guard_manager = torch._dynamo.guards.CheckFunctionManager( + foo.__code__, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + runtime_global_scope=new_globals, + ).guard_manager + + class ModuleForExport(torch.nn.Module): + def forward(self, x): + return types.FunctionType(dynamo_bytecode, new_globals)(x) + + m = ModuleForExport() + return guard_manager, torch.export.export(m, args) + + guards0, ep0 = _convert_to_ep_demo( + guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),) + ) + self.assertTrue(guards0.check({"x": torch.randn(3, 2)})) + self.assertFalse(guards0.check({"x": torch.randn(4)})) + input0 = torch.randn(3, 2) + self.assertEqual(ep0.module()(input0), foo(input0)) + def test_torch_guards_stack_frame_register_inlining_deep(self): x = torch.tensor([0.5, 0.5]) y = torch.tensor([0.75, 0.75, 0.75, 0.75]) @@ -8556,64 +8603,15 @@ def global_context_capture_fn(frame_summary): self.assertEqual(seen_frames[1].name, "uwu_inline_me") self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") - def test_recompile_on_disable_1(self): - # fix https://github.com/pytorch/pytorch/issues/157399 + def test_error_on_recompile(self): @torch.compile(backend="eager") - def fn(x): - @torch._dynamo.disable - def inner(x): - return x + 10 - - return inner(x) + 1 - - with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): - try: - for i in range(5): - fn(torch.rand(2, 3)) - except torch._dynamo.exc.RecompileError as e: - self.fail("RecompileError raised unexpectedly: " + str(e)) - - def test_recompile_on_disable_2(self): - def outer(x, cond): - @torch._dynamo.disable() - def fn0(y): - return y + 1 - - @torch._dynamo.disable() - def fn1(y): - return y + 2 - - if cond: - f = fn0 - else: - f = fn1 - - torch._dynamo.graph_break() - # there will be a resume function here - return f(x) - - with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): - with self.assertRaises(torch._dynamo.exc.RecompileError): - x = torch.rand(2, 3) - self.assertEqual(outer(x, True), torch.compile(outer)(x, True)) - self.assertEqual(outer(x, False), torch.compile(outer)(x, False)) - - def test_create_nested_fn_cache_clear(self): - def outer(x): - @torch._dynamo.disable() - def f(y): - return y + 2 - - return f(x) + 1 + def fn(a, b): + return a + b - outer = torch.compile(outer) with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): with self.assertRaises(torch._dynamo.exc.RecompileError): - outer(torch.randn(3, 3)) - from torch._dynamo.utils import create_nested_fn_cache - - create_nested_fn_cache.clear() - outer(torch.randn(3, 3)) + fn(torch.rand(2, 3), torch.rand(2, 3)) + fn(torch.rand(2, 3), (1, 2, 3)) def test_guards_strip_function_call(self): from torch._dynamo.guards import strip_function_call @@ -11292,7 +11290,7 @@ def EEE(): def fn(): return 3 """ - with tempfile.NamedTemporaryFile(mode="w") as f: + with WritableTempFile(mode="w") as f: f.write(src) f.flush() from torch._dynamo.funcname_cache import get_funcname @@ -11993,6 +11991,19 @@ def fn(x, d): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) + def test_hash_hop(self): + associative_scan = importlib.import_module( + "torch._higher_order_ops.associative_scan" + ) + + @torch.compile(fullgraph=True) + def fn(y, s): + d = dict() + d[s] = y + return d[s] + 1.0 + + fn(torch.ones(2, 2, device="cpu"), associative_scan.AssociativeScanOp()) + def test_iter_type(self): @torch.compile(fullgraph=True) def fn(y): diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index f8869fd804ef..ec9c4473a17f 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -12,6 +12,8 @@ _push_on_torch_function_stack, ) from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode +from torch.testing._internal.common_utils import skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE from torch.testing._internal.triton_utils import requires_gpu from torch.utils._device import DeviceContext from torch.utils._python_dispatch import TorchDispatchMode @@ -33,6 +35,23 @@ def __torch_function__(self, func, types, args, kwargs=None): return super().__torch_function__(func, types, args, kwargs) +class HopDetectionError(Exception): + pass + + +class TestModeRaises(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + import torch._higher_order_ops + + if func == torch._higher_order_ops.flex_attention: + raise HopDetectionError("test") + + return super().__torch_function__(func, types, args, kwargs) + + class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -660,6 +679,51 @@ def forward(self, x): with torch.device("cpu"): torch.compile(mod, fullgraph=True)(x) + @requires_gpu + @skipIfXpu(msg="XPU does not support flex attention") + def test_hop(self): + import torch + import torch._higher_order_ops + from torch.nn.attention.flex_attention import ( + flex_attention as flex_attention_eager, + ) + + with torch.device(GPU_TYPE): + flex_attention = torch.compile(flex_attention_eager, dynamic=False) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "raised exception HopDetectionError([ConstantVariable(str: 'test')])", + ): + # This runs in fullgraph already + with TestModeRaises(): + flex_attention( + torch.ones(2, 2, 2, 2), + torch.ones(2, 2, 2, 2), + torch.ones(2, 2, 2, 2), + ) + + @requires_gpu + @skipIfXpu(msg="XPU does not support flex attention") + def test_hop_eager(self): + import torch + import torch._higher_order_ops + from torch.nn.attention.flex_attention import ( + flex_attention as flex_attention_eager, + ) + + with torch.device(GPU_TYPE): + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "raised exception HopDetectionError([ConstantVariable(str: 'test')])", + ): + with TestModeRaises(): + flex_attention_eager( + torch.ones(2, 2, 2, 2), + torch.ones(2, 2, 2, 2), + torch.ones(2, 2, 2, 2), + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index f38b9bc50277..7cac7eca7239 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3422,6 +3422,58 @@ def forward(self, x): compiled_mod = torch.compile(mod, backend="eager") compiled_mod(x) + def test_trace_delattr(self): + TMP_PREFIX = "_tmp_" + + def pre_forward_rename_hook(module: torch.nn.Module, _input: torch.Tensor): + param_name = "weight" + original_param = getattr(module, param_name) + setattr(module, TMP_PREFIX + param_name, original_param) + new_param = original_param + 1.0 + delattr(module, param_name) + setattr(module, param_name, new_param) + + def post_forward_restore_hook( + module: torch.nn.Module, _input: torch.Tensor, _output: torch.Tensor + ): + param_name = "weight" + tmp_param_name = TMP_PREFIX + param_name + original_param = getattr(module, tmp_param_name) + delattr(module, param_name) + setattr(module, param_name, original_param) + delattr(module, tmp_param_name) + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + torch.manual_seed(0) + model = SimpleModel() + + model.linear.register_forward_pre_hook(pre_forward_rename_hook) + model.linear.register_forward_hook(post_forward_restore_hook) + + input_tensor = torch.randn(4, 10) + + eager_output = model(input_tensor) + assert hasattr(model.linear, "weight") + assert not hasattr(model.linear, "_tmp_weight") + + torch.manual_seed(0) + model_to_compile = SimpleModel() + model_to_compile.linear.register_forward_pre_hook(pre_forward_rename_hook) + model_to_compile.linear.register_forward_hook(post_forward_restore_hook) + + compiled_model = torch.compile(model_to_compile, fullgraph=True) + compiled_output = compiled_model(input_tensor) + assert hasattr(model.linear, "weight") + assert not hasattr(compiled_model.linear, "_tmp_weight") + torch.testing.assert_close(eager_output, compiled_output) + devices = ["cuda", "hpu", "xpu"] instantiate_device_type_tests( diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 630f0428ca0d..fdd01135ea2f 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -22,8 +22,12 @@ instantiate_parametrized_tests, parametrize, skipIfRocm, + skipIfXpu, +) +from torch.testing._internal.inductor_utils import ( + HAS_CUDA_AND_TRITON, + HAS_XPU_AND_TRITON, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU def compute_loss_helper(x): @@ -93,9 +97,9 @@ def forward(self, x): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_basic_fn(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -137,9 +141,9 @@ def fn(x): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_lazy_backward(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -184,9 +188,9 @@ def fn(x): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_graph_break_bomb(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -248,9 +252,9 @@ def guard_filter_fn(guards): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamic_shape(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -367,9 +371,9 @@ def guard_filter_fn(guards): @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamo_cache_manual_load(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -404,9 +408,9 @@ def fn2(x): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_serialize(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -437,11 +441,12 @@ def fn2(x): @parametrize("device", ("cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) + @skipIfXpu @skipIfRocm def test_automatic_dynamo_autotune_cache(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x, y): @@ -472,9 +477,9 @@ def fn(x, y): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -505,9 +510,9 @@ def fn(x): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_graph_breaks(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x, l, r): @@ -551,9 +556,9 @@ def guard_filter_fn(guards): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_lazy_backward(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -580,9 +585,9 @@ def fn(x): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_call_function_from_resume(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") mod = torch.nn.Linear(2, 3, device=device) diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index 93e5274431be..623143ae4dcb 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -12,7 +12,9 @@ import torch.compiler.config import torch.nested from torch._dynamo.testing import CompileCounter +from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.utils import clear_caches, fresh_cache +from torch.testing._internal.common_utils import IS_WINDOWS class PgoTest(torch._dynamo.test_case.TestCase): @@ -55,6 +57,10 @@ def f(x): f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 1) + @torch._dynamo.config.patch( + force_parameter_static_shapes=False, + force_nn_module_property_static_shapes=False, + ) def test_whitelist_suggestion(self): cnts = CompileCounter() @@ -194,14 +200,16 @@ def run(): self.assertEqual(cnts.frame_count, 3) # parameter static shapes are forced static, so we recompile once - run() - self.assertEqual(cnts.frame_count, 2) + with torch._dynamo.config.patch( + force_parameter_static_shapes=False, + force_nn_module_property_static_shapes=False, + ): + run() + self.assertEqual(cnts.frame_count, 2) - # flags are flipped, PGO records dynamism, so params are dynamically compiled to start - torch._dynamo.config.force_parameter_static_shapes = False - torch._dynamo.config.force_nn_module_property_static_shapes = False - run() - self.assertEqual(cnts.frame_count, 1) + # because flags were flipped, params were included in PGO + run() + self.assertEqual(cnts.frame_count, 1) def test_njt(self): cnts = CompileCounter() @@ -322,8 +330,9 @@ def func(x): temp_dir1 = tempfile.TemporaryDirectory() temp_dir2 = tempfile.TemporaryDirectory() - path1 = os.path.join(temp_dir1.name, "example.py") - path2 = os.path.join(temp_dir2.name, "example.py") + # We need normalize_path_separator for Windows file path. + path1 = normalize_path_separator(os.path.join(temp_dir1.name, "example.py")) + path2 = normalize_path_separator(os.path.join(temp_dir2.name, "example.py")) cnts = CompileCounter() assert path1 != path2 @@ -341,7 +350,11 @@ def write_load_and_run(path): write_load_and_run(path1) self.assertEqual(cnts.frame_count, 2) state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) - self.assertTrue("hash(390fe689)" in state) + + # Windows can't create unification temp path: + # hash(a18a3259)C:/Users/Xuhan/AppData/Local/Temp/tmpx3hfkuqa/example.py + # Skip hash check + self.assertTrue("hash" if IS_WINDOWS else "hash(390fe689)" in state) self.assertTrue("/example.py:4:func:" in state) self.assertTrue(" L['x']: tensor size=[?] stride=[1]" in state) # We should compile this only once due to PGO. diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 0cafaf9878e6..9f3d41964195 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -7,7 +7,7 @@ import torch import torch._dynamo.test_case from torch.testing._internal.common_utils import IS_FBCODE -from torch.testing._internal.inductor_utils import requires_triton +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton from torch.utils._triton import ( has_triton_experimental_host_tma, has_triton_tensor_descriptor_host_tma, @@ -420,7 +420,7 @@ def create_tma(tensor): ) return tensor + 1, tma - x = torch.randn(128, 128, device="cuda") + x = torch.randn(128, 128, device=GPU_TYPE) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) @@ -441,7 +441,7 @@ def create_tma(tensor): ) return tensor + 1, tma - x = torch.randn(128, 128, device="cuda") + x = torch.randn(128, 128, device=GPU_TYPE) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index e833dd9df886..be6bf8085af2 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -211,7 +211,7 @@ def f(x): Developer debug context: call_method TensorVariable() item () {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950 + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950 ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e0a3f7a5223f..fe16e4906ef3 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7160,7 +7160,7 @@ def fn(): "Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.\n\n" " Developer debug context: \n\n" " For more details about this graph break, please visit: " - "https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0264.html" + "https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0264.html" ) self.assertEqual(explain_output.break_reasons[0].reason, expected_msg) @@ -7673,6 +7673,31 @@ def forward(self, x): out2 = torch.compile(model, backend="eager")(input.clone()) self.assertEqual(out1, out2) + @requires_cuda + def test_zero_dim_param_mixed_device_grad(self): + # cpu 0-dim params with cuda grads + # https://github.com/pytorch/pytorch/issues/160084 + class RegressionModel(torch.nn.Module): + def __init__(self, a=0, b=0): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + + def forward(self, x): + return x * self.a + self.b + + model = RegressionModel() + model.forward = torch.compile( + model.forward, backend="aot_eager", fullgraph=True + ) + inputs = torch.randn(4, 10).to("cuda") + out = model(inputs) + out.sum().backward() + self.assertIsNotNone(model.a.grad) + self.assertIsNotNone(model.b.grad) + self.assertEqual(model.a.grad.device, torch.device("cpu")) + self.assertEqual(model.b.grad.device, torch.device("cpu")) + def test_filter_warnings(self): x = torch.ones(2, 2, requires_grad=True) diff --git a/test/dynamo/test_sets.py b/test/dynamo/test_sets.py index 0871c0c1e565..7b6421ce6a25 100644 --- a/test/dynamo/test_sets.py +++ b/test/dynamo/test_sets.py @@ -174,7 +174,7 @@ def fn(x, s): Developer debug context: Python set containing torch.Tensor elements - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0222.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0222.html from user code: File "test_sets.py", line N, in fn diff --git a/test/dynamo/test_skip_guard_eval_unsafe.py b/test/dynamo/test_skip_guard_eval_unsafe.py index 5a31047aedb2..dc7d74bc3629 100644 --- a/test/dynamo/test_skip_guard_eval_unsafe.py +++ b/test/dynamo/test_skip_guard_eval_unsafe.py @@ -54,8 +54,9 @@ def fn(x, y): def test_post_recompile(self): class Foo: - a = 4 - b = 5 + def __init__(self): + self.a = 4 + self.b = 5 foo = Foo() diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index cde880df17a6..5897c129b267 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -10,6 +10,7 @@ import subprocess import tempfile import unittest.mock +from contextlib import contextmanager import torch import torch._dynamo.test_case @@ -21,12 +22,15 @@ from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_utils import find_free_port -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda_and_triton + + +if torch.distributed.is_available(): + from torch.testing._internal.distributed.fake_pg import FakeStore HAS_TLPARSE = shutil.which("tlparse") is not None requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -233,7 +237,7 @@ def test_compile_id_serialization_deserialization(self): with self.assertRaises(ValueError): torch._guards.CompileId.from_string(bad_cid) - @requires_cuda + @requires_cuda_and_triton def test_schedule(self): fn_opt = torch.compile(inductor_schedule_fn, backend="inductor") fn_opt(torch.ones(1000, 1000, device="cuda")) @@ -265,7 +269,7 @@ def test_schedule(self): self.assertParses() - @requires_cuda + @requires_cuda_and_triton def test_cudagraphs(self): fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) fn_opt(torch.ones(1000, 1000, device="cuda")) @@ -312,11 +316,11 @@ def fn(x, y): {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -523,7 +527,7 @@ def throw(x): self.assertParses() @requires_distributed() - @requires_cuda + @requires_cuda_and_triton def test_ddp_graphs(self): class ToyModel(torch.nn.Module): def __init__(self) -> None: @@ -1120,6 +1124,215 @@ def user_context() -> str: f(torch.randn(i + 2 // 3, 5)) step.next_step() + @contextmanager + def _setup_collective_schedule_capture(self): + """Helper to turn on and capture the 'inductor_collective_schedule' structured trace.""" + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter( + StructuredTraceTestingFilter("inductor_collective_schedule") + ) + trace_log.addHandler(payload_handler) + try: + yield payload_buffer + finally: + trace_log.removeHandler(payload_handler) + + @requires_tlparse + def test_collective_schedule_empty(self): + """Verify logging when no collective kernels are present (empty schedule).""" + with self._setup_collective_schedule_capture() as payload_buffer: + from torch._inductor.debug import log_collective_schedule + + log_collective_schedule([]) + + # With no collectives, artifact should not be logged and payload should be empty + self.assertNotIn('"inductor_collective_schedule"', self.buffer.getvalue()) + self.assertEqual(payload_buffer.getvalue().strip(), "") + + @requires_tlparse + @requires_distributed() + @torch._inductor.config.patch("fx_graph_cache", False) + def test_collective_schedule_real(self): + """Test collective schedule with _c10d_functional ops that work with FakeStore.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class CollectiveModule(torch.nn.Module): + def forward(self, x): + # Use _c10d_functional ops that actually trigger collective kernels + y = torch.ops._c10d_functional.all_reduce.default(x, "sum", "0") + y = torch.ops._c10d_functional.wait_tensor.default(y) + return y * 2 + + try: + with self._setup_collective_schedule_capture() as payload_buffer: + torch._dynamo.reset() + + mod = CollectiveModule() + compiled = torch.compile(mod, backend="inductor") + + compiled(torch.randn(4, 4)) + + # Verify collective schedule artifact was logged + self.assertIn('"inductor_collective_schedule"', self.buffer.getvalue()) + + payload_content = payload_buffer.getvalue().strip() + schedule = json.loads(payload_content) + self.assertIsInstance(schedule, list) + + # Verify expected collective operations are present + self.assertExpectedInline( + str(schedule), + """\ +['torch.ops._c10d_functional.all_reduce_.default', 'torch.ops._c10d_functional.wait_tensor.default']\ +""", + ) + self.assertParses() + finally: + dist.destroy_process_group() + + @contextmanager + def _setup_runtime_estimates_capture(self): + """Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace.""" + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter( + StructuredTraceTestingFilter("inductor_tlparse_runtime") + ) + trace_log.addHandler(payload_handler) + try: + yield payload_buffer + finally: + trace_log.removeHandler(payload_handler) + + @requires_tlparse + @requires_distributed() + @requires_cuda_and_triton + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_runtime_estimates_simple(self): + """Test runtime estimates logging with simple compute and collective ops.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + h = self.linear(x) + h = torch.relu(h) + + h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0") + h = torch.ops._c10d_functional.wait_tensor.default(h) + return h + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + + mod = SimpleModule().cuda() + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device="cuda")) + + # Verify runtime estimates artifact was logged + self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue()) + + payload_content = payload_buffer.getvalue().strip() + if payload_content: + data = json.loads(payload_content) + self.assertIn("ops", data) + ops = data["ops"] + + # Verify runtime estimates + compute_ops = [op for op in ops if op["type"] == "compute"] + collective_ops = [op for op in ops if op["type"] == "collective"] + + self.assertTrue(len(compute_ops) > 0 or len(collective_ops) > 0) + + # Just check each op has an estimated runtime value (any value, including 0) + for op in ops: + self.assertIn("estimated_runtime_ns", op) + self.assertIsNotNone(op["estimated_runtime_ns"]) + + self.assertParses() + finally: + dist.destroy_process_group() + + @requires_tlparse + @requires_distributed() + @requires_cuda_and_triton + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_runtime_estimates_mixed(self): + """Test runtime estimates logging with mixed compute and collective sequence.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class MixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.LayerNorm(4) + + def forward(self, x): + h = self.norm(x) + h = torch.nn.functional.gelu(h) + + h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0") + h = torch.ops._c10d_functional.wait_tensor.default(h) + + h = h * 0.5 + + gathered = torch.ops._c10d_functional.all_gather_into_tensor.default( + h, 2, "0" + ) + gathered = torch.ops._c10d_functional.wait_tensor.default(gathered) + + return gathered.sum(dim=0) + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + + mod = MixedModule().cuda() + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device="cuda")) + + # Verify runtime estimates artifact was logged + self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue()) + + payload_content = payload_buffer.getvalue().strip() + if payload_content: + data = json.loads(payload_content) + self.assertIn("ops", data) + ops = data["ops"] + + # Should have both compute and collective ops + op_types = {op["type"] for op in ops} + self.assertIn("compute", op_types) + self.assertIn("collective", op_types) + + # Just check each op has an estimated runtime value (any value, including 0) + for op in ops: + self.assertIn("estimated_runtime_ns", op) + self.assertIsNotNone(op["estimated_runtime_ns"]) + + self.assertParses() + finally: + dist.destroy_process_group() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 17a01f745d40..9d60cbe81c97 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -31,7 +31,7 @@ parametrize, subtest, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor from torch.utils._python_dispatch import return_and_correct_aliasing @@ -145,8 +145,6 @@ def mk_subclass_dense_subclass_dense(): VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") - compile_full_eager = torch.compile(backend="eager", fullgraph=True) @@ -3798,7 +3796,7 @@ def fn1(nt1, nt2): def test_basic_autograd(self): self._test_autograd("aot_eager") - @requires_cuda + @requires_cuda_and_triton def test_basic_autograd_inductor(self): self._test_autograd("inductor") diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 70ba2a8bd1bd..91862e6d3eb0 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -714,6 +714,40 @@ def fn(x, y): self.assertEqual(fn_opt(x, y3), fn(x, y3)) self.assertEqual(cnt.frame_count, 1) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_tensorfiy_python_scalars_1(self): + @torch.compile(backend="aot_eager") + def f(x): + y = x.sum() + return x + y.item() + + dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64] + for i, dtype in enumerate(dtypes): + x = torch.ones(3, 3, dtype=dtype) + self.assertEqual(f(x), x + x.sum().item()) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_tensorfiy_python_scalars_2(self): + @torch.compile(backend="aot_eager") + def f(x): + return x.item() * x.item() * torch.ones((), dtype=torch.float64) + + x = torch.tensor(1e20, dtype=torch.float32) + self.assertEqual( + f(x), x.item() * x.item() * torch.ones((), dtype=torch.float64) + ) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_tensorfiy_python_scalars_3(self): + @torch.compile(backend="aot_eager") + def f(x): + y = x.item() * 101 + return y * torch.tensor([1], dtype=torch.float32) + + finfo_float16 = torch.finfo(torch.float16) + x = torch.tensor([finfo_float16.max], dtype=torch.float16) + self.assertEqual(f(x), x.item() * 101 * torch.tensor([1], dtype=torch.float32)) + @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False) def test_unspec_float_input_f64(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index b14a6c41dbdc..fdb34ab0b68e 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -12,6 +12,9 @@ from torch._inductor.test_case import TestCase +_IS_WINDOWS = sys.platform == "win32" + + class TestUtils(TestCase): def test_nan(self): a = torch.Tensor([float("nan")]) @@ -243,6 +246,32 @@ def add(x, y): utils.reset_frame_count() torch._logging._internal.structured_logging_overhead.clear() + @dynamo_config.patch({"log_compilation_metrics": True}) + @inductor_config.patch({"force_disable_caches": True}) + def test_stack_trace(self): + self.warmup() + + compilation_events = [] + with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event: + self.run_forward_backward() + compilation_events = [arg[0][0] for arg in log_event.call_args_list] + stack_trace_list = [] + for e in compilation_events: + stack_trace_list.append(e.stack_trace) + + self.assertGreater(len(stack_trace_list), 0) + result = "\n".join( + item + for sublist in stack_trace_list + if sublist + for item in (sublist if isinstance(sublist, list) else [sublist]) + ) + self.assertIn( + "test_stack_trace", + result, + "Log file does not contain the expected string: 'test_stack_trace'", + ) + @dynamo_config.patch( { "log_compilation_metrics": True, @@ -283,6 +312,37 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): self.assertExpectedInline( pprint.pformat(utils.compilation_time_metrics), """\ +{'GraphLowering.codegen': [0.0, 0.0], + 'GraphLowering.compile_to_fn': [0.0, 0.0], + 'GraphLowering.compile_to_module': [0.0, 0.0], + 'GraphLowering.run': [0.0, 0.0], + 'OutputGraph.call_user_compiler': [0.0], + 'PyCodeCache.load_by_key_path': [0.0, 0.0], + 'PythonWrapperCodegen.generate': [0.0, 0.0], + 'Scheduler.__init__': [0.0, 0.0], + 'Scheduler.codegen': [0.0, 0.0], + 'Scheduler.fused_nodes': [0.0, 0.0], + '_compile.compile_inner': [0.0], + '_recursive_joint_graph_passes': [0.0], + '_recursive_post_grad_passes': [0.0, 0.0], + '_recursive_pre_grad_passes': [0.0], + 'additional_fake_tensor_prop': [0.0, 0.0], + 'aot_collect_metadata': [0.0], + 'aot_trace_joint_graph': [0.0], + 'backward._backward_impl': [0.0], + 'build_guards': [0.0], + 'bytecode_tracing': [0.0], + 'compile_attempt_0': [0.0], + 'compile_file': [0.0, 0.0], + 'compile_fx..bw_compiler': [0.0], + 'compile_fx..fw_compiler_base': [0.0], + 'compile_fx_inner': [0.0, 0.0], + 'create_aot_dispatcher_function': [0.0], + 'fx_codegen_and_compile': [0.0, 0.0], + 'gc': [0.0], + 'min_cut_rematerialization_partition': [0.0]}""" + if _IS_WINDOWS + else """\ {'GraphLowering.codegen': [0.0, 0.0], 'GraphLowering.compile_to_fn': [0.0, 0.0], 'GraphLowering.compile_to_module': [0.0, 0.0], @@ -321,6 +381,18 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): self.assertExpectedInline( pprint.pformat(time_spent), """\ +{'_recursive_joint_graph_passes': 0.0, + '_recursive_post_grad_passes': 0.0, + '_recursive_pre_grad_passes': 0.0, + 'backend_compile': 0.0, + 'code_gen': 0.0, + 'entire_backward_compile': 0.0, + 'entire_frame_compile': 0.0, + 'gc': 0.0, + 'inductor_compile': 0.0, + 'total_wall_time': 0.0}""" + if _IS_WINDOWS + else """\ {'_recursive_joint_graph_passes': 0.0, '_recursive_post_grad_passes': 0.0, '_recursive_pre_grad_passes': 0.0, @@ -350,6 +422,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): e.cuda_version = None e.triton_version = None e.python_version = None + e.stack_trace = None # First event is for the forward. Formatting makes reading diffs # much easier. @@ -433,6 +506,89 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, 'specialize_float': False, + 'stack_trace': None, + 'start_time': 0.0001, + 'start_time_us': 100, + 'structured_logging_overhead_s': 0.0, + 'structured_logging_overhead_us': 0, + 'tensorify_float_attempt': None, + 'tensorify_float_failure': None, + 'tensorify_float_success': None, + 'triton_compile_time_us': None, + 'triton_kernel_compile_times_us': None, + 'triton_version': None}""" + if _IS_WINDOWS + else """\ +{'accumulated_cache_size': 0, + 'aot_autograd_cumulative_compile_time_us': 0, + 'backend_compile_time_s': 0.0, + 'backward_cumulative_compile_time_us': None, + 'cache_size': 0, + 'co_filename': None, + 'co_firstlineno': None, + 'co_name': 'forward', + 'code_gen_time_s': 0.0, + 'compile_id': '1/0', + 'compile_time_autotune_time_us': None, + 'compliant_custom_ops': set(), + 'config_inline_inbuilt_nn_modules': False, + 'config_suppress_errors': False, + 'cuda_version': None, + 'cudagraph_skip_reason': None, + 'distributed_ephemeral_timeout_us': None, + 'duration_us': 0, + 'dynamo_compile_time_before_restart_us': 0, + 'dynamo_config': None, + 'dynamo_cumulative_compile_time_us': 0, + 'dynamo_time_before_restart_s': 0.0, + 'end_time_us': 100, + 'entire_frame_compile_time_s': 0.0, + 'fail_reason': None, + 'fail_type': None, + 'fail_user_frame_filename': None, + 'fail_user_frame_lineno': None, + 'frame_key': '1', + 'gc_time_us': 0, + 'graph_input_count': 1, + 'graph_node_count': 3, + 'graph_op_count': 1, + 'guard_count': 9, + 'has_guarded_code': True, + 'inductor_code_gen_cumulative_compile_time_us': 0, + 'inductor_compile_time_s': 0.0, + 'inductor_config': None, + 'inductor_cumulative_compile_time_us': 0, + 'inductor_fx_remote_cache_backend_type': None, + 'inductor_fx_remote_cache_hit_count': None, + 'inductor_fx_remote_cache_hit_keys': None, + 'inductor_fx_remote_cache_miss_count': None, + 'inductor_fx_remote_cache_miss_keys': None, + 'is_forward': True, + 'is_runtime': False, + 'joint_graph_pass_time_us': 0, + 'log_format_version': 3, + 'non_compliant_ops': set(), + 'num_graph_breaks': 0, + 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, + 'post_grad_pass_time_us': 0, + 'pre_grad_pass_time_us': 0, + 'python_version': None, + 'recompile_reason': None, + 'recompile_user_contexts': None, + 'remote_cache_time_saved_s': None, + 'remote_cache_version': None, + 'remote_fx_graph_cache_get_time_ms': None, + 'remote_fx_graph_cache_get_time_us': None, + 'remote_fx_graph_cache_put_time_ms': None, + 'remote_fx_graph_cache_put_time_us': None, + 'restart_reasons': set(), + 'runtime_cudagraphify_time_us': None, + 'runtime_triton_autotune_time_us': None, + 'shape_env_guard_count': 0, + 'specialize_float': False, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -525,6 +681,89 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': None, 'specialize_float': None, + 'stack_trace': None, + 'start_time': 0.0001, + 'start_time_us': 100, + 'structured_logging_overhead_s': 0.0, + 'structured_logging_overhead_us': 0, + 'tensorify_float_attempt': None, + 'tensorify_float_failure': None, + 'tensorify_float_success': None, + 'triton_compile_time_us': None, + 'triton_kernel_compile_times_us': None, + 'triton_version': None}""" + if _IS_WINDOWS + else """\ +{'accumulated_cache_size': None, + 'aot_autograd_cumulative_compile_time_us': None, + 'backend_compile_time_s': None, + 'backward_cumulative_compile_time_us': 0, + 'cache_size': None, + 'co_filename': None, + 'co_firstlineno': None, + 'co_name': None, + 'code_gen_time_s': 0.0, + 'compile_id': '1/0', + 'compile_time_autotune_time_us': None, + 'compliant_custom_ops': None, + 'config_inline_inbuilt_nn_modules': False, + 'config_suppress_errors': False, + 'cuda_version': None, + 'cudagraph_skip_reason': None, + 'distributed_ephemeral_timeout_us': None, + 'duration_us': 0, + 'dynamo_compile_time_before_restart_us': None, + 'dynamo_config': None, + 'dynamo_cumulative_compile_time_us': None, + 'dynamo_time_before_restart_s': None, + 'end_time_us': 100, + 'entire_frame_compile_time_s': None, + 'fail_reason': None, + 'fail_type': None, + 'fail_user_frame_filename': None, + 'fail_user_frame_lineno': None, + 'frame_key': None, + 'gc_time_us': None, + 'graph_input_count': None, + 'graph_node_count': None, + 'graph_op_count': None, + 'guard_count': None, + 'has_guarded_code': None, + 'inductor_code_gen_cumulative_compile_time_us': 0, + 'inductor_compile_time_s': 0.0, + 'inductor_config': None, + 'inductor_cumulative_compile_time_us': 0, + 'inductor_fx_remote_cache_backend_type': None, + 'inductor_fx_remote_cache_hit_count': None, + 'inductor_fx_remote_cache_hit_keys': None, + 'inductor_fx_remote_cache_miss_count': None, + 'inductor_fx_remote_cache_miss_keys': None, + 'is_forward': False, + 'is_runtime': False, + 'joint_graph_pass_time_us': None, + 'log_format_version': 3, + 'non_compliant_ops': None, + 'num_graph_breaks': 0, + 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, + 'post_grad_pass_time_us': 0, + 'pre_grad_pass_time_us': None, + 'python_version': None, + 'recompile_reason': None, + 'recompile_user_contexts': None, + 'remote_cache_time_saved_s': None, + 'remote_cache_version': None, + 'remote_fx_graph_cache_get_time_ms': None, + 'remote_fx_graph_cache_get_time_us': None, + 'remote_fx_graph_cache_put_time_ms': None, + 'remote_fx_graph_cache_put_time_us': None, + 'restart_reasons': None, + 'runtime_cudagraphify_time_us': None, + 'runtime_triton_autotune_time_us': None, + 'shape_env_guard_count': None, + 'specialize_float': None, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_filter b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_filter deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_permutations b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_permutations deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_product b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_product deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_zip b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_zip deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_ziplongest b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_ziplongest deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_cycle b/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_cycle deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_permutations b/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_permutations deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_product b/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_product deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_accumulate b/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_accumulate deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_cycle b/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_cycle deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_permutations b/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_permutations deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_product b/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_product deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestVariousIteratorArgs.test_accumulate b/test/dynamo_expected_failures/CPython313-test_itertools-TestVariousIteratorArgs.test_accumulate deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 5c5795f45ce2..c650b102bf1a 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,6 +75,7 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out +aten::_cudnn_attention_backward aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index d5da9d8eb9cf..8dbe257ec3ae 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -52,7 +52,7 @@ torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None torch.fx.node.Node.append(self, x: 'Node') -> None -torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str] +torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str] torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.prepend(self, x: 'Node') -> None torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >, propagate_meta: bool = False) -> List[Node] diff --git a/test/export/test_export.py b/test/export/test_export.py index d5d0bcc7a9f6..1c997b8e86be 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -59,6 +59,7 @@ OutputSpec, TensorArgument, ) +from torch.export.passes import move_to_device_pass from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck @@ -85,7 +86,7 @@ ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.torchbind_impls import load_torchbind_test_lib -from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import ( LeafSpec, @@ -6348,7 +6349,9 @@ def forward(self, kjt) -> torch.Tensor: efoo = torch.export.export( foo, inputs, - dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]}, + dynamic_shapes={ + "kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}, None, None] + }, ) self.assertEqual( [out.shape for out in efoo.module()(*inputs)], @@ -8314,6 +8317,29 @@ def forward(self, b_a_buffer, x): torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4))) ) + def test_ccode_python_mod(self): + import sympy + + from torch.utils._sympy.functions import PythonMod + + class Foo(torch.nn.Module): + def forward(self, xs): + u0, u1 = xs.tolist() + torch._check_is_size(u1) + return u0, u1 + + ep = export(Foo(), (torch.tensor([2, 3]),), strict=False) + u0_node, u1_node = list(ep.graph.nodes)[-1].args[0] + u0 = u0_node.meta["val"] + u1 = u1_node.meta["val"] + self.assertExpectedInline( + sympy.ccode(PythonMod(u0, 3)), """(u0 % 3) < 0 ? u0 % 3 + 3 : u0 % 3""" + ) + self.assertExpectedInline( + sympy.ccode(PythonMod(u0, u1)), + """(u0 % u1) < 0 ? u0 % u1 + abs(u1) : u0 % u1""", + ) + def test_aten_lift_fresh_copy(self): class M(torch.nn.Module): def forward(self, x): @@ -8356,7 +8382,7 @@ def forward(self, x): len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 ) - @requires_cuda + @requires_cuda_and_triton @testing.expectedFailureCppRuntime def test_export_associative_scan_symbol_dim(self): device = torch.device("cuda") @@ -8381,7 +8407,7 @@ def forward(self, x): module_out = Foo()(xs) self.assertTrue(torch.allclose(ep.module()(xs), module_out)) - @requires_cuda + @requires_cuda_and_triton @testing.expectedFailureCppRuntime def test_export_associative_scan_symbol_scandim(self): device = torch.device("cuda") @@ -8406,7 +8432,7 @@ def forward(self, x): module_out = Foo()(xs) self.assertTrue(torch.allclose(ep.module()(xs), module_out)) - @requires_cuda + @requires_cuda_and_triton def test_export_associative_scan_lifted_buffers(self): if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") @@ -8889,7 +8915,7 @@ def _decompose_linear_custom(x, weight, bias): self.assertExpectedInline( str(ep_decompose_linear.graph_module.code).strip(), """\ -def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_bias, c_linear_weight, x, y): +def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None @@ -13902,21 +13928,6 @@ def forward(self, x): ep = export(m, args) self.assertEqual(ep.module()(*args), m(*args)) - def test_deepcopy(self): - class Model(torch.nn.Module): - def forward(self, input): - return input + input - - model = Model().eval() - inputs = (torch.ones(2, 2),) - - program = export(model, inputs) - copied_program = copy.deepcopy(program) - self.assertEqual(str(program.graph), str(copied_program.graph)) - self.assertEqual( - str(program.graph_module.code), str(copied_program.graph_module.code) - ) - def test_cse_for_symint(self): class Foo(torch.nn.Module): # check sym ops only get computed once @@ -14938,6 +14949,51 @@ def fn(x): self.assertEqual(x.sin(), ep.module()(x)) pytree._deregister_pytree_node(torch.FunctionSchema) + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + def test_exception(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) + self.register_buffer("buffer", torch.ones(4, 4)) + self.register_buffer("param", torch.ones(4, 4)) + + def forward(self, x): + token_ids = torch.randint(0, 10, (4,), device=x.device) + embedded = self.embedding(token_ids).sum() + return self.buffer.sum() + self.param.sum() + x.sum() + embedded + + class BarModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = Model() + + def forward(self, x): + if "cuda" in str(x.device): + mod = self.mod.to(x.device) + return mod(x) + else: + return x.sum() + + class BarBar(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = BarModel() + + def forward(self, x): + with torch.amp.autocast(device_type="cuda"): + y = self.mod(x) + return y + + with torch.no_grad(): + with self.assertRaisesRegex(RuntimeError, "Couldn't swap Embedding.weight"): + _ = torch.export.export( + BarBar(), + (), + {"x": torch.randn(4, 4, 4, device="cuda")}, + strict=False, + ).module() + def test_export_for_training_with_state_dict_hooks(self): def _state_dict_pre_hook(mod, prefix, keep_vars): mod._buffers["test"] = torch.Tensor([1]) @@ -15861,6 +15917,22 @@ def forward(self, x): len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs) ) + @requires_cuda_and_triton + def test_assert_tensor_metadata_device_index(self): + class N(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + x = x.float() + y = y.float() + return x + y + + inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda")) + ep = export(N(), inp) + ep = move_to_device_pass(ep, {"cuda:0": "cuda"}) + ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0")) + def test_input_output_no_stacktrace(self): class M(torch.nn.Module): def forward(self, x): diff --git a/test/export/test_passes.py b/test/export/test_passes.py index d3194ea352c3..d083b5a7cc6d 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -1302,6 +1302,28 @@ def forward(self, x): return (b_state, getitem_3, getitem_4)""", ) + @unittest.skipIf(not TEST_CUDA, "requires cuda") + def test_move_device_submod(self): + class M(torch.nn.Module): + def forward(self, x): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + x = x.to(device="cuda:0") + return x + x + + ep = torch.export.export(M(), (torch.ones(3),)) + ep = move_to_device_pass(ep, "cuda") + ep.graph_module.submod_1.recompile() + self.assertExpectedInline( + ep.graph_module.submod_1.code.strip("\n"), + """\ +def forward(self, arg0_1): + _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(arg0_1, dtype = torch.float32, device = 'cuda', layout = torch.strided); _assert_tensor_metadata_default = None + to = torch.ops.aten.to.dtype_layout(arg0_1, dtype = torch.float32, layout = torch.strided, device = 'cuda'); arg0_1 = None + add = torch.ops.aten.add.Tensor(to, to); to = None + return (add,) + """, # noqa: B950 + ) + @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_move_to_device_pass(self): class Model(torch.nn.Module): diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 7b86048d21be..f4f7b68a494a 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -152,25 +152,6 @@ def forward(self, x): self.assertEqual(exp_out, actual_out) self.assertEqual(exp_out.requires_grad, actual_out.requires_grad) - def test_conflicting_name(self): - class Model(torch.nn.Module): - def forward(self, input): - return input + input - - model = Model().eval() - inputs = (torch.ones(2, 2),) - - program = torch.export.export(model, inputs) - - buffer = io.BytesIO() - torch.export.save(program, buffer) - buffer.seek(0) - loaded_program = torch.export.load(buffer) - self.assertEqual(str(program.graph), str(loaded_program.graph)) - self.assertEqual( - str(program.graph_module.code), str(loaded_program.graph_module.code) - ) - def test_export_example_inputs_preserved(self): class MyModule(torch.nn.Module): """A test module with that has multiple args and uses kwargs""" diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 90dafb6507f4..d24262dab2b1 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -24,7 +24,7 @@ _empty_tensor_queue, init_torchbind_implementations, ) -from torch.testing._internal.triton_utils import requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton def _assertEqualSkipScriptObject(test_case, exp, actual): @@ -1330,7 +1330,7 @@ def setattr_f(tq): return tq with self.assertRaisesRegex( - RuntimeError, "call method __setattr__ on script object is not safe" + RuntimeError, "Weird method call on TorchScript object" ): torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) @@ -1343,7 +1343,7 @@ def setattr_f(tq): return tq._not_defined_attr with self.assertRaisesRegex( - RuntimeError, "doesn't define method _not_defined_attr" + RuntimeError, "FakeScriptObject missing method implementation" ): torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) @@ -1434,7 +1434,7 @@ def f(tq, x): x = torch.randn(2, 3) with self.assertRaisesRegex( - RuntimeError, "FakeScriptObject doesn't define method" + RuntimeError, "FakeScriptObject missing method implementation" ): torch.compile(f, backend=backend)(_empty_tensor_queue(), x) @@ -1552,7 +1552,7 @@ def f(tq, x): self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) ) - @requires_gpu + @requires_cuda_and_triton @parametrize("device", ["cpu", "cuda"]) @parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_compile_obj_torchbind_op_with_autocast(self, backend, device): @@ -1570,7 +1570,7 @@ def f(tq, x): self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) ) - @requires_gpu + @requires_cuda_and_triton @parametrize("device", ["cpu", "cuda"]) def test_export_obj_torchbind_op_with_autocast(self, device): class Mod(torch.nn.Module): diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index ab635d7bcd4b..3510403cc164 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -178,6 +178,39 @@ def forward(self, x): id(getattr(unflattened_module.sub_net, "2")), ) + def test_assert_tensor_metadata_stack(self): + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(3) + + def forward(self, x, y): + x = x.to(dtype=torch.int32) + y = y.to(dtype=torch.int32) + x = x + self.a + return x + y + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + def forward(self, x, y): + x = x * x + y = y * y + return self.n(x, y) + + m = M() + ep = torch.export.export(m, (torch.randn(3), torch.randn(3))) + for node in ep.graph.nodes: + if node.target == torch.ops.aten._assert_tensor_metadata.default: + self.assertEqual(len(node.meta.get("nn_module_stack")), 2) + + uep = torch.export.unflatten(ep) + + inp = (torch.randn(3), torch.randn(3)) + self.assertTrue(torch.allclose(uep(*inp), m(*inp))) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo") def test_unflatten_preserve_signature(self): diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py index 430d4a3d56dd..fde84b6683ed 100644 --- a/test/functorch/test_ac.py +++ b/test/functorch/test_ac.py @@ -6,7 +6,7 @@ import torch import torch._functorch.config as config from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint from torch.utils.flop_counter import FlopCounterMode, register_flop_formula @@ -405,5 +405,5 @@ def call(): if __name__ == "__main__": # I'm using the cuda memory allocator to verify memory allocations - if HAS_CUDA and not TEST_WITH_ROCM: + if HAS_CUDA_AND_TRITON and not TEST_WITH_ROCM: run_tests() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 54ccd0f7fef2..f6901be327d9 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -7449,6 +7449,7 @@ def forward(self, l_inp_, l_tmp_): ) self.assertEqual(out, f(inp, tmp)) + @skipIfCrossRef # Args get renamed to r in crossref mode @parametrize("requires_grad", [True, False]) def test_cond_symint_operands(self, requires_grad): backend = EagerAndRecordGraphs() diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index ab50b59fb96b..f4c3ef072f9a 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -24,6 +24,7 @@ def __init__(self, graph_module: torch.fx.GraphModule): ) +# original graph node order is: ['x', 'add', 'add_1', 'output'] class AddModule(torch.nn.Module): def forward(self, x): y = torch.add(x, x) @@ -32,8 +33,18 @@ def forward(self, x): class TestPartitionerOrder(TestCase): - # partitoner test to check graph node order - def test_partitioner_order(self): + # partitoner test to check graph node order remains the same with the original graph after partitioning + def test_partitioner_graph_node_order(self): + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + origin_node_order = [n.name for n in traced_m.graph.nodes] + partions = DummyPartitioner(traced_m).propose_partitions() + partion_nodes = [list(partition.nodes) for partition in partions] + partition_node_order = [n.name for n in partion_nodes[0]] + self.assertTrue(partition_node_order == origin_node_order) + + # partitoner test to check graph node order remains the same during multiple runs + def test_partitioner_multiple_runs_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) partitions = DummyPartitioner(traced_m).propose_partitions() diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index c800eb78f905..df1bd941d885 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -34,7 +34,7 @@ TestCase, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU -from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu nested_compile_region = torch.compiler.nested_compile_region @@ -556,7 +556,7 @@ def fn(x): self.assertEqual(ref, res) self.assertEqual(x.grad, x_clone.grad) - @requires_cuda + @requires_cuda_and_triton def test_sdpa(self): @nested_compile_region def gn(q, k, v): @@ -1195,17 +1195,11 @@ def fn(x, y): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", - ) as cm: + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Encountered aliasing during higher order op tracing", + ): opt_fn(x, y) - cause = cm.exception.__cause__ - self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) - self.assertTrue( - "Encountered aliasing during higher order op tracing" in str(cause) - ) - def test_input_input_aliasing(self): @nested_compile_region def gn(x, y): @@ -1219,17 +1213,11 @@ def fn(x): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", - ) as cm: + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Encountered aliasing during higher order op tracing", + ): opt_fn(x) - cause = cm.exception.__cause__ - self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) - self.assertTrue( - "Encountered aliasing during higher order op tracing" in str(cause) - ) - def test_output_output_aliasing(self): @nested_compile_region def gn(x): @@ -1244,17 +1232,11 @@ def fn(x): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", - ) as cm: + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Encountered aliasing during higher order op tracing", + ): opt_fn(x) - cause = cm.exception.__cause__ - self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) - self.assertTrue( - "Encountered aliasing during higher order op tracing" in str(cause) - ) - def test_mod_attr_aliasing(self): class MutateParam(torch.nn.Module): def __init__(self): @@ -1447,7 +1429,7 @@ def forward(self, l_x_: "f32[8, 8]"): """, ) - @requires_cuda + @requires_cuda_and_triton def test_return_none(self): from torch.nn import functional as F diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index ae1d00c5b634..ade7695a10d0 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -1,7 +1,7 @@ #include // @manual=fbcode//caffe2:libtorch -#include -#include +#include // @manual +#include // @manual #include #include diff --git a/test/inductor/test_analysis.py b/test/inductor/test_analysis.py index 9a60afba3224..ac0467a2d1b8 100644 --- a/test/inductor/test_analysis.py +++ b/test/inductor/test_analysis.py @@ -337,6 +337,7 @@ def test_augment_trace_helper_unit(self): ], ) @skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune") + @torch._inductor.config.patch(force_disable_caches=True) def test_triton_has_metadata(self, device, dtype, maxat): """ make sure that the chrome trace of triton kernels contains certain values @@ -359,7 +360,6 @@ def om(i, w): options={ "benchmark_kernel": True, "max_autotune_gemm_backends": backends, - "force_disable_caches": True, "max_autotune": max_autotune, }, ) @@ -396,6 +396,7 @@ def verify_triton(comp): @unittest.skipIf( not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune" ) + @torch._inductor.config.patch(force_disable_caches=True) def test_augment_trace_against_flop_counter(self, device, dtype, maxat): # this tests to see if we can only use a Triton backend for max autotune max_autotune, backends = maxat @@ -408,7 +409,6 @@ def test_augment_trace_against_flop_counter(self, device, dtype, maxat): options={ "benchmark_kernel": True, "max_autotune_gemm_backends": backends, - "force_disable_caches": True, "max_autotune": max_autotune, }, ) @@ -507,6 +507,7 @@ def test_augment_trace_against_flop_counter(self, device, dtype, maxat): @unittest.skipIf( not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune" ) + @torch._inductor.config.patch(force_disable_caches=True) def test_pointwise_bandwidth(self, device, dtype, maxat): # this tests to see if we can only use a Triton backend for max autotune max_autotune, backends = maxat @@ -518,7 +519,6 @@ def test_pointwise_bandwidth(self, device, dtype, maxat): options={ "benchmark_kernel": True, "max_autotune_gemm_backends": backends, - "force_disable_caches": True, "max_autotune": max_autotune, }, ) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index b7157909dbdd..9fa13dc180f9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -20,6 +20,7 @@ from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters from torch._inductor import config +from torch._inductor.codecache import WritableTempFile from torch._inductor.package import package_aoti from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.test_case import TestCase @@ -551,7 +552,7 @@ def forward(self, a, b): triton.set_allocator( lambda size, align, stream: torch.empty( - size, dtype=torch.int8, device="cuda" + size, dtype=torch.int8, device=GPU_TYPE ) ) @@ -5234,9 +5235,9 @@ def forward(self, a, b, c): return z example_inputs = ( - torch.randn(10, 20, device="cuda"), - torch.randn(20, 30, device="cuda"), - torch.randn(10, 30, device="cuda"), + torch.randn(10, 20, device=GPU_TYPE), + torch.randn(20, 30, device=GPU_TYPE), + torch.randn(10, 30, device=GPU_TYPE), ) model = Model() kernel_calls = [ @@ -5495,6 +5496,43 @@ def forward_block(self, x): example_inputs = (torch.randn(2, 128, 4096, device=self.device),) self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}}) + @requires_gpu + def test_d2h_copy(self): + # device to copy host should always have the same stride + if "cuda" not in self.device: + raise unittest.SkipTest("This test is only for CUDA") + + class ToCpuModel(nn.Module): + def forward(self, x): + predictions = x.permute(1, 0) + predictions = torch.nan_to_num( + predictions, nan=0.0, posinf=0.0, neginf=0.0 + ) + predictions = predictions.to("cpu", non_blocking=True) + p = predictions[0] + ones = p.new_ones(1) + return p, ones + + model = ToCpuModel().to(GPU_TYPE) + input_tensor = torch.randn(5, 10, device=GPU_TYPE).to(dtype=torch.float16) + ep = torch.export.export(model, (input_tensor,)) + package_path = torch._inductor.aoti_compile_and_package(ep) + + aoti_model = torch._inductor.aoti_load_package(package_path) + + expect_res = model(input_tensor) + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) as prof: + true_res = aoti_model(input_tensor) + + self.assertEqual(expect_res, true_res) + all_ops = [event.key for event in prof.key_averages()] + self.assertTrue(not any("aten::contiguous" in op for op in all_ops)) + def test_so_without_weight(self): class Model(torch.nn.Module): def __init__(self, n, k, device): @@ -5602,7 +5640,7 @@ def forward(self, a): example_inputs=example_inputs, ) - with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + with WritableTempFile(suffix=".pt2") as f: package_path = package_aoti( f.name, {"model": aoti_files}, @@ -6229,9 +6267,6 @@ def forward( dynamic_shapes=dynamic_shapes, ) - @skipIfXpu( - msg="The operator 'aten::_int_mm' is not currently implemented for the XPU device" - ) def test__int_mm(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -6689,6 +6724,34 @@ def forward(self, x): # the output should have int type self.check_model(Model2(), (x,)) + def test_upper_bound_i64(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return x + y + + inp = ( + torch.randint(0, 100, (2**18,), device=self.device, dtype=torch.int8), + torch.tensor([4], device=self.device, dtype=torch.int8), + ) + ep = torch.export.export( + Model(), + inp, + dynamic_shapes=({0: Dim("d", min=0, max=2**33)}, {0: Dim.STATIC}), + ) + so_path = torch._inductor.aot_compile(ep.module(), inp) + m = torch._export.aot_load(so_path, self.device) + + self.assertEqual(Model()(*inp), m(*inp)) + del inp + + inp = ( + torch.randint(0, 100, (3 * 2**30,), device=self.device, dtype=torch.int8), + torch.tensor([4], device=self.device, dtype=torch.int8), + ) + # don't check the accuracy of the result to reduce memory usage + # this test is mostly checking to ensure there's no IMA. + m(*inp) + def test_using_model_name_for_files(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -6722,6 +6785,36 @@ def forward(self, x, y): aot_inductor_module = torch._inductor.aoti_load_package(package_path) self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs)) + def test_copy_non_blocking_is_pinned(self): + if self.device == "cpu" or self.device == "mps": + raise unittest.SkipTest("only matters for device-to-cpu copy") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + a_cpu = a.to(device="cpu", non_blocking=True) + b_cpu = b.to(device="cpu", non_blocking=True) + a_to_cpu_event = torch.Event() + a_to_cpu_event.record() + a_to_cpu_event.synchronize() + return torch.cat([a_cpu, b_cpu]) + + model = Model() + a = torch.randn(2, 2, device=self.device) + b = torch.randn(2, 2, device=self.device) + example_inputs = (a, b) + outputs = model(*example_inputs) + package_path, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + FileCheck().check("pinned").run(code) + model_aoti = torch._inductor.aoti_load_package(package_path) + outputs_aoti = model_aoti(*example_inputs) + + self.assertEqual(outputs, outputs_aoti) + class AOTInductorLoggingTest(LoggingTestCase): @make_logging_test(dynamic=logging.DEBUG) @@ -6740,6 +6833,25 @@ def forward(self, x): torch._inductor.aot_compile(ep.module(), inputs) self.assertEqual([r.msg == "create_env" for r in records].count(True), 1) + @make_logging_test(dynamic=logging.DEBUG) + def test_shape_env_reuse_zero_consts_use_consts_asm_false(self, records): + # make sure ShapeEnv is only created once and reused afterwards + class Foo(torch.nn.Module): + def forward(self, x): + return x + 2 + + inputs = (torch.randn(4, 4),) + dynamic_shapes = { + "x": {0: Dim.AUTO, 1: Dim.AUTO}, + } + ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False) + with ( + torch.no_grad(), + config.patch({"aot_inductor.use_consts_asm_build": False}), + ): + torch._inductor.aot_compile(ep.module(), inputs) + self.assertEqual([r.msg == "create_env" for r in records].count(True), 1) + class TestAOTInductorConfig(TestCase): def test_no_compile_standalone(self): @@ -6751,11 +6863,21 @@ def test_compile_standalone_sets_package_cpp(self): result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True}) self.assertEqual(result["aot_inductor.package_cpp_only"], True) self.assertEqual(result["aot_inductor.compile_standalone"], True) + self.assertEqual(result["aot_inductor.embed_kernel_binary"], True) + self.assertEqual( + result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip + ) + self.assertEqual( + result["aot_inductor.model_name_for_generated_files"], "aoti_model" + ) - def test_compile_standalone_package_cpp_already_true(self): + def test_compile_standalone_explicit_set(self): patches = { "aot_inductor.compile_standalone": True, "aot_inductor.package_cpp_only": True, + "aot_inductor.embed_kernel_binary": True, + "aot_inductor.emit_multi_arch_kernel": not torch.version.hip, + "aot_inductor.model_name_for_generated_files": "aoti_model", } result = maybe_aoti_standalone_config(patches) self.assertEqual(result, patches) @@ -6834,31 +6956,16 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): # MPS doesn't support float8 "test_fp8": fail_mps(), "test_fp8_view_of_param": fail_mps(), - # unsupported operator: aten._scaled_dot_product_attention_math_for_mps.default - "test_issue_140766": fail_mps(), # cannot initialize a parameter of type 'double' with an rvalue of type 'std::nullptr_t' "test_fallback_kernel_with_symexpr_output": fail_mps(), - # while-loop subgraph calls same kernel as outside. need to figure out how to - # either (1) tell outside to initialize a new kernel or (2) generate - # subgraph as a separate function, which would(?) cause (1) to happen automatically. - "test_while_loop_nested": fail_mps(), # correctness issue "test_index_put_with_none_index": fail_mps(), - # Dynamism - "test_shifted_constraint_ranges": fail_mps(), - "test_while_loop_with_sym_expr_cond_dynamic_True": fail_mps(), - "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_mps(), - "test_cond_mismatched_branch_output_dynamic_True": fail_mps(), - "test_cond_unbacked_symint_closure_dynamic_True": fail_mps(), - "test_cond_non_tensor_predicates_dynamic_True": fail_mps(), - "test_zero_grid_with_unbacked_symbols": fail_mps(), - "test_reuse_kernel_dynamic": fail_mps(is_skip=True), - "test_cond_with_parameters": fail_mps(is_skip=True), - "test_cond_share_predicte": fail_mps(is_skip=True), # Error device may not be nil "test_zero_size_weight": fail_mps(is_skip=True), # RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0 "test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True), + # MPSGraph does not support tensor dims > INT_MAX + "test_upper_bound_i64": fail_mps(is_skip=True), # MPS doesn't support triton "test_autotuning_args_reuse": fail_mps(), "test_triton_autotuning": fail_mps(), diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index aa3c589b4546..0b4f508477ac 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -24,7 +24,7 @@ skipIfXpu, ) from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test -from torch.testing._internal.triton_utils import HAS_CUDA +from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON from torch.utils._python_dispatch import TorchDispatchMode @@ -556,5 +556,5 @@ class AOTInductorTestABICompatibleCuda(AOTICustomOpTestCase): from torch._inductor.test_case import run_tests # cpp_extension N/A in fbcode - if HAS_CUDA or sys.platform == "darwin": + if HAS_CUDA_AND_TRITON or sys.platform == "darwin": run_tests(needs="filelock") diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 51343b6b1883..46152103836a 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -15,6 +15,7 @@ from parameterized import parameterized_class import torch +import torch._inductor.config from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.package import load_package, package_aoti from torch._inductor.test_case import TestCase @@ -156,6 +157,7 @@ def cmake_compile_and_run(self, base_dir): check=True, ) subprocess.run(["make"], cwd=build_path, check=True) + result = subprocess.run( ["./build/main"], cwd=base_dir, @@ -363,6 +365,7 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # build system may be different + @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_after_package_static(self): # compile_standalone will set package_cpp_only=True self.check_package_cpp_only() @@ -419,12 +422,46 @@ def forward(self, x, y): with self.assertRaisesRegex(Exception, "Invalid AOTI model name"): self.cmake_compile(model, example_inputs, options, "") + @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @skipIfXpu # build system may be different + @torch._inductor.config.patch("test_configs.use_libtorch", True) + def test_compile_standalone_cos(self): + # compile_standalone will set package_cpp_only=True + self.check_package_cpp_only() + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return torch.cos(x) + + with torch.no_grad(): + example_inputs = (torch.randn(8, 32, device=self.device),) + model = Model().to(device=self.device) + + # Test compilation when model name is passed in + options = { + "aot_inductor.compile_standalone": True, + "aot_inductor.model_name_for_generated_files": "cos", + } + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + build_path, _ = self.cmake_compile( + model, example_inputs, options, tmp_dir + ) + # Check if the .a file was build successfully + a_path = build_path / "libcos.a" + self.assertTrue(a_path.exists()) + @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary + @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_with_exporter(self): self.check_package_cpp_only() @@ -466,16 +503,62 @@ def default(*args, **kwargs): if self.device == GPU_TYPE: self.assertEqual( result.stdout, - "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + "output_tensor1\n 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2\n 0 0 0\n" " 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n", ) else: self.assertEqual( result.stdout, - "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + "output_tensor1\n 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2\n 0 0 0\n" " 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n", ) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) + @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @skipIfRocm # doesn't support multi-arch binary + @skipIfXpu # doesn't support multi-arch binary + @torch._inductor.config.patch("test_configs.use_libtorch", True) + def test_compile_with_exporter_weights(self): + self.check_package_cpp_only() + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.fc1(x) + return x + + def default(*args, **kwargs): + return None + + example_inputs = (torch.ones(3, 3).to(self.device),) + + package = _ExportPackage() + m1 = Model().to(self.device) + exporter1 = package._exporter("Model", m1)._define_overload("default", default) + exporter1(*example_inputs) + expected_res = m1(*example_inputs) + + package_example_inputs = True + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + package._compiled_and_package( + tmp_dir + "/package.pt2", True, package_example_inputs + ) + + # Test compiling generated files + self.cmake_compile_and_run(tmp_dir) + tensor_model = torch.load( + tmp_dir + "/output_tensor1.pt", weights_only=False + ) + true_res = next(iter(tensor_model.parameters())) + self.assertEqual(expected_res, true_res) + def test_metadata(self): class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index a2706933d615..a86690270461 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -148,7 +148,7 @@ def legacy_run( @staticmethod def compile( model: Union[torch.nn.Module, types.FunctionType], - example_inputs: list[torch.Tensor], + example_inputs: tuple[torch.Tensor, ...], inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -169,7 +169,7 @@ def compile( @staticmethod def run( model: Union[torch.nn.Module, types.FunctionType], - example_inputs: list[torch.Tensor], + example_inputs: tuple[torch.Tensor, ...], inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -185,7 +185,7 @@ def run( @staticmethod def run_multiple( model: Union[torch.nn.Module, types.FunctionType], - list_example_inputs: list[list[torch.Tensor]], + list_example_inputs: list[tuple[torch.Tensor, ...]], inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index e05bbedbb95a..6025c90cdb4a 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -185,8 +185,7 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # Custom comment for test - foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None + foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = foo_default = None return ()""", # noqa: B950 ignore_comments=True, ) @@ -247,7 +246,7 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None + foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None return (getitem_4, getitem_5)""", # noqa: B950 @@ -334,9 +333,8 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): - # Custom comment for test - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ -arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = \ +arg3_1 = arg0_1 = arg1_1 = foo_default = None return ()""", ignore_comments=True, ) @@ -416,10 +414,10 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None - copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None - copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None +def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu", arg3_1: "f32[s77][1]cpu", arg4_1: "f32[s77][1]cpu", arg5_1: "f32[s77][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg1_1, [arg4_1, arg5_1], arg2_1, 2, arg3_1); arg4_1 = arg5_1 = arg3_1 = foo_default = None + copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -429,9 +427,9 @@ def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72 post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None - copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None + foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -523,11 +521,11 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None + foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None - copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None return (getitem_4, getitem_5)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -581,13 +579,13 @@ def f(x, y): self.assertExpectedInline( graph_aot, """\ -def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"): - auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) - getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1] - getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None - add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) - copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None - copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None +def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg2_1]) + getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[s77][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None + copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_2); arg2_1 = getitem_2 = copy__1 = None return (add,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -597,12 +595,12 @@ def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17 graph_aot, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): - auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) - copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None - copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None return (add,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -613,11 +611,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): self.assertExpectedInline( graph_inductor, """\ -def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None - add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) - copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None +def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg1_1, arg2_1); foo_default = None + add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg2_1) + copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None return (add,)""", ignore_comments=True, ignore_empty_lines=True, @@ -627,8 +625,8 @@ def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17 graph_inductor, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None - add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1) + foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None return (add,)""", @@ -843,11 +841,11 @@ def f(x, y): graph_aot, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): - auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None - copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None - copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -859,7 +857,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): graph_inductor, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None + foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None return ()""", # noqa: B950 @@ -979,8 +977,8 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None - copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg1_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index b3afba7d6843..8a61cc051c20 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -13,7 +13,7 @@ from torch.testing._internal.inductor_utils import ( get_func_call, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, IS_BIG_GPU, ) @@ -197,7 +197,7 @@ def f(x): self.common(f, (x,)) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: class BenchmarkFusionCudaTest(TestCase): common = check_model_cuda @@ -347,5 +347,5 @@ class BenchmarkFusionCpuTest(TestCase): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 7c50ee1dbd1f..f73a47e45a57 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -22,11 +22,11 @@ _quantize_rowwise, _quantize_tensorwise, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, ) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") log = logging.getLogger(__name__) @@ -464,5 +464,5 @@ def compiled_bmm(x, w): from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. - if HAS_CUDA and HAS_CPU and is_big_gpu(): + if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 996e81032a05..757ea061c26f 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -29,6 +29,7 @@ TensorMetadata, TensorMetadataAndValues, ) +from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.custom_graph_pass import ( CustomGraphModulePass, CustomGraphPass, @@ -59,7 +60,6 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA, HAS_GPU, HAS_MULTIGPU, HAS_TRITON, @@ -67,7 +67,7 @@ requires_gpu, requires_triton, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -872,7 +872,7 @@ def fn(x): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton def test_no_arguments_tensor_device_guards(self): """ Usually, when there are example inputs, the device index of the inputs @@ -902,7 +902,7 @@ def f(): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton def test_tensor_device_guards_cpu_tensor(self): """ CPU tensor arguments should still cache hit @@ -1006,7 +1006,7 @@ def fn(x, op): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) - @requires_cuda + @requires_cuda_and_triton @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @with_tf32_off @@ -1464,7 +1464,7 @@ def f(x, val): self.assertNotEqual(a, b) @config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False}) - @requires_cuda + @requires_cuda_and_triton @unittest.expectedFailure # TODO: pass in optimize_mem at runtime def test_async_compile_cache(self): class SimpleFunction(torch.autograd.Function): @@ -1807,7 +1807,9 @@ def f(x): assert not kwargs with tempfile.TemporaryDirectory() as temp_dir: - path = os.path.join(temp_dir, "compiled_artifact.bin") + path = normalize_path_separator( + os.path.join(temp_dir, "compiled_artifact.bin") + ) with fresh_cache(): compiled_artifact = torch._inductor.standalone_compile(gm, args) @@ -2574,7 +2576,7 @@ def test_get_hash_for_files(self): class TestCudaCompileCommand(TestCase): - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton def test_cuda_compile_command(self): cmd_no_extra_args: str = cuda_compile_command( ["abc.cu", "def.cu"], "output", "so" @@ -2619,7 +2621,7 @@ def reset(self): torch._dynamo.reset() clear_caches() - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @unittest.skipIf( TEST_WITH_ROCM, "Requires static cuda launcher, which does not support ROCM" @@ -2670,7 +2672,7 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2711,7 +2713,7 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2772,7 +2774,7 @@ def f(a, b, c, d, e, f): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_triton() - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2801,8 +2803,8 @@ def get_autotune_stats(): def fn(x, y): return (x + y).relu() - x = torch.randn(100, 100).cuda() - y = torch.randn(100, 100).cuda() + x = torch.randn(100, 100).to(GPU_TYPE) + y = torch.randn(100, 100).to(GPU_TYPE) with config.patch( { @@ -2836,7 +2838,7 @@ def fn(x, y): class TestRemoteAOTAutogradCache(TestCase): - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2875,7 +2877,7 @@ def f(a, b): for k in global_stats.fx_graph.cache.keys(): self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2950,7 +2952,7 @@ def fn(x, y): # This combination of settings exposed a bug where we cleared the # PyCodeCache disk artifacts while they were still needed: - @requires_cuda + @requires_cuda_and_triton @config.patch( { "coordinate_descent_tuning": True, diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index a054464bf668..90399546d26e 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -10,8 +10,8 @@ instantiate_parametrized_tests, TestCase, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton aten = torch.ops.aten @@ -55,7 +55,7 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() - @requires_cuda + @requires_cuda_and_triton def test_activation_functions(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -75,7 +75,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_reduce_functions(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -98,7 +98,7 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2) - @requires_cuda + @requires_cuda_and_triton def test_mutated_args(self): def test_mutated(a, b, c, d): a.add_(1) @@ -121,7 +121,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_reduce_split(self): def fn(a, b): a1 = torch.linalg.vector_norm(a) @@ -137,7 +137,7 @@ def fn(a, b): self.assertEqual(out_eager, out_compiled) - @requires_cuda + @requires_cuda_and_triton def test_2d_blocking_partitioning(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -184,7 +184,7 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() - @requires_cuda + @requires_cuda_and_triton def test_activation_benchmark(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -204,7 +204,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton def test_reduce_benchmark(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -227,7 +227,7 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) - @requires_cuda + @requires_cuda_and_triton def test_mutated_benchmark(self): def test_mutated(a, b, c, d): a.add_(1) @@ -250,7 +250,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) - @requires_cuda + @requires_cuda_and_triton def test_round_robin_dispatch(self): # combo kernel dispatch strategy: round robin def test_mutated(a, b, c, d): @@ -274,7 +274,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) - @requires_cuda + @requires_cuda_and_triton def test_2d_blocking_benchmark(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -296,7 +296,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda + @requires_cuda_and_triton def test_persistent_reduction_no_x_dim(self): def fn(x, y): return x.sum(1), y.sum(1) @@ -346,7 +346,7 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_activations(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -366,7 +366,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_2d_blocking(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -388,7 +388,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_reduce(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -411,7 +411,7 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_mutated(self): # combo kernel dispatch strategy: round robin def test_mutated(a, b, c, d): @@ -435,7 +435,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernels_autotune", 0) def test_dynamic_shapes_activations_no_autotune(self): def test_activations(a, b, c): @@ -456,7 +456,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_persistent_reduction_no_x_dim(self): @@ -475,7 +475,7 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_persistent_reduction_no_x_dim_2(self): @@ -494,7 +494,7 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_2d_blocking_round_robin(self): @@ -533,7 +533,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) @torch._inductor.config.patch("triton.autotune_at_compile_time", True) @@ -558,5 +558,5 @@ def fn(x, y, z): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_compile.py b/test/inductor/test_compile.py index e1f4f146636d..6908936eca3f 100644 --- a/test/inductor/test_compile.py +++ b/test/inductor/test_compile.py @@ -1,6 +1,14 @@ # Owner(s): ["module: inductor"] +import os +import shlex +import subprocess +import sys +from unittest import mock + import torch from torch import _dynamo as dynamo, _inductor as inductor +from torch._inductor.codecache import write +from torch._inductor.cpp_builder import CppBuilder, CppOptions from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import gen_gm_and_inputs from torch.fx import symbolic_trace @@ -8,6 +16,25 @@ from torch.testing._internal.inductor_utils import HAS_CPU +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + + +def safe_command_output(cmd, timeout=30): + try: + return subprocess.check_output( + cmd, + stderr=subprocess.STDOUT, + text=True, + timeout=timeout, + shell=isinstance(cmd, str), + ).strip() + except subprocess.CalledProcessError as e: + return f"run failed(error code {e.returncode}): {e.output.strip()}" + except subprocess.TimeoutExpired: + return "runt timeout" + + class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -109,6 +136,53 @@ def test_inductor_via_op_with_multiple_outputs(self): mod_opt = inductor.compile(mod, inp) self.assertEqual(mod(*inp), mod_opt(*inp)) + @mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"}) + def test_inductor_generate_debug_symbol(self): + cpp_code = """ +int main(){ + return 0; +} + """ + + _, source_path = write( + cpp_code, + "cpp", + ) + build_option = CppOptions() + cpp_builder = CppBuilder( + name="test_symbol", + sources=source_path, + output_dir=os.path.dirname(source_path), + BuildOption=build_option, + ) + cpp_builder.build() + binary_path = cpp_builder.get_target_file_path() + + """ + When we turn on generate debug symbol. + On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. + On Linux, it should create some debug sections in binary file. + """ + + def check_linux_debug_section(module_path: str): + check_cmd = shlex.split(f"readelf -S {module_path}") + output = safe_command_output(check_cmd) + has_debug_sym = ".debug_info" in output + self.assertEqual(has_debug_sym, True) + + def check_windows_pdb_exist(module_path: str): + file_name_no_ext = os.path.splitext(module_path)[0] + file_name_pdb = f"{file_name_no_ext}.pdb" + has_pdb_file = os.path.exists(file_name_pdb) + self.assertEqual(has_pdb_file, True) + + if _IS_WINDOWS: + check_windows_pdb_exist(binary_path) + elif _IS_MACOS: + pass # MacOS not sure that if it should be works. + else: + check_linux_debug_section(binary_path) + if __name__ == "__main__": if HAS_CPU: diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index 04297c38bf29..51aa7b70b9c4 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -62,9 +62,6 @@ "test_remove_noop_slice_scatter": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_default": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_dtype": TestFailure(("xpu"), is_skip=True), - # TODO:remove test_upsample_bicubic2d after the following issue resolved: - # https://github.com/intel/intel-xpu-backend-for-triton/issues/4184 - "test_upsample_bicubic2d": TestFailure(("xpu"), is_skip=False), } diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index dcbf1b380934..8fde26c6acf6 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import operator import os +import tempfile from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, @@ -66,6 +67,19 @@ def test_quiesce(self): finally: pool.shutdown() + @skipIfWindows(msg="pass_fds not supported on Windows.") + def test_logging(self): + os.environ["MAST_HPC_JOB_NAME"] = "test_job" + os.environ["ROLE_RANK"] = "0" + with tempfile.NamedTemporaryFile(delete=True) as temp_log: + os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name + pool = SubprocPool(2) + try: + pool.submit(operator.add, 100, 1) + self.assertEqual(os.path.exists(temp_log.name), True) + finally: + pool.shutdown() + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index c0d304290b05..dff94b4aa092 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -29,6 +29,7 @@ from torch._dynamo.testing import normalize_gm from torch._dynamo.utils import counters from torch._inductor import config as inductor_config +from torch._inductor.cpp_builder import is_msvc_cl from torch._inductor.test_case import run_tests, TestCase from torch.nn.attention.flex_attention import flex_attention from torch.nn.parallel import DistributedDataParallel as DDP @@ -40,13 +41,20 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_S390X, + IS_WINDOWS, parametrize, scoped_load_inline, skipIfWindows, ) from torch.testing._internal.hop_db import hop_db -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_CUDA_AND_TRITON, + HAS_GPU, +) from torch.testing._internal.logging_utils import logs_to_string +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._python_dispatch import TorchDispatchMode @@ -193,6 +201,18 @@ def model(i): for _ in range(3): self.run_as_subprocess(script) + def gen_cache_miss_log_prefix(self): + if IS_WINDOWS: + if is_msvc_cl(): + return "Cache miss due to new autograd node: struct " + else: + self.fail( + "Compilers other than msvc have not yet been verified on Windows." + ) + return "" + else: + return "Cache miss due to new autograd node: " + def test_reset(self): compiled_autograd.compiled_autograd_enabled = True torch._C._dynamo.compiled_autograd.set_autograd_compiler(lambda: None, True) @@ -2975,7 +2995,7 @@ def backward(ctx, grad): b = MyFunc.apply(a) b.sum().backward() - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_cpu_division(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -3015,7 +3035,7 @@ def test_cudagraphs_cpu_graph(self): self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_sdpa(self): query = torch.rand( 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True @@ -3037,7 +3057,7 @@ def test_cudagraphs_sdpa(self): 2 if inductor_config.cpp_wrapper else 0, ) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): class MyFn(torch.autograd.Function): @staticmethod @@ -3065,10 +3085,19 @@ def backward(ctx, gO): self.assertEqual(counters["compiled_autograd"]["captures"], 1) # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + else: + expected_cudagraph_skips = 1 + + self.assertEqual( + counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips + ) @scoped_load_inline - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { @@ -3130,9 +3159,18 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + elif inductor_config.cpp_wrapper: + expected_cudagraph_skips = 2 + else: + expected_cudagraph_skips = 1 + self.assertEqual( counters["inductor"]["cudagraph_skips"], - 2 if inductor_config.cpp_wrapper else 1, + expected_cudagraph_skips, ) def test_logs(self): @@ -3146,7 +3184,7 @@ def test_logs(self): self.assertEqual(counters["compiled_autograd"]["compiles"], 1) assert "torch::autograd::AccumulateGrad (NodeCall" in logs.getvalue() assert ( - "Cache miss due to new autograd node: torch::autograd::GraphRoot" + self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot" not in logs.getvalue() ) @@ -3353,7 +3391,6 @@ def fn(x, obj): sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) ) - @skipIfWindows(msg="AssertionError: Scalars are not equal!") def test_verbose_logs_cpp(self): torch._logging.set_logs(compiled_autograd_verbose=True) @@ -3381,8 +3418,9 @@ def fn(): self.check_output_and_recompiles(fn) patterns1 = [ - r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), " - r"previous key sizes=\[\]\n", + r".*" + + self.gen_cache_miss_log_prefix() + + r"torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), previous key sizes=\[\]\n", ] all_logs = logs.getvalue() @@ -3395,6 +3433,7 @@ def fn(): ) # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]... self.assertEqual(len(matches1), len(patterns1)) + @skipIfWindows(msg="node name demangling inconsistent on windows") def test_verbose_logs_dynamic_shapes(self): logs, ctx = logs_to_string( torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" @@ -3419,7 +3458,8 @@ def test_verbose_logs_dynamic_shapes(self): actual_logs = logs.getvalue() expected_logs = [ - "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]", + self.gen_cache_miss_log_prefix() + + "torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]", ] for expected in expected_logs: self.assertTrue(expected in actual_logs) @@ -3450,7 +3490,7 @@ def fn(): fn() unexpected_logs = [ - "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0)" + self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot (NodeCall 0)" ] self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0) @@ -3694,7 +3734,7 @@ def inner_compiler(gm_, example_inputs_): self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node)) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @requires_cuda_and_triton def test_flex_attention(self): def _squared(score, b, h, m, n): """Joint graph needed for correctness""" @@ -3862,7 +3902,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check), ) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @requires_cuda_and_triton def test_cpu_offloading(self): def fn(): def pack(x): @@ -5030,7 +5070,7 @@ def wrap_test_class(orig_cls): dct[name] = unittest.expectedFailure elif name.startswith("test_"): backend = lookup_backend(name) - if not HAS_CUDA and backend == "inductor": + if not HAS_CUDA_AND_TRITON and backend == "inductor": continue ctxs = [ compiled_autograd._enable( @@ -5267,7 +5307,7 @@ def wrap_test_class(orig_cls): skipped_tests = set() -if not HAS_CUDA: +if not HAS_CUDA_AND_TRITON: # Found Tesla M60 which is too old to be supported by the triton GPU compiler skipped_tests.add("test_type_conversions") @@ -5293,7 +5333,7 @@ def wrap_test_class(orig_cls): test_higher_order_ops.ActivationCheckpointingTests ) -if torch.distributed.is_available() and HAS_CUDA: +if torch.distributed.is_available() and HAS_CUDA_AND_TRITON: test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile") TestDTensorCompileWithCompiledAutograd = wrap_test_class( test_dtensor.TestDTensorCompile diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 9751b3ca8f55..3b23e7a51f70 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -64,7 +64,7 @@ HAS_GPU, has_triton, ) -from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu def get_inputs(optim): @@ -916,7 +916,7 @@ def fn(xs, ys): self.assertLess(end - start, 90) - @requires_cuda + @requires_cuda_and_triton def test_S429861(self): # Just verify we can compile this function without error try: @@ -935,7 +935,7 @@ def test_S429861(self): kwargs = aot_graph_input_parser(forward) torch.compile(forward)(**kwargs) - @requires_cuda + @requires_cuda_and_triton def test_foreach_map_adam(self): params = [ torch.rand( diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 107a65d6fa1d..511b9cea5e14 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -472,6 +472,9 @@ def false_fn(x): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @torch._inductor.config.patch(size_asserts=False) + # TODO: graph partition does not support creating tensor + # with dynamic shape in conditional subgraph yet + @torch._inductor.config.patch(graph_partition=False) def test_cond_unbacked_symint_inner(self, device): class Model(torch.nn.Module): def forward(self, p, a): diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index fc296b12a9d7..0b8f60dc0d26 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -18,7 +18,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class TestingHeuristics(InductorChoices): @@ -381,5 +381,5 @@ def fn(x, y): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 50001c24fd07..53b3e013a6b2 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -2644,6 +2644,18 @@ def fn(a, dim, index, b): self.common(fn, inps) assert metrics.generated_cpp_vec_kernel_count == 2 + def test_large_mean(self): + size = (30000, 100000) + t = torch.rand(size, dtype=torch.float) + op = torch.mean + expected = op(t) + actual = torch.compile(op)(t) + self.assertEqual(expected, actual) + with set_num_threads(1): + expected = op(t) + actual = torch.compile(op)(t) + self.assertEqual(expected, actual) + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @requires_vectorization @patch("torch.cuda.is_available", lambda: False) @@ -3105,6 +3117,30 @@ def get_traj_idx(lengths: torch.Tensor, num_slices: int) -> torch.Tensor: lengths = torch.zeros(11, dtype=torch.long) get_traj_idx(lengths, num_slices=4) + def test_store_reduction(self): + # fix https://github.com/pytorch/pytorch/issues/157683 + def fn(x, y): + r1 = x.amax(dim=0) + r2 = y.amax(dim=0) + return r1, r2 + + device = "cpu" + for int_dypte, float_dtype in zip( + [torch.int64, torch.int32, torch.int16, torch.int8], + [torch.float64, torch.float32, torch.float16, torch.bfloat16], + ): + x = torch.randint( + low=0, high=100, size=(16, 24, 59), dtype=int_dypte, device=device + ) + y = torch.randn(16, 24, 59, dtype=float_dtype, device=device) + self.common( + fn, + ( + x, + y, + ), + ) + @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 7e35c93ee0b7..75d091595cd8 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,6 +26,7 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, + IS_WINDOWS, parametrize, skipIfWindows, TEST_MKL, @@ -3094,5 +3095,5 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not IS_MACOS: + if HAS_CPU and not (IS_MACOS or IS_WINDOWS): run_tests() diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index bb59b626bef1..53506698297f 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -26,6 +26,7 @@ run_fw_bw_and_get_code, ) from torch.fx.experimental.proxy_tensor import make_fx +from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -177,9 +178,10 @@ def test_effn_attn_bias_padding_misaligned(self): inputs = [q, k, v, mask] def f(q, k, v, mask): - return F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) f_compiled = torch.compile(f) @@ -187,9 +189,9 @@ def f(q, k, v, mask): # padded bias should have an expanded dim FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) # single fused padded kernel - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("return").run(code[0]) + FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check( + "return" + ).run(code[0]) self.assertEqual(out, f(*inputs)) @@ -2216,7 +2218,7 @@ def forward(self, x): if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CUDA + from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON - if HAS_CUDA and not TEST_WITH_ASAN: + if HAS_CUDA_AND_TRITON and not TEST_WITH_ASAN: run_tests(needs="filelock") diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 36f73b200476..b6786130416b 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -1,7 +1,6 @@ # Owner(s): ["module: inductor"] import ctypes -import unittest import torch from torch._inductor.async_compile import AsyncCompile @@ -10,10 +9,7 @@ from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_cache -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.triton_utils import requires_cuda_and_triton _SOURCE_CODE = r""" @@ -41,7 +37,7 @@ class TestCUDACodeCache(InductorTestCase): - @requires_cuda + @requires_cuda_and_triton def test_cuda_load(self): with fresh_cache(): # Test both .o and .so compilation. @@ -69,14 +65,14 @@ def test_cuda_load(self): ) torch.testing.assert_close(y, expected_y) - @requires_cuda + @requires_cuda_and_triton def test_compilation_error(self): with fresh_cache(): error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) with self.assertRaises(CUDACompileError): CUDACodeCache.compile(error_source_code, "o") - @requires_cuda + @requires_cuda_and_triton def test_async_compile(self): with fresh_cache(): async_compile = AsyncCompile() diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index dc8ec985fbae..763384671eb5 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -40,6 +40,7 @@ skipIfRocm, TEST_CUDA_GRAPH, ) +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -55,11 +56,8 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CUDA - aten = torch.ops.aten -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_multigpu = functools.partial( unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" ) @@ -124,7 +122,7 @@ def tearDown(self): torch._dynamo.reset() -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: def get_all_cudagraph_segments(): segments = torch.cuda.memory_snapshot() @@ -281,10 +279,14 @@ def foo(x, y): with capture_stderr() as captured_output: foo(torch.ones([10], device="cuda"), torch.ones([20])) - FileCheck().check( - "skipping cudagraphs due to cpu device (arg1_1). Found from" - ).check("y + 2").run(captured_output[0]) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if torch._inductor.config.graph_partition: + # graph partition splits on cpu ops + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + else: + FileCheck().check( + "skipping cudagraphs due to cpu device (arg1_1). Found from" + ).check("y + 2").run(captured_output[0]) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) with capture_stderr() as captured_output: foo( @@ -294,7 +296,10 @@ def foo(x, y): FileCheck().check("skipping cudagraphs due to multiple devices").run( captured_output[0] ) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 1 if torch._inductor.config.graph_partition else 2, + ) @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) def test_skip_symbolic(self): @@ -809,10 +814,16 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -1129,8 +1140,13 @@ def foo2(x): node = self.curr_node() first_node = next(node._path_from_root) - self.assertFalse(first_node.unaliased_in_all_paths[0]) - self.assertTrue(first_node.cached_tensor_outputs[0] is None) + if torch._inductor.config.graph_partition: + # graph partition may changed the order of outputs + self.assertFalse(first_node.unaliased_in_all_paths[1]) + self.assertTrue(first_node.cached_tensor_outputs[1] is None) + else: + self.assertFalse(first_node.unaliased_in_all_paths[0]) + self.assertTrue(first_node.cached_tensor_outputs[0] is None) @torch._inductor.config.patch("implicit_fallbacks", True) def test_multinomial(self): @@ -1633,10 +1649,16 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) def test_separate_recordings(self): @@ -2139,8 +2161,8 @@ def forward(self, x) -> torch.Tensor: with self.assertRaisesRegex( Exception, r"(?s)static input data pointer changed.\n" - r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*" - r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", ): self.curr_node().run( @@ -2849,6 +2871,28 @@ def foo(x): self.assertEqual(x, torch.tensor(1, device="cpu")) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar_multiple(self): + def f(x, y, z): + return x + y, x + z + + compiled_f = torch.compile(f, mode="reduce-overhead") + + inputs = ( + torch.ones((), device="cpu"), + torch.ones((), device="cpu"), + torch.ones(2, 2, device="cuda"), + ) + for i in range(3): + if i == 0: + _, code = run_and_get_code(compiled_f, *inputs) + FileCheck().check_regex(r".copy_.*True").run(code[0]) + FileCheck().check_count(".copy_", 1, exactly=True).run(code[0]) + else: + compiled_f(*inputs) + self.assertEqual(compiled_f(*inputs), f(*inputs)) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + @torch._inductor.config.patch("graph_partition", True) @torch._inductor.config.patch("triton.cudagraphs", False) def test_graph_partition_reduce_overhead_mode_effectiveness(self): @@ -3531,6 +3575,278 @@ def run(padded_size, original_size): self.assertEqual(self.get_manager().new_graph_id().id, 2) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_simple(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + self.assertEqual(eager_out, compiled_out) + + _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").check( + "recursively_apply_fns = runner.recursively_apply_fns" + ).run(code[0]) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_foreach_op(self): + def fn(a0, a1): + c = torch._foreach_abs([a0, a1]) + return torch.mul(c[0], a0) + + compiled_fn = torch.compile(fn) + + a0 = torch.randn(2, 3, device="cuda") + a1 = torch.randn(2, 3, device="cuda") + eager_out = fn(a0, a1) + compiled_out = compiled_fn(a0, a1) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_condition_op(self): + def f(p, b): + def true_fn(x): + return torch.cos(x) + + def false_fn(x): + return torch.sin(x) + + return torch.cond(p, true_fn, false_fn, [b]) + + compiled_f = torch.compile(f) + + # static shape + p = torch.tensor([True], device="cuda") + a = torch.ones([2, 3], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + # dynamic shape with backed symint + p = torch.tensor([True], device="cuda") + a = torch.ones([4, 5], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_unbacked_symint_multi_output_layout(self): + def f(p, size_tensor): + size_val = size_tensor.item() + b = torch.ones([size_val, 3], device="cuda") + + def true_fn(x): + return torch.cos(x), torch.cos(x) + 1 + + def false_fn(x): + return torch.sin(x), torch.sin(x) + 1 + + cond_out = torch.cond(p, true_fn, false_fn, [b]) + return cond_out[0] + cond_out[1] + + compiled_f = torch.compile(f) + p = torch.tensor([True], device="cuda") + size_tensor = torch.tensor(2, device="cuda") + eager_out = f(p, size_tensor) + compiled_out = compiled_f(p, size_tensor) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + x, y = ( + torch.ones(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device="cuda") + eager_w = torch.randn(2, 2, device="cuda", requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device="cuda") + compiled_w = torch.randn(2, 2, device="cuda", requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_nested_indirect_indexing(self): + def nested(x, repeats): + rank = torch.arange(repeats.numel(), device=x.device) + index = rank.repeat_interleave(repeats, dim=0) + return torch.index_select(x, index=index, dim=0) + + example_inputs = ( + torch.randn((32, 64), device="cuda"), + repeats := torch.tensor([5, 10, 15], device="cuda"), + ) + torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + + nested_opt = torch.compile(nested, backend="inductor") + + expect = nested(*example_inputs) + actual = nested_opt(*example_inputs) + self.assertEqual(expect, actual) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_mutation_index(self): + x = torch.zeros(7, device="cuda") + + def fn(n, a): + a[n] = -1 + return a + + opt_fn = torch.compile(fn, fullgraph=True) + + for n in range(2, x.shape[0]): + opt_fn(n, x) + self.assertEqual(x[n], -1) + + # Negative index triggers new compilation. + opt_fn(-x.shape[0], x) + + self.assertEqual(x[0], -1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_unbacked_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y) + eager_out = f(x, y) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_dynamic_scalar_inputs(self): + def f(x, y, integer): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + z += integer + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y, 5) + self.assertEqual(compiled_out, f(x, y, 5)) + + compiled_out = f_compiled(x, y, 6) + self.assertEqual(compiled_out, f(x, y, 6)) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_item(self): + def f(x): + y = x + 1 + scalar = y.item() + return x + y + scalar + + compiled_f = torch.compile(f) + compiled_out = compiled_f(torch.tensor(1, device="cuda")) + self.assertEqual(compiled_out, f(torch.tensor(1, device="cuda"))) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_buffer_reuse(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x1 + y1 + x @ y + u = (y_cpu.to("cuda") + 2) @ y + 3 + u_cpu = u.cpu() + 2 + return z + u_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_fused_scheduler_node(self): + def foo(x): + x = x * 20 + x_alias = x[0] + y = x * 10 + y_alias = y[0] + torch._dynamo.graph_break() + ind = torch.tensor(4, device="cuda") + x_alias2 = x[ind:] + y_alias2 = y[ind:] + return x, x_alias, x_alias2, y_alias, y_alias2 + + compiled_foo = torch.compile(foo) + x = torch.rand([20, 20], device="cuda") + + eager_out = foo(x) + compiled_out = compiled_foo(x) + self.assertEqual(eager_out, compiled_out) + def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 @@ -4035,5 +4351,5 @@ def fn(x, y): sys.exit(0) raise unittest.SkipTest("cuda graph test is skipped") - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_cudagraph_trees_expandable_segments.py b/test/inductor/test_cudagraph_trees_expandable_segments.py index 04f2ad96fdc0..65597316091d 100644 --- a/test/inductor/test_cudagraph_trees_expandable_segments.py +++ b/test/inductor/test_cudagraph_trees_expandable_segments.py @@ -8,13 +8,13 @@ import torch from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: try: from .test_cudagraph_trees import CudaGraphTreeTests except ImportError: @@ -32,7 +32,12 @@ sys.path.remove(str(REPO_ROOT)) if __name__ == "__main__": - if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS and HAS_CUDA: + if ( + torch.cuda.is_available() + and not IS_JETSON + and not IS_WINDOWS + and HAS_CUDA_AND_TRITON + ): get_disabled_tests(".") torch.cuda.memory._set_allocator_settings("expandable_segments:True") diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 454ee544458c..8b0712dc810a 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -58,12 +58,12 @@ _quantize_rowwise, _quantize_tensorwise, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, ) torch.set_float32_matmul_precision("high") -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -158,8 +158,8 @@ def select_no_algorithm(*args, **kwargs): @instantiate_parametrized_tests class TestCutlassBackend(TestCase): def setUp(self): - if not HAS_CUDA: - self.skipTest("CUDA is not available") + if not HAS_CUDA_AND_TRITON: + self.skipTest("CUDA and triton are not available") if torch.version.hip: self.skipTest("CUTLASS backend is not supported on HIP") @@ -200,6 +200,19 @@ def run_evt_test(self, model, op, shape, num_fusions=1): ) torch.testing.assert_close(result, ref_result) + def test_check_paths(self): + cutlass_mock_imports_path = os.path.join( + os.path.dirname(torch.__file__), + "_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports", + ) + cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda") + cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot") + cutlass_mock_scipy_path = os.path.join(cutlass_mock_imports_path, "scipy") + self.assertTrue(os.path.exists(cutlass_mock_imports_path)) + self.assertTrue(os.path.exists(cutlass_mock_cuda_path)) + self.assertTrue(os.path.exists(cutlass_mock_pydot_path)) + self.assertTrue(os.path.exists(cutlass_mock_scipy_path)) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_threshold(self): @@ -281,20 +294,19 @@ def test_cutlass_backend_subproc_mm(self): Y = torch.mm(a, b) torch.testing.assert_close(Y_compiled, Y) - @unittest.skipIf( - True, "FIXME: Disabled temporarily since IMA or crashing in subprocess" - ) @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_cutlass_backend_subproc_addmm(self, shape_combo): + @parametrize("dtype", (torch.float16, torch.bfloat16)) + def test_cutlass_backend_subproc_addmm(self, dtype): """ Test autotune_in_subproc works for addmm. """ M, N, K = 4096, 2048, 25728 + dtype = torch.float16 - a = torch.randn(M, K).cuda().half() - b = torch.randn(N, K).cuda().half().t() + a = torch.randn(M, K, dtype=dtype).cuda() + b = torch.randn(N, K, dtype=dtype).cuda().t() x_shapes = [ (M, N), @@ -316,7 +328,10 @@ def test_cutlass_backend_subproc_addmm(self, shape_combo): } ): for x_shape in x_shapes: - x = torch.randn(x_shape).cuda().half() + torch._dynamo.reset() + clear_caches() + + x = torch.randn(x_shape).cuda().to(dtype) Y_compiled = torch.compile(torch.addmm)(x, a, b, alpha=alpha, beta=beta) Y = torch.addmm(x, a, b, alpha=alpha, beta=beta) torch.testing.assert_close(Y_compiled, Y) @@ -747,11 +762,7 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( Make sure autotuning mm in sub processes work without crashes. """ - def mm(a, b): - return a @ b - - a = torch.randn(128, 16).cuda().half() - b = torch.randn(128, 16).cuda().half().t() + compiled_model = torch.compile(torch.mm, dynamic=dynamic) with config.patch( { @@ -778,12 +789,66 @@ def mm(a, b): ): a = torch.randn(M, K).cuda().half() b = torch.randn(N, K).cuda().half().t() - Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) - Y = mm(a, b) + Y_compiled = compiled_model(a, b) + Y = torch.mm(a, b) # we need relaxed numerical limits due to the sheer size of the # matmuls involved. Many small addition differences add up. torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01) + @unittest.skipIf(not SM90OrLater, "need sm_90") + def test_streamk_with_dynamic( + self, + ): + """ + Test streamk with dynamic=True. Streamk should be filtered out. + + Problem is streamk can have a different workspace depending on the + shape. Without a correct workspace, the kernel will fail at runtime. + """ + + a = torch.randn(128, 16).cuda().half() + b = torch.randn(128, 16).cuda().half().t() + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + } + ): + with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"): + _ = torch.compile(torch.mm, dynamic=True)(a, b) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + def test_streamk_with_static( + self, + ): + """ + Test streamk with dynamic=False. Streamk should work. + """ + + shapes = [ + (18432, 3072, 6144), + (9216, 3072, 6144), + (4608, 3072, 6144), + ] + compiled_model = torch.compile(torch.mm, dynamic=False) + + for shape in shapes: + M, N, K = shape + a = torch.randn(M, K).cuda().half() + b = torch.randn(N, K).cuda().half().t() + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 1, + "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + } + ): + _ = compiled_model(a, b) + def _test_max_autotune_cutlass_backend_epilogue_fusion( self, dynamic: bool = False, @@ -1743,6 +1808,26 @@ def test_cutlass_backend_matmul_same_tensor(self): torch.testing.assert_close(A @ A.t(), compiled(A, A.t())) + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_cutlass_backend_matmul_nonzero_offset(self): + max_autotune_gemm_backends = "CUTLASS" + + M = 129 + A = torch.randn(M, M - 1).cuda().half() + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_max_profiling_configs": 2, + } + ): + compiled = torch.compile(torch.mm) + torch.testing.assert_close( + A[1:, :] @ A[1:, :].t(), compiled(A[1:, :], A[1:, :].t()) + ) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_flexible_layout(self): @@ -2051,23 +2136,25 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): ), ) @parametrize("has_bias", (False, True)) - @parametrize("use_fast_accum", (False,)) + @parametrize("use_fast_accum", (False, True)) + @parametrize("input_dtype", (torch.bfloat16, torch.float16)) def test_fp8_rowwise_scaling( self, float8_dtype: torch.dtype, shape: tuple[int, int, int], has_bias: bool, use_fast_accum: bool, + input_dtype: torch.dtype, ): # Only bf16 output type is supported for row-wise scaling, not fp32 output_dtype: torch.dtype = torch.bfloat16 device = "cuda" M, K, N = shape # Matmul Y = X [M, K] x W [N, K] - x = torch.randn(M, K, dtype=output_dtype, device=device) - w = torch.randn(N, K, dtype=output_dtype, device=device) + x = torch.randn(M, K, dtype=input_dtype, device=device) + w = torch.randn(N, K, dtype=input_dtype, device=device) bias = None if has_bias: - bias = torch.randn(N, device=device, dtype=torch.bfloat16) + bias = torch.randn(N, device=device, dtype=input_dtype).to(torch.bfloat16) # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_rowwise(w, float8_dtype) @@ -2124,24 +2211,25 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): ) @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False,)) + @parametrize("input_dtype", (torch.bfloat16, torch.float16)) def test_fp8_tensorwise_scaling( self, float8_dtype: torch.dtype, shape: tuple[int, int, int], has_bias: bool, use_fast_accum: bool, + input_dtype: torch.dtype, ): device = "cuda" M, K, N = shape # Matmul Y = X [M, K] x W [N, K] - input_dtype = torch.bfloat16 - output_dtype = torch.bfloat16 + output_dtype = input_dtype # input and output dtypes of _scaled_mm do not need to be the same, but # typically in a model they are x = torch.randn(M, K, dtype=input_dtype, device=device) w = torch.randn(N, K, dtype=input_dtype, device=device) bias = None if has_bias: - bias = torch.randn(N, device=device, dtype=torch.bfloat16) + bias = torch.randn(N, device=device, dtype=input_dtype) # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_tensorwise(w, float8_dtype) @@ -2240,5 +2328,5 @@ def test_config_number_post_filtering(self) -> None: from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. - if HAS_CUDA and HAS_CPU and is_big_gpu(): + if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index 1a90277d4e96..9c2b9a624a20 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -15,7 +15,7 @@ from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.utils import OrderedSet from torch.testing._internal.common_cuda import SM90OrLater -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON if try_import_cutlass(): @@ -392,12 +392,12 @@ def test_evt_argument_codegen(self): {}, /* C */ {}, /* compute_0 */ }, - {/* ptr_aux */ (float*) aux, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ + {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ {}, /* compute_1 */ }, - {/* ptr_aux */ (float*) F, /* dAux */ {2048, _1{}, _0{}}}, /* F */ + {/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */ }, - {/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */ + {/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ {}, /* compute_2 */ {}, /* compute_3 */ {}, /* compute_4 */ @@ -444,9 +444,9 @@ def fn(accum, bias): { /* thread */ { /* E */ {}, /* accum */ - {/* ptr_aux */ (float*) E, /* dAux */ {2048, _1{}, _0{}}}, /* E */ + {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */ }, - {/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */ + {/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ {}, /* compute_0 */ } """, @@ -455,7 +455,7 @@ def fn(accum, bias): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_codegen(self): - _, _, code = trace( + _, _, code, _ = trace( BIAS_CODE, EXAMPLE_TENSORS, DataType.f32, @@ -571,5 +571,5 @@ def test_evt_codegen(self): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 8be6e2347592..919d97f987f6 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -15,7 +15,7 @@ parametrize, TEST_XPU, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_gpu @@ -117,7 +117,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -128,7 +128,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 3 if should_decompose and HAS_CUDA else 0 + expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -177,7 +177,7 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -224,7 +224,7 @@ def test_decompose_linear_mixed_precision( self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -269,7 +269,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -281,7 +281,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -331,7 +331,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -343,7 +343,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -367,7 +367,7 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -381,7 +381,7 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_gradients(module, traced) expected_val = 0 - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: expected_val = 1 if has_bias else 2 self.assertEqual( diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index e78cf68244ee..8e4746212a0b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1889,6 +1889,50 @@ def score_mod_scale(qk, b, h, q, kv): self.run_test(score_mod_scale, dtype, device=device) + @supported_platform + @dtypes(*device_configs["cpu"].dtypes_fast) + @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @skip_on_cpu + def test_dynamic_divisibility_guards(self, device, dtype): + """Test guards for divisible/non-divisible shape transitions""" + if device == "cpu" and dtype is torch.float16: + dtype = torch.float32 + + def score_mod(qk, b, h, q, kv): + return torch.where(q >= kv, qk, -float("inf")) + + def test_shape(S, backend): + """Test a single shape configuration""" + block_mask = create_block_mask(noop_mask, 1, 1, S, S, device=device) + sdpa_partial = create_attention(score_mod, block_mask=block_mask) + + tensors = [ + torch.randn( + 2, 4, S, 64, dtype=dtype, device=device, requires_grad=False + ) + for _ in range(3) + ] + + compiled_sdpa = torch.compile(sdpa_partial, backend=backend) + out, code = run_and_get_code(compiled_sdpa, *tensors) + + # Check divisibility flag + is_divisible = S % 128 == 0 + expected_flag = f"IS_DIVISIBLE : tl.constexpr = {is_divisible}" + self.assertIn( + expected_flag, str(code), f"S={S} should have {expected_flag}" + ) + + self.assertEqual(out.shape, (2, 4, S, 64)) + return out, code + + torch._dynamo.reset() + backend = CompileCounterWithBackend("inductor") + + # Test divisible and non-divisible shapes + test_shapes = [256, 255, 383, 384] + _ = [test_shape(S, backend) for S in test_shapes] + @supported_platform def test_multiple_score_mod_calls(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index b5ec59dc291c..9a0cb945fc33 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -2,6 +2,7 @@ # flake8: noqa: B950 import functools +import sys import unittest from collections import namedtuple from typing import Callable, Optional, Union @@ -27,6 +28,15 @@ flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, ) +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS + + +if IS_WINDOWS and IS_CI: + # TODO(xuhancn) : Need track if it is a requirement on windows. + sys.stderr.write("This UT is validated on windows, a lot of crash. Skip it.\n") + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("skip on Windows") Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index 8eb113f18329..c51d0bba229e 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -14,8 +14,8 @@ IS_FBCODE, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._pytree import tree_flatten @@ -269,29 +269,29 @@ def fn(a0, a1): ) # called in test_cuda_cpp_wrapper.py - @requires_cuda + @requires_cuda_and_triton def test_foreach_cpp_wrapper_cuda(self): self._test_single_list(op=torch._foreach_add) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_single_list(self, op): self._test_single_list(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_single_scalar(self, op): self._test_single_scalar(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_tensor_bin_ops def test_single_scalar_tensor(self, op): self._test_single_scalar_tensor(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_scheduler_fusion_list(self, op): if op in un_ops_under_test: @@ -319,7 +319,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_scheduler_fusion_scalar(self, op): def fn(a0, a1): @@ -336,7 +336,7 @@ def fn(a0, a1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_broadcasting(self, op): def fn(a0, a1, b0, b1): @@ -355,7 +355,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_singleton_lists(self, op): if op in un_ops_under_test: @@ -392,7 +392,7 @@ def fn(a0, b0, c0): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_type_promotion(self, op): def fn(a0, a1, b0, b1): @@ -413,7 +413,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_kernel_split_arg_limit_list(self, op): # NB: foeach_copy won't pass this test because it will dce one set of buffers @@ -435,7 +435,7 @@ def fn(a, b): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops @unittest.skip( "Triton recursion depth exceeded: https://github.com/triton-lang/triton/issues/1763" @@ -455,7 +455,7 @@ def fn(a): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_fusion_duplicate_buffer_list(self, op): def fn(a0, a1, b0, b1): @@ -479,7 +479,7 @@ def fn(a0, a1, b0, b1): kernel_count = 2 self.assertEqual(torch._inductor.metrics.generated_kernel_count, kernel_count) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_non_foreach_consumer_list(self, op): if op in un_ops_under_test: @@ -507,7 +507,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_consumer_scalar(self, op): def fn(a0, a1): @@ -524,7 +524,7 @@ def fn(a0, a1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_non_foreach_producer_list(self, op): if op in un_ops_under_test: @@ -554,7 +554,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_producer_scalar(self, op): def fn(a0, a1, b0, b1): @@ -574,7 +574,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_non_foreach_consumer_producer_list(self, op): if op in un_ops_under_test: @@ -616,7 +616,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_consumer_producer_scalar(self, op): def fn(a0, a1, b0, b1): @@ -641,7 +641,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @@ -661,7 +661,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @@ -680,7 +680,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @@ -715,7 +715,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @decomp_ops def test_decomp(self, op): def fn(a0, a1, b0, b1, c0, c1): @@ -735,7 +735,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_fuse_concat(self): def fn(x1, x2, x3, w1, w2, w3): x = torch.stack([x1, x2, x3]) @@ -758,7 +758,7 @@ def fn(x1, x2, x3, w1, w2, w3): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton def test_zero_elems(self): def fn(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) @@ -775,7 +775,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_2d_blocking(self, op): def fn(a0, a1, b0, b1): @@ -793,7 +793,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_2d_blocking_partitioning(self, op): def fn(a0, a1, b0, b1): @@ -811,7 +811,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_2d_blocking_partitioning_elems(self, op): """2D blocking should be grouped by number of yelems""" @@ -833,7 +833,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @bin_ops @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_blocking_partitioning_mixed_sizes(self, op): @@ -856,7 +856,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing(self, op): def fn(a0, a1, b0, b1): @@ -874,7 +874,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing_mut_before(self, op): def fn(a0, a1, b0, b1): @@ -893,7 +893,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing_mut_after(self, op): def fn(a0, a1, b0, b1): @@ -912,7 +912,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_multi_device(self): def test_foreach_add(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) @@ -930,7 +930,7 @@ def test_foreach_add(a0, a1, b0, b1): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton def test_aliasing(self): def test_foreach_add(a0, a1, a2, b0, b1, b2): return torch._foreach_add_([a0, a1, a2], [b0, b1, b2]) @@ -952,7 +952,7 @@ def test_foreach_add(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 1) def test_2d_block_no_mixed_sizes_no_mask(self): """2D blocking with no mixed sizes constant mask""" @@ -974,7 +974,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_block_mixed_sizes_with_mask(self): """2D blocking with mixed sizes should have mask""" @@ -996,7 +996,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @foreach_map_bin_ops def test_foreach_map_backward_binary(self, op): from torch._dynamo.polyfills import foreach_map_fn @@ -1037,7 +1037,7 @@ def ref_fn(xs, ys): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton def test_foreach_map_input_mutation(self): def fn(xs, ys): outs = foreach_map_add_inplace(xs, ys) @@ -1073,7 +1073,7 @@ def fn(xs, ys): ): _ = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps)) - @requires_cuda + @requires_cuda_and_triton @foreach_map_un_ops def test_foreach_map_backward_unary(self, op): from torch._dynamo.polyfills import foreach_map_fn @@ -1109,5 +1109,5 @@ def ref_fn(xs): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 50044b2c1943..11d320315cdc 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -22,7 +22,7 @@ _quantize_tensorwise, _to_fp8_saturated, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, ) from torch.utils._triton import has_triton_tma_device @@ -766,5 +766,5 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): if __name__ == "__main__": - if HAS_CUDA or HAS_CPU: + if HAS_CUDA_AND_TRITON or HAS_CPU: run_tests() diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index a0e1b47032b8..25e96fa9f1e9 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -15,7 +15,12 @@ SM80OrLater, ) from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_CUDA_AND_TRITON, + HAS_XPU_AND_TRITON, +) def checkpoint_wrapper(fn): @@ -1114,7 +1119,7 @@ def dot_prod_attention( ) -if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): +if HAS_XPU_AND_TRITON or (HAS_CUDA_AND_TRITON and PLATFORM_SUPPORTS_FUSED_ATTENTION): class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): device = GPU_TYPE diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index e10690666799..d474f66c6b91 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -411,6 +411,28 @@ def test_dynamic_shapes_precomputed_size(self): ) self.assertIn("ks0", triton_node.kwargs["kwargs"]) + def test_dynamic_launch_grid_calc(self): + """ + Test the dyanmic launch grid calculation for Triton kernel wrapper + """ + func = torch.add + args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]] + (gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True}) + + # Check for the precomputed size arg. + (triton_node,) = gm.graph.find_nodes( + op="call_function", target=triton_kernel_wrapper_mutation + ) + self.assertIn("grid", triton_node.kwargs) + self.assertIn("xnumel", triton_node.kwargs["kwargs"]) + self.assertIn("XBLOCK", triton_node.kwargs["kwargs"]) + grid = triton_node.kwargs["grid"][0] + xblock = triton_node.kwargs["kwargs"]["XBLOCK"] + xnumel = triton_node.kwargs["kwargs"]["xnumel"] + self.assertEqual(grid[0].node.expr, ((xnumel + xblock - 1) // xblock)) + self.assertEqual(grid[1], 1) + self.assertEqual(grid[2], 1) + @config.patch({"trace.enabled": True}) @unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code") def test_debug(self, mock_output_code): diff --git a/test/inductor/test_graph_transform_observer.py b/test/inductor/test_graph_transform_observer.py index 1def72ae9e27..2bd0b6ef43f1 100644 --- a/test/inductor/test_graph_transform_observer.py +++ b/test/inductor/test_graph_transform_observer.py @@ -11,7 +11,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON try: @@ -28,7 +28,10 @@ class TestGraphTransformObserver(TestCase): def test_sdpa_rewriter(self): if not ( - HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT + HAS_CUDA_AND_TRITON + and PLATFORM_SUPPORTS_FUSED_ATTENTION + and HAS_PYDOT + and HAS_DOT ): return diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index 75f53f4dd9b8..3824b25cdeae 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -3,7 +3,7 @@ import torch._inductor.config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton class InductorAnnotationTestCase(TestCase): @@ -18,7 +18,7 @@ def f(a, b): _, code = run_and_get_code(f_comp, a, b) return code[0] - @requires_cuda + @requires_cuda_and_triton def test_no_annotations(self): code = self.get_code() @@ -26,15 +26,16 @@ def test_no_annotations(self): self.assertTrue("training_annotation" not in code) @inductor_config.patch(annotate_training=True) - @requires_cuda + @requires_cuda_and_triton def test_training_annotation(self): code = self.get_code() self.assertTrue("from torch.cuda import nvtx" in code) - self.assertEqual( - code.count("training_annotation = nvtx._device_range_start('inference')"), 1 + self.assertTrue( + code.count("training_annotation = nvtx._device_range_start('inference')") + >= 1 ) - self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1) + self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1) if __name__ == "__main__": diff --git a/test/inductor/test_inductor_scheduler.py b/test/inductor/test_inductor_scheduler.py index a93112cc6ddd..f180c9d003df 100644 --- a/test/inductor/test_inductor_scheduler.py +++ b/test/inductor/test_inductor_scheduler.py @@ -95,15 +95,14 @@ def test_disable_get_estimated_runtime_logging(self, device, dtype): { "max_autotune": True, "max_autotune_gemm_backends": "TRITON", - "force_disable_caches": True, }, { "max_autotune": True, "max_autotune_gemm_backends": "TRITON,ATEN", - "force_disable_caches": True, }, ], ) + @torch._inductor.config.patch({"force_disable_caches": True}) @skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune") def test_flop_counter_op(self, device, dtype, options): if device == "cpu": diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index 46d5cf61121e..7ddd0dd4441b 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -233,9 +233,9 @@ def f(x, y): loss.backward() return loss - x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() - y = torch.randint(0, V, (B * T,)).cuda() + y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) opt_f = torch.compile(f) diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 6e438cdeab91..4c35cec9bee9 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -386,6 +386,9 @@ def f(a, b, c): max_autotune=True, max_autotune_gemm_backends="TRITON", force_shape_pad=True ) def test_slice_mm_bandwidth_computation(self): + if GPU_TYPE == "xpu" and not torch._inductor.utils.is_big_gpu(): + raise unittest.SkipTest("unsupported device") + M, N, K = 1000, 2000, 3000 @torch.compile diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 57f7cdf2abe3..ff1d8c3fb875 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -35,7 +35,10 @@ TritonTemplate, TritonTemplateCaller, ) -from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig +from torch._inductor.template_heuristics import ( + CUDAMMTemplateConfigHeuristic, + GemmConfig, +) from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -65,13 +68,13 @@ get_kernel_launch, GPU_TYPE, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_GPU, ) torch.set_float32_matmul_precision("high") -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -139,8 +142,16 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() - b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) with config.patch( { @@ -163,8 +174,8 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -191,8 +202,8 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -258,9 +269,17 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() - b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) with config.patch( { @@ -283,9 +302,9 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -312,9 +331,9 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -359,15 +378,15 @@ def scaled_mm( # Create large matrices to ensure we use all possible sms size = 2560 - a = torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + a = torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) b = ( - torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) .transpose(0, 1) .contiguous() .transpose(0, 1) ) - scale_a = torch.tensor(1, dtype=torch.float32, device="cuda") - scale_b = torch.tensor(1, dtype=torch.float32, device="cuda") + scale_a = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) + scale_b = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) args = ( (a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn), scale_a, scale_b) @@ -946,9 +965,9 @@ def f(x, y): loss.backward() return loss - x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() - y = torch.randint(0, V, (B * T,)).cuda() + y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) import torch._inductor.utils as inductor_utils @@ -982,8 +1001,8 @@ def test_max_autotune_decompose_k(self, sizes, dtype, dynamic): M, N, K = sizes - a = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=dtype, device="cuda", requires_grad=True) + a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True) possible_splits = range(2, min(K // M, K // N) + 1) @@ -1080,10 +1099,10 @@ def f(a, b): return (a_in @ b).relu() a = torch.randn( - 32, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + 32, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) b = torch.randn( - 32768, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True + 32768, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) torch._dynamo.reset() @@ -1123,9 +1142,11 @@ def f(a, b): a_in = torch.cat([a for _ in range(256)], dim=0) return (a_in @ b).relu().sum() - a = torch.randn(8, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + a = torch.randn( + 8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True + ) b = torch.randn( - 64, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + 64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) torch._dynamo.reset() @@ -1165,15 +1186,15 @@ def f(a, b): a = a.transpose(0, 1) return a @ b - a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16) - b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16) + a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16) b = b[:, :1096] # Force only decomposeK choice with ( mock.patch( - "torch._inductor.kernel.mm.V.choices.get_base_mm_configs" + "torch._inductor.kernel.mm.V.choices.get_mm_configs" ) as base_mm_mock, mock.patch( "torch._inductor.kernel.mm.use_decompose_k_choice" @@ -1333,8 +1354,8 @@ def func_test1(x, y, z, m): if not TEST_WITH_ROCM: expected = """{ 'input_nodes':[ - "[[s77,s17],[s17,1],torch.float32,device(type='cuda',index=0),0]", - "[[s17,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], + "[[s77,s27],[s27,1],torch.float32,device(type='cuda',index=0),0]", + "[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94], 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True, @@ -1519,8 +1540,8 @@ def test_max_autotune_decompose_k_envvars( for M, N, K in shapes: get_k_splits.cache_clear() use_decompose_k_choice.cache_clear() - a = torch.randn(M, K, dtype=torch.float16, device="cuda") - b = torch.randn(K, N, dtype=torch.float16, device="cuda") + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE) with config.patch( { @@ -1557,13 +1578,13 @@ def f(a, b): M, N, K = (1024, 1024, 1024) - a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) with mock.patch( - "torch._inductor.kernel.mm.V.choices.get_config_heuristics" + "torch._inductor.template_registry.get_template_heuristic" ) as config_mock: - config_heuristics = CUDAConfigHeuristic() + config_heuristics = CUDAMMTemplateConfigHeuristic() # Traditionally, this would be set of all possible configs # We mock out the code path for the sake of the unit test @@ -1593,8 +1614,8 @@ def mm(x, y): for i in range(90, 100): torch._dynamo.reset() - a = torch.randn((i, 1), device="cuda", dtype=torch.float32) - b = torch.randn((1, i), device="cuda", dtype=torch.float32) + a = torch.randn((i, 1), device=GPU_TYPE, dtype=torch.float32) + b = torch.randn((1, i), device=GPU_TYPE, dtype=torch.float32) compiled_f = torch.compile(mm) out, code = run_and_get_code(compiled_f, a, b) @@ -2152,6 +2173,9 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): "del", num_deallocs, exactly=True ).run(code_str) + @skipIfXpu( + msg="Triton issue exposed by new driver, will be resolved after next triton update." + ) @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes @@ -2316,6 +2340,9 @@ def test_multiple_fusions(x): ).run(code[0]) self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05) + @skipIfXpu( + msg="Triton issue exposed by new driver, will be resolved after next triton update." + ) @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_multiple_inputs(self, sizes): M, K, N = sizes diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 3e23442b38ec..80372bca9fdc 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -68,9 +68,16 @@ def test_reorder_peak_memory(self): outp_corr = self.model(self.inputs) compiled_model = torch.compile(self.model) code = run_and_get_triton_code(compiled_model, self.inputs) + + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf1 = ") .check("buf0 = ") .check("buf2 = ") @@ -105,6 +112,12 @@ def reorder_with_only_lpmf( methods=[memory.topological_sort_lpmf], ) + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_lpmf ): @@ -113,7 +126,7 @@ def reorder_with_only_lpmf( code = run_and_get_triton_code(compiled_model, self.inputs) ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf1 = ") .check("buf0 = ") .check("buf2 = ") @@ -148,15 +161,22 @@ def reorder_with_only_bfs( methods=[memory.topological_sort_bfs], ) + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_bfs ): compiled_model = torch.compile(self.model) code = run_and_get_triton_code(compiled_model, self.inputs) + ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf0 = ") .check("buf1 = ") .check("buf2 = ") @@ -191,6 +211,12 @@ def reorder_with_only_dfs( methods=[memory.topological_sort_dfs], ) + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_dfs ): @@ -199,7 +225,7 @@ def reorder_with_only_dfs( code = run_and_get_triton_code(compiled_model, self.inputs) ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf0 = ") .check("buf2 = ") .check("buf4 = ") @@ -215,6 +241,7 @@ def reorder_with_only_dfs( @mock.patch.object(config, "allow_buffer_reuse", False) @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + @config.patch("test_configs.track_memory_lifecycle", "assert") def test_mutation_size_propogation(self): """ This tests correct size propogation in the case of mutations. @@ -262,6 +289,7 @@ def assign_memory_planning_info_for_scheduler_buffers_with_records( buffer_info[buf_name] = ( buf.mpi_buffer.size_alloc, buf.mpi_buffer.size_free, + buf.mpi_buffer.succ_nodes, ) # test example and checks @@ -281,11 +309,15 @@ def f(a, p): ): f_compiled = torch.compile(f) f_compiled(a, p) - for buf_name in ["buf0", "buf2", "buf4", "buf6"]: - self.assertEqual(buffer_info[buf_name], (2048, 0)) - for buf_name in ["buf1", "buf3", "buf5", "buf7"]: - self.assertEqual(buffer_info[buf_name], (0, 2048)) + pre_mutation = ["buf0", "buf2", "buf4", "buf6"] + post_mutation = ["buf1", "buf3", "buf5", "buf7"] + + for pre, post in zip(pre_mutation, post_mutation): + self.assertEqual(buffer_info[pre][0:2], (2048, 2048)) + self.assertEqual(buffer_info[post][0:2], (0, 0)) + # succ nodes should be forwarded to pre mutation buffer + self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2]) @unittest.skipIf( not torch.cuda.is_available() @@ -359,6 +391,49 @@ def f(x, y, z): .run(code) ) + @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + def test_multiple_mutations_of_buf(self): + @torch.compile() + def foo(inp, inp2): + inp = inp @ inp + inp = inp.view(2, -1, 256) + x = inp[0] + y = inp[1] + x, y = torch._foreach_add([x, y], 1.0) + out = x.sum() + out2 = y.sum(dim=-1) + + return out, out2, inp2 @ inp2 + + inp = torch.rand([256, 256], device=GPU_TYPE) + inp2 = torch.rand([256, 256], device=GPU_TYPE) + + def replace_foreach(gm): + nodes = gm.find_nodes( + op="call_function", target=torch.ops.aten._foreach_add.Scalar + ) + assert len(nodes) == 1 + node = nodes[0] + nodes[0].target = torch.ops.aten._foreach_add_.Scalar + for inp, out in zip(node.args[0], list(node.users.keys())): + out.replace_all_uses_with(inp) + gm.erase_node(out) + + with torch._inductor.config.patch( + { + "post_grad_custom_post_pass": replace_foreach, + "test_configs.track_memory_lifecycle": "assert", + "allow_buffer_reuse": False, + # make sure the mm is at the end so + # the earlier deallocation is not at the last step, + # which doesnt distinguish between returned tensors + # and which tensors are deallocated immediately prior + "reorder_for_peak_memory": False, + } + ): + code = run_and_get_triton_code(foo, inp, inp2) + FileCheck().check("allocated=['buf0']").run(code) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index d5f90e662697..1bcdeaa08e95 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -24,6 +24,14 @@ from torch.export import Dim +try: + from .test_aot_inductor import AOTIRunnerUtil +except ImportError: + from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library + AOTIRunnerUtil, + ) + + @requires_gpu() @config.patch(memory_planning=True) class TestMemoryPlanning(TestCase): @@ -76,13 +84,6 @@ def test_cpp_wrapper(self): @skipIfXpu(msg="aoti doesn't work on XPU") def test_aoti(self): - try: - from .test_aot_inductor import AOTIRunnerUtil - except ImportError: - from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library - AOTIRunnerUtil, - ) - f, args = self._generate(device=GPU_TYPE) dim0_x = Dim("dim0_x", min=1, max=2048) dynamic_shapes = ({0: dim0_x}, None, None) @@ -103,6 +104,54 @@ def test_aoti(self): ).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code) self.assertTrue(same(f(*args), result)) + @config.patch({"triton.autotune_at_compile_time": False}) + def test_unbacked_symint(self): + # when allocation's size has unbacked symints + # the unbacked symints are only available after computed + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Repro(torch.nn.Module): + def forward(self, x, y): + x = x + 1 + u0 = x.item() + torch._check(u0 >= 1) + s0 = y.size(0) + expr = u0 * s0 + sevens = torch.empty_strided( + size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device + ).fill_(7) + return sevens * 3 + + example_inputs = ( + torch.scalar_tensor(2, dtype=torch.int, device=self.device), + torch.ones(8, device=self.device), + ) + model = Repro().to(self.device) + result, code = run_and_get_cpp_code( + lambda: AOTIRunnerUtil.run(model, example_inputs) + ) + self.assertTrue(same(model(*example_inputs), result)) + + # check allocation is done after the unbacked symint is computed + FileCheck().check("auto u0 = u0_raw;").check( + "const int64_t int_array_2[] = {10L, 8L*u0, 32L};" + ).check("AtenTensorHandle pool0_handle;").check( + "aoti_torch_empty_strided(3, int_array_2, int_array_3" + ).run(code) + + # all AtenTensorHandle allocated using aoti_torch__alloc_from_pool are wrapped with RAIIAtenTensorHandle + # otherwise we'll have memory leak + FileCheck().check_count( + "aoti_torch__alloc_from_pool(pool1", 1, exactly=True + ).check_count("aoti_torch__alloc_from_pool(pool0", 1, exactly=True).run(code) + + FileCheck().check( + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_int32, 0, int_array_1, int_array_1, &tmp_tensor_handle_0));" # noqa: B950 + ).check("RAIIAtenTensorHandle(tmp_tensor_handle_0);").check( + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" # noqa: B950 + ).check("RAIIAtenTensorHandle(tmp_tensor_handle_1);").run(code) + if __name__ == "__main__": if HAS_GPU: diff --git a/test/inductor/test_move_constructors_to_cuda.py b/test/inductor/test_move_constructors_to_cuda.py index 3c3b8708c630..b174c79f1ebd 100644 --- a/test/inductor/test_move_constructors_to_cuda.py +++ b/test/inductor/test_move_constructors_to_cuda.py @@ -9,7 +9,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON requires_multigpu = functools.partial( @@ -112,5 +112,5 @@ def foo(x): if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_needs_exact_strides.py b/test/inductor/test_needs_exact_strides.py index 97c760c712bf..2d636db3f88f 100644 --- a/test/inductor/test_needs_exact_strides.py +++ b/test/inductor/test_needs_exact_strides.py @@ -13,7 +13,7 @@ IS_LINUX, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class TestNeedsExactStrides(InductorTestCase): @@ -98,5 +98,5 @@ def f(x, other): instantiate_parametrized_tests(TestNeedsExactStrides) if __name__ == "__main__": - if IS_LINUX and HAS_GPU: + if IS_LINUX and HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 798d86b0dd61..1e94ff1f4987 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -14,7 +14,7 @@ IS_LINUX, parametrize, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -297,5 +297,5 @@ def f(x, mask): instantiate_parametrized_tests(TestOnlineSoftmax) if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 458d64aa41d5..6f7eec601666 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -260,7 +260,7 @@ def test_downcast_div_mod(self): def fn(x, y): return x % y, x / y - x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2)) + x, y = (torch.rand([8], dtype=torch.float16, device=GPU_TYPE) for _ in range(2)) out, code = run_and_get_code(torch.compile(fn), x, y) @@ -271,7 +271,7 @@ def fn(x, y): @config.patch("test_configs.runtime_triton_dtype_assert", True) def test_constant(self): def fn(): - return (torch.full((2, 3), 3.1416, device="cuda", dtype=torch.float16),) + return (torch.full((2, 3), 3.1416, device=GPU_TYPE, dtype=torch.float16),) out, code = run_and_get_code(torch.compile(fn)) FileCheck().check("static_assert").check_same(".dtype").run(code[0]) @@ -284,7 +284,7 @@ def test_any(self): def fn(x): return torch.any(x) - x = torch.rand([40], device="cuda").to(torch.bool) + x = torch.rand([40], device=GPU_TYPE).to(torch.bool) out, code = run_and_get_code(torch.compile(fn), x) self.assertEqual(fn(x), out) @@ -293,7 +293,7 @@ def fn(x): def test_assoc_scan(self): from torch._higher_order_ops.associative_scan import associative_scan - x = torch.randn(10, device="cuda") + x = torch.randn(10, device=GPU_TYPE) # dtype check correctly associative_scan( lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise" diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index bcd1519c5935..d04bed2a9032 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -16,7 +16,7 @@ from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class PadMMTest(TestCase): @@ -541,5 +541,5 @@ def fn(x, y): if __name__ == "__main__": - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 15c1abdf32db..41944a916923 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -49,6 +49,18 @@ def geninp(): return input_dict +def get_padded_stride(shape, alignment_bytes, pad_output, itemsize): + align = alignment_bytes // itemsize + new_strides = [0 for _ in range(len(shape))] + new_strides[len(shape) - 1] = 1 + for i in range(len(shape) - 1, 0, -1): + stride = shape[i] * new_strides[i] + if pad_output and stride % align != 0: + stride = (stride + align - 1) // align * align + new_strides[i - 1] = stride + return tuple(new_strides) + + class LinearAndSoftmax(nn.Module): """ It's very common that a transformer model will do a matmul and then @@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: input_tensors = [get_input(shape, alignment_bytes) for _ in range(num_inputs)] config_patches = { - "compile_threads": 1, "comprehensive_padding": pad_output, "cpu_backend": "triton", - "disable_padding_cpu": False, - "implicit_fallbacks": False, - "inplace_buffers": False, "padding_alignment_bytes": alignment_bytes, - "pad_channels_last": True, "pad_outputs": True, "padding_stride_threshold": 0, - "triton.prefer_nd_tiling": True, - "triton.use_block_ptr": True, - "triton.codegen_upcast_to_fp32": False, - "unroll_reductions_threshold": 1, } with config.patch(config_patches): compiled = torch.compile(torch.cat) @@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: output_shape = (shape[0] * num_inputs, shape[1]) output_stride = input_tensors[0].stride() output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)" - self.assertTrue(any(output_line in line for line in code)) + self.assertTrue(output_line in code[0]) + + @parametrize( + "shape,alignment_bytes,pad_output", + [ + ((512, 1), 32, False), + ((512, 1), 32, True), + ((32, 30), 64, False), + ((32, 30), 64, True), + ((512, 100, 1), 32, False), + ((512, 100, 1), 32, True), + ((32, 50, 30), 64, False), + ((32, 50, 30), 64, True), + ], + ) + def test_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output): + """ + When only the outermost dim is dynamic shape, the output can still be padded up + based on padding configuration. + """ + num_inputs = 2 + input_tensors = [ + torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs) + ] + + config_patches = { + "comprehensive_padding": pad_output, + "cpu_backend": "triton", + "padding_alignment_bytes": alignment_bytes, + "pad_outputs": True, + "padding_stride_threshold": 0, + } + with config.patch(config_patches): + torch._dynamo.mark_dynamic(input_tensors[0], 0) + torch._dynamo.mark_dynamic(input_tensors[1], 0) + compiled = torch.compile(torch.add) + result, _ = run_and_get_code(compiled, *input_tensors) + + expected_stride = get_padded_stride( + result.shape, alignment_bytes, pad_output, result.dtype.itemsize + ) + self.assertEqual(result.stride(), expected_stride) + + @parametrize( + "shape,alignment_bytes,pad_output", + [ + ((500, 10, 1), 32, False), + ((500, 20, 1), 32, True), + ((30, 10, 20), 64, True), + ((30, 10, 20), 64, False), + ], + ) + def test_perm_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output): + """ + When only the outermost dim is dynamic shape, the output can still be padded up + based on padding configuration. Test when this occurs after a permute op. + """ + + def permute_contig(x): + return torch.transpose(x, 0, 2).contiguous() + + num_inputs = 1 + input_tensors = [ + torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs) + ] + + config_patches = { + "comprehensive_padding": pad_output, + "cpu_backend": "triton", + "padding_alignment_bytes": alignment_bytes, + "pad_outputs": True, + "padding_stride_threshold": 0, + "triton.use_block_ptr": True, + } + with config.patch(config_patches): + torch._dynamo.mark_dynamic(input_tensors[0], 2) + compiled = torch.compile(permute_contig) + result, _ = run_and_get_code(compiled, *input_tensors) + + expected_stride = get_padded_stride( + result.shape, alignment_bytes, pad_output, result.dtype.itemsize + ) + self.assertEqual(result.stride(), expected_stride) if __name__ == "__main__": diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index ac940f048009..0ffe7cb37deb 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1355,13 +1355,13 @@ def repl(inp, x1, x2): FileCheck().check_not("extern_kernels.addmm(").run(code[0]) def test_addmm_dtype_mismatch(self): - a = torch.nn.Linear(1024, 1024, bias=False).cuda() + a = torch.nn.Linear(1024, 1024, bias=False).to(GPU_TYPE) a = a.to(dtype=torch.float16) - w = torch.randn(1024, 1024, device="cuda") + w = torch.randn(1024, 1024, device=GPU_TYPE) def func(): - x = torch.ones(1024, 1024, device="cuda", dtype=torch.float16) + x = torch.ones(1024, 1024, device=GPU_TYPE, dtype=torch.float16) x = a(x) x = x + w return x diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 0ca54257250f..83cd236875f4 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -28,13 +28,16 @@ # performance for that setting. # # Defines all the kernels for tests -from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda +from torch.testing._internal.triton_utils import ( + HAS_CUDA_AND_TRITON, + requires_cuda_and_triton, +) # set so that metrics appear torch._logging.set_logs(inductor_metrics=True) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: import triton # @manual import triton.language as tl # @manual @@ -920,7 +923,7 @@ def f(a, b): inp = (T(10, 10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """42""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_training(self): @triton.jit def sin_kernel( @@ -964,7 +967,7 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_not_fusable_with_users(self): @triton.jit def _sin_kernel( @@ -1017,7 +1020,7 @@ def f(x): # (it will cost an extra kernel) self.assertExpectedInline(count_numel_train(f, x), """27""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_training_two_mutated_inputs(self): @torch.library.custom_op( "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"} @@ -1037,7 +1040,7 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel(f, x), """21""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_training(self): @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) def sin(x: torch.Tensor, result: torch.Tensor) -> None: @@ -1066,7 +1069,7 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x, Tensor(a!) out) -> ()") @@ -1096,7 +1099,7 @@ def f(x, out): self.assertExpectedInline(count_numel(f, x, out), """21""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_intermediate(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x, Tensor(a!) out) -> ()") @@ -1127,7 +1130,7 @@ def f(x, out): self.assertExpectedInline(count_numel(f, x, out), """21""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_two_mutated_inputs(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor") @@ -1159,7 +1162,7 @@ def f(): self.assertExpectedInline(count_numel(f), """39""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v1(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1171,7 +1174,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """50""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v2(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1184,7 +1187,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """70""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v3(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1197,7 +1200,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """80""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v4(self): def f(x: torch.Tensor, y: torch.Tensor): x_view = x.view(-1) @@ -1211,7 +1214,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """70""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v5(self): def f(x: torch.Tensor, y: torch.Tensor): x_view = x.view(-1) @@ -1225,7 +1228,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """80""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v6(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1292,5 +1295,5 @@ def f(a, b): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index 3d54c378de4a..f22f0374813b 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -12,7 +12,7 @@ from torch._inductor import config from torch.profiler import ProfilerActivity from torch.testing._internal.common_utils import TemporaryFileName -from torch.testing._internal.inductor_utils import HAS_CUDA, IS_BIG_GPU +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON, IS_BIG_GPU from torch.torch_version import TorchVersion from torch.utils._triton import has_triton @@ -313,5 +313,5 @@ def fn(x, y): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 2dd9ca44eb68..77e099cf0cb9 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -19,7 +19,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.virtualized import V from torch.testing._internal.inductor_utils import HAS_GPU -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -229,7 +229,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): if filepath: shutil.rmtree(filepath) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_cuda(self): self._test_triton_kernel_to_post_grad_tracing(device="cuda") @@ -237,7 +237,7 @@ def test_triton_kernel_to_post_grad_tracing_cuda(self): def test_triton_kernel_to_post_grad_tracing_cpu(self): self._test_triton_kernel_to_post_grad_tracing(device="cpu") - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): M = 8 N = 6 @@ -285,7 +285,7 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): if filepath: shutil.rmtree(filepath) - @requires_cuda + @requires_cuda_and_triton def _test_pt_tracing_combo_kernel(self, backend): """This test checks that generated provenance tracing artifact from triton combo kernel to post grad nodes""" a = torch.randn(10, 10, device="cuda") @@ -320,7 +320,7 @@ def _test_pt_tracing_combo_kernel(self, backend): expected_data = {"triton_poi_fused_0": ["relu", "sigmoid", "tanh"]} self._check_provenance_tracing_artifact(filepath, expected_data) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_combo_kernel(self): self._test_pt_tracing_combo_kernel(backend="inductor") self._test_pt_tracing_combo_kernel(backend="aot_inductor") @@ -437,7 +437,7 @@ def get_node_with_target(self, gm, target): """ return next(iter([node for node in gm.graph.nodes if node.target == target])) - @requires_cuda # test only works for cuda pattern matcher + @requires_cuda_and_triton # test only works for cuda pattern matcher def test_pattern_matcher_transfer_meta(self): """ Test that stack trace is transfered when node is decomposed in post_grad_passes diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index d2cd77fe5cd2..e5838f2d4d32 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -163,7 +163,6 @@ def foo(a, b): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @patches - @skipIfXpu(msg="XPU has not supported _int_mm yet") def test__int_mm(self): @torch.compile def foo(a, b): diff --git a/test/inductor/test_smoke.py b/test/inductor/test_smoke.py index 895e8ba16ab0..2a247fddbe76 100644 --- a/test/inductor/test_smoke.py +++ b/test/inductor/test_smoke.py @@ -6,7 +6,11 @@ import torch._logging from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CUDA_AND_TRITON, + HAS_GPU, +) class MLP(torch.nn.Module): @@ -62,5 +66,5 @@ def test_compile_invalid_options(self): from torch._inductor.test_case import run_tests if IS_LINUX and HAS_GPU: - if (not HAS_CUDA) or torch.cuda.get_device_properties(0).major <= 5: + if (not HAS_CUDA_AND_TRITON) or torch.cuda.get_device_properties(0).major <= 5: run_tests() diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index 354552c497d9..0ec7825df001 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -5,7 +5,7 @@ from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -248,7 +248,7 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -291,7 +291,7 @@ def test_split_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -317,7 +317,7 @@ def test_split_cat_post_grad_singular(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -342,7 +342,7 @@ def test_select_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 2ce294ed0ff5..654bfd269f76 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -13,10 +13,10 @@ from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfRocm -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton -@requires_cuda +@requires_cuda_and_triton class TestStaticCudaLauncher(TestCase): def setUp(self): super().setUp() @@ -396,7 +396,7 @@ def kernel_many_args(out_tensor, {decl}): self.assertEqual(buf0, buf1) -@requires_cuda +@requires_cuda_and_triton @torch._inductor.config.patch( {"use_static_cuda_launcher": True, "strict_static_cuda_launcher": True} ) diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index 40695e6affb1..201590d02ed5 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,6 +1,5 @@ # Owner(s): ["module: functorch"] import json -import tempfile import zipfile from pathlib import Path @@ -11,8 +10,10 @@ import torch._inductor.decomposition from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing from torch._inductor import aot_compile, ir +from torch._inductor.codecache import WritableTempFile from torch._inductor.package import package_aoti from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu from torch.testing._internal.torchbind_impls import ( _empty_tensor_queue, @@ -158,6 +159,7 @@ def test_torchbind_hop_schema_no_output(self): "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method, Tensor _1) -> NoneType _0", ) + @skipIfWindows(msg="AOTI is not fully support on Windows") def test_torchbind_aot_compile(self): ep, inputs, _, _ = self.get_exported_model() aoti_files = aot_compile( @@ -280,7 +282,7 @@ def test_torchbind_aot_compile(self): ) # Test that the files are packaged - with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + with WritableTempFile(suffix=".pt2") as f: package_path = package_aoti(f.name, aoti_files) with zipfile.ZipFile(package_path, "r") as zip_ref: @@ -302,6 +304,7 @@ def test_torchbind_aoti(self): self.assertEqual(result, orig_res) @torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True) + @skipIfWindows(msg="AOTI is not fully support on Windows") def test_torchbind_aot_compile_constant_folding(self): ep, inputs, orig_res, _ = self.get_exported_model() pt2_path = torch._inductor.aoti_compile_and_package(ep) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 2ef129d5fe10..0e76ca489284 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -69,6 +69,7 @@ from torch.nn import functional as F from torch.testing import FileCheck, make_tensor from torch.testing._internal.common_cuda import ( + IS_SM90, PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, @@ -137,7 +138,7 @@ skipCPUIf, skipCUDAIf, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton _T = TypeVar("_T") @@ -9569,7 +9570,6 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 0) - @xfail_if_mps_unimplemented def test_avg_pool3d_backward(self): def fn(a, b): return aten.avg_pool3d_backward( @@ -9591,7 +9591,6 @@ def fn(a, b): ], ) - @xfail_if_mps_unimplemented @skip_if_halide # compiles for 5+ minutes def test_avg_pool3d_backward2(self): def fn(a, b): @@ -9614,7 +9613,6 @@ def fn(a, b): ], ) - @xfail_if_mps_unimplemented def test_avg_pool3d_backward3(self): def fn(a, b): return aten.avg_pool3d_backward( @@ -9638,7 +9636,6 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 1) - @xfail_if_mps_unimplemented def test_avg_pool3d_backward4(self): def fn(a, b): return aten.avg_pool3d_backward( @@ -9840,6 +9837,7 @@ def fn(x): ], ) + @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin2(self): def fn(x): return ( @@ -9851,6 +9849,7 @@ def fn(x): self.common(fn, (torch.randn([144, 144]),)) + @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin_with_duplicates(self): def fn(x): return ( @@ -9872,6 +9871,7 @@ def fn(x): t1 = torch.randint(8, size=(1028, 1028)) self.common(fn, (t1,)) + @skipIfXpu(msg="# Incorrect XPU reference ") @xfail_if_mps # eager nan is wrong, see https://github.com/pytorch/pytorch/issues/130295 @skip_if_halide # nan behavior def test_argmax_argmin_with_nan(self): @@ -9972,6 +9972,7 @@ def shrink_rank(x, rank): [rank4_inps, rank3_inps, rank5_inps], ) + @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin3(self): def fn(x): return ( @@ -13154,7 +13155,7 @@ def f(x): "assert_size_stride(buf2, (16, 32), (32, 1)" ).run(code) - @requires_cuda + @requires_cuda_and_triton @config.patch(use_fast_math=True) def test_prepare_softmax_with_fast_math(self): """ @@ -13563,7 +13564,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar ) @config.patch("min_num_split", 256) - @xfail_if_mps # TypeError: cannot determine truth value of Relational def test_split_reduction_dynamic_shape(self): from torch._dynamo.decorators import mark_dynamic @@ -13654,6 +13654,71 @@ def forward(self, x): inputs = (torch.randn(4, device=self.device),) self.common(Model(), inputs) + @requires_cuda_and_triton + @parametrize("use_cat", [True, False]) + def test_copy_non_blocking_is_pinned(self, use_cat): + def f(a_list): + a_cpu_list = [] + a_to_cpu_event_list = [] + + for a in a_list: + a_cpu = a.to(device="cpu", non_blocking=True) + a_to_cpu_event = torch.Event() + a_to_cpu_event.record() + a_cpu_list.append(a_cpu) + a_to_cpu_event_list.append(a_to_cpu_event) + + for e in a_to_cpu_event_list: + e.synchronize() + + if use_cat: + return torch.cat(a_cpu_list) + else: + return a_cpu_list + + f_compiled = torch.compile(f) + inputs = [ + torch.rand(1000, dtype=torch.float16, device=GPU_TYPE) for _ in range(100) + ] + outputs = f(inputs) + + with torch.profiler.profile( + activities=[ + getattr(torch.profiler.ProfilerActivity, GPU_TYPE.upper()), + ], + ) as p: + outputs_compiled = f_compiled(inputs) + + # outputs_compiled, (code,) = run_and_get_code(f_compiled, inputs) + # self.assertTrue("pinned" in code) + + self.assertEqual(outputs, outputs_compiled) + profile_output = str(p.key_averages()) + print(profile_output) + self.assertFalse("Pageable" in profile_output) + + @unittest.skipIf( + config.cpp_wrapper, + "cpp_wrapper samples will lead to invalid indexing", + ) + def test_inductor_triton_bucketize_respects_masking(self): + def fn(inp, repeats, output_size): + # return torch.repeat_interleave(inp, repeats, dim=0, output_size=output_size) + idx = torch.searchsorted( + repeats.cumsum(0), + torch.arange(0, output_size, device=repeats.device), + right=True, + ) + return torch.index_select(inp, 0, idx) + + inp = torch.arange(0, 4, device=self.device) + repeats = torch.tensor([1, 2, 3, 4], device=self.device) + output_size = repeats.sum().item() + args = (inp, repeats, output_size) + self.assertEqual(fn(*args), torch.compile(fn)(*args)) + + # end of class CommonTemplate - add new tests here + @dataclasses.dataclass class TestFailure: @@ -13697,6 +13762,25 @@ def new_test(self, value=value): other_cls.is_dtype_supported = my_cls.is_dtype_supported +def add_test_failures( + test_failures: dict[str, TestFailure], added_test_failures: dict[str, TestFailure] +): + """ + In-place modifies the given dictionary of `test_failures` to add the + contents of `added_test_failures` by unioning the test_failure.suffixes, and + or-ing the the is_skip value. + """ + for name, new_failure in added_test_failures.items(): + if name in test_failures: + orig_failure = test_failures[name] + orig_failure.suffixes = tuple( + set(orig_failure.suffixes).union(set(new_failure.suffixes)) + ) + orig_failure.is_skip = orig_failure.is_skip or new_failure.is_skip + else: + test_failures[name] = new_failure + + if RUN_CPU: class SweepInputsCpuTest(SweepInputs2, TestCase): @@ -14009,7 +14093,7 @@ def forward( torch._inductor.aot_compile(traced, inputs) @skipCUDAIf(not SM90OrLater, "Requires sm90") - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(TEST_WITH_ROCM, "no grouped_mm support") @config.patch(implicit_fallbacks=True) def test_grouped_mm(self): @@ -14523,11 +14607,11 @@ def fn(x): else: self.assertTrue("Graph fragment" in code) self.assertTrue( - "%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default]" + f'%sin : Tensor "f32[4, 4][4, 1]{GPU_TYPE}:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default]' in code ) self.assertTrue( - "%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default]" + f'%relu : Tensor "f32[4, 4][4, 1]{GPU_TYPE}:0"[num_users=1] = call_function[target=torch.ops.aten.relu.default]' in code ) @@ -14982,301 +15066,60 @@ def fn(x): "'XBLOCK': 'constexpr'" ).run(code[0]) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - self.assertEqual(eager_out, compiled_out) - - _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").check( - "(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)" - ).check("recursively_apply_fns = runner.recursively_apply_fns").run( - code[0] - ) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_foreach_op(self): - def fn(a0, a1): - c = torch._foreach_abs([a0, a1]) - return torch.mul(c[0], a0) - - compiled_fn = torch.compile(fn) - - a0 = torch.randn(2, 3, device=self.device) - a1 = torch.randn(2, 3, device=self.device) - eager_out = fn(a0, a1) - compiled_out = compiled_fn(a0, a1) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_multiple_functions(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - def g(x): - return x + 1 - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = g(f(x, y)) + @unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support") + def test_respect_scaled_grouped_mm_layout_tag(self): + # scaled_grouped_mm needs `mat2` to be column-major + M, K, N = 128, 64, 32 # K and N must be divisible by 16 + num_groups = 2 + E = num_groups # B_t batch size must match number of groups + group_size = M // num_groups - f_compiled = torch.compile(f) - g_compiled = torch.compile(g) - compiled_out = g_compiled(f_compiled(x_cloned, y_cloned)) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_condition_op(self): - def f(p, b): - def true_fn(x): - return torch.cos(x) - - def false_fn(x): - return torch.sin(x) - - return torch.cond(p, true_fn, false_fn, [b]) - - compiled_f = torch.compile(f) - - # static shape - p = torch.tensor([True], device=self.device) - a = torch.ones([2, 3], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - # dynamic shape with backed symint - p = torch.tensor([True], device=self.device) - a = torch.ones([4, 5], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_unbacked_symint_multi_output_layout(self): - def f(p, size_tensor): - size_val = size_tensor.item() - b = torch.ones([size_val, 3], device=GPU_TYPE) - - def true_fn(x): - return torch.cos(x), torch.cos(x) + 1 - - def false_fn(x): - return torch.sin(x), torch.sin(x) + 1 - - cond_out = torch.cond(p, true_fn, false_fn, [b]) - return cond_out[0] + cond_out[1] - - compiled_f = torch.compile(f) - p = torch.tensor([True], device=GPU_TYPE) - size_tensor = torch.tensor(2, device=GPU_TYPE) - eager_out = f(p, size_tensor) - compiled_out = compiled_f(p, size_tensor) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) + A = torch.randn( + M, K, dtype=torch.bfloat16, device=GPU_TYPE + ) # Row-major by default - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - x, y = ( - torch.ones(4, 4, device=self.device), - torch.randn(4, 4, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_cat_backward(self): - def f(x, w): - y = torch.cat((x, x), dim=0) - z = y @ w - return z @ z.T - - compiled_f = torch.compile(f) - - for shape in (2, 3): - torch.manual_seed(42) - eager_x = torch.randn(shape, 2, device=self.device) - eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) - torch.manual_seed(42) - compiled_x = torch.randn(shape, 2, device=self.device) - compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) - - f(eager_x, eager_w).sum().backward() - compiled_f(compiled_x, compiled_w).sum().backward() - self.assertEqual(eager_w.grad, compiled_w.grad) - - @dynamo_config.patch("capture_dynamic_output_shape_ops", True) - @config.patch(implicit_fallbacks=True) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_nested_indirect_indexing(self): - def nested(x, repeats): - rank = torch.arange(repeats.numel(), device=x.device) - index = rank.repeat_interleave(repeats, dim=0) - return torch.index_select(x, index=index, dim=0) + # Create B_t with proper column-major layout + B_t_transposed = torch.randn( + E, N, K, dtype=torch.bfloat16, device=GPU_TYPE + ).contiguous() + B_t = B_t_transposed.transpose(-2, -1) # (E, K, N) + B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1) - example_inputs = ( - torch.randn((32, 64), device=self.device), - repeats := torch.tensor([5, 10, 15], device=self.device), - ) - torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + # Verify column-major layout + def _is_column_major(x: torch.Tensor) -> bool: + """Check if tensor is column-major (stride(-2) == 1 and stride(-1) > 1)""" + assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" + return x.stride(-2) == 1 and x.stride(-1) > 1 - nested_opt = torch.compile(nested, backend="inductor") + self.assertTrue(_is_column_major(B_t)) - expect = nested(*example_inputs) - actual = nested_opt(*example_inputs) - self.assertEqual(expect, actual) + offs = torch.tensor([group_size, M], dtype=torch.int32, device=GPU_TYPE) + out_dtype = torch.bfloat16 - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_mutation_index(self): - x = torch.zeros(7, device=GPU_TYPE) - - def fn(n, a): - a[n] = -1 - return a - - opt_fn = torch.compile(fn, fullgraph=True) - - for n in range(2, x.shape[0]): - opt_fn(n, x) - self.assertEqual(x[n], -1) - - # Negative index triggers new compilation. - opt_fn(-x.shape[0], x) - - self.assertEqual(x[0], -1) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_unbacked_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y) - eager_out = f(x, y) - self.assertEqual(compiled_out, eager_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_dynamic_scalar_inputs(self): - def f(x, y, integer): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - z += integer - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y, 5) - self.assertEqual(compiled_out, f(x, y, 5)) - - compiled_out = f_compiled(x, y, 6) - self.assertEqual(compiled_out, f(x, y, 6)) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_item(self): - def f(x): - y = x + 1 - scalar = y.item() - return x + y + scalar - - compiled_f = torch.compile(f) - compiled_out = f(torch.tensor(1, device=GPU_TYPE)) - self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE))) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_buffer_reuse(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x1 + y1 + x @ y - u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3 - u_cpu = u.cpu() + 2 - return z + u_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_fused_scheduler_node(self): - def foo(x): - x = x * 20 - x_alias = x[0] - y = x * 10 - y_alias = y[0] - torch._dynamo.graph_break() - ind = torch.tensor(4, device=GPU_TYPE) - x_alias2 = x[ind:] - y_alias2 = y[ind:] - return x, x_alias, x_alias2, y_alias, y_alias2 - - foo = torch.compile(foo) - x = torch.rand([20, 20], device=GPU_TYPE) - _, code = run_and_get_code(foo, x) + @torch.compile + def fn(): + A_scales = torch.ones(M, dtype=torch.float32, device=GPU_TYPE) + A_scaled = A.to(torch.float32) * A_scales.unsqueeze(-1) + A_fp8_row_major = A_scaled.to(torch.float8_e4m3fn) + + B_t_scales = torch.ones(E, N, dtype=torch.float32, device=GPU_TYPE) + B_t_scaled = B_t.to(torch.float32) * B_t_scales.unsqueeze(1) + B_t_fp8_col_major = B_t_scaled.to(torch.float8_e4m3fn) + + A_scales_reciprocal = A_scales.reciprocal() + B_t_scales_reciprocal = B_t_scales.reciprocal() + + return torch.ops.aten._scaled_grouped_mm( + A_fp8_row_major, + B_t_fp8_col_major, + A_scales_reciprocal, + B_t_scales_reciprocal, + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").run(code[0]) + fn() class RNNTest(TestCase): device_type = GPU_TYPE diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 6a7d40b6b7ca..cdf76772b936 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -25,6 +25,7 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + add_test_failures, CommonTemplate, copy_tests, run_and_get_cpp_code, @@ -382,9 +383,10 @@ def run(*ex, **kwargs): # Refinement means we don't actually generate dynamic shapes (but only on # cpu apparently?!) "test_nonzero_unbacked_refinement_dynamic_shapes": TestFailure(("cpu",)), - **dynamic_shapes_test_failures, } +add_test_failures(test_failures, dynamic_shapes_test_failures) + if not TEST_WITH_ROCM: test_failures.update( { diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index b75907894f63..8b6d625a5447 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -36,6 +36,7 @@ GPU_TYPE, HAS_CPU, HAS_GPU, + HAS_MPS, patch_inductor_backend, ) @@ -59,9 +60,34 @@ "test_kwargs_dynamic_shapes": TestFailure(("cpu",)), # calling div on only symint args "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( - ("cpu", "cuda", "xpu") + ("cpu", "cuda", "xpu", "mps") + ), + "test_argmax_argmin_with_duplicates_dynamic_shapes": TestFailure(("mps",)), + "test_batch_norm_2d_2_dynamic_shapes": TestFailure(("mps",)), + "test_buffer_batch_norm_dynamic_shapes": TestFailure(("mps",)), + "test_convolution4_dynamic_shapes": TestFailure(("mps",)), + "test_index_propagation_abs_dynamic_shapes": TestFailure(("mps",)), + "test_index_propagation_floordiv_dynamic_shapes": TestFailure(("mps",)), + "test_index_propagation_remainder_dynamic_shapes": TestFailure(("mps",)), + "test_multilayer_var_dynamic_shapes": TestFailure(("mps",)), + "test_multilayer_var_lowp_dynamic_shapes": TestFailure(("mps",)), + "test_reduction2_dynamic_shapes": TestFailure(("mps",)), + "test_reduction3_dynamic_shapes": TestFailure(("mps",)), + "test_reduction5_dynamic_shapes": TestFailure(("mps",)), + "test_reflection_pad2d_dynamic_shapes": TestFailure(("mps",)), + "test_require_stride_expanded_dynamic_shapes": TestFailure(("mps",)), + "test_roll_dynamic_shapes": TestFailure(("mps",)), + "test_std_dynamic_shapes": TestFailure(("mps",)), + "test_var_correction_dynamic_shapes": TestFailure(("mps",)), + "test_var_mean_div_by_dynamic_shapes": TestFailure(("mps",)), + "test_var_mean_tile_reduction_False_dynamic_shapes": TestFailure(("mps",)), + "test_var_mean_tile_reduction_True_dynamic_shapes": TestFailure(("mps",)), + "test_vectorized_ops_masked_var_novec_dynamic_shapes": TestFailure(("mps",)), + "test_reflection_pad2d_backward_dynamic_shapes": TestFailure( + ("mps",), is_skip=True ), } + if not torch._inductor.config.cpp_wrapper: test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure( ("cuda",) @@ -106,7 +132,7 @@ class DynamicShapesCpuTests(TestCase): copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures) -if HAS_GPU and not TEST_WITH_ASAN: +if (HAS_GPU or HAS_MPS) and not TEST_WITH_ASAN: class DynamicShapesGPUTests(TestCase): common = check_model_gpu @@ -121,7 +147,7 @@ class TestInductorDynamic(TestCase): compile_fn = partial(torch.compile, dynamic=True) def setUp(self): - # HAS_CUDA also checks compute capability to skip tests + # HAS_CUDA_AND_TRITON also checks compute capability to skip tests # on older devices if not HAS_GPU: self.skipTest("Triton not available") @@ -1133,5 +1159,5 @@ def fn(a, descending): from torch._inductor.test_case import run_tests # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 - if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: + if (HAS_CPU or HAS_GPU or HAS_MPS) and not TEST_WITH_ASAN: run_tests(needs="filelock") diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 33bcb8cd7d1a..c3a6662f1bf3 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -46,9 +46,9 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, has_triton, - HAS_XPU, + HAS_XPU_AND_TRITON, maybe_skip_size_asserts, ) from torch.utils._dtype_abbrs import dtype_abbrs @@ -286,8 +286,6 @@ def format_op(op): "torch.ops.aten._efficient_attention_forward": {f16, f32}, "to_sparse": {f32, f64}, "linalg.eig": {f32, f64}, - # Double and complex datatype matmul is not supported in oneDNN - "byte": {f16, f32}, ("linalg.pinv", "singular"): {f64}, # could not create a primitive "addmv": {f64}, @@ -295,9 +293,17 @@ def format_op(op): # a deconvolution forward propagation primitive "nn.functional.conv_transpose2d": {f32, f64}, "nn.functional.conv_transpose3d": {f32, f64}, - # not implemented for 'Half' - "sort": {b8}, - "argsort": {b8}, + # [Begin] Incorrect XPU reference due to new driver. + "masked.prod": {b8, i32, i64}, + "masked.amin": {i64}, + "masked.amax": {i64}, + "amax": {i64}, + "amin": {i64}, + "std": {f64}, + "var": {f64}, + "std_mean": {f64}, + "var_mean": {f64}, + # [End] } @@ -677,6 +683,14 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("nn.functional.unfold", f16): { "reference_in_float": True, }, + # Reference crash on Intel LTS2 driver. + ("nn.functional.interpolate.trilinear", f32): { + "check_gradient": False, + }, + # Reference crash on Intel LTS2 driver. + ("nn.functional.interpolate.trilinear", f64): { + "check_gradient": False, + }, } if TEST_WITH_ROCM: inductor_override_kwargs["cuda"].update( @@ -1120,8 +1134,10 @@ def tearDown(self): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently - @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") - @skipXPUIf(not HAS_XPU, "Skipped! Supported XPU compiler not found") + @skipCUDAIf(not HAS_CUDA_AND_TRITON, "Skipped! Triton not found") + @skipXPUIf( + not HAS_XPU_AND_TRITON, "Skipped! Supported XPU compiler and Triton not found" + ) @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @skipIfTorchDynamo("Test uses dynamo already") diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 82bfdd6290bb..034f83096c1a 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -26,7 +26,7 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_GPU, requires_gpu, skip_windows_ci, @@ -746,31 +746,6 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) - def test_2d_reduction_no_x_dim(self): - """ - Tests a 2D reduction without an "x" dimension. - """ - # We need a size to get no x dim. - view = self._discontiguous_tensor((2, 346), self.device) - - # Expect 1 block pointer for the input. - result, (code,) = self._run_and_compare( - torch.prod, - view, - expected_num_block_pointers=1, - expected_num_triton_kernels=1, - config_patches=tiled_reduction_config, - ) - - # Check that there's no X dimension in the signature. - (signature_line,) = ( - line for line in code.splitlines() if line.startswith("def triton") - ) - self.assertNotIn("BLOCK", signature_line) - - # Check for 2 reduction dimensions in the body. - self._assert_reduction_ndims(code, 2) - @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", [ @@ -1188,6 +1163,9 @@ def foo(x, y, z): # } # This is now fixed by ensuring that that wild symbols only match integers @xfail_if_use_tensor_descriptor + @skipIfXpu( + msg="Triton issue exposed by new driver, will be resolved after next triton update." + ) def test_ensure_integral_dims_and_strides(self): def model(data, *args): return torch.nn.functional.unfold(data, *args) @@ -1346,7 +1324,7 @@ class TritonBlockPointerTestGPU(BlockDescriptorTestBase): @unittest.skipIf( - not (HAS_CUDA and torch.cuda.get_device_capability()[0] >= 9), + not (HAS_CUDA_AND_TRITON and torch.cuda.get_device_capability()[0] >= 9), "Requires Triton CUDA backend and CUDA compute capability >= 9.0", ) @config.patch({"triton.use_tensor_descriptor": True, "assume_aligned_inputs": True}) diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index a9f898a36af5..4c2a04678b88 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -257,7 +257,7 @@ def grid(meta): def fn(x): return triton_sqr(x) - x = torch.randn(32, device="cuda") + x = torch.randn(32, device=GPU_TYPE) ref = fn(x) res = torch.compile(fn)(x) self.assertEqual(ref, res) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 689cf218b2bc..fc9f92477c79 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -31,7 +31,12 @@ skipIfWindows, skipIfXpu, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CUDA_AND_TRITON, + HAS_GPU, + HAS_XPU_AND_TRITON, +) from torch.testing._internal.logging_utils import log_settings, logs_to_string # Defines all the kernels for tests @@ -47,7 +52,7 @@ import triton from triton import language as tl - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: try: from triton.language.extra.libdevice import ( # @manual fast_dividef, @@ -58,7 +63,7 @@ fast_dividef, fast_dividef as my_fast_dividef, ) - elif HAS_XPU: + elif HAS_XPU_AND_TRITON: from triton.language.extra.intel.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, @@ -2195,7 +2200,7 @@ def f(x): self.assertEqual(compiled_out, eager_out) # TODO enable this test case on XPU. - @requires_cuda + @requires_cuda_and_triton @parametrize("cfg", ["normal", "cpp_wrapper"]) def test_triton_kernel_dtype_view(self, cfg): # https://github.com/pytorch/pytorch/issues/136159 @@ -3578,6 +3583,40 @@ def f(x, y): self.assertNotIn(libname, code) self.assertNotIn(opname, code) + @requires_gpu + def test_subclass(self): + libname = "my_cool_namespace" + opname = "my_triton_operator" + + @torch.library.triton_op(f"{libname}::{opname}", mutates_args={}) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = output.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) + + return output + + def f(x, y): + return add(x, y) + + x0 = torch.randn(3, device=GPU_TYPE) + y0 = torch.randn(3, device=GPU_TYPE) + x1 = torch.randn(3, device=GPU_TYPE) + y1 = torch.randn(3, device=GPU_TYPE) + + from torch.testing._internal.two_tensor import TwoTensor + + x = TwoTensor(x0, x1) + y = TwoTensor(y0, y1) + + out = torch.compile(f, fullgraph=True)(x, y) + expected = f(x, y) + self.assertEqual(out.a, expected.a) + self.assertEqual(out.b, expected.b) + @requires_gpu @dynamo_config.patch("recompile_limit", 1) def test_triton_dynamic_grid_no_recompile(self): diff --git a/test/inductor/test_xpu_basic.py b/test/inductor/test_xpu_basic.py index 0572eccb77fd..4501b8264c5f 100644 --- a/test/inductor/test_xpu_basic.py +++ b/test/inductor/test_xpu_basic.py @@ -53,7 +53,7 @@ def fn(a, b): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_XPU + from torch.testing._internal.inductor_utils import HAS_XPU_AND_TRITON - if HAS_XPU: + if HAS_XPU_AND_TRITON: run_tests(needs="filelock") diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index b84bc96519cb..781080f5deb6 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -131,6 +131,164 @@ def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) self.assertEqual(py_out, jit_out) + def test_torch_check(self): + """Test torch._check functionality with flexible argument handling""" + + def test_check_basic(x): + torch._check(x.sum().item() > -1000) + return x + + def test_check_with_message(x): + torch._check(x.sum().item() > -1000, "Tensor sum must be reasonable") + return x + + def test_check_with_kwarg_message(x): + torch._check( + x.sum().item() > -1000, message="Tensor sum must be reasonable" + ) + return x + + def test_check_cond_kwarg(x): + torch._check(cond=x.sum().item() > -1000) + return x + + def test_check_both_kwargs(x): + torch._check(cond=x.sum().item() > -1000, message="Both as kwargs") + return x + + def test_check_kwargs_reversed(x): + torch._check(message="Reversed order", cond=x.sum().item() > -1000) + return x + + def test_check_in_loop(x): + sizes = torch.jit.annotate(List[int], x.tolist()) + for s in sizes: + torch._check(s > -100) + return x + + test_tensor = torch.tensor([1, 2, 3]) + + # Test all variations + self.checkScript(test_check_basic, (test_tensor,)) + self.checkScript(test_check_with_message, (test_tensor,)) + self.checkScript(test_check_with_kwarg_message, (test_tensor,)) + self.checkScript(test_check_cond_kwarg, (test_tensor,)) + self.checkScript(test_check_both_kwargs, (test_tensor,)) + self.checkScript(test_check_kwargs_reversed, (test_tensor,)) + self.checkScript(test_check_in_loop, (test_tensor,)) + + # Test that the compiled functions work correctly + scripted_basic = torch.jit.script(test_check_basic) + scripted_with_message = torch.jit.script(test_check_with_message) + scripted_with_kwarg = torch.jit.script(test_check_with_kwarg_message) + scripted_cond_kwarg = torch.jit.script(test_check_cond_kwarg) + scripted_both_kwargs = torch.jit.script(test_check_both_kwargs) + scripted_kwargs_reversed = torch.jit.script(test_check_kwargs_reversed) + scripted_in_loop = torch.jit.script(test_check_in_loop) + + # These should all succeed without throwing + result1 = scripted_basic(test_tensor) + result2 = scripted_with_message(test_tensor) + result3 = scripted_with_kwarg(test_tensor) + result4 = scripted_cond_kwarg(test_tensor) + result5 = scripted_both_kwargs(test_tensor) + result6 = scripted_kwargs_reversed(test_tensor) + result7 = scripted_in_loop(test_tensor) + + # Results should be the same as input + for result in [result1, result2, result3, result4, result5, result6, result7]: + self.assertEqual(result, test_tensor) + + # Check that the message constants are present in the graphs + FileCheck().check("Tensor sum must be reasonable").run( + scripted_with_message.graph + ) + FileCheck().check("Tensor sum must be reasonable").run( + scripted_with_kwarg.graph + ) + FileCheck().check("Both as kwargs").run(scripted_both_kwargs.graph) + FileCheck().check("Reversed order").run(scripted_kwargs_reversed.graph) + + # Verify the graphs contain some computation (not just empty) + basic_graph_str = str(scripted_basic.graph) + self.assertTrue( + len(basic_graph_str) > 100, "Basic graph should contain some computation" + ) + + # Verify the loop case contains a loop + FileCheck().check("prim::Loop").run(scripted_in_loop.graph) + + for scripted_func in [ + scripted_basic, + scripted_with_message, + scripted_with_kwarg, + scripted_cond_kwarg, + scripted_both_kwargs, + scripted_kwargs_reversed, + ]: + FileCheck().check("prim::If").check("prim::RaiseException").run( + scripted_func.graph + ) + + def test_torch_check_invalid_args(self): + """Test torch._check with invalid arguments""" + + # Test too many arguments + with self.assertRaisesRegex( + RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" + ): + + @torch.jit.script + def too_many_args(x): + torch._check(True, "msg", "extra") + return x + + # Test invalid keyword argument + with self.assertRaisesRegex(RuntimeError, "unexpected keyword argument"): + + @torch.jit.script + def invalid_kwarg(x): + torch._check(True, invalid_arg="msg") + return x + + # Test duplicate cond argument (positional + keyword) + with self.assertRaisesRegex( + RuntimeError, "multiple values for argument 'cond'" + ): + + @torch.jit.script + def duplicate_cond(x): + torch._check(True, cond=False) + return x + + # Test missing required cond argument + with self.assertRaisesRegex(RuntimeError, "missing required argument 'cond'"): + + @torch.jit.script + def missing_cond(x): + torch._check(message="msg only") + return x + + # Test no arguments at all + with self.assertRaisesRegex( + RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" + ): + + @torch.jit.script + def no_args(x): + torch._check() + return x + + # Test too many total arguments (positional + keyword) + with self.assertRaisesRegex( + RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" + ): + + @torch.jit.script + def too_many_total_args(x): + torch._check(True, "msg", cond=False) + return x + class TestTensorBuiltins(JitTestCase): def test_tensor_properties(self): diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index d595c793e79b..d6addfddca1a 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -4,6 +4,7 @@ import os import re import sys +import threading import types import typing import typing_extensions @@ -773,6 +774,25 @@ def forward(self, x): mod.foo = None self.checkModule(mod, (torch.rand(2, 2),)) + def test_thread_safe_error_stacks(self): + # prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack] + callstacks = [] + + def callstack_creator(): + factory = torch._C._jit_tree_views.SourceRangeFactory( + "source code", "a.py", 1, 0 + ) + x = torch._C.CallStack("a", factory.make_range(1, 0, 1)) + callstacks.append(x) + del x + + t = threading.Thread(target=callstack_creator) + t.start() + t.join() + del t + del callstacks[0] + self.assertTrue(len(callstacks) == 0) + def test_override_instance_method_ignore(self): class M(torch.nn.Module): @torch.jit.ignore diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index ad715598e580..64e6349e0364 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -2842,6 +2842,7 @@ def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): @parametrize_test("strided", [False, True]) # Test with both contiguous and non-contiguous inputs. @parametrize_test("contiguous", [False, True]) + @expectedFailureMPS # No double support def test_conv_backend( self, device, @@ -4057,13 +4058,22 @@ def test_conv3d_64bit_indexing(self, device): @largeTensorTest("20GB") @largeTensorTest("64GB", "cpu") def test_depthwise_conv_64bit_indexing(self, device): - x = torch.randn(1, 2, 32800, 32800, dtype=torch.half) + x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to( + memory_format=torch.channels_last + ) c = nn.Conv2d( 2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.half - ) + ).to(memory_format=torch.channels_last) + yref = c(x) + y = c.to(device=device)(x.to(device=device)) + self.assertEqual(yref, y, atol=1e-3, rtol=1e-4) + del y, yref + + # try a batch-splittable case + x = x.reshape(100, 2, 3280, 3280).contiguous(memory_format=torch.channels_last) yref = c(x) y = c.to(device=device)(x.to(device=device)) - self.assertEqual(yref, y, atol=5e-3, rtol=1e-4) + self.assertEqual(yref, y, atol=1e-3, rtol=1e-4) instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True) diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index e33385bcfa11..a8f77df22d31 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -504,6 +504,7 @@ def test_quantized_max_pool3d(self): class TestPoolingNNDeviceType(NNTestCase): + @expectedFailureMPS # No double, float shape prop does not work @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_adaptive_pooling_zero_batch(self, dtype, device): @@ -523,6 +524,7 @@ def test_adaptive_pooling_zero_batch(self, dtype, device): # when output_size = 0, in adaptive_{avg, max}_pool and its variants. # These tests are explicitly written because ErrorInputs does not support backward calls # Issue: https://github.com/pytorch/pytorch/issues/78868 + @expectedFailureMPS # No double, float shape prop does not work @onlyNativeDeviceTypes @dtypes(torch.float32, torch.float64) @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16) @@ -556,6 +558,7 @@ def test_adaptive_pooling_empty_output_size(self, dtype, device): with self.assertRaisesRegex(RuntimeError, error_msg): fn(input2, output_size).sum().backward() + @expectedFailureMPS # Error message does not match @onlyNativeDeviceTypes def test_adaptive_avg_pooling_backward_fails(self, device): grad_output = torch.randn(1, 2, 7, device=device) @@ -582,6 +585,7 @@ def test_adaptive_max_pooling_backward_fails(self, device): with self.assertRaisesRegex(RuntimeError, "expected dimensions"): torch.ops.aten.adaptive_max_pool3d_backward(grad_output, input, indices) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_batch(self, device): mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) @@ -592,6 +596,7 @@ def test_FractionalMaxPool2d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, device=device) mod(inp) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_batch(self, device): mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device) @@ -602,6 +607,7 @@ def test_FractionalMaxPool3d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, 32, device=device) mod(inp) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_out_size(self, device): mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1]) @@ -609,6 +615,7 @@ def test_FractionalMaxPool2d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device)) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_out_size(self, device): mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1]) @@ -616,6 +623,7 @@ def test_FractionalMaxPool3d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device)) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_samples(self, device): samples = torch.rand([0, 16, 2], device=device) @@ -630,6 +638,7 @@ def test_FractionalMaxPool2d_zero_samples(self, device): with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): mod(inp1) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_samples(self, device): samples = torch.rand([0, 16, 3], device=device) @@ -823,6 +832,7 @@ def test_MaxUnpool_index_errors( else: unpool(output, indices) + @expectedFailureMPS @onlyNativeDeviceTypes def test_AdaptiveMaxPool_zero_batch_dim(self, device): inp = torch.randn(0, 16, 50, device=device) @@ -962,6 +972,7 @@ def test_adaptive_avg_pool3d_output_size_one(self, device): c = out.size(1) self.assertEqual(out.stride(), [c, 1, 1, 1, 1]) + @expectedFailureMPS # Runtime Error not raised for mps @expectedFailureMeta # Runtime Error not raised for meta @onlyNativeDeviceTypes @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long) @@ -976,6 +987,7 @@ def test_adaptive_pooling_no_suppot_input(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "not implemented"): module(input) + @expectedFailureMPS # TODO: fixme @onlyNativeDeviceTypes @gcIfJetson @dtypes(torch.float, torch.double) @@ -1123,6 +1135,7 @@ def helper(n, c, h, w, ks): helper(1, 100000, 32, 32, ks=4) helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d + @expectedFailureMPS # TODO: Fixme @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @@ -1198,6 +1211,7 @@ def check(x, args, expected, memory_format): torch.channels_last, ) + @expectedFailureMPS # TODO: Fixme @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @@ -1722,6 +1736,7 @@ def test_maxpool_indices_no_batch_dim(self, device, dtype): @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) + @expectedFailureMPS # Exception not raise @onlyNativeDeviceTypes # TODO: Fails on XLA @gcIfJetson def test_max_pool_nan_inf(self, device, dtype): @@ -1758,6 +1773,7 @@ def test_max_pool_nan_inf(self, device, dtype): res2 = fn(x2, 1 if adaptive else 3) self.assertTrue(math.isinf(res2.item())) + @expectedFailureMPS # float64 @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta @onlyNativeDeviceTypes def test_fractional_max_pool2d(self, device): @@ -1820,6 +1836,7 @@ def test_fractional_max_pool2d_backward_fails(self, device): grad_output, input, kernel_size, output_size, indices ) + @expectedFailureMPS # float64 @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta @onlyNativeDeviceTypes def test_fractional_max_pool3d(self, device): @@ -1867,6 +1884,7 @@ def func(x): x, (2, 2, 2), output_size=output_size, _random_samples=samples ) + @expectedFailureMPS # Not implemented @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) @onlyNativeDeviceTypes # TODO: Fails on XLA @@ -1896,6 +1914,7 @@ def test_fractional_max_pool_nan_inf(self, device, dtype): res2.backward(torch.randn_like(res2)) self.assertTrue(math.isinf(res2.item())) + @expectedFailureMPS # TODO: Fix me @onlyNativeDeviceTypes # TODO: RuntimeError message different on XLA def test_pooling_zero_stride(self, device): for op in ("max", "avg"): diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 9a8a171b5fe2..593cc524ebe7 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -600,6 +600,18 @@ def test_torchscript_exporter_raises_deprecation_warning(self): SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False ) + def test_model_output_can_be_none(self): + class ModelWithNoneOutput(torch.nn.Module): + def forward(self, x): + return x + 1, None + + onnx_program = torch.onnx.export( + ModelWithNoneOutput(), + (torch.randn(1, 1, 2),), + dynamo=True, + ) + onnx_testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/torchlib/error_reproduction.py b/test/onnx/torchlib/error_reproduction.py index 260a37b65f16..9fd1dace7767 100644 --- a/test/onnx/torchlib/error_reproduction.py +++ b/test/onnx/torchlib/error_reproduction.py @@ -205,7 +205,7 @@ def create_reproduction_report( onnxscript=={onnxscript.__version__} numpy=={np.__version__} torch=={torch.__version__}""" - short_test_name = test_name.split(".")[-1] + short_test_name = test_name.rsplit(".", maxsplit=1)[-1] reproduction_code = _REPRODUCTION_TEMPLATE.format( onnx_model_text=onnx_model_text, ort_inputs=input_text, @@ -245,7 +245,7 @@ def create_mismatch_report( error_text = str(error) error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) - short_test_name = test_name.split(".")[-1] + short_test_name = test_name.rsplit(".", maxsplit=1)[-1] diff = difflib.unified_diff( str(actual).splitlines(), str(expected).splitlines(), diff --git a/test/package/test_save_load.py b/test/package/test_save_load.py index a0cc967787e6..edbba9f6f8ee 100644 --- a/test/package/test_save_load.py +++ b/test/package/test_save_load.py @@ -208,11 +208,10 @@ def make_exporter(): # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. return pe - # This should fail. The 'PackageAObject' type defined from 'importer1' - # is not necessarily the same 'obj2's version of 'PackageAObject'. + # This succeeds because OrderedImporter.get_name() properly + # falls back to sys_importer which can find the original PackageAObject pe = make_exporter() - with self.assertRaises(pickle.PicklingError): - pe.save_pickle("obj", "obj.pkl", obj2) + pe.save_pickle("obj", "obj.pkl", obj2) # This should also fail. The 'PackageAObject' type defined from 'importer1' # is not necessarily the same as the one defined from 'importer2' diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index c15b5b220b0e..7da2898ffbe7 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -425,6 +425,11 @@ def fn(a, b, c): ) @skipCPUIf(True, "skip CPU device for testing profiling triton") def test_execution_trace_env_enabled_with_pt2(self, device): + # clean up the local cache for triton kernel + from torch._inductor.codecache import PyCodeCache as PyCodeCache + + PyCodeCache.cache_clear(purge=True) + import os os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1" @@ -439,7 +444,9 @@ def fn(a, b, c): a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3)) inputs = [a, b, c] - with torch._inductor.config.patch(compile_threads=1): + with torch._inductor.config.patch( + compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False + ): fn(*inputs) with profile( @@ -480,10 +487,12 @@ def fn(a, b, c): ) @skipCPUIf(True, "skip CPU device for testing profiling triton") def test_triton_fx_graph_with_et(self, device): - import os + # clean up the local cache for triton kernel + from torch._inductor.codecache import PyCodeCache as PyCodeCache - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1" - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1" + PyCodeCache.cache_clear(purge=True) + + import os @torchdynamo.optimize("inductor") def fn(a, b, c): @@ -503,12 +512,18 @@ def fn(a, b, c): ): fn(*inputs) + fp = tempfile.NamedTemporaryFile("w+t", suffix="fx_graph_et.json", delete=False) + fp.close() + et = ExecutionTraceObserver() + et.register_callback(fp.name) + et.set_extra_resource_collection(True) with profile( activities=torch.profiler.supported_activities(), record_shapes=True, schedule=torch.profiler.schedule( skip_first=0, wait=1, warmup=1, active=1, repeat=1 ), + execution_trace_observer=et, ) as p: for idx in range(10): with record_function(f"## LOOP {idx} ##"): @@ -550,23 +565,23 @@ def fn(a, b, c): ) assert ( fx_graph[2] - == "# %sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})" # noqa: B950 + == '# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})' # noqa: B950 ) assert ( fx_graph[3] - == "# %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})" # noqa: B950 + == '# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})' # noqa: B950 ) assert ( fx_graph[4] - == "# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})" # noqa: B950 + == '# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})' # noqa: B950 ) assert ( fx_graph[5] - == "# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})" # noqa: B950 + == '# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})' # noqa: B950 ) assert ( fx_graph[6] - == "# %cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})" # noqa: B950 + == '# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950 ) assert fx_graph[7] == "# return %cos" diff --git a/test/profiler/test_python_tracer.py b/test/profiler/test_python_tracer.py index 389395d8027c..f7732b0b3893 100644 --- a/test/profiler/test_python_tracer.py +++ b/test/profiler/test_python_tracer.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: profiler"] import json +import subprocess import sys import time @@ -63,6 +64,46 @@ def test_monitoring_callback(self): name = monitoring.get_tool(2) self.assertEqual(name, None) + def test_unexpected_c_return_events(self): + code = """ +import threading +import time +import torch + +from threading import Event, Lock + +lock = Lock() +lock.acquire() +event1 = Event() +event2 = Event() +event3 = Event() + +def run(): + event1.set() + event2.wait() + lock.acquire() + event3.set() + +threading.Thread(target=run).start() + +with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True): + event1.wait() + event2.set() + time.sleep(1) + +with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True): + lock.release() + event3.wait() + """ + + result = subprocess.run( + [sys.executable, "-c", code], capture_output=True, text=True, check=True + ) + + self.assertFalse( + "Python replay stack is empty during pop operation" in result.stderr + ) + if __name__ == "__main__": run_tests() diff --git a/test/run_test.py b/test/run_test.py index e0bde4e6d52d..5e9548d4eab1 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -182,7 +182,6 @@ def __contains__(self, item): "dynamo/test_misc", "inductor/test_cpu_repro", "inductor/test_cpu_select_algorithm", - "inductor/test_aot_inductor_arrayref", "inductor/test_torchinductor_codegen_dynamic_shapes", "lazy/test_meta_kernel", "onnx/test_utility_funs", @@ -240,7 +239,6 @@ def __contains__(self, item): # some false errors "doctests", # new failures to investigate and fix - "cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic", "test_tensorboard", # onnx + protobuf failure, see # https://github.com/protocolbuffers/protobuf/issues/22104 @@ -1557,7 +1555,7 @@ def get_selected_tests(options) -> list[str]: if options.einops: selected_tests = list( filter( - lambda test_name: test_name.startswith("test/dynamo/test_einops"), + lambda test_name: test_name.startswith("dynamo/test_einops"), selected_tests, ) ) @@ -1584,6 +1582,7 @@ def get_selected_tests(options) -> list[str]: "inductor/test_mps_basic", "inductor/test_torchinductor", "inductor/test_aot_inductor", + "inductor/test_torchinductor_dynamic_shapes", ] else: # Exclude all mps tests otherwise diff --git a/test/slow_tests.json b/test/slow_tests.json index 457701b46b61..579e69d7e488 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,260 +1,239 @@ { - "EndToEndLSTM (__main__.RNNTest)": 200.1896718343099, - "MultiheadAttention (__main__.ModulesTest)": 141.92533365885416, - "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 210.3270060221354, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 105.85777706570096, - "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 115.53966522216797, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 62.45811038547092, - "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 177.51766967773438, - "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 74.74966557820638, - "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 68.23533376057942, - "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 61.625999450683594, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 134.07366434733072, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 188.88899739583334, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.63599904378255, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.27233378092448, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 105.4979985555013, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 633.0828002929687, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 91.86733309427898, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 481.1977776421441, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 491.7155592176649, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 124.39833196004231, - "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.104000091552734, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 81.22966766357422, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 69.64550145467122, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 175.67355600992838, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 125.82333374023438, - "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 369.5883280436198, - "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 418.0381130642361, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 312.76700168185766, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 84.68433380126953, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 86.41216786702473, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 60.670833587646484, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 84.44266510009766, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 86.69533284505208, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 63.40933354695638, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 375.11133829752606, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 64.89966583251953, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 386.1840108235677, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 66.45699818929036, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 227.58533223470053, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 236.75483194986978, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1000.12451171875, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 63.72516632080078, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 936.3953450520834, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 65.74933242797852, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 70.87016677856445, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.49433453877766, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.39149983723958, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.41349919637044, - "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.10983467102051, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 64.13150151570638, - "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 89.73133341471355, - "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 86.45633188883464, - "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 88.76399993896484, - "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.25218469125254, - "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.11777793036566, - "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 176.61566670735678, - "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 173.7596689860026, - "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 163.57832845052084, - "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 161.29700215657553, - "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 208.6990000406901, - "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 198.11366271972656, - "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 198.788330078125, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 121.93983332316081, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 119.3211669921875, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.11850102742513, - "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 121.52633412679036, - "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 114.41900126139323, - "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 120.74099985758464, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 92.1571667989095, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 93.97516759236653, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 93.90033213297527, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 102.24433135986328, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 237.9564997355143, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 263.09083048502606, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 70.44449869791667, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 78.58383433024089, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 66.97166633605957, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 81.04183451334636, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 89.63233439127605, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 94.67216491699219, - "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 168.28499857584634, - "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 171.91666666666666, - "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 166.12066650390625, - "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1279.8836669921875, - "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1132.968994140625, - "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1118.725341796875, - "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 973.7703247070312, - "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 972.6750081380209, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1209.7756754557292, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1256.0619710286458, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1281.5216471354167, - "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 917.3249918619791, - "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 733.1909790039062, - "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 724.7653401692709, - "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 726.2100219726562, - "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 705.0809936523438, - "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 517.8646697998047, - "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 521.0065002441406, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 130.64300028483072, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 124.43033345540364, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 128.03166707356772, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.71049880981445, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.55933380126953, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 65.66183217366536, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 69.40700022379558, - "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 74.34766642252605, - "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 112.48366800944011, - "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 116.27966562906902, - "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 117.50433603922527, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 106.86666615804036, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 94.00083287556966, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 62.15316645304362, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.82649993896484, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 61.87600072224935, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 69.6066665649414, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 68.90516599019368, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 102.65083312988281, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 85.81283442179362, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 70.68100102742513, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 98.76588948567708, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 229.82177903917102, - "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 81.8357684795673, - "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 135.92233530680338, - "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 141.42266845703125, - "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 74.59500092726488, - "test_conv3d_unary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 64.01784662099985, - "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 73.09766684638129, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 95.88766733805339, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 94.47416687011719, - "test_count_nonzero_all (__main__.TestBool)": 641.161878797743, - "test_custom_module_lstm (__main__.TestQuantizedOps)": 307.93677775065106, - "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 302.5940024058024, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 81.91116714477539, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDTensorOpsCPU)": 88.2913335164388, - "test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 67.36266835530598, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 60.49377780490451, - "test_fail_creation_ops.py (__main__.TestTyping)": 68.32106041185784, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 76.85566584269206, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 91.61366780598958, - "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 204.6830037434896, - "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 134.79716873168945, - "test_fuse_large_params_cpu (__main__.CpuTests)": 97.0917501449585, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.09088897705078, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 147.25677744547525, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 125.67216491699219, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 94.74416732788086, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 98.06850051879883, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 150.5540008544922, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 139.7729949951172, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 232.7606684366862, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 154.89383188883463, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 156.3326670328776, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 650.9168192545573, - "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.89266459147134, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 273.2460021972656, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.99511040581598, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 101.2813351949056, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 154.23166741265192, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.40700022379558, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 123.70700073242188, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 95.7520014444987, - "test_linear (__main__.TestStaticQuantizedModule)": 62.20888815985786, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 102.4893315633138, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 127.22689056396484, - "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 431.17966715494794, - "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 133.41966756184897, - "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 360.4186706542969, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 60.48455513848199, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 63.52433310614692, - "test_proper_exit (__main__.TestDataLoader)": 234.38233439127603, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 242.4615020751953, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 65.31966749827068, - "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 150.28666602240668, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 65.1363112979465, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 63.50664397345649, - "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 62.56345471468839, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 73.45999908447266, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.02366638183594, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 85.85933430989583, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 74.7816670735677, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 88.31666564941406, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.21133422851562, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.58400217692058, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 85.65733337402344, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 94.56866709391277, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 80.31666564941406, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 95.52099863688152, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.52433522542317, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.57466634114583, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.05966695149739, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.94766743977864, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 77.00899759928386, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 95.18199920654297, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.22000122070312, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 69.10733286539714, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 84.89466603597005, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 85.52066548665364, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 93.1520004272461, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.66366831461589, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 370.8893330891927, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 733.5455017089844, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 605.9030151367188, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1136.014139811198, - "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 72.65350023905437, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 64.6456667582194, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 207.27167002360025, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 91.64166768391927, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 167.19299825032553, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 64.22866694132487, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 116.8476676940918, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 70.6433334350586, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 137.72866566975912, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 87.72266642252605, - "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 78.25366719563802, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.75999959309895, - "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 68.58633486429851, - "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.43899959988065, - "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 155.9663340250651, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 110.39933268229167, - "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 85.31637557347615, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 136.4769990709093, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 113.9978896247016, - "test_sort_stable_cpu (__main__.CpuTritonTests)": 76.96166737874348, - "test_split_cumsum_cpu (__main__.CpuTritonTests)": 89.43966674804688, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 149.7841674486796, - "test_tensor_split (__main__.TestVmapOperators)": 76.2336671680021, - "test_terminate_handler_on_crash (__main__.TestTorch)": 111.58677675988939, - "test_terminate_signal (__main__.ForkTest)": 136.8188896137807, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 136.99289169742002, - "test_terminate_signal (__main__.SpawnTest)": 140.61755683687, - "test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 69.51326649983724, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 68.61666615804036, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 65.95349820454915, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.64900016784668, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.68766657511394, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 120.926331837972, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 104.47883415222168, - "test_unary_ops (__main__.TestTEFuserDynamic)": 172.1952222188314, - "test_unary_ops (__main__.TestTEFuserStatic)": 158.92655531565347, - "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 96.95966339111328, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 90.34199778238933, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 69.39216740926106, - "test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 73.56816864013672, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 96.19633483886719, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 93.57866668701172, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 95.94100189208984, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 71.65300051371257, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 84.81466547648112, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 100.53633308410645, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 69.77733103434245, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 67.43849881490071, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 77.40583229064941, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 64.32900110880534, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 71.61133193969727, - "test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 60.90399932861328, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.39033381144206, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.00383377075195, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 143.61550013224283 + "EndToEndLSTM (__main__.RNNTest)": 192.05133056640625, + "MultiheadAttention (__main__.ModulesTest)": 139.78399658203125, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 87.68600040011935, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 65.84855567084418, + "test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 60.25300089518229, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.21100107828777, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 75.08200073242188, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 157.21666717529297, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 208.15966288248697, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 125.87799835205078, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 77.12099711100261, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 140.02066548665366, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1035.8856404622395, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 135.24966684977213, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 508.929680718316, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 505.31178114149304, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 136.39566548665366, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.21700286865234, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 75.41950098673503, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 223.36288791232639, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.77316665649414, + "test_cat_2k_args (__main__.TestTEFuserDynamic)": 115.93922015362315, + "test_cat_2k_args (__main__.TestTEFuserStatic)": 130.553553307222, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 345.87477620442706, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 444.5221184624566, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 320.5727776421441, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 113.46416600545247, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 112.7143325805664, + "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.17833370632596, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 74.29283396402995, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 112.0316670735677, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 100.49766794840495, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 461.6960042317708, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 456.4236653645833, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.10166422526044, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 282.37300364176434, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1475.5308430989583, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.82050069173177, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1480.9661661783855, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 76.27283477783203, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 77.9731674194336, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 75.6216672261556, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 78.13583374023438, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 79.3071657816569, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 73.1963342030843, + "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 73.24300003051758, + "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.95249938964844, + "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 60.023167292277016, + "test_comprehensive_logspace_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 60.90595825513204, + "test_comprehensive_logspace_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 60.20212459564209, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 146.75049845377603, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 134.19933319091797, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 131.4624989827474, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 63.848776499430336, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 63.11926663716634, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 63.54826672871908, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 128.72383244832358, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 125.754332224528, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 112.56066640218098, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 105.46999867757161, + "test_comprehensive_nn_functional_grid_sample_cuda_bfloat16 (__main__.TestDecompCUDA)": 62.39555570814345, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 319.47683970133465, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 318.15632883707684, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 104.06650034586589, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 87.9704984029134, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 88.85649871826172, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 91.08616511027019, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 145.80900065104166, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 144.81166712443033, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1361.4583333333333, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1364.7848307291667, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1371.0353393554688, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 567.3706563313802, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 562.332997639974, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 75.43950017293294, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 73.2380002339681, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.18633397420247, + "test_comprehensive_nn_functional_unfold_cuda_complex128 (__main__.TestDecompCUDA)": 64.52433310614691, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 135.42366409301758, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 135.88899993896484, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 73.0211664835612, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 75.32600021362305, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 76.17533365885417, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 78.49149958292644, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 80.97866566975911, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 143.84516398111978, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 139.04916763305664, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 107.44683329264323, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 349.12533315022785, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 713.3404405381945, + "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.65333302815755, + "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 147.33233133951822, + "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 65.11533101399739, + "test_conv_bn_folded_vs_unfolded (__main__.TestQuantizeEagerQATNumerics)": 60.53688989910815, + "test_conv_bn_fuse_cpu (__main__.CpuTests)": 82.8076680501302, + "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 79.54511260986328, + "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 86.01536305745442, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 118.80933380126953, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 103.28283437093098, + "test_count_nonzero_all (__main__.TestBool)": 636.5518866644966, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 806.537343343099, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 86.1219991048177, + "test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 129.43338103521438, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 226.9676717122396, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 64.93344370524089, + "test_fail_random.py (__main__.TestTyping)": 69.7191998799642, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 89.57850011189778, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 91.1931660970052, + "test_fuse_large_params_cpu (__main__.CpuTests)": 68.59933344523112, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 157.28044637044272, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 155.77044677734375, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.154665629069, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 107.34999974568684, + "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 75.96997397985214, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 98.00283304850261, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 125.0576680501302, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.84066518147786, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 227.8953374226888, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 121.02666727701823, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 128.9303321838379, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 607.3985087076823, + "test_group_norm (__main__.TestQuantizedOps)": 94.22445230773, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 322.7479960123698, + "test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 126.8058580671038, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 74.46766620212131, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 98.24650065104167, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 165.09344482421875, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 117.98733266194661, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 125.10833231608073, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 96.8866678873698, + "test_linear (__main__.TestStaticQuantizedModule)": 177.4332241482205, + "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 99.29573364257813, + "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 112.58993326822916, + "test_linear_relu (__main__.TestStaticQuantizedModule)": 70.74819436942602, + "test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 106.39933342403836, + "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.2489998227074, + "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 581.2816569010416, + "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 515.0809936523438, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 65.59099833170573, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 130.8411119249132, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.907222747802734, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 63.92422188652886, + "test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 80.63411996126175, + "test_optimize_for_inference_cpu_torchvision (__main__.TestFXExperimental)": 70.60716595252354, + "test_out_variant_custom_op_dynamic_shapes (__main__.DynamicShapesMiscTests)": 61.15033358619327, + "test_proper_exit (__main__.TestDataLoader)": 224.09533182779947, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 258.17566172281903, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 61.226499239603676, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 159.05066765679254, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 63.150904201325915, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 62.33847640809559, + "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 99.43811119927301, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 81.92866770426433, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 90.84566497802734, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.01099904378255, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.23799896240234, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.45733388264973, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.5086669921875, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 76.81433359781902, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 86.00199890136719, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 86.0836664835612, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 73.06933339436848, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 98.68933614095052, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.80333201090495, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 78.26366678873698, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.90333557128906, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.47400156656902, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.05833435058594, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.04699961344402, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 69.11566670735677, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.11000061035156, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 83.76499938964844, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.46166483561198, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 93.64866638183594, + "test_qrnncell (__main__.TestDynamicQuantizedOps)": 76.3342770516562, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 578.3420003255209, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1415.7366739908855, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 764.0906778971354, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1710.9246826171875, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 97.7066650390625, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 350.8980000813802, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.1796646118164, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 271.30833435058594, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 76.83166758219402, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 166.40349833170572, + "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.98755560980902, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 106.40633392333984, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 189.75599924723306, + "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.40213343302409, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 119.15783309936523, + "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 122.17516708374023, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.66699981689453, + "test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 165.6238899230957, + "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 155.86678059895834, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 76.51850128173828, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 77.36766730414496, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 163.50216674804688, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 135.39966328938803, + "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 161.2034437391493, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 145.5945544772678, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 122.7945556640625, + "test_softmax_view_reshape (__main__.HelionTests)": 174.26483281453451, + "test_std (__main__.TestQuantizedOps)": 91.47738643594978, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 150.35899583498636, + "test_terminate_handler_on_crash (__main__.TestTorch)": 110.8061129252116, + "test_terminate_signal (__main__.ForkTest)": 134.98833089901342, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 135.13266838259167, + "test_terminate_signal (__main__.SpawnTest)": 139.0918925603231, + "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 83.97499879201253, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 166.78876847487228, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 76.76449902852376, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 74.20233408610027, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 77.21166737874348, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 126.05833435058594, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 124.58566665649414, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 102.95399856567383, + "test_unary_ops (__main__.TestTEFuserDynamic)": 94.66122142473857, + "test_unary_ops (__main__.TestTEFuserStatic)": 97.9681122303009, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 94.58433278401692, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 80.96083323160808, + "test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 84.94333267211914, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 93.61533101399739, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 99.49200185139973, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 60.70061842600504, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 98.77016703287761, + "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 80.70883369445801, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 117.87966664632161, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.81652414231073, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 138.76616923014322, + "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 66.88895261855353, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 66.50699996948242, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 98.47683461507161, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 115.15083122253418, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 102.98050053914388, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 132.38116709391275, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 124.73283131917317, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 159.73250325520834 } \ No newline at end of file diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 0ea224d704cb..21731bd275b6 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -1,5 +1,6 @@ # Owner(s): ["module: tests"] +import gc import sys import unittest @@ -156,6 +157,83 @@ def test_generic_event_behavior(self): ): event1.elapsed_time(event2) + @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") + def test_memory_stats(self): + # Ensure that device allocator is initialized + acc = torch.accelerator.current_accelerator() + tmp = torch.randn(100, device=acc) + del tmp + gc.collect() + self.assertTrue(torch._C._accelerator_isAllocatorInitialized()) + torch.accelerator.empty_cache() + + pool_type = ["all", "small_pool", "large_pool"] + metric_type = ["peak", "current", "allocated", "freed"] + stats_type = [ + "allocated_bytes", + "reserved_bytes", + "active_bytes", + "requested_bytes", + ] + mem_stats = torch.accelerator.memory_stats() + expected_stats = [ + f"{st}.{pt}.{mt}" + for st in stats_type + for pt in pool_type + for mt in metric_type + ] + missing_stats = [stat for stat in expected_stats if stat not in mem_stats] + self.assertEqual( + len(missing_stats), + 0, + f"Missing expected memory statistics: {missing_stats}", + ) + + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertGreaterEqual(prev_allocated, 0) + self.assertGreaterEqual(prev_reserved, 0) + self.assertGreater(prev_max_allocated, 0) + self.assertGreater(prev_max_reserved, 0) + tmp = torch.ones(256, device=acc) + self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated) + self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved) + del tmp + gc.collect() + torch.accelerator.empty_cache() + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated) + self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved) + torch.accelerator.reset_accumulated_memory_stats() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device=acc) + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + if __name__ == "__main__": run_tests() diff --git a/test/test_autograd.py b/test/test_autograd.py index 5d7f81eeb4fb..7ce40e59dd4b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -109,6 +109,10 @@ def graph_desc(fn): class TestAutograd(TestCase): + def tearDown(self): + torch.autograd._force_original_view_tracking(False) + super(TestCase, self).tearDown() + def test_copy_slices_graph_task_updates(self): def f1(x, y): out = x.clone().view(-1) @@ -610,8 +614,6 @@ def unpack(x): with disable_gc(): unpack_hook_ref = scope() - if torch._dynamo.is_compiling(): - torch._dynamo.reset() self.assertIsNone(unpack_hook_ref()) def test_will_engine_execute_node(self): @@ -1194,6 +1196,33 @@ def fn(x, reduce=True): tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0]) ) + def test_gradient_edge_graph_ownership(self): + # Ensure we own the graph properly + class Clone(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, gX): + return gX.clone() + + inp = torch.rand(1, requires_grad=True).clone() + + # C++ Node + out = inp.clone() + edge = torch.autograd.graph.get_gradient_edge(out) + torch.autograd.backward(edge) + del out + torch.autograd.backward(edge) + + # python Node + out = Clone.apply(inp) + edge = torch.autograd.graph.get_gradient_edge(out) + torch.autograd.backward(edge) + del out + torch.autograd.backward(edge) + def test_grad_nonleaf(self): x_init = torch.randn(2, 2, requires_grad=True) x = x_init @@ -12394,6 +12423,29 @@ def test_resize_version_bump(self, device): x.resize_as_(y) self.assertEqual(x._version, 2) + @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + def test_zero_dim_param_mixed_device_grad(self, device): + # cpu 0-dim params with an accelerator device grad + # https://github.com/pytorch/pytorch/issues/160084 + class RegressionModel(torch.nn.Module): + def __init__(self, a=0, b=0): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + + def forward(self, x): + return x * self.a + self.b + + # Keep the model on cpu as we do want to test the mixed cpu/accelerator behavior here + model = RegressionModel() + inputs = torch.randn(4, 10, device=device) + out = model(inputs) + out.sum().backward() + self.assertIsNotNone(model.a.grad) + self.assertIsNotNone(model.b.grad) + self.assertEqual(model.a.grad.device, torch.device("cpu")) + self.assertEqual(model.b.grad.device, torch.device("cpu")) + class TestAllowMutationOnSaved(TestCase): def assertClonedLenEqual(self, ctx, n): diff --git a/test/test_cuda.py b/test/test_cuda.py index 39065eea1a9c..9755835853ee 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -373,6 +373,42 @@ def test_memory_allocation(self): torch.cuda.caching_allocator_delete(mem) self.assertEqual(torch.cuda.memory_allocated(), prev) + def test_memory_stats(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertEqual(prev_allocated, prev_max_allocated) + self.assertEqual(prev_reserved, prev_max_reserved) + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device="cuda") + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + def test_check_error(self): # Assert this call doesn't raise. torch.cuda.check_error(0) @@ -1503,6 +1539,7 @@ def run(dev: torch.device) -> int: ) @largeTensorTest("20GB", "cuda") + @serialTest() def test_randint_generation_for_large_numel(self) -> None: numel = 2**31 + 1 s = torch.randint(2, (numel,), device="cuda", dtype=torch.int8).sum() @@ -5331,6 +5368,7 @@ def test_mempool_empty_cache(self): segments = torch.cuda.memory._snapshot()["segments"] self.assertTrue(len(segments) > 0, "expected more than one segment") + @serialTest() def test_mempool_empty_cache_inactive(self): torch.cuda.empty_cache() allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True) @@ -5560,6 +5598,7 @@ def test_mempool_expandable(self): out_0 = torch.randn(nelem_1mb, device="cuda") torch.cuda.memory._set_allocator_settings("expandable_segments:False") + @serialTest() def test_mempool_ctx_multithread(self): torch.cuda.empty_cache() segments = torch.cuda.memory._snapshot()["segments"] @@ -6479,6 +6518,7 @@ def test_autocast_rnn(self): for grad, grad_control in zip(grads, grads_control): self.assertEqual(grad.half(), grad_control) + @serialTest() def test_autocast_cache_leak(self): # Reported at https://github.com/pytorch/pytorch/issues/48049 # Test is used to check, if autocast recaches the same parameters @@ -6493,7 +6533,7 @@ def test_autocast_cache_leak(self): first_iter_mem = torch.cuda.memory_allocated() for _ in range(3): out = linear(data) - self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) + self.assertEqual(first_iter_mem, torch.cuda.memory_allocated()) def test_autocast_checkpointing(self): model = torch.nn.Sequential( diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index b713edeb7a95..5a494f548742 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1769,7 +1769,7 @@ def f(x): Developer debug context: _torch_testing.numpy_nonzero.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0036.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html""", ) # pre-existing problem: torch.compile(dynamic=True) will, by default, diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 931c32774004..8f4e74d85177 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -3133,6 +3133,15 @@ def test_pin_memory(self): self.assertTrue(sample["a_tensor"].is_pinned()) self.assertTrue(sample["another_dict"]["a_number"].is_pinned()) + @skipIfXpu + @skipIfRocm + @unittest.skipIf(TEST_CUDA, "Test for when CUDA is not available") + def test_pin_memory_no_cuda(self): + loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) + for sample in loader: + self.assertFalse(sample["a_tensor"].is_pinned()) + self.assertFalse(sample["another_dict"]["a_number"].is_pinned()) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_pin_memory_device(self): loader = DataLoader( diff --git a/test/test_dlpack.py b/test/test_dlpack.py index f734126b5e7c..b960575cc634 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -5,6 +5,7 @@ from torch.testing._internal.common_device_type import ( deviceCountAtLeast, dtypes, + dtypesIfMPS, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -13,10 +14,14 @@ skipCUDAIfRocm, skipMeta, ) -from torch.testing._internal.common_dtype import all_types_and_complex_and +from torch.testing._internal.common_dtype import ( + all_mps_types_and, + all_types_and_complex_and, +) from torch.testing._internal.common_utils import ( IS_JETSON, run_tests, + skipIfMPS, skipIfTorchDynamo, TestCase, ) @@ -55,6 +60,7 @@ class TestTorchDlPack(TestCase): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_dlpack_capsule_conversion(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(to_dlpack(x)) @@ -72,6 +78,7 @@ def test_dlpack_capsule_conversion(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_dlpack_protocol_conversion(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(x) @@ -80,7 +87,8 @@ def test_dlpack_protocol_conversion(self, device, dtype): @skipMeta @onlyNativeDeviceTypes def test_dlpack_shared_storage(self, device): - x = make_tensor((5,), dtype=torch.float64, device=device) + dtype = torch.bfloat16 if device.startswith("mps") else torch.float64 + x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(to_dlpack(x)) z[0] = z[0] + 20.0 self.assertEqual(z, x) @@ -120,12 +128,14 @@ def test_dlpack_conversion_with_streams(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_from_dlpack(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) y = torch.from_dlpack(x) self.assertEqual(x, y) @skipMeta + @skipIfMPS # MPS crashes with noncontiguous now @onlyNativeDeviceTypes @dtypes( *all_types_and_complex_and( @@ -189,6 +199,7 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_from_dlpack_dtype(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) y = torch.from_dlpack(x) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 7aa530ae3296..9baad91da79d 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1721,6 +1721,16 @@ def test_nonzero_stride(self): self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous()) + def test_nan_to_num(self): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + with fake_mode: + x = torch.randn(5, 10).t() + y = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + + self.assertEqual(x.size(), y.size()) + self.assertEqual(x.stride(), y.stride()) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_torch_load_with_fake_mode(self): model = torch.nn.Linear(5, 10) @@ -2476,5 +2486,81 @@ def forward( self.assertBypasses("unrepresented symbol in output", 2) +class FakeTensorPreferDeviceType(TestCase): + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_fake_tensor_prefer_device_type(self): + """ + Test that fake_tensor_prefer_device_type configuration works correctly + for device mismatch scenarios. + """ + + # Create a custom operation that would normally cause device mismatch + def mixed_device_op(a, b): + # This simulates an operation where 'a' is on MTIA/CUDA but 'b' is created on CPU + cpu_tensor = torch.arange(a.shape[0], device="cpu") + return a + cpu_tensor.unsqueeze(-1) + + with FakeTensorMode(): + # Test default behavior (should raise error on device mismatch) + cuda_tensor = torch.randn(3, 4, device="cuda") + + # Without the config, this should raise a device mismatch error + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation" + ): + mixed_device_op(cuda_tensor, None) + + # Test with prefer_device_type set to "cuda" + with torch._functorch.config.patch(fake_tensor_prefer_device_type="cuda"): + with FakeTensorMode(): + cuda_tensor = torch.randn(3, 4, device="cuda") + + # This should now work and prefer the CUDA device + result = mixed_device_op(cuda_tensor, None) + + # The result should be on CUDA device (preferred device type) + self.assertEqual(result.device.type, "cuda") + self.assertEqual(result.shape, (3, 4)) + self.assertTrue(isinstance(result, FakeTensor)) + + # Test that the configuration doesn't affect normal operations + with torch._functorch.config.patch(fake_tensor_prefer_device_type="cuda"): + with FakeTensorMode(): + # Normal same-device operations should work as before + x = torch.randn(2, 3, device="cuda") + y = torch.randn(2, 3, device="cuda") + result = x + y + self.assertEqual(result.device.type, "cuda") + + # CPU operations should still work + x_cpu = torch.randn(2, 3, device="cpu") + y_cpu = torch.randn(2, 3, device="cpu") + result_cpu = x_cpu + y_cpu + self.assertEqual(result_cpu.device.type, "cpu") + + # Test that the configuration is properly scoped + with FakeTensorMode(): + cuda_tensor = torch.randn(3, 4, device="cuda") + + # After exiting the config context, should raise error again + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation" + ): + mixed_device_op(cuda_tensor, None) + + def test_fake_tensor_prefer_device_type_cpu_only(self): + """ + Test that fake_tensor_prefer_device_type works correctly when only CPU tensors are involved. + """ + with torch._functorch.config.patch(fake_tensor_prefer_device_type="cuda"): + with FakeTensorMode(): + # When all tensors are CPU, the result should still be CPU + x = torch.randn(2, 3, device="cpu") + y = torch.randn(2, 3, device="cpu") + result = x + y + self.assertEqual(result.device.type, "cpu") + self.assertTrue(isinstance(result, FakeTensor)) + + if __name__ == "__main__": run_tests() diff --git a/test/test_foreach.py b/test/test_foreach.py index a5ca220dcb52..7ac128d6bac8 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -43,7 +43,7 @@ TEST_WITH_ROCM, TestCase, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton _BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator" @@ -1375,7 +1375,7 @@ def test_foreach_copy_with_multi_dtypes_large_input(self): ref_out = torch.empty_like(self_tensor).copy_(src_tensor) self.assertEqual(self_tensor, ref_out) - @requires_cuda + @requires_cuda_and_triton @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) def test_foreach_copy_with_different_device_inputs(self, device, dtype, op): if dtype in (torch.complex128, torch.complex64): diff --git a/test/test_fx.py b/test/test_fx.py index 55e98df70248..ba80f69828df 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4660,7 +4660,6 @@ def tearDown(self): "linear": BUILT_IN_FUNC, "logsigmoid": BUILT_IN_FUNC, "one_hot": BUILT_IN_FUNC, - "pad": ARG_TYPE_MISMATCH, "pairwise_distance": BUILT_IN_FUNC, "pdist": BUILT_IN_FUNC, "pixel_shuffle": BUILT_IN_FUNC, @@ -4693,12 +4692,6 @@ def tearDown(self): "max_unpool3d": PROXY_ITERATED, "fold": PROXY_ITERATED, "unfold": PROXY_ITERATED, - "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, - "layer_norm": ARG_TYPE_MISMATCH, - "rms_norm": ARG_TYPE_MISMATCH, - "lp_pool1d": ARG_TYPE_MISMATCH, "affine_grid": CONTROL_FLOW, "alpha_dropout": CONTROL_FLOW, "batch_norm": CONTROL_FLOW, @@ -4732,9 +4725,6 @@ def tearDown(self): "leaky_relu": CONTROL_FLOW, "local_response_norm": CONTROL_FLOW, "margin_ranking_loss": CONTROL_FLOW, - "max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "max_pool3d_with_indices": ARG_TYPE_MISMATCH, "mse_loss": CONTROL_FLOW, "multi_head_attention_forward": CONTROL_FLOW, "multi_margin_loss": CONTROL_FLOW, diff --git a/test/test_indexing.py b/test/test_indexing.py index 3870734f60d3..c1b4612db9e3 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -16,6 +16,7 @@ dtypesIfCPU, dtypesIfCUDA, dtypesIfMPS, + expectedFailureMPS, instantiate_device_type_tests, onlyCUDA, onlyNativeDeviceTypes, @@ -183,6 +184,7 @@ def delitem(): @onlyNativeDeviceTypes @dtypes(torch.half, torch.double) + @dtypesIfMPS(torch.half) # TODO: add bf16 there? def test_advancedindex(self, device, dtype): # Tests for Integer Array Indexing, Part I - Purely integer array # indexing @@ -1193,6 +1195,7 @@ def func1(x, i, v): out_cpu = func1(t, ind, val) self.assertEqual(out_cuda.cpu(), out_cpu) + @expectedFailureMPS # Doubles not supported @onlyNativeDeviceTypes def test_index_put_accumulate_duplicate_indices(self, device): for i in range(1, 512): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 8d3a8090c67a..c3e26d37da1b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -2939,7 +2939,10 @@ def test_unsupported(self, device, dtype, op): @slowTest @onlyCPU - @ops(op_db, dtypes=OpDTypes.supported) + @ops( + [op for op in op_db if get_name(op) not in known_failures], + dtypes=OpDTypes.supported, + ) def test_nnc_correctness(self, device, dtype, op): if not op.supports_tracing: self.skipTest("Requires tracing support") diff --git a/test/test_linalg.py b/test/test_linalg.py index 100a175fade9..ac668fee049d 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7765,7 +7765,7 @@ def dyn_quant_matmul_4bit( all_elements_within_threshold, "Some elements have error >= 0.06" ) - @onlyCPU + @onlyNativeDeviceTypes @parametrize("m", [32, 64]) @parametrize("k", [32, 64]) @parametrize("n", [48, 64]) @@ -7811,6 +7811,32 @@ def weight_int8pack_mm(a, b_int8pack, b_scales): mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) + @slowTest + @onlyCPU + @largeTensorTest('12GB', device='cpu') + def test__int8_mm_large_shape(self, device): + torch.manual_seed(1) + m = 65536 + k = 64 + n = 50400 + a = torch.rand((m, k), dtype=torch.bfloat16, device=device) + b = torch.rand((n, k), dtype=torch.bfloat16, device=device) + + def convert_weight_to_int8pack(b): + b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( + b, -128, 127, torch.int8 + ) + return b_int8pack, b_scales + + def weight_int8pack_mm(a, b_int8pack, b_scales): + return torch._weight_int8pack_mm( + a, b_int8pack, b_scales + ) + + b_int8pack, b_scales = convert_weight_to_int8pack(b) + # should pass without segfault + weight_int8pack_mm(a, b_int8pack, b_scales) + @onlyCPU @parametrize("m", [32, 35, 36, 40, 64]) @parametrize("k", [32, 35, 36, 40, 64]) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 96a51c33386b..09a2d1f4e5dd 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -544,7 +544,7 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) else: B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] elif op == "2d/3d": - n, k = 7, 13 + n, k = 7, 259 # k is larger here, to validate iterating over k tiles on an op n_align = (n + align - 1) // align * align k_align = (k + align - 1) // align * align if a_row_major: diff --git a/test/test_mps.py b/test/test_mps.py index dd6a488be77e..25e8836c761f 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -29,6 +29,7 @@ from torch.testing._internal.common_dtype import get_all_dtypes, integral_types import torch.backends.mps from torch.distributions import Uniform, Exponential +from torch.utils._python_dispatch import TorchDispatchMode from functools import partial from torch.testing._internal.common_methods_invocations import ( @@ -198,6 +199,13 @@ def test_scaled_dot_product_attention_autocast(self, dtype): y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) self.assertEqual(y.to(y_autocast.dtype), y_autocast) + def test_conv_transpose3d_autocast_fp32(self): + m = nn.ConvTranspose3d(16, 33, 3, stride=2).to("mps") + x = torch.randn(20, 16, 10, 50, 100, device="mps") + with torch.amp.autocast(device_type="mps"): + y = m(x) + self.assertEqual(y.dtype, torch.float32) + def test_gradscaler_mps(self): # big model to force chunking/depth in the gradscaler dispatch class Model(nn.Module): @@ -7735,6 +7743,8 @@ def helper(shape, alpha, op_name, inplace): y = torch.arange(32, device='mps', dtype=torch.int32) self.assertEqual(torch.add(x, y, alpha=2).cpu(), torch.add(x.cpu(), y.cpu(), alpha=2)) self.assertEqual(torch.add(x, 3, alpha=2).cpu(), torch.add(x.cpu(), 3, alpha=2)) + # Regression test for https://github.com/pytorch/pytorch/issues/160208 + self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2)) # Test add def test_add_scalars(self): @@ -9446,6 +9456,78 @@ def test_fast_full_attention(self, dtype, contiguous, head_dim, with_mask): self.run_fast_attention_test(q, k, v, with_mask) + + +class TestSDPAMetaDispatchMode(TorchDispatchMode): + """ + TorchDispatchMode which intercepts the + _scaled_dot_product_attention_math_for_mps aten operator to check that the + meta kernel is correct. + """ + + def __init__(self, test): + self.test = test + super().__init__() + + def __torch_dispatch__(self, func, types, args, kwargs=None): + kwargs = kwargs or {} + res = func(*args, **kwargs) + if func != torch.ops.aten._scaled_dot_product_attention_math_for_mps.default: + return res + + meta_args, meta_kwargs = pytree.tree_map_only(torch.Tensor, lambda t: t.to(device="meta"), (args, kwargs)) + meta_res = func(*meta_args, **meta_kwargs) + + def format_res(res): + return [ + (t.shape, t.stride(), t.dtype) if isinstance(t, torch.Tensor) else t + for t in pytree.tree_flatten(res)[0] + ] + + # Format the output so that we only look at the tensor metadata + self.test.assertEqual(format_res(res), format_res(meta_res)) + return res + + +def create_sdpa_meta_test(): + """ + Creates a new class which takes every test in TestSDPA and adds the + TestSDPAMetaDispatchMode context in order to test the + scaled_dot_product_attention_for_mps meta kernel. This allows us to test all + the branches for the sdpa op. If there are changes to the sdpa kernel + without changing the meta kernel, a torch.compile guard will catch the issue + but not necessarily export. + """ + orig_test_cls = TestSDPA + + new_test_cls = type(f"{orig_test_cls.__name__}Meta", orig_test_cls.__bases__, {}) + new_test_cls.__qualname__ = new_test_cls.__name__ + + for name in dir(orig_test_cls): + if name.startswith("test_"): + fn = getattr(orig_test_cls, name) + if not callable(fn): + setattr(new_test_cls, name, getattr(orig_test_cls, name)) + continue + + new_name = f"{name}_meta" + + def new_fn(self, *args, **kwargs): + with TestSDPAMetaDispatchMode(self): + fn(self, *args, **kwargs) + + new_fn.__name__ = new_name + + setattr(new_test_cls, new_name, new_fn) + + elif not hasattr(new_test_cls, name): + setattr(new_test_cls, name, getattr(orig_test_cls, name)) + + return new_test_cls + +TestSDPAMeta = create_sdpa_meta_test() +instantiate_parametrized_tests(TestSDPAMeta) + class TestGatherScatter(TestCaseMPS): def test_slicing_with_step(self): # Slicing with step @@ -12504,10 +12586,11 @@ def test_reduction_utils(self, dtype): if not dtype.is_floating_point: return - x[5] = torch.nan + idx = 25 + x[idx] = torch.nan lib.do_max(z0, z1, x) - self.assertTrue(z0.isnan().all().item(), "results are {z0}, but all elements shold have been nan") - self.assertTrue((z1 == 5).all().item(), "results are {z1}, but all elements shold have been 5") + self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements shold have been nan") + self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements shold have been {idx}") @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16]) def test_atomic_add(self, dtype): @@ -12622,6 +12705,67 @@ def test_resize(self): sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0) self.assertEqual(sparse, sparse_cpu) + @parametrize("dtype", [torch.int8, torch.int16, torch.uint8, torch.int32, torch.int64, + torch.float32, torch.float16, torch.bfloat16, torch.bool]) + def test_coalesce(self, dtype): + indices = torch.tensor([[0, 0, 1, 1], [0, 0, 2, 2]], dtype=torch.int64, device="mps") + values = torch.tensor([1., 2., 3., 4.], dtype=dtype, device="mps") + size = (2, 3) + indices_cpu = indices.cpu() + values_cpu = values.cpu() + sparse_mps = torch.sparse_coo_tensor(indices, values, size, device="mps") + sparse_cpu = torch.sparse_coo_tensor(indices_cpu, values_cpu, size, device="cpu") + coalesced_mps = sparse_mps.coalesce() + coalesced_cpu = sparse_cpu.coalesce() + + self.assertTrue(coalesced_mps.is_coalesced()) + self.assertTrue(coalesced_cpu.is_coalesced()) + self.assertEqual(coalesced_mps._nnz(), 2) + self.assertEqual(coalesced_mps.cpu(), coalesced_cpu) + + def test_already_coalesced_tensor(self): + already_coalesced = self._get_basic_sparse_coo() + result = already_coalesced.coalesce() + self.assertTrue(result.is_coalesced()) + self.assertEqual(result._indices().cpu(), already_coalesced._indices().cpu()) + self.assertEqual(result._values().cpu(), already_coalesced._values().cpu()) + + def test_coalesce_empty_sparse_tensor(self): + empty_indices = torch.zeros((2, 0), dtype=torch.int64, device="mps") + empty_values = torch.tensor([], dtype=torch.float32, device="mps") + empty_sparse = torch.sparse_coo_tensor(empty_indices, empty_values, (3, 3), device="mps") + empty_coalesced = empty_sparse.coalesce() + self.assertTrue(empty_coalesced.is_coalesced()) + self.assertEqual(empty_coalesced._nnz(), 0) + + def test_coalesce_large_tensor(self): + size = (1000000, 1000000) + num_elements = 1000 + + # 800 unique random positions + unique_indices = torch.randint(0, size[0], (2, 800), dtype=torch.int64) + # 200 duplicates by repeating some of the first 200 indices + duplicate_indices = unique_indices[:, :200] + indices = torch.cat([unique_indices, duplicate_indices], dim=1) + # shuffle indices to mix duplicates with unique entries + perm = torch.randperm(indices.size(1)) + indices = indices[:, perm] + + values = torch.randn(num_elements, dtype=torch.float32) + indices_mps = indices.to("mps") + values_mps = values.to("mps") + sparse_mps = torch.sparse_coo_tensor(indices_mps, values_mps, size, device="mps") + sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu") + + self.assertFalse(sparse_mps.is_coalesced()) + coalesced_mps = sparse_mps.coalesce() + coalesced_cpu = sparse_cpu.coalesce() + self.assertTrue(coalesced_mps.is_coalesced()) + self.assertTrue(coalesced_cpu.is_coalesced()) + self.assertEqual(coalesced_mps._nnz(), coalesced_cpu._nnz()) + self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices()) + self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values()) + # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. # This requires mps to be properly registered in the device generic test framework which is not the @@ -12637,6 +12781,7 @@ def test_resize(self): instantiate_parametrized_tests(TestSDPA) instantiate_parametrized_tests(TestSmoothL1Loss) instantiate_parametrized_tests(TestMetalLibrary) +instantiate_parametrized_tests(TestSparseMPS) if __name__ == "__main__": run_tests() diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 116b29033708..f4473aacfb8b 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4444,12 +4444,18 @@ def test_jagged_op_different_output_shape_dim( @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) + @parametrize( + "func", + [torch.nn.functional.softmax, torch.nn.functional.log_softmax], + name_fn=lambda func: func.__name__, + ) def test_softmax_dim( self, device, dtype, requires_grad, components_require_grad, + func, ): """ Softmax passes when reducing on valid reduction dimensions. @@ -4468,7 +4474,7 @@ def test_softmax_dim( for reduce_dim, _ in reduce_dims: nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) - out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) + out_actual = func(nt, dim=reduce_dim) torch._dynamo.disable(self.assertEqual)( len(out_actual.shape), len(output_shape) ) # disable if running on dynamo @@ -4498,12 +4504,10 @@ def test_softmax_dim( reduce_dim, reduce_dim_expected = reduce_dim_tuple if nt.dim() > reduce_dim: - out_actual = torch.nn.functional.softmax( - nt, dim=reduce_dim - ) # nested tensor - out_expected = torch.nn.functional.softmax( - nt.values(), dim=reduce_dim_expected - ) # dense tensor of dimensions 1 less than out_actual + # nested tensor + out_actual = func(nt, dim=reduce_dim) + # dense tensor of dimensions 1 less than out_actual + out_expected = func(nt.values(), dim=reduce_dim_expected) self.assertTrue( torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) ) @@ -4601,8 +4605,13 @@ def test_softmax_dim_reduce_ragged_idx_1( @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) + @parametrize( + "func", + [torch.nn.functional.softmax, torch.nn.functional.log_softmax], + name_fn=lambda func: func.__name__, + ) def test_softmax_reduce_batch_dim( - self, device, dtype, requires_grad, components_require_grad + self, device, dtype, requires_grad, components_require_grad, func ): """ Softmax on NestedTensor fails when trying to reduce across batch dimension. @@ -4627,7 +4636,7 @@ def test_softmax_reduce_batch_dim( RuntimeError, "not supported when reducing across the batch dimension for NestedTensor", ): - out = torch.nn.functional.softmax(nt, dim=reduce_dim) + out = func(nt, dim=reduce_dim) @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @@ -6751,11 +6760,10 @@ def check_forward_backward(skip_backward=False): and check_cudnn and (dtype == torch.float16 or dtype == torch.bfloat16) ): - with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.CUDNN_ATTENTION - ): - check_forward_backward() + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @@ -8573,12 +8581,6 @@ def f(values, offsets): sample_match_fn=lambda device, sample: ("batch_dim" in sample.name), name="broken_select_backward_unbacked", ), - # Bug: no idea what's going on here; needs investigation within AOTAutograd - XFailRule( - op_match_fn=lambda device, op: (op.full_name == "nan_to_num"), - sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), - name="crazy_aot_autograd_bug1", - ), ] COMPILE_BACKWARD_SKIPS_AND_XFAILS = [ diff --git a/test/test_nn.py b/test/test_nn.py index 8c27ff20f74e..904b819a6fc4 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -579,6 +579,22 @@ def test_register_buffer_allows_overwriting_with_same_name(self): m.buffer_name = Buffer(buffer3) self.assertEqual(m.buffer_name, Buffer(buffer3)) + def test_register_buffer_allows_tensor_like_object(self): + class TensorLike: + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + raise NotImplementedError(f"TensorLike.__torch_function__: {func}") + + buffer1 = TensorLike() + buffer2 = TensorLike() + m = nn.Module() + m.register_buffer('buffer_name', buffer1) + self.assertEqual(m.buffer_name, buffer1) + self.assertEqual(m.get_buffer('buffer_name'), buffer1) + m.buffer_name = buffer2 + self.assertEqual(m.buffer_name, buffer2) + self.assertEqual(m.get_buffer('buffer_name'), buffer2) + def test_get_buffer(self): m = nn.Module() buffer1 = torch.randn(2, 3) @@ -8750,6 +8766,7 @@ def rms_norm_reference_fn(i, normalized_shape, weight, eps=None): @onlyNativeDeviceTypes @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32) def test_rmsnorm_epsilon(self, device, dtype): def rms_norm_reference_fn(i, normalized_shape): eps = torch.finfo(i.dtype).eps @@ -8924,6 +8941,7 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps): Y_cpu = group_norm(X.cpu()) self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5) + @expectedFailureMPS # Double is not supported on MPS @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) def test_pad(self, device, dtype): @@ -8955,6 +8973,7 @@ def test_pad(self, device, dtype): out.fill_(4) self.assertTrue(torch.all(torch.abs(inputs) < 2)) + @expectedFailureMPS # Unsupported float64/complex128 @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) def test_ReplicationPad_empty(self, device, dtype): @@ -9093,6 +9112,7 @@ def test_Bilinear_empty(self, device): self.assertEqual(inp1.grad, torch.zeros_like(inp1)) self.assertEqual(inp2.grad, torch.zeros_like(inp2)) + @expectedFailureMPS # Double not supported @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes def test_TransformerEncoderLayer_empty(self, device): @@ -9122,6 +9142,7 @@ def test_TransformerEncoderLayer_empty(self, device): _test_module_empty_input(self, encoder_layer, input, check_size=False) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_TransformerEncoder_empty(self, device): for batch_first, input_shape in [(True, (0, 10, 512)), @@ -9132,6 +9153,7 @@ def test_TransformerEncoder_empty(self, device): _test_module_empty_input(self, transformer_encoder, input, check_size=False) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_TransformerDecoderLayer_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), @@ -9142,6 +9164,7 @@ def test_TransformerDecoderLayer_empty(self, device): self._test_module_empty_inputs(decoder_layer, [tgt, memory]) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_TransformerDecoder_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), @@ -9153,6 +9176,7 @@ def test_TransformerDecoder_empty(self, device): self._test_module_empty_inputs(transformer_decoder, [tgt, memory]) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_Transformer_empty(self, device): for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]: @@ -9288,6 +9312,7 @@ def test_ReflectionPad3d_large(self, device): self.assertEqual(x.grad, ref_x.grad) + @expectedFailureMPS # Unimplemented margin_loss @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_MarginLoss_empty(self, device, dtype): @@ -9354,6 +9379,7 @@ def test_mse_loss_error(self, device): with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): F.mse_loss(i, t) + @expectedFailureMPS # TODO: Fixme, and raise assert on empty tensor @onlyNativeDeviceTypes def test_Unfold_empty(self, device): inp = torch.randn(0, 3, 3, 4, device=device) @@ -9577,6 +9603,7 @@ def verify_reduction_scalars(input, reduction, output): verify_reduction_scalars(input, reduction, output) # verify that bogus reduction strings are errors + @expectedFailureMPS # CTCLoss unimplemented @onlyNativeDeviceTypes def test_invalid_reduction_strings(self, device): input = torch.randn(3, 5, requires_grad=True, device=device) @@ -10063,6 +10090,7 @@ def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize @parametrize_test("align_corners", [True, False]) @parametrize_test("mode", ["bilinear", "bicubic"]) @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) + @expectedFailureMPS # double device type @onlyNativeDeviceTypes def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format): # Forward AD does not support XLA because XLA tensors don't have storage @@ -10132,6 +10160,7 @@ def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory @parametrize_test("num_channels", [3, 5]) @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"]) @parametrize_test("dtype", integral_types() + floating_types()) + @skipIfMPS # Error message is wrong for some dtypes @onlyNativeDeviceTypes def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype): x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device) @@ -11454,6 +11483,7 @@ def test_hardsigmoid_grad(self, device): self.assertTrue(gradcheck(F.hardsigmoid, (inputs,))) # currently fails on XLA + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @onlyNativeDeviceTypes def test_hardswish_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 @@ -11661,6 +11691,7 @@ def test_batchnorm_simple_average_mixed(self, device, dtype): self._test_batchnorm_simple_average(device, dtype, torch.float) @onlyNativeDeviceTypes + @expectedFailureMPS # Unsupported Border padding mode @dtypes(torch.float, torch.double) def test_grid_sample_nan_inf(self, device, dtype): input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype) @@ -12773,6 +12804,7 @@ def test_threshold_inplace_overlap(self, device): F.threshold(x, 0.5, 0.5, inplace=True) F.threshold_(x, 0.5, 0.5) + @expectedFailureMPS # Double is unsupported @onlyNativeDeviceTypes def test_triplet_margin_with_distance_loss_default_parity(self, device): # Test for `nn.TripletMarginWithDistanceLoss` and @@ -12807,6 +12839,7 @@ def test_triplet_margin_with_distance_loss_default_parity(self, device): self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n), (anchor, positive, negative))) + @expectedFailureMPS # Double is unsupported @onlyNativeDeviceTypes def test_triplet_margin_with_distance_loss(self, device): # Test for parity between `nn.TripletMarginWithDistanceLoss` and diff --git a/test/test_numa_binding.py b/test/test_numa_binding.py index e1637b2aad96..e89d06174f38 100644 --- a/test/test_numa_binding.py +++ b/test/test_numa_binding.py @@ -2,16 +2,19 @@ from __future__ import annotations +import multiprocessing.spawn as spawn +import os import subprocess import sys +import tempfile from dataclasses import dataclass from typing import Any, Optional -from unittest import skipIf, skipUnless +from unittest import skipUnless from unittest.mock import mock_open, patch import torch from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes -from torch.distributed.numa.binding import ( +from torch.numa.binding import ( _get_ranges_str_from_ints, _get_set_of_int_from_ranges_str, AffinityMode, @@ -35,12 +38,10 @@ class MockDeviceProperties: _real_open = open +_real_mkstemp = tempfile.mkstemp -@skipIf( - sys.platform == "win32", - "Windows is missing various os module attributes like sched_getaffinity", -) +@skipUnless(sys.platform == "linux", "Only linux currently supported") @skipUnless( torch.distributed.is_available(), "Need access to some distributed submodules" ) @@ -53,26 +54,44 @@ def setUp(self) -> None: self._mock_num_logical_cpus = 0 self._mock_num_numa_nodes = 0 self._mock_num_sockets = 0 + self._temp_file_paths = [] self._context_managers_to_apply_to_all_tests = [ patch("torch.cuda.device_count", self._mock_device_count), patch("torch.cuda.get_device_properties", self._mock_get_device_properties), patch("torch.cuda.is_available", self._mock_is_available), + # Implicitly used by dynamo + patch("torch.cuda.get_rng_state"), patch("builtins.open", new=self._mock_open), patch("os.listdir", new=self._mock_listdir), patch("os.sched_getaffinity", new=self._mock_sched_getaffinity), patch("shutil.which", return_value="/usr/bin/numactl"), - patch("subprocess.run"), + patch("torch.numa.binding.run"), + patch("torch.numa.binding.mkstemp", self._mock_mkstemp), ] for context_manager in self._context_managers_to_apply_to_all_tests: context_manager.__enter__() def tearDown(self) -> None: + # Clean up temporary files + for temp_file_path in self._temp_file_paths: + try: + os.unlink(temp_file_path) + except FileNotFoundError: + # File may have already been deleted or doesn't exist + pass + for context_manager in self._context_managers_to_apply_to_all_tests: context_manager.__exit__(None, None, None) super().tearDown() + def _mock_mkstemp(self, *args, **kwargs): + # Just keep track of temp files so we can delete them + fd, path = _real_mkstemp(*args, **kwargs) + self._temp_file_paths.append(path) + return fd, path + def _add_mock_hardware( self, *, @@ -204,7 +223,7 @@ def _mock_get_device_properties(self, index: int) -> MockDeviceProperties: def _mock_open(self, path: str, *args, **kwargs) -> Any: if path in self._mock_file_path_to_contents: return mock_open(read_data=self._mock_file_path_to_contents[path])() - if path.startswith("/sys/"): + if isinstance(path, str) and path.startswith("/sys/"): raise FileNotFoundError(f"File {path} was not mocked.") # Looks like CI is calling open and intending to open an actual file in some places. # Need this to make the CI pass. @@ -222,8 +241,8 @@ def _mock_listdir(self, target_path: str) -> set[str]: def _mock_sched_getaffinity(self, pid: int) -> set[int]: return set(range(self._mock_num_logical_cpus)) - def _start_test_processes_and_get_command_args_for_local_rank( - self, *, numa_options: Optional[NumaOptions], local_rank: int + def _start_processes_for_str_entrypoint_and_get_Popen_args( + self, *, numa_options: Optional[NumaOptions], target_local_rank: int ) -> tuple[str, ...]: """ Calls start_processes like elastic_launch ultimately would @@ -250,10 +269,58 @@ def _start_test_processes_and_get_command_args_for_local_rank( call_args = next( call_args for call_args in mock_popen.call_args_list - if call_args.kwargs.get("env", {}).get("LOCAL_RANK") == str(local_rank) + if call_args.kwargs.get("env", {}).get("LOCAL_RANK") + == str(target_local_rank) ) return call_args.kwargs["args"] + def _start_processes_for_callable_entrypoint_and_get_executable_contents( + self, *, numa_options: Optional[NumaOptions], target_local_rank: int + ) -> str: + active_local_rank = None + executable_path = None + + def _mock_process_start(self: Any) -> None: + nonlocal active_local_rank + active_local_rank = self._args[1] + spawn.get_command_line() + self._target(*self._args) + + original_get_command_line = spawn.get_command_line + + def _mock_get_command_line(*args, **kwargs) -> list[str]: + nonlocal executable_path + result = original_get_command_line(*args, **kwargs) + if active_local_rank == target_local_rank: + executable_path = result[0] + + return result + + with ( + patch("multiprocessing.context.SpawnProcess.start", _mock_process_start), + patch("multiprocessing.spawn.get_command_line", _mock_get_command_line), + patch("multiprocessing.process.BaseProcess.sentinel", 1), + # Prevent hanging + patch( + "multiprocessing.synchronize.Event.wait", + lambda self, timeout=None: None, + ), + ): + start_processes( + name="test_process", + entrypoint=lambda x: x, + args=dict.fromkeys(range(self._mock_device_count()), (0,)), + envs={ + i: {"LOCAL_RANK": str(i)} for i in range(self._mock_device_count()) + }, + logs_specs=DefaultLogsSpecs(), + numa_options=numa_options, + ) + + assert executable_path is not None + with open(executable_path) as executable_file: + return executable_file.read() + def test_node_numa_binding(self) -> None: self._add_mock_hardware( num_sockets=4, @@ -263,8 +330,9 @@ def test_node_numa_binding(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=11 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=11, ) self.assertEqual( command_args, @@ -273,7 +341,6 @@ def test_node_numa_binding(self) -> None: ( "numactl", "--cpunodebind=5", - "--preferred=5", "echo", "Hello, world!", ), @@ -288,8 +355,8 @@ def test_no_numa_binding_if_numa_options_not_provided(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=None, local_rank=11 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=None, target_local_rank=11 ) self.assertEqual( command_args, @@ -340,20 +407,18 @@ def test_fallback(self) -> None: ) with ( - patch("torch.distributed.numa.binding.signpost_event") as signpost_patch, + patch("torch.numa.binding.signpost_event") as signpost_patch, patch( - "subprocess.run", + "torch.numa.binding.run", side_effect=subprocess.CalledProcessError(1, "numactl"), ), ): - command_args = ( - self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions( - affinity_mode=AffinityMode.NODE, - should_fall_back_if_binding_fails=True, - ), - local_rank=0, - ) + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions( + affinity_mode=AffinityMode.NODE, + should_fall_back_if_binding_fails=True, + ), + target_local_rank=0, ) self.assertIn( "subprocess.CalledProcessError", @@ -387,6 +452,25 @@ def test_explicit_numa_options_overrides_default(self) -> None: NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), ) + def test_fork_start_method_does_not_call_get_default_numa_options(self) -> None: + # Inner import to avoid crashing if not torch.distributed.is_available() + from torch.distributed.launcher.api import LaunchConfig + + with patch( + "torch.distributed.launcher.api.get_default_numa_options" + ) as mock_get_default_numa_options: + launch_config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=1, + start_method="fork", + # Don't provide numa_options + ) + # Verify get_default_numa_options was not called + mock_get_default_numa_options.assert_not_called() + # Verify numa_options is None when start_method is fork + self.assertIsNone(launch_config.numa_options) + def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None: self._add_mock_hardware( num_sockets=4, @@ -396,15 +480,15 @@ def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), local_rank=15 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), + target_local_rank=15, ) self.assertEqual( command_args, ( "numactl", "--cpunodebind=6-7", - "--preferred-many=6-7", "echo", "Hello, world!", ), @@ -419,15 +503,15 @@ def test_socket_numa_binding_with_single_numa_per_socket(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), local_rank=7 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), + target_local_rank=7, ) self.assertEqual( command_args, ( "numactl", "--cpunodebind=3", - "--preferred=3", "echo", "Hello, world!", ), @@ -442,8 +526,9 @@ def test_exclusive_numa_binding(self) -> None: num_physical_core_per_l3_cache=3, ) - command_args_0 = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), local_rank=0 + command_args_0 = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), + target_local_rank=0, ) self.assertEqual( command_args_0, @@ -451,14 +536,14 @@ def test_exclusive_numa_binding(self) -> None: "numactl", # Gets an extra physical core due to odd number of physical cores on numa node "--physcpubind=0-3", - "--preferred=0", "echo", "Hello, world!", ), ) - command_args_1 = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), local_rank=1 + command_args_1 = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), + target_local_rank=1, ) self.assertEqual( command_args_1, @@ -466,7 +551,6 @@ def test_exclusive_numa_binding(self) -> None: "numactl", # Does not get an extra physical core, since the 1st GPU already took the extra. "--physcpubind=4-5", - "--preferred=0", "echo", "Hello, world!", ), @@ -485,9 +569,9 @@ def test_exclusive_raises_if_too_few_physical_cores(self) -> None: RuntimeError, "There are only 1 physical cores on numa_node_index=0, but there are 2 GPUs associated with this NUMA node.", ): - self._start_test_processes_and_get_command_args_for_local_rank( + self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), - local_rank=1, + target_local_rank=1, ) def test_core_complex_numa_binding_with_extra_l3(self) -> None: @@ -499,9 +583,9 @@ def test_core_complex_numa_binding_with_extra_l3(self) -> None: num_physical_core_per_l3_cache=3, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=3, + target_local_rank=3, ) self.assertEqual( command_args, @@ -509,7 +593,6 @@ def test_core_complex_numa_binding_with_extra_l3(self) -> None: "numactl", # The second L3 on the second numa node "--physcpubind=24-29", - "--preferred=1", "echo", "Hello, world!", ), @@ -524,9 +607,9 @@ def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None: num_physical_core_per_l3_cache=3, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=3, + target_local_rank=3, ) self.assertEqual( command_args, @@ -535,7 +618,6 @@ def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None: # There are only 2 L3 caches, so the 4th GPU shares the same # cores as the 3rd GPU. "--physcpubind=6-11", - "--preferred=1", "echo", "Hello, world!", ), @@ -552,11 +634,9 @@ def test_core_complex_prefers_caches_with_more_cpus(self) -> None: # Only some subset of the CPUs are available this time. with patch("os.sched_getaffinity", return_value={0, 4, 6, 7, 9}): - command_args = ( - self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=0, - ) + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), + target_local_rank=0, ) self.assertEqual( @@ -565,7 +645,6 @@ def test_core_complex_prefers_caches_with_more_cpus(self) -> None: "numactl", # Binds to the second L3 because it has the most available CPUs "--physcpubind=6-7,9", - "--preferred=0", "echo", "Hello, world!", ), @@ -584,42 +663,20 @@ def test_core_complex_tiebreak_prefers_lower_cache_key(self) -> None: num_physical_core_per_l3_cache=1, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=0, + target_local_rank=0, ) self.assertEqual( command_args, ( "numactl", "--physcpubind=0-1", - "--preferred=0", "echo", "Hello, world!", ), ) - def test_raises_error_if_numa_options_provided_for_callable_entrypoint( - self, - ) -> None: - # Inner import to avoid crashing if not torch.distributed.is_available() - from torch.distributed.elastic.agent.server.api import WorkerSpec - - def mock_entrypoint() -> None: - pass - - with self.assertRaisesRegex(ValueError, r".*numa_options.*"): - # not relevant to test, just pass in an arbitrary value - mock_rdzv_handler: Any = 0 - WorkerSpec( - role="trainer", - # Only str entrypoint (e.g. "echo") is currently supported - entrypoint=mock_entrypoint, - local_world_size=8, - rdzv_handler=mock_rdzv_handler, - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), - ) - def test_raises_error_if_numactl_unavailable(self) -> None: self._add_mock_hardware( num_sockets=1, @@ -632,8 +689,9 @@ def test_raises_error_if_numactl_unavailable(self) -> None: patch("shutil.which", return_value=None), self.assertRaisesRegex(RuntimeError, r".*numactl.*"), ): - self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=0 + self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=0, ) def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: @@ -654,20 +712,50 @@ def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: contents="-1", ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=0 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=0, ) self.assertEqual( command_args, ( "numactl", "--cpunodebind=0", - "--preferred=0", "echo", "Hello, world!", ), ) + def test_callable_entrypoint_basic(self) -> None: + self._add_mock_hardware( + num_sockets=4, + num_numa_nodes_per_socket=2, + num_gpus_per_numa_node=2, + num_l3_caches_per_numa_node=4, + num_physical_core_per_l3_cache=2, + ) + + executable_contents = ( + self._start_processes_for_callable_entrypoint_and_get_executable_contents( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=11, + ) + ) + self.assertEqual( + executable_contents, + # There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be + # on numa node 11 // 2 = 5. + f"""#!/bin/bash + +# If this file is more than a few minutes old and still exists on your machine, +# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such +# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163. + +rm -- "$0" +numactl --cpunodebind=5 {sys.executable} "$@" +""", + ) + def test_get_set_of_int_from_ranges_str(self) -> None: self.assertEqual( _get_set_of_int_from_ranges_str("0-2,4,6-7"), {0, 1, 2, 4, 6, 7} diff --git a/test/test_ops.py b/test/test_ops.py index 201b0323a86f..2d5af9966690 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1601,6 +1601,16 @@ def _tensor_requires_grad(x): ) == 0: return + if TEST_WITH_TORCHDYNAMO: + # NOTE: Also for TEST_WITH_TORCHINDUCTOR tests + # Under compile, some ops may be decomposed into supported ops + # So it is okay to have supported_but_unclaimed_* + if ( + len(claimed_but_unsupported_forward) + + len(claimed_but_unsupported_backward) + ) == 0: + return + # Reference operators often support additional dtypes, and that's OK if op in python_ref_db: if ( diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index e0480ba6a684..9faa5ce4b894 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1,7 +1,6 @@ # Owner(s): ["module: __torch_dispatch__"] # ruff: noqa: F841 -import logging import pickle import sys import tempfile @@ -1718,49 +1717,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertEqual(s.device_index, 2) self.assertEqual(s.device_type, 3) - def test_subclass_autograd_device_check(self) -> None: - class NonWrapperSubclass(torch.Tensor): - elem: torch.Tensor - - __slots__ = ["elem"] - - @staticmethod - def __new__(cls, elem, *args, **kwargs): - # Wrong device here! - r = torch.Tensor._make_subclass( - cls, elem.to("meta"), elem.requires_grad - ) - # ...the real tensor is held as an element on the tensor. - r.elem = elem - return r - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(e): - return e.elem if isinstance(e, NonWrapperSubclass) else e - - def wrap(e): - return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e - - rs = tree_map( - wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - ) - logging.getLogger("NonWrapperSubclass").info( - f"{func.__module__}.{func.__name__}", # noqa: G004 - args, - kwargs, - rs, - ) - return rs - - x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True)) - y = torch.randn(2, requires_grad=True) - z = x * y - self.assertIsInstance(z, NonWrapperSubclass) - z.sum().backward(torch.tensor(1)) - self.assertEqual(x.grad, y) - self.assertEqual(y.grad, x) - def test_none_wrapping(self): # A Tensor subclass that returns None when doing add # See LoggingTensor above for more details on the subclass @@ -2513,6 +2469,19 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): with Mode(): torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) + def test_dispatch_uint64(self): + class DummyMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs): + self.last_args = args + return func(*args, **kwargs) + + # Value that could not be intepreted as signed int64 + uarg = 2**63 + 1 + with DummyMode() as m: + a = torch.full((3, 3), uarg, dtype=torch.uint64) + self.assertEqual(m.last_args[1], uarg) + self.assertTrue((a == uarg).all().item()) + class TestPythonDispatcher(TestCase): def test_basic(self): diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 29ea36fd8a5f..91d9a484d3c8 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -14,9 +14,12 @@ from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests +from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) + + def secretly_aliasing(x): return x.view(-1) @@ -493,9 +496,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with SchemaInfoBindTestMode(self) as schemaInfoCheck: x.add(x) - class TestSchemaCheckModeOpInfo(JitTestCase): @ops(op_db, dtypes=OpDTypes.supported) + @slowTestIf(IS_WINDOWS) def test_schema_correctness(self, device, dtype, op): # Currently torch.equal isn't supported with torch.complex32 # There's also errors with complex64 and complex128 diff --git a/test/test_serialization.py b/test/test_serialization.py index 3413366608f4..8fa78cb5da4b 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -61,6 +61,7 @@ ) from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 from torch.utils._import_utils import import_dill +from pickle import UnpicklingError if not IS_WINDOWS: @@ -1356,6 +1357,39 @@ def test_weights_only_error(self, unsafe_global): "file an issue with the following so that we can make `weights_only=True`"): torch.load(f, weights_only=True) + def test_weights_only_blocked_func_error_msg(self): + import datetime + import zoneinfo + + data = { + "a": torch.tensor([1, 2, 3]), + "b": datetime.datetime(2025, 1, 1, 12, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC")), + } + with tempfile.NamedTemporaryFile() as f: + torch.save(data, f) + f.seek(0) + + with torch.serialization.safe_globals([datetime.datetime, getattr, zoneinfo.ZoneInfo]): + with self.assertRaisesRegex(UnpicklingError, ".*_unpickle.*zoneinfo.ZoneInfo.*"): + torch.load(f) + + + def test_weights_only_with_zoneinfo_unpickle_registration_success(self): + import datetime + import zoneinfo + + data = { + "a": torch.tensor([1, 2, 3]), + "b": datetime.datetime(2025, 1, 1, 12, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC")), + } + with tempfile.NamedTemporaryFile() as f: + torch.save(data, f) + f.seek(0) + + with torch.serialization.safe_globals([datetime.datetime, getattr, zoneinfo.ZoneInfo, zoneinfo.ZoneInfo._unpickle]): + loaded_data = torch.load(f) + self.assertEqual(loaded_data, data) + @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): t = torch.randn(1, dtype=torch.cfloat) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 360dc058212a..669f165529e7 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -215,7 +215,7 @@ def test_stable_sort(self, device, dtype): ) @onlyCUDA - @dtypes(torch.uint8) + @dtypes(torch.float16) @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype) diff --git a/test/test_torch.py b/test/test_torch.py index a44831eb4ac0..d55fd1aeb6e8 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -66,6 +66,7 @@ get_all_qint_dtypes, all_types_complex_float8_and, ) from torch.testing._internal.two_tensor import TwoTensor +from torch.testing._internal.common_utils import IS_WINDOWS if TEST_WITH_TORCHINDUCTOR: from torch._inductor.test_case import TestCase @@ -158,6 +159,7 @@ def test_constants(self, device): self.assertEqual(torch.inf, math.inf) @onlyNativeDeviceTypes + @slowTestIf(IS_WINDOWS) @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.uint16, torch.uint32, torch.uint64) @@ -190,6 +192,7 @@ def test_int64_upsample3d(self, device, dtype): @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.uint16, torch.uint32, torch.uint64) + @slowTestIf(IS_WINDOWS) def test_storage(self, device, dtype): v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9) self.assertEqual(v.storage()[0], v[0][0]) @@ -220,6 +223,7 @@ def test_storage(self, device, dtype): torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) + @slowTestIf(IS_WINDOWS) def test_storage_setitem(self, device, dtype): # Skip quantized dtypes for CUDA, since they're not supported if torch.device(device).type == 'cuda': @@ -251,6 +255,7 @@ def test_storage_setitem(self, device, dtype): @skipIfTorchDynamo("Not a suitable test for TorchDynamo") @onlyNativeDeviceTypes + @slowTestIf(IS_WINDOWS) def test_storage_use_count(self, device): a = torch.randn(10, device=device) prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) @@ -261,6 +266,7 @@ def test_storage_use_count(self, device): @xfailIfTorchDynamo @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_tensor_storage_type(self, device, dtype): a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) @@ -271,6 +277,7 @@ def test_tensor_storage_type(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) + @slowTestIf(IS_WINDOWS) def test_tensor_from_storage(self, device, dtype): a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) a_s = a.storage() @@ -288,6 +295,7 @@ def test_tensor_from_storage(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_set_storage(self, device, dtype): a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) a_s = a.storage() @@ -326,6 +334,7 @@ def _check_storage_meta(self, s, s_check): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_typed_storage_meta(self, device, dtype): args_list = [ [], @@ -339,6 +348,7 @@ def test_typed_storage_meta(self, device, dtype): self._check_storage_meta(s, s_check) @onlyNativeDeviceTypes + @slowTestIf(IS_WINDOWS) def test_untyped_storage_meta(self, device): args_list = [ [], @@ -353,6 +363,7 @@ def test_untyped_storage_meta(self, device): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_storage_meta_from_tensor(self, device, dtype): t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) t = t_check.to('meta') @@ -362,6 +373,7 @@ def test_storage_meta_from_tensor(self, device, dtype): self._check_storage_meta(s, s_check) @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_storage_meta_errors(self, device, dtype): s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) @@ -402,6 +414,7 @@ def test_storage_meta_errors(self, device, dtype): @onlyCPU @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_storage_meta_ok(self, device, dtype): s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) @@ -417,6 +430,7 @@ def test_module_share_memory(self): model.share_memory() @dtypes(torch.float32, torch.complex64) + @slowTestIf(IS_WINDOWS) def test_deepcopy(self, device, dtype): from copy import deepcopy a = torch.randn(5, 5, dtype=dtype, device=device) @@ -444,6 +458,7 @@ def test_deepcopy(self, device, dtype): self.assertEqual(deepcopy(a).foo, 3) @dtypes(torch.float32, torch.complex64) + @slowTestIf(IS_WINDOWS) def test_deepcopy_scalar(self, device, dtype): from copy import deepcopy a = torch.tensor(5, dtype=dtype, device=device) @@ -3696,6 +3711,7 @@ def ref_index_select(src, dim, idx): # FIXME: find a test suite for the take operator @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_take(self, device, dtype): idx_size = (4,) @@ -8337,7 +8353,7 @@ def test_print(self): self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''') torch.set_printoptions(sci_mode=False) self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''') + self.assertExpectedInline(str(x), '''tensor([100.0000, 0.0100])''') torch.set_printoptions(sci_mode=None) # reset to the default value # test no leading space if all elements positive diff --git a/test/test_transformers.py b/test/test_transformers.py index 89db8d798c26..05a21569aeac 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -49,7 +49,6 @@ PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, - SM90OrLater, tf32_on_and_off, tf32_enabled, ) @@ -2657,6 +2656,7 @@ def test_cudnn_attention_gqa(self, device): @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + @unittest.expectedFailure # cuDNN currently doesn't support this on SM100+/fails graph validation def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) @@ -2667,7 +2667,7 @@ def test_cudnn_attention_d256_heuristic(self, device): v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True): + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True): actual = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) actual.backward(torch.randn_like(actual)) @@ -2705,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device): @skipIfRocm # No cuDNN Attention - @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + @unittest.skipIf(True, "broken as of cuDNN 9.10") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 b, h = 1, 2 @@ -2720,7 +2720,6 @@ def test_cudnn_attention_fail_d128(self, device): ISSM90 = device_cap == (9, 0) ISSM100 = device_cap == (10, 0) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): - # SM90/100 support d <= 256 as of cuDNN 9.5.1+ if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501: torch.nn.functional.scaled_dot_product_attention(q, k, v) else: @@ -3156,15 +3155,19 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + device_capability = None + if "cuda" in str(device): + device_capability = torch.cuda.get_device_capability() + prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" in os.environ + prefer_cudnn = prefer_cudnn and device_capability and (device_capability == (9, 0) or device_capability == (10, 0)) + # TODO we are currently disabling this by default, lets assert that this returns # FlashAttention, we need to change when we make remove opt-in for cudnn - if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows + elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index d7d9a2b1aab6..9939e8e76ce9 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -54,6 +54,8 @@ ) from torch.utils import _pytree as pytree +from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf + if TEST_SCIPY: import scipy @@ -271,6 +273,7 @@ def _helper_reference_numerics( # and noncontiguities. @suppress_warnings @ops(reference_filtered_ops) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_normal(self, device, dtype, op): tensors = generate_elementwise_unary_tensors( op, device=device, dtype=dtype, requires_grad=False @@ -279,6 +282,7 @@ def test_reference_numerics_normal(self, device, dtype, op): @suppress_warnings @ops(reference_filtered_ops) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_small(self, device, dtype, op): if dtype in (torch.bool,): raise self.skipTest("bool has no small values") @@ -290,6 +294,7 @@ def test_reference_numerics_small(self, device, dtype, op): @suppress_warnings @ops(reference_filtered_ops) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_large(self, device, dtype, op): if dtype in (torch.bool, torch.uint8, torch.int8): raise self.skipTest("bool, uint8, and int8 dtypes have no large values") @@ -304,6 +309,7 @@ def test_reference_numerics_large(self, device, dtype, op): reference_filtered_ops, allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), ) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_extremal(self, device, dtype, op): tensors = generate_elementwise_unary_extremal_value_tensors( op, device=device, dtype=dtype, requires_grad=False @@ -312,6 +318,7 @@ def test_reference_numerics_extremal(self, device, dtype, op): # Tests for testing (non)contiguity consistency @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_vs_every_other(self, device, dtype, op): contig = make_tensor( (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] @@ -328,6 +335,7 @@ def test_contig_vs_every_other(self, device, dtype, op): self.assertEqual(result, expected) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_vs_transposed(self, device, dtype, op): contig = make_tensor( (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] @@ -344,6 +352,7 @@ def test_contig_vs_transposed(self, device, dtype, op): self.assertEqual(result, expected) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_non_contig(self, device, dtype, op): shapes = [(5, 7), (1024,)] for shape in shapes: @@ -360,6 +369,7 @@ def test_non_contig(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_non_contig_index(self, device, dtype, op): contig = make_tensor( (2, 2, 1, 2), @@ -378,6 +388,7 @@ def test_non_contig_index(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_non_contig_expand(self, device, dtype, op): shapes = [(1, 3), (1, 7), (5, 7)] for shape in shapes: @@ -399,6 +410,7 @@ def test_non_contig_expand(self, device, dtype, op): ) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_size1(self, device, dtype, op): contig = make_tensor( (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] @@ -414,6 +426,7 @@ def test_contig_size1(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_size1_large_dim(self, device, dtype, op): contig = make_tensor( (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), @@ -435,6 +448,7 @@ def test_contig_size1_large_dim(self, device, dtype, op): # Tests that computation on a multiple batches is the same as # per-batch computation. @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_batch_vs_slicing(self, device, dtype, op): input = make_tensor( (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 5aa30483deba..fd0fa0290c94 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -11,15 +11,16 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfMPS, + expectedFailureMPS, instantiate_device_type_tests, onlyCPU, onlyNativeDeviceTypes, - onlyNativeDeviceTypesAnd, skipLazy, skipMeta, skipXLA, ) from torch.testing._internal.common_dtype import ( + all_mps_types_and, all_types_and, all_types_and_complex_and, complex_types, @@ -157,8 +158,11 @@ def test_conj_self(self, device, dtype): @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool)) def test_view_dtype_new(self, device, dtype): dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + if device.startswith("mps"): + del dtypes[torch.float64] del dtypes[torch.bool] def generate_inputs(): @@ -271,6 +275,7 @@ def calc_expected_size_and_stride(a, view_dtype): # has a greater element size than the original dtype @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.bool)) def test_view_dtype_upsize_errors(self, device, dtype): dtype_size = torch._utils._element_size(dtype) @@ -372,6 +377,7 @@ def fn(contiguous_input=True, dim0=0, dim1=1): @onlyNativeDeviceTypes @dtypes(*complex_types(), torch.complex32) + @dtypesIfMPS(torch.cfloat, torch.chalf) def test_view_as_real(self, device, dtype): def fn(contiguous_input=True): t = torch.randn(3, 4, dtype=dtype, device=device) @@ -398,9 +404,7 @@ def fn(contiguous_input=True): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) - @dtypesIfMPS( - *integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32) - ) + @dtypesIfMPS(*all_mps_types_and(torch.bool)) def test_view_tensor_split(self, device, dtype): a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9) a_split_dim0 = a.tensor_split(7, 0) @@ -412,6 +416,7 @@ def test_view_tensor_split(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) def test_view_tensor_hsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_hsplit = torch.hsplit(t, 2) @@ -422,6 +427,7 @@ def test_view_tensor_hsplit(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) def test_view_tensor_vsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_vsplit = torch.vsplit(t, 2) @@ -432,6 +438,7 @@ def test_view_tensor_vsplit(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) def test_view_tensor_dsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_dsplit = torch.dsplit(t, 2) @@ -440,9 +447,9 @@ def test_view_tensor_dsplit(self, device, dtype): t[2, 2, 2] = 7 self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) - @onlyNativeDeviceTypesAnd("mps") + @onlyNativeDeviceTypes @dtypes(*all_types_and(torch.half, torch.bfloat16)) - @dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32)) + @dtypesIfMPS(*all_mps_types_and(torch.bool)) def test_imag_noncomplex(self, device, dtype): t = torch.ones((5, 5), dtype=dtype, device=device) @@ -451,6 +458,7 @@ def test_imag_noncomplex(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*complex_types()) + @dtypesIfMPS(torch.cfloat) def test_real_imag_view(self, device, dtype): def compare_with_numpy(contiguous_input=True): t = torch.randn(3, 3, dtype=dtype, device=device) @@ -481,6 +489,7 @@ def compare_with_numpy(contiguous_input=True): self.assertEqual(a[5:].imag, a.imag[5:]) @onlyNativeDeviceTypes + @expectedFailureMPS @dtypes(*complex_types()) def test_conj_imag_view(self, device, dtype) -> None: t = _make_tensor((4, 5), dtype, device) @@ -512,6 +521,12 @@ def test_conj_view_with_shared_memory(self, device) -> None: all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), ) ) + @dtypesIfMPS( + *product( + [torch.cfloat, torch.chalf], + all_mps_types_and(torch.cfloat, torch.chalf, torch.bool), + ) + ) @suppress_warnings def test_set_real_imag(self, device, dtypes): x = torch.randn(10, dtype=dtypes[0], device=device) diff --git a/test/test_xpu.py b/test/test_xpu.py index cd5275418c44..beb5a53a4a6b 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,5 +1,6 @@ # Owner(s): ["module: intel"] +import gc import re import subprocess import sys @@ -520,6 +521,42 @@ def test_device_memory_allocated(self): ) del a + def test_memory_stats(self): + gc.collect() + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() + torch.xpu.reset_accumulated_memory_stats() + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertEqual(prev_allocated, prev_max_allocated) + self.assertEqual(prev_reserved, prev_max_reserved) + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device="xpu") + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + @skipXPUIf( int(torch.version.xpu) < 20250000, "Test requires SYCL compiler version 2025.0.0 or newer.", diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index 103a71d5debb..1164a2b67636 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -12,6 +12,9 @@ import numpy as np import torch +import torch._inductor.decomposition +from torch._higher_order_ops.out_dtype import out_dtype +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( dtypes, @@ -1364,6 +1367,85 @@ def test_mm_with_offset(self, device): cpu_out = torch.matmul(a.cpu(), b.cpu()) self.assertEqual(gpu_out.cpu(), cpu_out) + @parametrize("m", [0, 8, 17]) + @parametrize("k", [0, 16, 32]) + @parametrize("n", [16, 32]) + @parametrize("use_transpose_a", [True, False]) + @parametrize("use_transpose_b", [True, False]) + @parametrize("non_contig_type", [0, 1, 2]) + def test__int_mm( + self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type + ): + # non_contig_type: + # 0: the whole data buffer is contiguous (can be transposed) + # 1: stride of one dimension is 1, but the whole buffer is not contiguous + # 2: Neither stride is 1 + + def genf_int_float(x, y, use_transpose, non_contig_type): + if use_transpose: + x, y = y, x + if non_contig_type != 0: + y = y * 2 + x_int8 = torch.randint(-128, 127, (x, y), dtype=torch.int8, device=device) + x_float = x_int8.to(torch.float32) + if non_contig_type == 1: + x_int8 = x_int8[:, : y // 2] + x_float = x_float[:, : y // 2] + elif non_contig_type == 2: + x_int8 = x_int8[:, ::2] + x_float = x_float[:, ::2] + if use_transpose: + return x_int8.t(), x_float.t() + return x_int8, x_float + + if non_contig_type != 0 and (m == 0 or k == 0): + return + a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type) + b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type) + c_int32 = torch._int_mm(a_int8, b_int8) + self.assertTrue(c_int32.dtype is torch.int32) + self.assertEqual(c_int32.device, torch.device(device)) + self.assertEqual(c_int32.float(), torch.mm(a_float, b_float)) + c_int32_result = c_int32.new_empty(c_int32.size()) + # Checking out variant + torch._int_mm(a_int8, b_int8, out=c_int32_result) + self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) + + def test_out_dtype_inductor_decomp_trace(self, device) -> None: + def func(x, w): + return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w) + + w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=device) + x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=device) + + # Check that make_fx with inductor decomps produces _int_mm + decomp_table = torch._inductor.decomposition.select_decomp_table() + gm = make_fx(func, decomp_table, tracing_mode="symbolic")(x, w) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x_1, w_1): + _int_mm = torch.ops.aten._int_mm.default(x_1, w_1); x_1 = w_1 = None + return _int_mm""", + ) + + def test_out_dtype_int_mm_default_trace(self, device) -> None: + def func(x, w): + return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w) + + w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=device) + x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=device) + + # By default, out_dtype is preserved in the trace + gm = make_fx(func, tracing_mode="symbolic")(x, w) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x_1, w_1): + out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, x_1, w_1); x_1 = w_1 = None + return out_dtype""", + ) + instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) diff --git a/third_party/composable_kernel b/third_party/composable_kernel index df6023e305f3..7fe50dc3da20 160000 --- a/third_party/composable_kernel +++ b/third_party/composable_kernel @@ -1 +1 @@ -Subproject commit df6023e305f389bbf7249b0c4414e649f3ad6598 +Subproject commit 7fe50dc3da2069d6645d9deb8c017a876472a977 diff --git a/third_party/fbgemm b/third_party/fbgemm index 0adf628317e0..21c7d30c526c 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 0adf628317e0cea414f66dcca901e0b85280fdb1 +Subproject commit 21c7d30c526c0f1ad873ecc632dca6cfa8a69067 diff --git a/third_party/tensorpipe b/third_party/tensorpipe index 52791a2fd214..dacda0567d9f 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit 52791a2fd214b2a9dc5759d36725909c1daa7f2e +Subproject commit dacda0567d9f23d4bc503e1c4f84aa65f33ac38a diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index ece345fda4a2..5e5b69b4cb4e 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -7,6 +7,7 @@ LIBUV_COMMON_SRCS = [ "third_party/libuv/src/inet.c", "third_party/libuv/src/random.c", "third_party/libuv/src/strscpy.c", + "third_party/libuv/src/strtok.c", "third_party/libuv/src/threadpool.c", "third_party/libuv/src/timer.c", "third_party/libuv/src/uv-common.c", @@ -37,9 +38,7 @@ LIBUV_POSIX_SRCS = [ LIBUV_LINUX_SRCS = LIBUV_POSIX_SRCS + [ "third_party/libuv/src/unix/proctitle.c", - "third_party/libuv/src/unix/linux-core.c", - "third_party/libuv/src/unix/linux-inotify.c", - "third_party/libuv/src/unix/linux-syscalls.c", + "third_party/libuv/src/unix/linux.c", "third_party/libuv/src/unix/procfs-exepath.c", "third_party/libuv/src/unix/random-getrandom.c", "third_party/libuv/src/unix/random-sysctl-linux.c", @@ -60,6 +59,7 @@ cc_library( "third_party/libuv/src/unix/*.h", ], ), + copts = ["-D_GNU_SOURCE"], visibility = ["//visibility:public"], ) @@ -151,7 +151,7 @@ cc_library( ".", ], copts = [ - "-std=c++14", + "-std=c++17", ], visibility = ["//visibility:public"], deps = [ @@ -168,7 +168,7 @@ cc_library( ".", ], copts = [ - "-std=c++14", + "-std=c++17", ], visibility = ["//visibility:public"], deps = [ diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index db16e3565273..b353d5d0d598 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -2227,6 +2227,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], # doesn't cover iphonesimulator-x86_64 "ovr_config//runtime:arm64-linux-ubuntu-neon": [":arm64_lib"], + "ovr_config//runtime:fbcode-arm64": [":arm64_lib"], "ovr_config//runtime:platform010": [":x86_and_x86_64_lib"], }), ) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index f3cfe7166aa7..b84ebb55a901 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -3a9419c8bb6a98dd3e3cd473c36691fb4abeae40 +1f7a57f50745a429b7da10dddf2e366687659b87 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a778c1a85da0..c050c6cbdc4c 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2904,6 +2904,10 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) +- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + - name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 457b224354fb..9d43de80f129 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -88,7 +88,8 @@ def build_pytorch( ) -> None: my_env = _create_build_env() if ( - not check_negative_env_flag("USE_CUDA") + not check_negative_env_flag("USE_DISTRIBUTED") + and not check_negative_env_flag("USE_CUDA") and not check_negative_env_flag("USE_NCCL") and not check_env_flag("USE_SYSTEM_NCCL") ): diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index 2a9cee36f7bc..4bc268022e28 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -24,6 +24,7 @@ Traceback, ) from tools.flight_recorder.components.utils import ( + add_stack_id_in_entries, align_trace_from_beginning, check_current_entry_match, check_no_missing_dump_files, @@ -391,6 +392,9 @@ def build_db( # Ensure version is consistent across all ranks. check_version(version_by_ranks, version) entries = align_trace_from_beginning(entries) + stack_id_trace_map: dict[str, int] = {} + if args.just_print_entries: + entries, stack_id_trace_map = add_stack_id_in_entries(entries) # flattened database groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( @@ -402,7 +406,9 @@ def build_db( check_no_missing_dump_files(entries, memberships) if args.just_print_entries: - just_print_entries(entries, _groups, _memberships, _pg_guids, args) + just_print_entries( + entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map + ) sys.exit(0) tracebacks, collectives, nccl_calls = build_collectives( diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index ea9b0cf3918c..abd7f5372133 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -67,6 +67,7 @@ def __init__(self: "JobConfig"): ) self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") + self.parser.add_argument("--print_stack_trace", action="store_true") def parse_args( self: "JobConfig", args: Optional[Sequence[str]] diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index dd2eb109aa56..7634226bae52 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -78,9 +78,9 @@ def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]: if prefix is None: prefix = _determine_prefix(files) for f in files: - if f.find(prefix) != 0: + if (offset := f.find(prefix)) == -1: continue - details[f] = read_dump(prefix, os.path.join(root, f)) + details[f] = read_dump(f[:offset] + prefix, os.path.join(root, f)) filecount += 1 if not version: version = str(details[f]["version"]) diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 597ee8e3ceda..ded30fb077cd 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -417,6 +417,7 @@ def __init__( else: self.input_sizes, self.output_sizes = None, None self.collective_seq_id = event["collective_seq_id"] + self.stack_id = event.get("stack_id", -1) self.p2p_seq_id = event["p2p_seq_id"] self.input_dtypes = event["input_dtypes"] self.output_dtypes = event["output_dtypes"] @@ -456,6 +457,7 @@ def __repr__(self) -> str: f"pg_name={self.pg_name}", f"pg_description={self.pg_desc}", f"pg_size={self.pg_size}", + f"stack_id={self.stack_id}", f"state={self.state}", ) return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s) diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 73ec2a13d3be..b68266c79b2c 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -616,6 +616,7 @@ def just_print_entries( _memberships: dict[str, set[Any]], _pg_guids: dict[tuple[str, int], str], args: argparse.Namespace, + stack_id_trace_map: dict[str, int], ) -> None: rows = [] ranks = sorted(all_entries.keys()) @@ -650,6 +651,17 @@ def just_print_entries( logger.info(tabulate(rows, headers=headers)) + if stack_id_trace_map and args.print_stack_trace: + headers = ["stack_id", "frame_stack"] + rows = [] + + for frame, stack_id in sorted( + stack_id_trace_map.items(), key=lambda item: item[1] + ): + rows.append([str(stack_id), frame]) + + logger.info(tabulate(rows, headers=headers)) + def check_no_missing_dump_files( entries: dict[int, Any], memberships: list[Membership] @@ -677,6 +689,27 @@ def get_version_detail(version: str) -> tuple[int, int]: return major, minor +def add_stack_id_in_entries( + entries: dict[int, list[dict[str, Any]]], +) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]: + stack_id = 0 + stack_id_trace_map = {} + for rank in entries: + for dump in entries[rank]: + if dump.get("frames", []): + frames = str(dump["frames"]) + if frames not in stack_id_trace_map: + stack_id_trace_map[frames] = stack_id + dump["stack_id"] = stack_id + stack_id += 1 + else: + dump["stack_id"] = stack_id_trace_map[frames] + else: + dump["stack_id"] = -1 + + return entries, stack_id_trace_map + + def align_trace_from_beginning( entries: dict[int, list[dict[str, Any]]], ) -> dict[int, list[dict[str, Any]]]: diff --git a/tools/linter/adapters/_linter/bracket_pairs.py b/tools/linter/adapters/_linter/bracket_pairs.py index 23f08c9ff739..323f4da88bce 100644 --- a/tools/linter/adapters/_linter/bracket_pairs.py +++ b/tools/linter/adapters/_linter/bracket_pairs.py @@ -16,9 +16,10 @@ def bracket_pairs(tokens: Sequence[TokenInfo]) -> dict[int, int]: """Returns a dictionary mapping opening to closing brackets""" braces: dict[int, int] = {} stack: list[int] = [] + in_fstring = False for i, t in enumerate(tokens): - if t.type == token.OP: + if t.type == token.OP and not in_fstring: if t.string in BRACKETS: stack.append(i) elif inv := BRACKETS_INV.get(t.string): @@ -34,9 +35,11 @@ def bracket_pairs(tokens: Sequence[TokenInfo]) -> dict[int, int]: raise ParseError(t, f"Mismatched braces '{b}' at {begin}") elif t.type == FSTRING_START: stack.append(FSTRING_START) + in_fstring = True elif t.type == FSTRING_END: if stack.pop() != FSTRING_START: raise ParseError(t, "Mismatched FSTRING_START/FSTRING_END") + in_fstring = False if stack: raise ParseError(t, "Left open") return braces diff --git a/tools/linter/adapters/black_linter.py b/tools/linter/adapters/black_linter.py deleted file mode 100644 index c22a89032cfb..000000000000 --- a/tools/linter/adapters/black_linter.py +++ /dev/null @@ -1,225 +0,0 @@ -from __future__ import annotations - -import argparse -import concurrent.futures -import json -import logging -import os -import subprocess -import sys -import time -from enum import Enum -from typing import BinaryIO, NamedTuple - - -IS_WINDOWS: bool = os.name == "nt" - - -class LintSeverity(str, Enum): - ERROR = "error" - WARNING = "warning" - ADVICE = "advice" - DISABLED = "disabled" - - -class LintMessage(NamedTuple): - path: str | None - line: int | None - char: int | None - code: str - severity: LintSeverity - name: str - original: str | None - replacement: str | None - description: str | None - - -def as_posix(name: str) -> str: - return name.replace("\\", "/") if IS_WINDOWS else name - - -def _run_command( - args: list[str], - *, - stdin: BinaryIO, - timeout: int, -) -> subprocess.CompletedProcess[bytes]: - logging.debug("$ %s", " ".join(args)) - start_time = time.monotonic() - try: - return subprocess.run( - args, - stdin=stdin, - capture_output=True, - shell=IS_WINDOWS, # So batch scripts are found. - timeout=timeout, - check=True, - ) - finally: - end_time = time.monotonic() - logging.debug("took %dms", (end_time - start_time) * 1000) - - -def run_command( - args: list[str], - *, - stdin: BinaryIO, - retries: int, - timeout: int, -) -> subprocess.CompletedProcess[bytes]: - remaining_retries = retries - while True: - try: - return _run_command(args, stdin=stdin, timeout=timeout) - except subprocess.TimeoutExpired as err: - if remaining_retries == 0: - raise err - remaining_retries -= 1 - logging.warning( - "(%s/%s) Retrying because command failed with: %r", - retries - remaining_retries, - retries, - err, - ) - time.sleep(1) - - -def check_file( - filename: str, - retries: int, - timeout: int, -) -> list[LintMessage]: - try: - with open(filename, "rb") as f: - original = f.read() - with open(filename, "rb") as f: - proc = run_command( - [sys.executable, "-mblack", "--stdin-filename", filename, "-"], - stdin=f, - retries=retries, - timeout=timeout, - ) - except subprocess.TimeoutExpired: - return [ - LintMessage( - path=filename, - line=None, - char=None, - code="BLACK", - severity=LintSeverity.ERROR, - name="timeout", - original=None, - replacement=None, - description=( - "black timed out while trying to process a file. " - "Please report an issue in pytorch/pytorch with the " - "label 'module: lint'" - ), - ) - ] - except (OSError, subprocess.CalledProcessError) as err: - return [ - LintMessage( - path=filename, - line=None, - char=None, - code="BLACK", - severity=LintSeverity.ADVICE, - name="command-failed", - original=None, - replacement=None, - description=( - f"Failed due to {err.__class__.__name__}:\n{err}" - if not isinstance(err, subprocess.CalledProcessError) - else ( - "COMMAND (exit code {returncode})\n" - "{command}\n\n" - "STDERR\n{stderr}\n\n" - "STDOUT\n{stdout}" - ).format( - returncode=err.returncode, - command=" ".join(as_posix(x) for x in err.cmd), - stderr=err.stderr.decode("utf-8").strip() or "(empty)", - stdout=err.stdout.decode("utf-8").strip() or "(empty)", - ) - ), - ) - ] - - replacement = proc.stdout - if original == replacement: - return [] - - return [ - LintMessage( - path=filename, - line=None, - char=None, - code="BLACK", - severity=LintSeverity.WARNING, - name="format", - original=original.decode("utf-8"), - replacement=replacement.decode("utf-8"), - description="Run `lintrunner -a` to apply this patch.", - ) - ] - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Format files with black.", - fromfile_prefix_chars="@", - ) - parser.add_argument( - "--retries", - default=3, - type=int, - help="times to retry timed out black", - ) - parser.add_argument( - "--timeout", - default=90, - type=int, - help="seconds to wait for black", - ) - parser.add_argument( - "--verbose", - action="store_true", - help="verbose logging", - ) - parser.add_argument( - "filenames", - nargs="+", - help="paths to lint", - ) - args = parser.parse_args() - - logging.basicConfig( - format="<%(threadName)s:%(levelname)s> %(message)s", - level=logging.NOTSET - if args.verbose - else logging.DEBUG - if len(args.filenames) < 1000 - else logging.INFO, - stream=sys.stderr, - ) - - with concurrent.futures.ThreadPoolExecutor( - max_workers=os.cpu_count(), - thread_name_prefix="Thread", - ) as executor: - futures = { - executor.submit(check_file, x, args.retries, args.timeout): x - for x in args.filenames - } - for future in concurrent.futures.as_completed(futures): - try: - for lint_message in future.result(): - print(json.dumps(lint_message._asdict()), flush=True) - except Exception: - logging.critical('Failed at "%s".', futures[future]) - raise - - -if __name__ == "__main__": - main() diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index 137e4637bdb4..05a7a8acf932 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -41,11 +41,6 @@ def main() -> None: parser.add_argument( "--dry-run", help="do not install anything, just print what would be done." ) - parser.add_argument( - "--no-black-binary", - help="do not use pre-compiled binaries from pip for black.", - action="store_true", - ) args = parser.parse_args() @@ -97,8 +92,6 @@ def main() -> None: "Package {package_name} did not have a version specified. " "Please specify a version to produce a consistent linting experience." ) - if args.no_black_binary and "black" in package_name: - pip_args.append(f"--no-binary={package_name}") dry_run = args.dry_run == "1" if dry_run: diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 55ffa429e7f9..ce5f8252a20f 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -2,7 +2,6 @@ import argparse import concurrent.futures -import fnmatch import json import logging import os @@ -13,7 +12,6 @@ from pathlib import Path from typing import NamedTuple -import black import isort import usort @@ -21,44 +19,6 @@ IS_WINDOWS: bool = os.name == "nt" REPO_ROOT = Path(__file__).absolute().parents[3] -# TODO: remove this when it gets empty and remove `black` in PYFMT -USE_BLACK_FILELIST = re.compile( - "|".join( - ( - r"\A\Z", # empty string - *map( - fnmatch.translate, - [ - # ** - # .ci/** - # .github/** - # benchmarks/** - # functorch/** - # tools/** - # torchgen/** - # test/** - # test/[a-h]*/** - # test/[i-j]*/** - # test/[k-m]*/** - # test/optim/** - # test/[p-z]*/**, - # torch/** - # torch/_[a-c]*/** - # torch/_[e-h]*/** - # torch/_i*/** - # torch/_[j-z]*/** - # torch/[a-c]*/** - # torch/d*/** - # torch/[e-m]*/** - # torch/optim/** - # torch/[p-z]*/** - "torch/[p-z]*/**", - ], - ), - ) - ) -) - class LintSeverity(str, Enum): ERROR = "error" @@ -118,23 +78,6 @@ def run_usort(content: str, path: Path) -> str: return usort.usort_string(content, path=path, config=usort_config) -def run_black(content: str, path: Path) -> str: - black_config = black.parse_pyproject_toml(black.find_pyproject_toml((str(path),))) # type: ignore[attr-defined,arg-type] - # manually patch options that do not have a 1-to-1 match in Mode arguments - black_config["target_versions"] = { - black.TargetVersion[ver.upper()] # type: ignore[attr-defined] - for ver in black_config.pop("target_version", []) - } - black_config["string_normalization"] = not black_config.pop( - "skip_string_normalization", False - ) - black_mode = black.Mode(**black_config) - black_mode.is_pyi = path.suffix.lower() == ".pyi" - black_mode.is_ipynb = path.suffix.lower() == ".ipynb" - - return black.format_str(content, mode=black_mode) - - def run_ruff_format(content: str, path: Path) -> str: try: return subprocess.check_output( @@ -166,10 +109,7 @@ def check_file(filename: str) -> list[LintMessage]: # NB: run isort first to enforce style for blank lines replacement = run_isort(replacement, path=path) replacement = run_usort(replacement, path=path) - if USE_BLACK_FILELIST.match(path.absolute().relative_to(REPO_ROOT).as_posix()): - replacement = run_black(replacement, path=path) - else: - replacement = run_ruff_format(replacement, path=path) + replacement = run_ruff_format(replacement, path=path) if original == replacement: return [] diff --git a/tools/linter/adapters/test_device_bias_linter.py b/tools/linter/adapters/test_device_bias_linter.py index 9901d5f3fe52..a2079e4fe810 100644 --- a/tools/linter/adapters/test_device_bias_linter.py +++ b/tools/linter/adapters/test_device_bias_linter.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """ This lint verifies that every Python test file (file that matches test_*.py or -*_test.py in the test folder) has a cuda hard code in `requires_gpu()` -decorated function to ensure that the test not fail on other GPU. - +*_test.py in the test folder) has a cuda hard code in `requires_gpu()` or +`requires_triton()` decorated function or `if HAS_GPU:` guarded main section, +to ensure that the test not fail on other GPU devices. """ from __future__ import annotations @@ -39,21 +39,59 @@ class LintMessage(NamedTuple): DEVICE_BIAS = ["cuda", "xpu", "mps"] +GPU_RELATED_DECORATORS = {"requires_gpu", "requires_triton"} + + +def is_main_has_gpu(tree: ast.AST) -> bool: + def _contains_has_gpu(node: ast.AST) -> bool: + if isinstance(node, ast.Name) and node.id in ["HAS_GPU", "RUN_GPU"]: + return True + elif isinstance(node, ast.BoolOp): + return any(_contains_has_gpu(value) for value in node.values) + elif isinstance(node, ast.UnaryOp): + return _contains_has_gpu(node.operand) + elif isinstance(node, ast.Compare): + return _contains_has_gpu(node.left) or any( + _contains_has_gpu(comp) for comp in node.comparators + ) + elif isinstance(node, (ast.IfExp, ast.Call)): + return False + return False + + for node in ast.walk(tree): + # Detect if __name__ == "__main__": + if isinstance(node, ast.If): + if ( + isinstance(node.test, ast.Compare) + and isinstance(node.test.left, ast.Name) + and node.test.left.id == "__name__" + ): + if any( + isinstance(comp, ast.Constant) and comp.value == "__main__" + for comp in node.test.comparators + ): + for inner_node in node.body: + if isinstance(inner_node, ast.If) and _contains_has_gpu( + inner_node.test + ): + return True + return False class DeviceBiasVisitor(ast.NodeVisitor): - def __init__(self, filename: str): + def __init__(self, filename: str, is_gpu_test_suite: bool) -> None: self.filename = filename self.lint_messages: list[LintMessage] = [] + self.is_gpu_test_suite = is_gpu_test_suite - def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool: + def _has_proper_decorator(self, node: ast.FunctionDef) -> bool: for d in node.decorator_list: - if isinstance(d, ast.Name) and d.id == "requires_gpu": + if isinstance(d, ast.Name) and d.id in GPU_RELATED_DECORATORS: return True if ( isinstance(d, ast.Call) and isinstance(d.func, ast.Name) - and d.func.id == "requires_gpu" + and d.func.id in GPU_RELATED_DECORATORS ): return True return False @@ -62,7 +100,6 @@ def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool: def _check_keyword_device(self, subnode: ast.keyword, msg_prefix: str) -> None: if subnode.arg != "device": return - val = subnode.value if isinstance(val, ast.Constant) and any( bias in val.value for bias in DEVICE_BIAS @@ -105,15 +142,26 @@ def _check_device_methods(self, subnode: ast.Call, msg_prefix: str) -> None: f"{msg_prefix} .to('{arg.value}'), suggest to use .to(GPU_TYPE)", ) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Check if the function is decorated with @requires_gpu, which indicates - # that the function is intended to run on GPU devices (e.g., CUDA or XPU), - # but ensure it does not hardcode the device to CUDA. - if not self._has_requires_gpu_decorator(node): - self.generic_visit(node) - return - - msg_prefix = "`@requires_gpu` function should not hardcode" + def _check_with_statement(self, node: ast.With, msg_prefix: str) -> None: + for item in node.items: + ctx_expr = item.context_expr + if isinstance(ctx_expr, ast.Call): + func = ctx_expr.func + if ( + isinstance(func, ast.Attribute) + and func.attr == "device" + and isinstance(func.value, ast.Name) + and func.value.id == "torch" + and ctx_expr.args + and isinstance(ctx_expr.args[0], ast.Constant) + and any(bias in ctx_expr.args[0].value for bias in DEVICE_BIAS) + ): + self.record( + ctx_expr, + f"{msg_prefix} `with torch.device('{ctx_expr.args[0].value}')`, suggest to use torch.device(GPU_TYPE)", + ) + + def _check_node(self, node: ast.AST, msg_prefix: str) -> None: for subnode in ast.walk(node): if isinstance(subnode, ast.keyword): self._check_keyword_device(subnode, msg_prefix) @@ -121,7 +169,19 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: subnode.func, ast.Attribute ): self._check_device_methods(subnode, msg_prefix) + elif isinstance(subnode, ast.With): + self._check_with_statement(subnode, msg_prefix) + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if self._has_proper_decorator(node): + msg_prefix = ( + "`@requires_gpu` or `@requires_triton` function should not hardcode" + ) + self._check_node(node, msg_prefix) + elif self.is_gpu_test_suite: + # If the function is guarded by HAS_GPU in main(), we still need to check for device bias + msg_prefix = "The test suites is shared amount GPUS, should not hardcode" + self._check_node(node, msg_prefix) self.generic_visit(node) def record(self, node: ast.AST, message: str) -> None: @@ -144,16 +204,16 @@ def check_file(filename: str) -> list[LintMessage]: with open(filename) as f: source = f.read() tree = ast.parse(source, filename=filename) - checker = DeviceBiasVisitor(filename) + is_gpu_test_suite = is_main_has_gpu(tree) + checker = DeviceBiasVisitor(filename, is_gpu_test_suite) checker.visit(tree) - return checker.lint_messages def main() -> None: parser = argparse.ArgumentParser( - description="Detect Device bias in python functions decorated with [require_gpu]" - " that may potentially break support for other GPU devices.", + description="Detect Device bias in functions decorated with requires_gpu/requires_triton" + " or guarded by HAS_GPU block in main() that may break other GPU devices.", fromfile_prefix_chars="@", ) parser.add_argument( diff --git a/tools/nightly.py b/tools/nightly.py index c0af8bccf152..ba66eb702228 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -436,7 +436,7 @@ def python( check=check, text=True, encoding="utf-8", - env={**self._env, **env}, + env={**os.environ, **self._env, **env}, **popen_kwargs, ) @@ -481,7 +481,7 @@ def uv( check=check, text=True, encoding="utf-8", - env={**self._env, **env, "UV_PYTHON": str(python)}, + env={**os.environ, **self._env, **env, "UV_PYTHON": str(python)}, **popen_kwargs, ) diff --git a/tools/packaging/build_wheel.py b/tools/packaging/build_wheel.py index 16e9a87bd963..10c4516a3280 100644 --- a/tools/packaging/build_wheel.py +++ b/tools/packaging/build_wheel.py @@ -4,6 +4,7 @@ import contextlib import logging import os +import re import subprocess import sys import tempfile @@ -16,11 +17,12 @@ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) ROOT_PATH = Path(__file__).absolute().parent.parent.parent SETUP_PY_PATH = ROOT_PATH / "setup.py" REQUIREMENTS_PATH = ROOT_PATH / "requirements.txt" +PYPROJECT_TOML_PATH = ROOT_PATH / "pyproject.toml" def run_cmd( @@ -45,6 +47,79 @@ def interpreter_version(interpreter: str) -> str: return str(version_string.split(" ")[1]) +def get_supported_python_versions() -> list[str]: + """Extract supported Python versions from pyproject.toml classifiers.""" + with open(PYPROJECT_TOML_PATH) as f: + content = f.read() + + # Find Python version classifiers + pattern = r'"Programming Language :: Python :: (\d+\.\d+)"' + matches = re.findall(pattern, content) + + # Sort versions and return them + return sorted(matches, key=lambda x: tuple(map(int, x.split(".")))) + + +def find_python_interpreters(mode: str) -> list[str]: + """Find Python interpreters based on the specified mode.""" + if mode == "manylinux": + return _find_manylinux_interpreters() + else: + raise ValueError(f"Unsupported mode: {mode}") + + +def _find_manylinux_interpreters() -> list[str]: + """Find Python interpreters in manylinux format (/opt/python/).""" + supported_versions = get_supported_python_versions() + interpreters = [] + + python_root = Path("/opt/python") + if not python_root.exists(): + logger.warning("Path /opt/python does not exist, no interpreters found") + return [] + + # Find all python3 binaries in /opt/python/ + python_binaries = list(python_root.glob("*/bin/python3")) + + for python_path in python_binaries: + try: + # Check if it's PyPy (skip it) + version_output = run_cmd( + [str(python_path), "--version"], capture_output=True + ) + version_string = version_output.stdout.decode("utf-8").strip() + + if "PyPy" in version_string: + logger.debug("Skipping PyPy interpreter: %s", python_path) + continue + + # Extract Python version (e.g., "Python 3.9.1" -> "3.9") + match = re.search(r"Python (\d+\.\d+)", version_string) + if not match: + logger.debug("Could not parse version from: %s", version_string) + continue + + python_version = match.group(1) + + # Check if this version is supported + if python_version in supported_versions: + interpreters.append(str(python_path)) + logger.debug( + "Found supported Python %s at %s", python_version, python_path + ) + else: + logger.debug( + "Python %s not in supported versions: %s", + python_version, + supported_versions, + ) + + except subprocess.CalledProcessError as e: + logger.debug("Failed to get version for %s: %s", python_path, e) + continue + return interpreters + + @contextlib.contextmanager def venv(interpreter: str) -> Iterator[str]: # Should this use EnvBuilder? Probably, maybe a good todo in the future @@ -100,6 +175,16 @@ def parse_args() -> argparse.Namespace: " should ideally be full paths, (default: %(default)s)" ), ) + parser.add_argument( + "--find-python", + type=str, + choices=["manylinux"], + help=( + "Automatically find Python interpreters based on the specified mode. " + "Available modes: 'manylinux' (searches /opt/python/ for interpreters " + "matching supported versions in pyproject.toml)" + ), + ) parser.add_argument( "-d", "--destination", @@ -112,7 +197,26 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - pythons = args.python or [sys.executable] + + if args.find_python: + if args.python: + logger.warning( + "Both --python and --find-python specified. Using --find-python and ignoring --python." + ) + pythons = find_python_interpreters(args.find_python) + if not pythons: + logger.error( + "No Python interpreters found with --find-python %s", args.find_python + ) + sys.exit(1) + logger.info( + "Found %d supported Python interpreters: %s", + len(pythons), + ", ".join(pythons), + ) + else: + pythons = args.python or [sys.executable] + build_times: dict[str, float] = dict() if len(pythons) > 1 and args.destination == "dist/": diff --git a/tools/packaging/split_wheel.py b/tools/packaging/split_wheel.py deleted file mode 100644 index fd52c39a22b0..000000000000 --- a/tools/packaging/split_wheel.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Script to build split pytorch wheels - -What is split build / why is it important? - > Split build is splitting the PyTorch build into a libtorch & - > PyTorch python frontend package. This allows us to to publish - > both as separate packages and opens up our ability to have users - > install different libtorch backends per their PyTorch frontend - > - > Example: opening up the door to things like: - > pip install torch[cuda] - > pip install torch[rocm] - > pip install torch[cpu] - > etc. - -Why does this exist? - > Currently our split build requires you to invoke setup.py twice - > Which ends up complicating the build process and adds some level - > of complexity to our setup.py / build invocation for split builds. - > Ideally this script will eventually not be needed but for - > development purposes we should have an easy way to invoke this script -""" - -import argparse -import logging -import os -import subprocess -import sys -from pathlib import Path -from typing import Optional - - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -# NOTE: This will need to be updated if this script is ever moved -ROOT_PATH = Path(__file__).absolute().parents[2] -SETUP_PY_PATH = ROOT_PATH / "setup.py" - - -def requirements_installed() -> bool: - try: - import setuptools # type: ignore[import-untyped] # noqa: F401 - - return True - except ImportError: - logger.error( - "Requirements not installed, run the following command to install:" - ) - logger.error( - " > %s -m pip install -r %s/requirements.txt", sys.executable, ROOT_PATH - ) - return False - - -def setup_py(cmd_args: list[str], extra_env: Optional[dict[str, str]] = None) -> None: - if extra_env is None: - extra_env = {} - cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args] - logger.debug("+ %s", " ".join(cmd)) - subprocess.run( - cmd, - # Give the parent environment to the subprocess - env={**os.environ, **extra_env}, - check=True, - ) - - -def split_build(cmd: str) -> None: - logger.info("Running %s for libtorch wheel", cmd) - setup_py( - [cmd], - extra_env={"BUILD_LIBTORCH_WHL": "1", "BUILD_PYTHON_ONLY": "0"}, - ) - logger.info("Running %s for torch wheel", cmd) - # NOTE: Passing CMAKE_FRESH=1 is necessary here since the torch frontend has it's - # own cmake files that it needs to generate - setup_py( - [cmd], - extra_env={ - "BUILD_LIBTORCH_WHL": "0", - "BUILD_PYTHON_ONLY": "1", - "CMAKE_FRESH": "1", - }, - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - command_subparser = parser.add_subparsers(dest="command") - # Ideally these should mirror setuptools commands if we need support here for that - command_subparser.add_parser("install") - command_subparser.add_parser("bdist_wheel") - command_subparser.add_parser("develop") - return parser.parse_args() - - -def main() -> None: - args = parse_args() - if not requirements_installed(): - sys.exit(1) - split_build(args.command) - - -if __name__ == "__main__": - main() diff --git a/tools/test/set_linter_testdata/python_code.py.txt b/tools/test/set_linter_testdata/python_code.py.txt index 59d2826286a0..e805a3ca92be 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt +++ b/tools/test/set_linter_testdata/python_code.py.txt @@ -30,6 +30,9 @@ class A: set = A().set +# An f string as in https://github.com/pytorch/pytorch/issues/159056 +f_string = f" {h:{w}} " + # Braced sets set1 = {1} diff --git a/tools/test/set_linter_testdata/python_code.py.txt.json b/tools/test/set_linter_testdata/python_code.py.txt.json index 772fba0149f1..22935a7904df 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt.json +++ b/tools/test/set_linter_testdata/python_code.py.txt.json @@ -47,7 +47,7 @@ "char": 7, "code": "SET_LINTER", "description": null, - "line": 35, + "line": 38, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -58,7 +58,7 @@ "char": 9, "code": "SET_LINTER", "description": null, - "line": 35, + "line": 38, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -69,7 +69,7 @@ "char": 7, "code": "SET_LINTER", "description": null, - "line": 36, + "line": 39, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -80,7 +80,7 @@ "char": 12, "code": "SET_LINTER", "description": null, - "line": 36, + "line": 39, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -91,7 +91,7 @@ "char": 15, "code": "SET_LINTER", "description": null, - "line": 38, + "line": 41, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -102,7 +102,7 @@ "char": 36, "code": "SET_LINTER", "description": null, - "line": 38, + "line": 41, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -113,7 +113,7 @@ "char": 17, "code": "SET_LINTER", "description": null, - "line": 41, + "line": 44, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -124,7 +124,7 @@ "char": 22, "code": "SET_LINTER", "description": null, - "line": 41, + "line": 44, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -135,7 +135,7 @@ "char": 30, "code": "SET_LINTER", "description": null, - "line": 41, + "line": 44, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -146,7 +146,7 @@ "char": 50, "code": "SET_LINTER", "description": null, - "line": 41, + "line": 44, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -157,7 +157,7 @@ "char": 10, "code": "SET_LINTER", "description": null, - "line": 44, + "line": 47, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -168,7 +168,7 @@ "char": 51, "code": "SET_LINTER", "description": null, - "line": 44, + "line": 47, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -179,7 +179,7 @@ "char": 75, "code": "SET_LINTER", "description": null, - "line": 44, + "line": 47, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -190,7 +190,7 @@ "char": 77, "code": "SET_LINTER", "description": null, - "line": 44, + "line": 47, "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -203,9 +203,9 @@ "description": null, "line": null, "name": "Suggested fixes for set_linter", - "original": "# Basic tests\nimport tempfile\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = set()\nb = \"set()\"\nc = set\nd = c.set\nf = (\n set(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# Braced sets\n\nset1 = {1}\nset2 = {1, 2}\n\niterator_set = {i for i in range(10)}\n\n# A dict with two sets.\ndict_set = {\"a\": {2, 3}, \"b\": {i for i in range(3)}}\n\n# A set containing an object constructed with a dict and a set\nsos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})}\n", + "original": "# Basic tests\nimport tempfile\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = set()\nb = \"set()\"\nc = set\nd = c.set\nf = (\n set(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# An f string as in https://github.com/pytorch/pytorch/issues/159056\nf_string = f\" {h:{w}} \"\n\n# Braced sets\n\nset1 = {1}\nset2 = {1, 2}\n\niterator_set = {i for i in range(10)}\n\n# A dict with two sets.\ndict_set = {\"a\": {2, 3}, \"b\": {i for i in range(3)}}\n\n# A set containing an object constructed with a dict and a set\nsos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})}\n", "path": "tools/test/set_linter_testdata/python_code.py.txt", - "replacement": "# Basic tests\nimport tempfile\nfrom torch.utils._ordered_set import OrderedSet\n\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = OrderedSet()\nb = \"set()\"\nc = OrderedSet\nd = c.set\nf = (\n OrderedSet(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# Braced sets\n\nset1 = OrderedSet([1])\nset2 = OrderedSet([1, 2])\n\niterator_set = OrderedSet([i for i in range(10)])\n\n# A dict with two sets.\ndict_set = {\"a\": OrderedSet([2, 3]), \"b\": OrderedSet([i for i in range(3)])}\n\n# A set containing an object constructed with a dict and a set\nsos_set = OrderedSet([Something({i: i + 1 for i in range(3)}, OrderedSet([i + 1 for i in range(3)]))])\n", + "replacement": "# Basic tests\nimport tempfile\nfrom torch.utils._ordered_set import OrderedSet\n\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = OrderedSet()\nb = \"set()\"\nc = OrderedSet\nd = c.set\nf = (\n OrderedSet(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# An f string as in https://github.com/pytorch/pytorch/issues/159056\nf_string = f\" {h:{w}} \"\n\n# Braced sets\n\nset1 = OrderedSet([1])\nset2 = OrderedSet([1, 2])\n\niterator_set = OrderedSet([i for i in range(10)])\n\n# A dict with two sets.\ndict_set = {\"a\": OrderedSet([2, 3]), \"b\": OrderedSet([i for i in range(3)])}\n\n# A set containing an object constructed with a dict and a set\nsos_set = OrderedSet([Something({i: i + 1 for i in range(3)}, OrderedSet([i + 1 for i in range(3)]))])\n", "severity": "error" } ] diff --git a/tools/test/set_linter_testdata/python_code.py.txt.lintrunner b/tools/test/set_linter_testdata/python_code.py.txt.lintrunner index 901cd664f96d..4926368e9ab1 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt.lintrunner +++ b/tools/test/set_linter_testdata/python_code.py.txt.lintrunner @@ -30,106 +30,106 @@ tools/test/set_linter_testdata/python_code.py.txt:12:4: Builtin `set` is depreca 13 | ) 14 | ) -tools/test/set_linter_testdata/python_code.py.txt:35:8: Builtin `set` is deprecated - 33 | # Braced sets - 34 | - 35 | set1 = {1} - ^ - 36 | set2 = {1, 2} +tools/test/set_linter_testdata/python_code.py.txt:38:8: Builtin `set` is deprecated + 36 | # Braced sets 37 | + 38 | set1 = {1} + ^ + 39 | set2 = {1, 2} + 40 | -tools/test/set_linter_testdata/python_code.py.txt:35:10: Builtin `set` is deprecated - 33 | # Braced sets - 34 | - 35 | set1 = {1} - ^ - 36 | set2 = {1, 2} +tools/test/set_linter_testdata/python_code.py.txt:38:10: Builtin `set` is deprecated + 36 | # Braced sets 37 | + 38 | set1 = {1} + ^ + 39 | set2 = {1, 2} + 40 | -tools/test/set_linter_testdata/python_code.py.txt:36:8: Builtin `set` is deprecated - 34 | - 35 | set1 = {1} - 36 | set2 = {1, 2} - ^ +tools/test/set_linter_testdata/python_code.py.txt:39:8: Builtin `set` is deprecated 37 | - 38 | iterator_set = {i for i in range(10)} + 38 | set1 = {1} + 39 | set2 = {1, 2} + ^ + 40 | + 41 | iterator_set = {i for i in range(10)} -tools/test/set_linter_testdata/python_code.py.txt:36:13: Builtin `set` is deprecated - 34 | - 35 | set1 = {1} - 36 | set2 = {1, 2} - ^ +tools/test/set_linter_testdata/python_code.py.txt:39:13: Builtin `set` is deprecated 37 | - 38 | iterator_set = {i for i in range(10)} + 38 | set1 = {1} + 39 | set2 = {1, 2} + ^ + 40 | + 41 | iterator_set = {i for i in range(10)} -tools/test/set_linter_testdata/python_code.py.txt:38:16: Builtin `set` is deprecated - 36 | set2 = {1, 2} - 37 | - 38 | iterator_set = {i for i in range(10)} +tools/test/set_linter_testdata/python_code.py.txt:41:16: Builtin `set` is deprecated + 39 | set2 = {1, 2} + 40 | + 41 | iterator_set = {i for i in range(10)} ^ - 39 | - 40 | # A dict with two sets. + 42 | + 43 | # A dict with two sets. -tools/test/set_linter_testdata/python_code.py.txt:38:37: Builtin `set` is deprecated - 36 | set2 = {1, 2} - 37 | - 38 | iterator_set = {i for i in range(10)} +tools/test/set_linter_testdata/python_code.py.txt:41:37: Builtin `set` is deprecated + 39 | set2 = {1, 2} + 40 | + 41 | iterator_set = {i for i in range(10)} ^ - 39 | - 40 | # A dict with two sets. - -tools/test/set_linter_testdata/python_code.py.txt:41:18: Builtin `set` is deprecated - 39 | - 40 | # A dict with two sets. - 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} - ^ 42 | - 43 | # A set containing an object constructed with a dict and a set + 43 | # A dict with two sets. -tools/test/set_linter_testdata/python_code.py.txt:41:23: Builtin `set` is deprecated - 39 | - 40 | # A dict with two sets. - 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} - ^ +tools/test/set_linter_testdata/python_code.py.txt:44:18: Builtin `set` is deprecated 42 | - 43 | # A set containing an object constructed with a dict and a set + 43 | # A dict with two sets. + 44 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 45 | + 46 | # A set containing an object constructed with a dict and a set -tools/test/set_linter_testdata/python_code.py.txt:41:31: Builtin `set` is deprecated - 39 | - 40 | # A dict with two sets. - 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} - ^ +tools/test/set_linter_testdata/python_code.py.txt:44:23: Builtin `set` is deprecated 42 | - 43 | # A set containing an object constructed with a dict and a set + 43 | # A dict with two sets. + 44 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 45 | + 46 | # A set containing an object constructed with a dict and a set -tools/test/set_linter_testdata/python_code.py.txt:41:51: Builtin `set` is deprecated - 39 | - 40 | # A dict with two sets. - 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} - ^ +tools/test/set_linter_testdata/python_code.py.txt:44:31: Builtin `set` is deprecated 42 | - 43 | # A set containing an object constructed with a dict and a set + 43 | # A dict with two sets. + 44 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 45 | + 46 | # A set containing an object constructed with a dict and a set -tools/test/set_linter_testdata/python_code.py.txt:44:11: Builtin `set` is deprecated +tools/test/set_linter_testdata/python_code.py.txt:44:51: Builtin `set` is deprecated 42 | - 43 | # A set containing an object constructed with a dict and a set - 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} + 43 | # A dict with two sets. + 44 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 45 | + 46 | # A set containing an object constructed with a dict and a set + +tools/test/set_linter_testdata/python_code.py.txt:47:11: Builtin `set` is deprecated + 45 | + 46 | # A set containing an object constructed with a dict and a set + 47 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} ^ -tools/test/set_linter_testdata/python_code.py.txt:44:52: Builtin `set` is deprecated - 42 | - 43 | # A set containing an object constructed with a dict and a set - 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} +tools/test/set_linter_testdata/python_code.py.txt:47:52: Builtin `set` is deprecated + 45 | + 46 | # A set containing an object constructed with a dict and a set + 47 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} ^ -tools/test/set_linter_testdata/python_code.py.txt:44:76: Builtin `set` is deprecated - 42 | - 43 | # A set containing an object constructed with a dict and a set - 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} +tools/test/set_linter_testdata/python_code.py.txt:47:76: Builtin `set` is deprecated + 45 | + 46 | # A set containing an object constructed with a dict and a set + 47 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} ^ -tools/test/set_linter_testdata/python_code.py.txt:44:78: Builtin `set` is deprecated - 42 | - 43 | # A set containing an object constructed with a dict and a set - 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} +tools/test/set_linter_testdata/python_code.py.txt:47:78: Builtin `set` is deprecated + 45 | + 46 | # A set containing an object constructed with a dict and a set + 47 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} ^ diff --git a/tools/test/set_linter_testdata/python_code.py.txt.python b/tools/test/set_linter_testdata/python_code.py.txt.python index 3dbabd3e7845..52aaf12f2631 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt.python +++ b/tools/test/set_linter_testdata/python_code.py.txt.python @@ -32,6 +32,9 @@ class A: set = A().set +# An f string as in https://github.com/pytorch/pytorch/issues/159056 +f_string = f" {h:{w}} " + # Braced sets set1 = OrderedSet([1]) diff --git a/tools/test/test_set_linter.py b/tools/test/test_set_linter.py index 9e879c8c6112..003096c3c408 100644 --- a/tools/test/test_set_linter.py +++ b/tools/test/test_set_linter.py @@ -77,6 +77,7 @@ def test_match_braced_sets(self) -> None: ("{i for i in range(2, 3)}", 1), ("{1, 2}", 1), ("{One({'a': 1}), Two([{}, {2}, {1, 2}])}", 3), + ('f" {h:{w}} "', 0), ) for s, expected in TESTS: pf = SetLinter.make_file(s) diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 28ff5bc3ff29..96aee230f89f 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -13,7 +13,7 @@ def parse_test_module(test: str) -> str: - return test.split(".")[0] + return test.split(".", maxsplit=1)[0] def discover_tests( diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py index e698cf3586dd..e0ef858b96b2 100644 --- a/tools/testing/modulefinder_determinator.py +++ b/tools/testing/modulefinder_determinator.py @@ -186,7 +186,7 @@ def get_dep_modules(test: str) -> set[str]: def parse_test_module(test: str) -> str: - return test.split(".")[0] + return test.split(".", maxsplit=1)[0] def print_to_stderr(message: str) -> None: diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 8d761068d1e6..1632147f0220 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -265,7 +265,7 @@ add_custom_command( OUTPUT "${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi" COMMAND - ${CMAKE_COMMAND} -E env PYTHONPATH="${TORCH_ROOT}" + ${CMAKE_COMMAND} -E env --modify PYTHONPATH=path_list_prepend:"${TORCH_ROOT}" -- "${Python_EXECUTABLE}" ${TORCH_SRC_DIR}/utils/data/datapipes/gen_pyi.py DEPENDS "${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi.in" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9e03c7dba830..fb7e9c5ce56e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2435,6 +2435,11 @@ def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... +def _accelerator_isAllocatorInitialized() -> _bool: ... +def _accelerator_emptyCache() -> None: ... +def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... +def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... +def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f0413764cda6..9007d3fbf5a0 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -315,6 +315,7 @@ class Backend: def options(self) -> Options: ... def rank(self) -> int: ... def size(self) -> int: ... + def name(self) -> str: ... def abort(self) -> None: ... def shutdown(self) -> None: ... def eager_connect_single_device(self, device: torch.device | None) -> None: ... diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 6261679dcdef..117795db5ac3 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -2,12 +2,9 @@ import enum import types from typing import Optional, overload -from torch._dynamo.types import ( - DynamoCallback, - DynamoGuardCompleteHook, - DynamoGuardHook, - GuardFn, -) +from torch._dynamo.guards import GuardManagerWrapper +from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook +from torch._guards import CompileId def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -25,14 +22,20 @@ def raise_sigtrap() -> None: ... class _CacheEntry: def check_fn(self, *args: object, **kwargs: object) -> bool: ... + def update_diff_guard_root_manager(self) -> None: ... code: types.CodeType + compile_id: CompileId + # If we run into circular issues, just use object + guard_manager: GuardManagerWrapper next: _CacheEntry | None class _PrecompileEntry: - guard_manager: GuardFn + guard_manager: GuardManagerWrapper class _ExtraState: - def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ... + def invalidate( + self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper + ) -> None: ... class _FrameAction(enum.IntEnum): DEFAULT = 0 @@ -69,7 +72,9 @@ py_opcode_caches: list[int] def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... def _load_precompile_entry( - code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType + code: types.CodeType, + guard_manager: GuardManagerWrapper, + dynamo_code: types.CodeType, ) -> None: ... def _reset_precompile_entries(code: types.CodeType) -> None: ... def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ... diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 772eb5bd50d0..64800504f479 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -7,8 +7,15 @@ class GlobalStateGuard: def check(self) -> bool: ... def reason(self) -> str: ... -class LeafGuard: ... -class GuardDebugInfo: ... +class LeafGuard: + def verbose_code_parts(self) -> list[str]: ... + +class RelationalGuard: ... + +class GuardDebugInfo: + verbose_code_parts: list[str] + result: bool + num_guards_executed: int class GuardManager: def check(self, value) -> bool: ... @@ -36,6 +43,84 @@ class GuardManager: example_value, guard_manager_enum, ) -> GuardManager: ... + def grad_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def generic_getattr_manager( + self, + attr: str, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def get_generic_dict_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def list_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tuple_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def set_getitem_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def func_defaults_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def func_kwdefaults_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tuple_iterator_getitem_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def weakref_call_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def call_function_no_args_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... def global_weakref_manager( self, global_name: str, @@ -91,7 +176,44 @@ class GuardManager: example_value, guard_manager_enum, ) -> GuardManager: ... - + def get_root(self) -> RootGuardManager: ... + def get_source(self) -> str: ... + def fail_count(self) -> int: ... + def get_child_managers(self) -> list[GuardManager]: ... + def repr(self) -> str: ... + def type_of_guarded_value(self) -> str: ... + def get_leaf_guards(self) -> list[LeafGuard]: ... + def get_accessors(self) -> list[GuardManager]: ... + def is_guarded_value_immutable(self) -> bool: ... + def is_tag_safe(self) -> bool: ... + def is_tag_safe_root(self) -> bool: ... + def has_no_accessors(self) -> bool: ... + def has_object_aliasing_guard(self) -> bool: ... + def get_type_of_guarded_value(self) -> type: ... + def type_dict_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def type_mro_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def code_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def closure_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... # Leaf guards def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ... def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ... @@ -106,7 +228,94 @@ class GuardManager: def add_torch_function_mode_stack_guard( self, initial_stack, verbose_code_parts: list[str] ) -> None: ... - def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ... + def add_mapping_keys_guard(self, value, verbose_code_parts: list[str]) -> None: ... + def add_dict_length_check_guard( + self, value, verbose_code_parts: list[str] + ) -> None: ... + def add_length_check_guard(self, value, verbose_code_parts: list[str]) -> None: ... + def add_true_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_false_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_none_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_not_none_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_dispatch_key_set_guard( + self, + dispatch_key, + verbose_code_parts: list[str], + ) -> None: ... + def add_tensor_match_guard( + self, + value, + sizes, + strides, + tensor_name, + verbose_code_parts: list[str], + ptype, + dispatch_keys, + ) -> None: ... + def add_dynamic_indices_guard( + self, + value, + verbose_code_parts: list[str], + ) -> None: ... + def add_no_hasattr_guard( + self, + attr_name, + verbose_code_parts: list[str], + ) -> None: ... + def add_dict_contains_guard( + self, + contains, + key, + verbose_code_parts: list[str], + ) -> None: ... + def add_type_match_guard( + self, + value, + verbose_code_parts: list[str], + ) -> None: ... + def add_dict_version_guard( + self, + value, + verbose_code_parts: list[str], + ) -> None: ... + def add_set_contains_guard( + self, + contains, + item, + verbose_code_parts: list[str], + ) -> None: ... + def add_tuple_iterator_length_guard( + self, + length, + type_id, + verbose_code_parts: list[str], + ) -> None: ... + def add_range_iterator_match_guard( + self, + start, + stop, + step, + type_id, + verbose_code_parts: list[str], + ) -> None: ... + def add_default_device_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def mark_tag_safe(self) -> None: ... + def mark_tag_safe_root(self) -> None: ... class RootGuardManager(GuardManager): def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... @@ -118,6 +327,7 @@ class RootGuardManager(GuardManager): def clone_manager( self, clone_filter_fn: Callable[[GuardManager], bool] ) -> RootGuardManager: ... + def attach_compile_id(self, compile_id: str) -> None: ... class DictGuardManager(GuardManager): def get_key_manager( @@ -134,15 +344,29 @@ class DictGuardManager(GuardManager): example_value, guard_manager_enum, ) -> GuardManager: ... + def get_key_value_managers( + self, + ) -> dict[int, tuple[GuardManager, GuardManager]]: ... # Guard accessor stubs class GuardAccessor: ... class DictGetItemGuardAccessor(GuardAccessor): ... class GetGenericDictGuardAccessor(GuardAccessor): ... +class TypeDictGuardAccessor(GuardAccessor): ... +class TypeMROGuardAccessor(GuardAccessor): ... +class ClosureGuardAccessor(GuardAccessor): ... +class TupleGetItemGuardAccessor(GuardAccessor): ... +class TypeGuardAccessor(GuardAccessor): ... +class CodeGuardAccessor(GuardAccessor): ... +class FuncDefaultsGuardAccessor(GuardAccessor): ... +class FuncKwDefaultsGuardAccessor(GuardAccessor): ... + +class GetAttrGuardAccessor(GuardAccessor): + def get_attr_name(self) -> str: ... def install_object_aliasing_guard( - guard_managers: list[GuardManager], - tensor_names: list[str], + x: GuardManager, + y: GuardManager, verbose_code_parts: list[str], ): ... def install_no_tensor_aliasing_guard( diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index dd3e9e8fa2dd..208c18e392a4 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -648,7 +648,7 @@ def custom_op_from_existing(op): name = op.name().split("::")[-1] schema_str = str(op._schema) # CustomOp expects the schema string without the namespace - schema_str = schema_str.split("::")[-1] + schema_str = schema_str.rsplit("::", maxsplit=1)[-1] schema = FunctionSchema.parse(schema_str) return CustomOp(lib, ns, schema, name, op, _private_access=True) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 59c11803bb9f..02b921b30ee2 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -51,7 +51,6 @@ from .pgo import reset_code_state from .symbolic_convert import TensorifyState from .utils import ( - create_nested_fn_cache, graph_break_reasons, guard_failures, orig_code_map, @@ -145,7 +144,6 @@ def reset() -> None: torch._dynamo.utils.warn_once_cache.clear() torch._dynamo.utils.user_obj_id_to_weakref.clear() torch._C._autograd._saved_tensors_hooks_set_tracing(False) - create_nested_fn_cache.clear() def reset_code_caches() -> None: diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index bda2494e7a9f..8f411a0d2472 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Provides functionality for compiling PyTorch's autograd (automatic differentiation) system. @@ -22,7 +20,8 @@ import operator import time from collections import Counter, defaultdict -from typing import Optional, TYPE_CHECKING, Union +from collections.abc import Generator, Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree @@ -44,10 +43,11 @@ AutogradLazyBackwardCompileInfo, CachedAutogradLazyBackwardCompileInfo, ) -from torch._guards import compile_context, CompileContext, CompileId +from torch._guards import compile_context, CompileContext, CompileId, Source from torch._logging import getArtifactLogger, trace_structured from torch._prims_common import clone_preserve_strides from torch._subclasses import FakeTensorMode +from torch._subclasses.fake_tensor import FakeTensor from torch.fx import GraphModule from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import ( @@ -61,6 +61,7 @@ ) from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv from torch.fx.traceback import preserve_node_meta, set_stack_trace +from torch.types import FloatLikeType, IntLikeType from torch.utils._ordered_set import OrderedSet from torch.utils._traceback import CapturedTraceback @@ -79,23 +80,23 @@ verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") -def snapshot_verbose_logging_enabled(): +def snapshot_verbose_logging_enabled() -> bool: return torch._logging._internal.log_state.is_artifact_enabled( "compiled_autograd_verbose" ) -def snapshot_cudagraph_enabled(): +def snapshot_cudagraph_enabled() -> bool: return torch._inductor.config.triton.cudagraphs -def maybe_clone(x): +def maybe_clone(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if x is not None: return clone_preserve_strides(x) return x -def extract_bw_module(CompiledFunction): +def extract_bw_module(CompiledFunction: Any) -> Callable[..., Any]: if isinstance( CompiledFunction._lazy_backward_info, AutogradLazyBackwardCompileInfo ): @@ -123,13 +124,13 @@ def extract_bw_module(CompiledFunction): # So different semantics are needed, this implementation below will check # for NaNs at the end of the autograd call, instead of after each node class NaNChecker: - def __init__(self, accumulate_grad: bool): + def __init__(self, accumulate_grad: bool) -> None: self.accumulate_grad = accumulate_grad self.params_indices: list[int] = [] self.params_to_check: dict[str, torch.Tensor] = {} self.output_names: list[str] = [] - def prep_with_graph(self, graph: torch.fx.Graph): + def prep_with_graph(self, graph: torch.fx.Graph) -> None: inputs_node = next(iter(graph.nodes)) acc_grad_nodes = graph.find_nodes( op="call_function", target=call_accumulate_grad @@ -153,7 +154,7 @@ def prep_with_graph(self, graph: torch.fx.Graph): self.output_names = [node.name for node in output_nodes] - def prep_with_inputs(self, inputs: tuple[torch.Tensor]): + def prep_with_inputs(self, inputs: tuple[torch.Tensor]) -> None: if not self.accumulate_grad: # Using .grad, nothing to prep return @@ -169,7 +170,7 @@ def prep_with_inputs(self, inputs: tuple[torch.Tensor]): self.params_to_check[f"inputs[{idx}]"] = inputs[idx] - def check(self, out: tuple[torch.Tensor]): + def check(self, out: tuple[torch.Tensor]) -> None: if self.accumulate_grad: # Using .backward, graph outputs are empty assert not out @@ -202,10 +203,16 @@ def check(self, out: tuple[torch.Tensor]): # function is called. It's possible to avoid lazy binding and instead bind # all of this upfront (perhaps at import time) via codegen changes. class OpNamespace: - def __init__(self): + def __init__(self) -> None: self.custom_function_name_counter: Counter[str] = Counter() - def add(self, name, fn, is_custom_function, is_traceable): + def add( + self, + name: str, + fn: Callable[..., Any], + is_custom_function: bool, + is_traceable: bool, + ) -> str: if is_custom_function: name = "CppNode" + name count = self.custom_function_name_counter[name] @@ -219,28 +226,30 @@ def add(self, name, fn, is_custom_function, is_traceable): else: # C++ autograd function was not marked as traceable # Dynamo can't dry run it at compile time, so must fallback to eager - @torch._dynamo.disable - def run_non_traceable_cpp_in_eager(*args, **kwargs): + @torch._dynamo.disable # type: ignore[misc] + def run_non_traceable_cpp_in_eager(*args: Any, **kwargs: Any) -> Any: return result(*args, **kwargs) setattr(self, name, run_non_traceable_cpp_in_eager) return name - def get(self, name): + def get(self, name: str) -> Any: return getattr(self, name) class Op: - def __init__(self, name, fn, is_custom_function): + def __init__( + self, name: str, fn: Callable[..., Any], is_custom_function: bool + ) -> None: self.fn = fn self.is_custom_function = is_custom_function self.__name__ = name self.__module__ = "torch._dynamo.compiled_autograd.ops" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.fn(*args, **kwargs) - def __repr__(self): + def __repr__(self) -> str: return self.__module__ + "." + self.__name__ @@ -260,7 +269,7 @@ def __repr__(self): COMPILE_COUNTER = itertools.count() -def make_compile_context(compiled_autograd_id): +def make_compile_context(compiled_autograd_id: int) -> Any: return compile_context( CompileContext( CompileId( @@ -273,7 +282,7 @@ def make_compile_context(compiled_autograd_id): class AutogradCompilerInstance: - def __init__(self, compiler_fn) -> None: + def __init__(self, compiler_fn: Callable[..., Any]) -> None: self.compiler_fn = compiler_fn self.stack = contextlib.ExitStack() self.close = self.stack.close @@ -287,12 +296,12 @@ def __init__(self, compiler_fn) -> None: self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") self.hooks_proxy: Optional[Proxy] = None - def wrap_fake(self, x, source): + def wrap_fake(self, x: torch.Tensor, source: Optional[Source]) -> FakeTensor: assert isinstance(x, torch.Tensor) return self.fake_tensor_mode.from_tensor(x, source=source) @staticmethod - def source(name, idx) -> GetItemSource: + def source(name: str, idx: Any) -> GetItemSource: return GetItemSource(LocalSource(name), idx) def begin_capture( @@ -303,14 +312,12 @@ def begin_capture( origins: list[list[tuple[int, str]]], accumulate_grad: bool, check_nans: bool, - ): - global in_compiled_autograd_initial_trace + ) -> tuple[str, list[torch.Tensor], list[IntLikeType], list[FloatLikeType]]: counters["compiled_autograd"]["captures"] += 1 self.id = next(COMPILE_COUNTER) self.aot_id_counter: dict[int, int] = defaultdict(int) self.compile_context = make_compile_context(self.id) self.compile_context.__enter__() - in_compiled_autograd_initial_trace = True self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None self.start_time_ns = time.time_ns() get_chromium_event_logger().log_event_start( @@ -349,7 +356,7 @@ def begin_capture( self.bind_objects_to_proxies(inputs, args_proxy, inputs_origins) # size inputs to symints - sizes = [ + sym_sizes = [ self.shape_env.create_unspecified_symint_and_symbol( val, self.source("sizes", idx), @@ -361,8 +368,8 @@ def begin_capture( # We want to mark every size as dynamic, but since there's no way to # mark a primitive `int` as dynamic, we need to wrap it in a tensor. # In the graph, we unwrap it with `unwrap_maybe_dynamic_int` back into a primitive. - proxies = [self.sizes_proxy[i] for i in range(len(sizes))] # type: ignore[index] - for i, symint in enumerate(sizes): + proxies = [self.sizes_proxy[i] for i in range(len(sym_sizes))] # type: ignore[index] + for i, symint in enumerate(sym_sizes): proxies[i] = self.fx_tracer.create_proxy( "call_function", unwrap_maybe_dynamic_int, @@ -370,7 +377,7 @@ def begin_capture( {}, ) self.symnode_proxy_lookup[symint.node] = proxies[i] - proxies = self.bind_objects_to_proxies(sizes, proxies, sizes_origins) + proxies = self.bind_objects_to_proxies(sym_sizes, proxies, sizes_origins) for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -410,14 +417,14 @@ def begin_capture( return ( str(CompileContext.current_compile_id()), inputs, - sizes, - scalars, + sym_sizes, + scalars, # type: ignore[return-value] ) def log_compile_reasons( self, compile_reasons: list[str], - ): + ) -> None: assert compile_reasons trace_structured( "artifact", @@ -430,13 +437,13 @@ def log_compile_reasons( def proxy_call_aot_backward( self, - pinputs, - psaved_tensors, - saved_tensors, - pctx, - ctx, - maybe_backward_state_idx, - ): + pinputs: Sequence[Any], + psaved_tensors: Sequence[torch.Tensor], + saved_tensors: Sequence[torch.Tensor], + pctx: Any, + ctx: Any, + maybe_backward_state_idx: Optional[int], + ) -> Sequence[Any]: # The AOTBackward call consists of three things: the prologue, the # backward graph, and the epilogue. # Our strategy is: @@ -466,7 +473,11 @@ def proxy_call_aot_backward( ) @torch._dynamo.allow_in_graph # type: ignore[misc] - def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args): + def call_aot_bwd_prologue( + ctx_saved_tensors: Sequence[torch.Tensor], + ctx_symints: Sequence[IntLikeType], + *flat_args: Sequence[Any], + ) -> Any: out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional( ctx_saved_tensors, ctx_symints, @@ -492,8 +503,8 @@ def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args): pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index] # Copy-paste the AOT backward graph into the compiled autograd graph - def copy_paste_aot_backward_graph(): - def num_inputs(graph): + def copy_paste_aot_backward_graph() -> list[torch.Tensor]: + def num_inputs(graph: torch.fx.Graph) -> int: num_args = 0 for node in graph.nodes: if node.op == "placeholder": @@ -505,7 +516,7 @@ def num_inputs(graph): # set up the proxy inputs to bw_module # the calling convention is: [*symints, *args (primals and tangents), backward_state] - num_args = num_inputs(bw_module.graph) + num_args = num_inputs(bw_module.graph) # type: ignore[attr-defined] pall_args = [ pgrads[i] for i in range(num_args - int(pbackward_state is not None)) ] @@ -531,11 +542,11 @@ def num_inputs(graph): deduped_aot_id += f"_{self.aot_id_counter[aot_id]}" self.aot_id_counter[aot_id] += 1 - def make_unique(node_name): + def make_unique(node_name: str) -> str: # make it both informative and unique return f"aot{deduped_aot_id}_{node_name}" - for node in bw_module.graph.nodes: + for node in bw_module.graph.nodes: # type: ignore[attr-defined] if node.op == "placeholder": ph = pall_args[args_idx].node ph.name = make_unique(node.name) @@ -583,7 +594,7 @@ def make_unique(node_name): # In general we don't know what the shapes of the outputs are, so allocate # some dummy sizes for them. - def dummy(): + def dummy() -> torch.Tensor: with disable_proxy_modes_tracing(): return torch.zeros(0, 0, 0, 0, 123) @@ -595,9 +606,11 @@ def dummy(): outputs = copy_paste_aot_backward_graph() - def proxy_subclass_constructor(subclass_meta, is_runtime, unwrapped_args): - @torch._dynamo.allow_in_graph - def make_subclass(*unwrapped_args): + def proxy_subclass_constructor( + subclass_meta: Any, is_runtime: bool, unwrapped_args: Sequence[Any] + ) -> torch.Tensor: + @torch._dynamo.allow_in_graph # type: ignore[misc] + def make_subclass(*unwrapped_args: Any) -> Any: return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args) @@ -624,13 +637,13 @@ def make_subclass(*unwrapped_args): def proxy_call_backward( self, - inputs, - output_metadatas, - saved_tensors, + inputs: Sequence[Any], + output_metadatas: Sequence[Any], + saved_tensors: Sequence[torch.Tensor], backward_idx: int, ctx: torch.autograd.function.BackwardCFunction, maybe_backward_state_idx: Optional[int], - ): + ) -> tuple[Optional[torch.Tensor], ...]: assert self.hooks_proxy is not None pctx = self.hooks_proxy[backward_idx] # type: ignore[index] pinputs = self.to_proxy(inputs) @@ -675,14 +688,14 @@ def proxy_call_backward( def call_copy_slices_prologue( self, - inputs, - base_sizes, - base_strides, - base_storage_offset, - view_sizes, - view_strides, - view_storage_offset, - ): + inputs: Sequence[Any], + base_sizes: Sequence[Any], + base_strides: Sequence[Any], + base_storage_offset: Any, + view_sizes: Sequence[Any], + view_strides: Sequence[Any], + view_storage_offset: Any, + ) -> Sequence[torch.Tensor]: args = ( inputs, self.to_proxy(base_sizes), @@ -694,28 +707,48 @@ def call_copy_slices_prologue( ) return self.proxy_call(copy_slices_prologue, args, [None] * 3) - def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice): + def call_copy_slices_epilogue( + self, + needs_input_grad: Sequence[bool], + result: torch.Tensor, + res: Sequence[Any], + grad_slice: torch.Tensor, + ) -> Sequence[torch.Tensor]: return self.proxy_call( copy_slices_epilogue, (needs_input_grad, result, res, grad_slice), [None] * len(needs_input_grad), ) - def allocate_dummy(self): + def allocate_dummy(self) -> torch.Tensor: with disable_proxy_modes_tracing(): # Weird quantity so it's easy to grep return torch.zeros([0, 123456789]) - def bind_function(self, fn_name, fn, is_custom_function, is_traceable): + def bind_function( + self, + fn_name: str, + fn: Callable[..., Any], + is_custom_function: bool, + is_traceable: bool, + ) -> str: """Binds ops.fn_name = fn""" return ops.add(fn_name, fn, is_custom_function, is_traceable) - def apply_functional(self, fn_name, grads, args, output_metadata): + def apply_functional( + self, + fn_name: str, + grads: Sequence[Any], + args: Any, + output_metadata: Sequence[Any], + ) -> Sequence[torch.Tensor]: """Proxies a call to ops.fn_name(grads, *args) into the graph""" op = ops.get(fn_name) return self.proxy_call(op, (grads, *args), output_metadata) - def proxy_call(self, fn, args, output_metadata): + def proxy_call( + self, fn: Callable[..., Any], args: Any, output_metadata: Sequence[Any] + ) -> Sequence[torch.Tensor]: """Proxies a call to fn(*args) into the graph""" flat_args, _ = pytree.tree_flatten(args) proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) @@ -726,7 +759,9 @@ def proxy_call(self, fn, args, output_metadata): self.bind_objects_to_proxies(result, [proxy_out[i] for i in range(len(result))]) return result - def validate_outputs(self, _, outputs, args, output_metadata): + def validate_outputs( + self, _: Any, outputs: Sequence[Any], args: Any, output_metadata: Sequence[Any] + ) -> Sequence[torch.Tensor]: """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" op = ops.get("validate_outputs") proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) @@ -737,7 +772,7 @@ def validate_outputs(self, _, outputs, args, output_metadata): self.bind_objects_to_proxies(outputs, new_proxy_outputs) return outputs - def accumulate(self, old_var, new_var): + def accumulate(self, old_var: Any, new_var: Any) -> torch.Tensor: old_var_proxy = self.to_proxy(old_var) new_var_proxy = self.to_proxy(new_var) proxy_out = self.fx_tracer.create_proxy( @@ -747,7 +782,9 @@ def accumulate(self, old_var, new_var): self.bind_objects_to_proxies([result], [proxy_out]) return result - def accumulate_grad(self, variable, grad, has_post_hooks): + def accumulate_grad( + self, variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool + ) -> None: self.fx_tracer.create_proxy( "call_function", call_accumulate_grad, @@ -759,7 +796,9 @@ def accumulate_grad(self, variable, grad, has_post_hooks): kwargs={}, ) - def proxy_call_hook(self, hook, *args, **kwargs): + def proxy_call_hook( + self, hook: Callable[..., Any], *args: Any, **kwargs: Any + ) -> torch.fx.Proxy: return self.fx_tracer.create_proxy( "call_function", call_hook, @@ -770,7 +809,7 @@ def proxy_call_hook(self, hook, *args, **kwargs): kwargs, ) - def unpack_hook(self, hook_id, data_id): + def unpack_hook(self, hook_id: int, data_id: int) -> torch.Tensor: assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] data = self.packed_data_proxy[data_id] # type: ignore[index] @@ -783,7 +822,9 @@ def unpack_hook(self, hook_id, data_id): self.bind_objects_to_proxies([out], [proxy]) return out - def tensor_pre_hook(self, inputs, hook_id, i: int): + def tensor_pre_hook( + self, inputs: list[torch.Tensor], hook_id: int, i: int + ) -> list[torch.Tensor]: assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxy = self.proxy_call_hook( @@ -792,11 +833,13 @@ def tensor_pre_hook(self, inputs, hook_id, i: int): hook_type="tensor_pre_hook", ) with disable_proxy_modes_tracing(): - inputs[i] = maybe_clone(inputs[i]) + inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment] self.bind_objects_to_proxies([inputs[i]], [proxy]) return inputs - def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int): + def cpp_tensor_pre_hook( + self, inputs: list[torch.Tensor], hook_id: int, i: int + ) -> list[torch.Tensor]: proxy = self.fx_tracer.create_proxy( "call_function", torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks, @@ -804,11 +847,11 @@ def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int): {}, ) with disable_proxy_modes_tracing(): - inputs[i] = maybe_clone(inputs[i]) + inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment] self.bind_objects_to_proxies([inputs[i]], [proxy]) return inputs - def pre_hook(self, inputs, hook_id): + def pre_hook(self, inputs: Sequence[Any], hook_id: int) -> list[torch.Tensor]: assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( @@ -821,7 +864,9 @@ def pre_hook(self, inputs, hook_id): self.bind_objects_to_proxies(inputs, proxies) return inputs - def post_hook(self, outputs, inputs, hook_id): + def post_hook( + self, outputs: list[torch.Tensor], inputs: Sequence[torch.Tensor], hook_id: int + ) -> list[torch.Tensor]: assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( @@ -831,11 +876,13 @@ def post_hook(self, outputs, inputs, hook_id): hook_type="post_hook", ) with disable_proxy_modes_tracing(): - outputs = [maybe_clone(x) for x in outputs] + outputs = [maybe_clone(x) for x in outputs] # type: ignore[misc] self.bind_objects_to_proxies(outputs, proxies) return outputs - def post_acc_grad_hook(self, input, hook_id): + def post_acc_grad_hook( + self, input: torch.Tensor, hook_id: int + ) -> list[torch.Tensor]: assert isinstance(input, torch.Tensor) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] @@ -845,16 +892,16 @@ def post_acc_grad_hook(self, input, hook_id): hook_type="post_acc_grad_hook", ) with disable_proxy_modes_tracing(): - input = [maybe_clone(input)] - self.bind_objects_to_proxies(input, [proxy]) - return input + res = [maybe_clone(input)] + self.bind_objects_to_proxies(res, [proxy]) + return res # type: ignore[return-value] # Note: [Compiled autograd and cudagraphs] # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. - def move_graph_nodes_to_cuda(self, graph) -> list[int]: + def move_graph_nodes_to_cuda(self, graph: torch.fx.Graph) -> list[int]: to_move: dict[int, torch.fx.Node] = {} has_cuda_inputs = False nodes = list(graph.nodes) @@ -903,7 +950,7 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: return [] - def is_sym_node(self, node): + def is_sym_node(self, node: Any) -> bool: return ( isinstance(node, torch.fx.Node) and node.op == "call_function" @@ -911,7 +958,7 @@ def is_sym_node(self, node): in [torch.ops.aten.sym_size.int, torch.ops.aten.sym_numel.default] ) - def dce(self): + def dce(self) -> None: # Most of these removed nodes would have been removed during Dynamo and AOTDispatch # Remove some of these nodes earlier to improve compilation speed @@ -921,7 +968,7 @@ def dce(self): unpack_nodes.update(node.users.keys()) assert i == len(_graph_placeholders) - 1 - def is_impure(node): + def is_impure(node: torch.fx.Node) -> bool: if node in unpack_nodes or ( node.op == "call_function" and node.target in _impure_targets ): @@ -933,7 +980,7 @@ def is_impure(node): after = len(self.fx_tracer.graph.nodes) verbose_log.debug("DCE removed %d nodes", before - after) - def remove_unused_sizes(self): + def remove_unused_sizes(self) -> set[int]: used_sizes = [] unused_sizes = [] @@ -967,12 +1014,10 @@ def remove_unused_sizes(self): return used_sizes_idx - def create_graph_module(self, id): + def create_graph_module(self, id: str) -> GraphModule: return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) - def end_capture(self, outputs): - global in_compiled_autograd_initial_trace - + def end_capture(self, outputs: Any) -> tuple[Callable[..., Any], Any]: self.fx_tracer.create_proxy( "call_function", FakeCompiledAutogradEngine._exec_final_callbacks_stub, @@ -1050,7 +1095,14 @@ def end_capture(self, outputs): payload_fn=lambda: graph.print_readable(print_output=False), ) - def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): + def runtime_wrapper( + compiled_fn: Callable[..., Any], + inputs: Any, + sizes: Any, + scalars: Any, + hooks: Any, + packed_inputs: Any, + ) -> tuple[Any, Any]: global in_compiled_autograd_region try: in_compiled_autograd_region = True @@ -1089,26 +1141,25 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): log_pt2_compile_event=True, ) self.compile_context.__exit__(None, None, None) - in_compiled_autograd_initial_trace = False return runtime_wrapper, self.compiler_fn(graph) @staticmethod - def get_all_nodes(args): + def get_all_nodes(args: Sequence[Any]) -> list[torch.fx.Node]: # filter out non-Node args, like None nodes = [n for n in args if type(n) is torch.fx.Node] return nodes @staticmethod - def is_placeholder(node): + def is_placeholder(node: torch.fx.Node) -> bool: if node.op == "placeholder" or ( node.op == "call_function" and node.target == operator.getitem - and node.args[0].op == "placeholder" + and node.args[0].op == "placeholder" # type: ignore[union-attr, arg-type] ): return True return False - def reorder_accumulate_grad_nodes(self): + def reorder_accumulate_grad_nodes(self) -> None: """ Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of the graph. This differs from eager mode, which schedules them as soon as possible. This @@ -1129,7 +1180,7 @@ def reorder_accumulate_grad_nodes(self): if getitem_node is not None: arg.append(getitem_node) - def delay_unpack_hook_nodes(self): + def delay_unpack_hook_nodes(self) -> None: """ We can delay unpack hooks until they are needed, even later than in the eager autograd engine. """ @@ -1142,7 +1193,7 @@ def delay_unpack_hook_nodes(self): first_user = min(node.users) first_user.prepend(node) - def reorder_tensor_pre_hook_nodes(self): + def reorder_tensor_pre_hook_nodes(self) -> None: """ Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed to the end of the graph. This differs from eager mode, which schedules @@ -1162,7 +1213,7 @@ def reorder_tensor_pre_hook_nodes(self): input_node.append(getitem_node) getitem_node.append(node) - def reorder_pre_hook_nodes_to_schedule_asap(self): + def reorder_pre_hook_nodes_to_schedule_asap(self) -> None: """ In this function, we schedule the pre hooks as soon as possible. This does not match eager behavior (schedule pre hook right before its @@ -1190,7 +1241,7 @@ def reorder_pre_hook_nodes_to_schedule_asap(self): hook_block.append(n) for a, b in zip(to_remove, to_append): input_nodes.remove(a) - input_nodes.append(b) + input_nodes.append(b) # type: ignore[arg-type] arg = max(input_nodes) # last input if arg is not node.prev and not self.is_placeholder(arg): @@ -1198,7 +1249,7 @@ def reorder_pre_hook_nodes_to_schedule_asap(self): for n in hook_block: getitem_node.append(n) - def reorder_pre_hook_nodes_to_mimic_eager(self): + def reorder_pre_hook_nodes_to_mimic_eager(self) -> None: """ Usage of AOTAutograd causes all the pre_hook nodes to get pushed to the end of the graph. This differs from eager mode, which schedules them @@ -1233,7 +1284,7 @@ def reorder_pre_hook_nodes_to_mimic_eager(self): for getitem in users: registered_node.prepend(getitem) - def reorder_post_acc_grad_hook_nodes(self): + def reorder_post_acc_grad_hook_nodes(self) -> None: """ Usage of AOTAutograd causes all the post_acc_grad_hook nodes to get pushed to the end of the graph. This differs from eager mode, which @@ -1269,7 +1320,7 @@ def reorder_post_acc_grad_hook_nodes(self): acc_grad_node.append(getitem_node) getitem_node.append(node) - def reorder_post_hook_nodes(self): + def reorder_post_hook_nodes(self) -> None: """ Usage of AOTAutograd causes all the post_hook nodes to get pushed to the end of the graph. This differs from eager mode, which schedules them as @@ -1326,7 +1377,7 @@ def reorder_post_hook_nodes(self): arg.append(getitem_node) getitem_node.append(node) - def to_proxy(self, t): + def to_proxy(self, t: Any) -> Any: if t is None: return None if isinstance(t, list): @@ -1343,8 +1394,11 @@ def to_proxy(self, t): return proxy_tensor.proxy def bind_objects_to_proxies( - self, objects, proxies, origins: Optional[list[tuple[int, str]]] = None - ): + self, + objects: Sequence[Any], + proxies: Any, + origins: Optional[list[tuple[int, str]]] = None, + ) -> Sequence[Any]: if isinstance(proxies, torch.fx.Proxy): if origins: assert len(origins) == len(objects) @@ -1361,7 +1415,7 @@ def bind_objects_to_proxies( track_tensor_tree(objects, proxies, constant=None, tracer=self.fx_tracer) return proxies - def bind_backward_state(self, index: int): + def bind_backward_state(self, index: int) -> BackwardState: assert self.hooks_proxy is not None proxy = self.hooks_proxy[index] # type: ignore[index] bw_state = BackwardState() @@ -1373,7 +1427,7 @@ def set_node_origin( node_name: str, nodecall_index: int, pyobj: Optional[torch.autograd.Function], - ): + ) -> None: maybe_aot_id = "" if pyobj is not None: forward_cls = pyobj._forward_cls # type: ignore[attr-defined] @@ -1399,9 +1453,6 @@ def set_node_origin( # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" compiled_autograd_enabled_force_eager = False -# global flag to check if we are capturing for compiled autograd -in_compiled_autograd_initial_trace = False - # global flag to check if we are processing graphs produced from a compiled autograd graph in_compiled_autograd_region = False @@ -1411,7 +1462,11 @@ def set_node_origin( @contextlib.contextmanager -def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): +def _enable( + compiler_fn: Callable[..., Any], + dynamic: bool = True, + ignore_active_disable_ctx: bool = True, +) -> Generator[None, None, None]: # The entrypoint to enable CA. # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather # than using this context manager directly. If you are torch.compiling the corresponding @@ -1483,7 +1538,7 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): @contextlib.contextmanager -def _disable(): +def _disable() -> Generator[None, None, None]: ( prior_compiler, prior_dynamic, @@ -1506,13 +1561,12 @@ def _disable(): # return to starting state of a new process def reset() -> None: - global compiled_autograd_enabled, in_compiled_autograd_initial_trace + global compiled_autograd_enabled compiled_autograd_enabled = False assert not in_compiled_autograd_region torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) torch._C._dynamo.compiled_autograd.set_verbose_logger(None) torch._C._dynamo.compiled_autograd.clear_cache() - in_compiled_autograd_initial_trace = False global COMPILE_COUNTER COMPILE_COUNTER = itertools.count() @@ -1520,14 +1574,14 @@ def reset() -> None: # Reimplementation of part of CopySlices::apply in Python. # The shared code is really similar so we're not going to try to deduplicate. def copy_slices_prologue( - inputs, - base_sizes, - base_strides, - base_storage_offset, - view_sizes, - view_strides, - view_storage_offset, -): + inputs: Sequence[torch.Tensor], + base_sizes: Sequence[IntLikeType], + base_strides: Sequence[IntLikeType], + base_storage_offset: IntLikeType, + view_sizes: Sequence[IntLikeType], + view_strides: Sequence[IntLikeType], + view_storage_offset: IntLikeType, +) -> list[torch.Tensor]: grad = inputs[0] result = grad.new_empty_strided(base_sizes, base_strides) assert grad is not None @@ -1539,14 +1593,21 @@ def copy_slices_prologue( # Reimplementation of part of CopySlices::apply in Python. # The shared code is really similar so we're not going to try to deduplicate. -def copy_slices_epilogue(needs_input_grad, result, res, grad_slice): - grad_inputs = [None] * len(needs_input_grad) +def copy_slices_epilogue( + needs_input_grad: Sequence[bool], + result: torch.Tensor, + res: Sequence[Optional[torch.Tensor]], + grad_slice: torch.Tensor, +) -> list[Optional[torch.Tensor]]: + grad_inputs: list[Optional[torch.Tensor]] = [None] * len(needs_input_grad) for i in range(len(needs_input_grad)): if needs_input_grad[i]: if res[i] is None: continue if i == 0: - grad_slice.copy_(res[i]) + to_copy = res[i] + assert to_copy is not None + grad_slice.copy_(to_copy) grad_inputs[i] = result else: grad_inputs[i] = res[i] diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0d83b7078eae..b8b7561dde16 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -354,6 +354,25 @@ # Skips guards on func.__defaults__ if the element to be guarded is a constant skip_guards_on_constant_func_defaults = True + +# The recursive-dict-tag guard relies on the class/function identity staying +# stable. We therefore assume that the following function dunder attributes +# are **never rebound** to a different object: +# +# • __code__ • __closure__ +# • __defaults__ • __kwdefaults__ +# • __annotations__ • __mro__ +# +# It is fine to mutate the objects they already point to (e.g. tweak an element +# inside __defaults__), but assignments like +# +# foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED +# +# would invalidate the optimization. This type of rebinding is rare, so we +# assume that the rebinding never happens for guard purposes. Set the flag +# below to False only in environments where such rebinding is known to occur. +assume_dunder_attributes_remain_unchanged = True + # Speedup guard execution of nested nn modules by recursively checking for dict # tags to avoid full guard execution. use_recursive_dict_tags_for_guards = True diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index bba4d9c98086..fb27c2993543 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -225,30 +225,35 @@ def fx_forward_from_src_skip_result( return result -def log_dynamo_start(code: CodeType, skip: int = 0) -> None: +def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]: convert_frame_intern = structured.intern_string(__file__) + # Extract and filter the stack + stack = list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] # Initialize the ChromiumEventLogger on start torch._logging.trace_structured( "dynamo_start", - lambda: { - "stack": list( - itertools.takewhile( - lambda f: f["filename"] != convert_frame_intern, - structured.from_traceback( - CapturedTraceback.extract(skip=4 + skip).summary() - ), - ) - ) - + [ - { - "line": code.co_firstlineno, - "name": code.co_name, - "filename": structured.intern_string(code.co_filename), - } - ] - }, + lambda: {"stack": stack}, ) + stack_strings = [ + f"Line: {frame['line']}, Name: {frame['name']}, Filename: {frame['filename']}" + for frame in stack + ] + return stack_strings + def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ @@ -1160,7 +1165,7 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in - log_dynamo_start(code, skip) + stack_trace = log_dynamo_start(code, skip) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None @@ -1300,6 +1305,7 @@ def format_func_info(code: CodeType) -> str: "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), + "stack_trace": stack_trace, } # TODO: replace with CompileEventLogger.compilation_metrics # There are some columns here not in PT2 Compile Events diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fd85b5d28e03..63c2ed9e9bad 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -113,7 +113,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from torch._dynamo.package import CompilePackage + from torch._dynamo.package import CompilePackage, DynamoCaptureOutput from torch._dynamo.repro.after_dynamo import WrapBackendDebug from torch._subclasses import fake_tensor from torch.fx.node import Argument, Node, Target @@ -2288,3 +2288,83 @@ def skip_code(code: types.CodeType) -> None: set_code_exec_strategy( code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) ) + + +@dataclass +class BackendInput: + graph_module: torch.fx.GraphModule + example_inputs: tuple[Any, ...] + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode + + +@dataclass +class CaptureOutput: + """ + Core data structure that contains the all the information dynamo generates + from fullgraph=True. Ideally, this is should be the "return" type if dynamo + has a standard API to return compilation artifacts. + """ + + dynamo_artifacts: DynamoCaptureOutput + backend_inputs: dict[str, BackendInput] + + +def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]: + """ + A helper function which wraps a model and returns a callable like optimize(). + The callable can be called with normal inputs like torch.compile()-ed functions + and user can dump dynamo compilation artifacts through `get_artifacts()` call. + + The CaptureOutput is separated into two parts: + 1. Dynamo specific information from DynamoCaptureOutput, which includes: + - guards + - generated bytecode + - python source information + 2. Backend specific information (indexed by unique backend id) such as: + - fx graph + - example inputs + + Example: + def fn(*args): + ... + + compiled_fn = fullgraph_capture(fn) + compiled_fn(args) + compiled_fn(another_args) + artifacts = compiled_fn.get_artifacts() + """ + from torch._dynamo.package import CompilePackage + + package = CompilePackage(model) + + backend_inputs: dict[str, BackendInput] = {} + + def _backend( + gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...] + ) -> torch.fx.GraphModule: + from torch._guards import TracingContext + + fake_mode = TracingContext.get().fake_mode + assert fake_mode is not None + backend_id = gm._backend_id + assert isinstance(backend_id, str) + backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode) + return gm + + # TODO For now we use eval_frame to give us the frame. This is can be simplified to + # a manual frame creation helper. + optimized_model = optimize(nopython=True, backend=_backend, package=package)(model) + + @functools.wraps(model) + def capture_context(*args: Any, **kwargs: Any) -> Any: + return optimized_model(*args, **kwargs) + + def get_artifacts() -> CaptureOutput: + cache_entry = package.cache_entry() + assert len(cache_entry.codes) == 1 + return CaptureOutput( + dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs + ) + + capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined] + return capture_context diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 5039cf63526c..063617039131 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -264,7 +264,14 @@ class UnsafeScriptObjectError(TorchDynamoException): class UncapturedHigherOrderOpError(TorchDynamoException): - pass + def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None: + super().__init__(msg) + self.msg = msg + self.real_stack = ( + real_stack + if real_stack is not None + else torch._guards.TracingContext.extract_stack() + ) class IncorrectUsage(Exception): @@ -527,7 +534,7 @@ def get_gbid_documentation_link(gb_type: str) -> Optional[str]: A string containing the documentation URL if found, otherwise None. """ GRAPH_BREAK_SITE_URL = ( - "https://pytorch-labs.github.io/compile-graph-break-site/gb/" # @lint-ignore + "https://meta-pytorch.github.io/compile-graph-break-site/gb/" # @lint-ignore ) registry = _load_graph_break_registry() diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 1e42f414478d..7c25d683b475 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2647,9 +2647,47 @@ "Try to construct `torch.nn.Parameter()` outside the compiled region.", "If this is not possible, turn `graph_break_on_nn_param_ctor` off", "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ], - "Additional_Info": [ - "Try to construct nn.Parameter() outside the compiled region. If this is not possible, turn `graph_break_on_nn_param_ctor` off" + ] + } + ], + "GB0265": [ + { + "Gb_type": "FakeScriptObject missing method implementation", + "Context": "value={self.value}, method={name}", + "Explanation": "TorchScript object {self.value} doesn't define the method {name}.", + "Hints": [ + "Ensure the method {name} is implemented in {self.value}.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0266": [ + { + "Gb_type": "Weird method call on TorchScript object", + "Context": "value={self.value}, method={name}", + "Explanation": "This particular method call ({name}) is not supported (e.g. calling `__setattr__`). Most method calls to TorchScript objects should be supported.", + "Hints": [ + "Avoid calling this method." + ] + } + ], + "GB0267": [ + { + "Gb_type": "Attempted to access non-callable attribute of TorchScript object", + "Context": "value={self.value}, method={name}", + "Explanation": "Attribute accesses of TorchScript objects to non-callable attributes are not supported.", + "Hints": [ + "Use method calls instead of attribute access." + ] + } + ], + "GB0268": [ + { + "Gb_type": "Unsupported kwargs for itertools.product", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Expected kwargs: 'repeat', but got {','.join(set(kwargs.keys()) - {'repeat'})}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." ] } ] diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a80aa51799ce..445224319b97 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Core guard system for Dynamo that detects when compiled code needs to be recompiled due to changes in program state. Guards are conditions that must remain true for previously-compiled @@ -40,6 +38,7 @@ from copy import deepcopy from inspect import currentframe from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAliasType, TypeVar from weakref import ReferenceType import torch @@ -49,16 +48,30 @@ from torch._C._dynamo.guards import ( check_obj_id, check_type_id, + ClosureGuardAccessor, + CodeGuardAccessor, dict_version, DictGetItemGuardAccessor, DictGuardManager, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, GetGenericDictGuardAccessor, + GuardAccessor, + GuardDebugInfo, + GuardManager, install_no_tensor_aliasing_guard, install_object_aliasing_guard, install_storage_overlapping_guard, install_symbolic_shape_guard, + LeafGuard, profile_guard_manager, + RelationalGuard, RootGuardManager, + TupleGetItemGuardAccessor, + TypeDictGuardAccessor, + TypeGuardAccessor, + TypeMROGuardAccessor, ) from torch._dynamo.source import ( get_global_source_name, @@ -67,6 +80,7 @@ is_from_flatten_script_object_source, is_from_local_source, is_from_optimizer_source, + is_from_skip_guard_source, is_from_unspecialized_builtin_nn_module_source, TensorProperty, TensorPropertySource, @@ -83,6 +97,7 @@ Source, StorageOverlap, ) +from torch._inductor.utils import IndentedBuffer from torch._logging import structured from torch._utils_internal import justknobs_check from torch.fx.experimental.symbolic_shapes import ( @@ -105,6 +120,8 @@ CallFunctionNoArgsSource, CallMethodItemSource, ChainedSource, + ClosureSource, + CodeSource, ConstantSource, ConstDictKeySource, DataclassFieldsSource, @@ -132,6 +149,8 @@ TorchFunctionModeStackSource, TorchSource, TupleIteratorGetItemSource, + TypeDictSource, + TypeMROSource, TypeSource, UnspecializedBuiltinNNModuleSource, UnspecializedNNModuleSource, @@ -178,11 +197,14 @@ if TYPE_CHECKING: - from sympy import Symbol + from collections.abc import Generator, KeysView, Sequence - from torch._dynamo.output_graph import OutputGraphGuardsState + from sympy import Symbol + from torch._C import DispatchKeySet + from torch._dynamo.output_graph import OutputGraph +T = TypeVar("T") log = logging.getLogger(__name__) guards_log = torch._logging.getArtifactLogger(__name__, "guards") recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") @@ -192,6 +214,28 @@ verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") +dunder_attrs_assumed_constants = ( + "__defaults__", + "__kwdefaults__", + "__code__", + "__closure__", + "__annotations__", + "__func__", + "__mro__", +) + + +class IndentedBufferWithPrefix(IndentedBuffer): + def prefix(self) -> str: + return "| " * (self._indent * self.tabwidth) + + def writeline(self, line: str, skip_prefix: bool = False) -> None: # type: ignore[override] + if skip_prefix: + super().writeline(line) + else: + super().writeline("+- " + line) + + class GuardManagerWrapper: """ A helper class that contains the root guard manager. An instance of this @@ -200,37 +244,38 @@ class is stored in the Dynamo cache entry, so that the cache entry can the check_nopybind from C++. """ - def __init__(self, root=None): + def __init__(self, root: Optional[RootGuardManager] = None) -> None: if root is None: self.root = RootGuardManager() else: self.root = root - self.diff_guard_root = None - self.closure_vars = None - self.args = None - self.code_parts = [] - self.verbose_code_parts = None - self.global_scope = None - self.guard_fail_fn = None - self.cache_entry = None - self.extra_state = None - self.id_matched_objs = {} - self.no_tensor_aliasing_sources = [] + self.diff_guard_root: Optional[RootGuardManager] = None + self.closure_vars: Optional[dict[str, Any]] = None + self.args: Optional[list[str]] = None + self.code_parts: list[str] = [] + self.verbose_code_parts: Optional[list[str]] = None + self.global_scope: Optional[dict[str, Any]] = None + self.guard_fail_fn: Optional[Callable[[GuardFail], None]] = None + self.cache_entry: Optional[CacheEntry] = None + self.extra_state: Optional[ExtraState] = None + self.id_matched_objs: dict[str, ReferenceType[object]] = {} + self.no_tensor_aliasing_sources: list[str] = [] - self.printed_relational_guards = set() + self.printed_relational_guards: set[RelationalGuard] = set() self.diff_guard_sources: OrderedSet[str] = OrderedSet() @contextmanager - def _preserve_printed_relational_guards(self): + def _preserve_printed_relational_guards(self) -> Generator[None, None, None]: self.printed_relational_guards = set() try: yield finally: self.printed_relational_guards = set() - def collect_diff_guard_sources(self): + # TODO: clarify what fn and attributes guard manager has to get the right things here + def collect_diff_guard_sources(self) -> OrderedSet[str]: # At the time of finalize, we have only marked guard managers with # TENSOR_MATCH guards as diff guard managers. So, we do a tree traversal # and collect all the nodes in the tree (branches) that lead to tensor @@ -240,7 +285,7 @@ def collect_diff_guard_sources(self): # 0, so we collect them as well. Later on, we accumulate the diff guard # sources for all the guard managers. - def visit_dict_manager(node): + def visit_dict_manager(node: DictGuardManager) -> bool: is_diff_guard_node = ( node.get_source() in self.diff_guard_sources or node.fail_count() > 0 ) @@ -254,7 +299,7 @@ def visit_dict_manager(node): return is_diff_guard_node - def visit_manager(node): + def visit_manager(node: GuardManager) -> bool: assert not isinstance(node, DictGuardManager) is_diff_guard_node = ( @@ -268,7 +313,7 @@ def visit_manager(node): return is_diff_guard_node - def visit(node): + def visit(node: GuardManager) -> bool: if node is None: return False if isinstance(node, DictGuardManager): @@ -279,18 +324,18 @@ def visit(node): return self.diff_guard_sources - def finalize(self): + def finalize(self) -> None: if config.use_recursive_dict_tags_for_guards and justknobs_check( "pytorch/compiler:use_recursive_dict_tags_for_guards" ): self.find_tag_safe_roots() self.prepare_diff_guard_manager() - def prepare_diff_guard_manager(self): + def prepare_diff_guard_manager(self) -> None: self.collect_diff_guard_sources() self.populate_diff_guard_manager() - def find_tag_safe_roots(self): + def find_tag_safe_roots(self) -> None: """ Identify ``tag safe nodes`` and ``tag safe roots`` within a guard tree. @@ -348,10 +393,20 @@ def find_tag_safe_roots(self): subset that are tag safe roots. """ - def visit_dict_manager(node): + def check_tag_safety( + node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...] + ) -> bool: + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + return all( + isinstance(accessor, accepted_accessors) and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + + def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: # Just recurse through the key and value dict managers and check if # all of them are tag safe nodes. - assert node.is_guarded_value_dict() + assert issubclass(node.get_type_of_guarded_value(), dict) tag_safe_roots = [] is_subtree_tag_safe = True @@ -378,7 +433,7 @@ def visit_dict_manager(node): node.mark_tag_safe() return tag_safe_roots - def visit_manager(node): + def visit_manager(node: GuardManager) -> list[GuardManager]: assert not isinstance(node, DictGuardManager) # Collect the subtree tag safe roots @@ -390,12 +445,12 @@ def visit_manager(node): # If the node guards a tensor, mark it tag safe only if there # are no accessors. Presence of accessors means presence of # symbolic shape guards. - if node.is_guarded_value_tensor(): + if issubclass(node.get_type_of_guarded_value(), torch.Tensor): if node.has_no_accessors() and not node.has_object_aliasing_guard(): node.mark_tag_safe() else: node.mark_tag_safe() - elif node.is_guarded_value_dict(): + elif issubclass(node.get_type_of_guarded_value(), dict): accessors = node.get_accessors() child_mgrs = node.get_child_managers() is_subtree_tag_safe = all( @@ -404,13 +459,9 @@ def visit_manager(node): ) if is_subtree_tag_safe: node.mark_tag_safe() - elif node.is_guarded_value_nn_module(): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance(accessor, GetGenericDictGuardAccessor) - and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) + elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module): + is_subtree_tag_safe = check_tag_safety( + node, (GetGenericDictGuardAccessor, TypeGuardAccessor) ) if is_subtree_tag_safe: node.mark_tag_safe() @@ -419,9 +470,80 @@ def visit_manager(node): return [ node, ] + elif ( + node.get_type_of_guarded_value() + in ( + types.FunctionType, + types.MethodType, + staticmethod, + classmethod, + ) + and config.assume_dunder_attributes_remain_unchanged + ): + # Assumption: callers will not reassignthe attributes + # func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__. + # Mutating the objects those attributes point to is fine; + # rebinding the attribute itself is not. + # Example ─ allowed: foo.__defaults__[0].bar = 99 + # forbidden: foo.__defaults__ = (3, 4) + is_subtree_tag_safe = check_tag_safety( + node, + ( + CodeGuardAccessor, + ClosureGuardAccessor, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, + ), + ) + + for accessor in node.get_accessors(): + if isinstance(accessor, GetAttrGuardAccessor): + is_subtree_tag_safe &= ( + accessor.get_attr_name() in dunder_attrs_assumed_constants + ) + + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), types.CellType): + is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,)) + + is_subtree_tag_safe &= all( + isinstance(accessor, GetAttrGuardAccessor) + and accessor.get_attr_name() == "cell_contents" + for accessor in node.get_accessors() + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif ( + issubclass(node.get_type_of_guarded_value(), tuple) + and node.get_source().endswith(dunder_attrs_assumed_constants) + and config.assume_dunder_attributes_remain_unchanged + ): + # We trust tuples obtained from a function’s __closure__ or + # __defaults__. Any *other* tuple-valued attribute can be + # silently replaced—for example: + # + # foo.bar = (1, 2) # original + # foo.bar = (3, 4) # rebinding that our dict-tag optimisation won’t see + # + # Therefore only tuples from __closure__ / __defaults__ participate in the + # recursive-dict-tag optimization; all others are ignored. + is_subtree_tag_safe = check_tag_safety( + node, (TupleGetItemGuardAccessor,) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), type): + is_subtree_tag_safe = check_tag_safety( + node, (TypeDictGuardAccessor, TypeMROGuardAccessor) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + return tag_safe_roots - def visit(node): + def visit(node: GuardManager) -> list[GuardManager]: if node is None: return [] if isinstance(node, DictGuardManager): @@ -430,10 +552,10 @@ def visit(node): tag_safe_roots = visit(self.root) for node in tag_safe_roots: - if node.is_guarded_value_nn_module(): + if issubclass(node.get_type_of_guarded_value(), torch.nn.Module): node.mark_tag_safe_root() - def populate_diff_guard_manager(self): + def populate_diff_guard_manager(self) -> None: self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) # Ensure that that C++ side points to the updated diff guard manager. @@ -446,29 +568,35 @@ def populate_diff_guard_manager(self): if self.cache_entry: self.cache_entry.update_diff_guard_root_manager() - def clone_with_chosen_sources(self, chosen_sources): - def filter_fn(node_mgr): + def clone_with_chosen_sources( + self, chosen_sources: OrderedSet[str] + ) -> RootGuardManager: + def filter_fn(node_mgr: GuardManager) -> bool: return node_mgr.get_source() in chosen_sources return self.root.clone_manager(filter_fn) - def get_guard_lines(self, guard): + def get_guard_lines(self, guard: LeafGuard) -> list[str]: guard_name = guard.__class__.__name__ parts = guard.verbose_code_parts() parts = [guard_name + ": " + part for part in parts] return parts - def get_manager_line(self, guard_manager, accessor_str=None): + def get_manager_line( + self, guard_manager: GuardManager, accessor_str: Optional[str] = None + ) -> str: source = guard_manager.get_source() t = guard_manager.__class__.__name__ s = t + ": source=" + source if accessor_str: s += ", " + accessor_str - s += f", type={guard_manager.type_of_guarded_value()}" + s += f", type={guard_manager.get_type_of_guarded_value()}" s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})" return s - def construct_dict_manager_string(self, mgr, body): + def construct_dict_manager_string( + self, mgr: DictGuardManager, body: IndentedBufferWithPrefix + ) -> None: for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): body.writeline(f"KeyValueManager pair at index={idx}") with body.indent(): @@ -480,10 +608,12 @@ def construct_dict_manager_string(self, mgr, body): body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") self.construct_manager_string(val_mgr, body) - def construct_manager_string(self, mgr, body): + def construct_manager_string( + self, mgr: GuardManager, body: IndentedBufferWithPrefix + ) -> None: with body.indent(): for guard in mgr.get_leaf_guards(): - if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if isinstance(guard, RelationalGuard): if guard not in self.printed_relational_guards: self.printed_relational_guards.add(guard) body.writelines(self.get_guard_lines(guard)) @@ -509,19 +639,7 @@ def construct_manager_string(self, mgr, body): ) self.construct_manager_string(child_mgr, body) - def __str__(self): - from torch._inductor.utils import IndentedBuffer - - class IndentedBufferWithPrefix(IndentedBuffer): - def prefix(self): - return "| " * (self._indent * self.tabwidth) - - def writeline(self, line, skip_prefix=False): - if skip_prefix: - super().writeline(line) - else: - super().writeline("+- " + line) - + def __str__(self) -> str: with self._preserve_printed_relational_guards(): body = IndentedBufferWithPrefix() body.tabwidth = 1 @@ -534,29 +652,29 @@ def writeline(self, line, skip_prefix=False): body.writelines(self.get_guard_lines(guard)) return body.getvalue() - def check(self, x): + def check(self, x: Any) -> bool: # Only needed for debugging purposes. return self.root.check(x) - def check_verbose(self, x): + def check_verbose(self, x: Any) -> GuardDebugInfo: # Only needed for debugging purposes. return self.root.check_verbose(x) - def populate_code_parts_for_debugging(self): + def populate_code_parts_for_debugging(self) -> None: # This should be called when the guard manager is fully populated relational_guards_seen = set() - def get_code_parts(leaf_guard): + def get_code_parts(leaf_guard: LeafGuard) -> list[str]: code_parts = [] for verbose_code_part in leaf_guard.verbose_code_parts(): code_part = verbose_code_part.split("#")[0].rstrip() code_parts.append(code_part) return code_parts - def visit(mgr): + def visit(mgr: GuardManager) -> None: nonlocal relational_guards_seen for guard in mgr.get_leaf_guards(): - if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if isinstance(guard, RelationalGuard): if guard not in relational_guards_seen: self.code_parts.extend(get_code_parts(guard)) relational_guards_seen.add(guard) @@ -569,7 +687,7 @@ def visit(mgr): visit(self.root) -def from_numpy(a): +def from_numpy(a: Any) -> torch.Tensor: # If not numpy array, piggy back on e.g. tensor guards to check type # Re-enable torch function since we disable it on leaf guards # we need it to properly construct the tensor if a default device is set @@ -579,7 +697,7 @@ def from_numpy(a): # For user stack printing @functools.cache -def uninteresting_files(): +def uninteresting_files() -> set[str]: import torch._dynamo.external_utils import torch._dynamo.polyfills @@ -595,7 +713,7 @@ def uninteresting_files(): _CLOSURE_VARS: Optional[dict[str, object]] = None -def _get_closure_vars(): +def _get_closure_vars() -> dict[str, object]: global _CLOSURE_VARS if _CLOSURE_VARS is None: _CLOSURE_VARS = { @@ -631,7 +749,7 @@ def _ast_unparse(node: ast.AST) -> str: strip_function_call = torch._C._dynamo.strip_function_call -def get_verbose_code_part(code_part: str, guard: Guard) -> str: +def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str: extra = "" if guard is not None: if guard.user_stack: @@ -649,14 +767,14 @@ def get_verbose_code_part(code_part: str, guard: Guard) -> str: def get_verbose_code_parts( - code_parts: Union[str | list[str]], guard: Guard + code_parts: Union[str, list[str]], guard: Optional[Guard] ) -> list[str]: if not isinstance(code_parts, list): code_parts = [code_parts] return [get_verbose_code_part(code_part, guard) for code_part in code_parts] -def convert_int_to_concrete_values(dim) -> Optional[int]: +def convert_int_to_concrete_values(dim: Any) -> Optional[int]: if dim is None: return None if not is_symbolic(dim): @@ -666,11 +784,18 @@ def convert_int_to_concrete_values(dim) -> Optional[int]: return dim.node.maybe_as_int() -def convert_to_concrete_values(size_or_stride): +def convert_to_concrete_values(size_or_stride: list[Any]) -> list[Optional[int]]: return [convert_int_to_concrete_values(dim) for dim in size_or_stride] -def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys): +def get_tensor_guard_code_part( + value: torch.Tensor, + name: str, + sizes: list[Optional[int]], + strides: list[Optional[int]], + pytype: type, + dispatch_keys: DispatchKeySet, +) -> str: dispatch_key = ( dispatch_keys | torch._C._dispatch_tls_local_include_set() ) - torch._C._dispatch_tls_local_exclude_set() @@ -684,7 +809,7 @@ def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_key return guard_str -def get_key_index(dct, key): +def get_key_index(dct: dict[Any, Any], key: Any) -> int: # Ensure that we call dict.keys and not value.keys (which can call # overridden keys method). In the C++ guards, we relied on PyDict_Next # to traverse the dictionary, which uses the internal data structure and @@ -692,7 +817,7 @@ def get_key_index(dct, key): return list(builtin_dict_keys(dct)).index(key) -def get_key_index_source(source, index): +def get_key_index_source(source: Any, index: Any) -> str: return f"list(dict.keys({source}))[{index}]" @@ -720,8 +845,12 @@ class NNModuleAttrAccessorInfo: def getitem_on_dict_manager( - source, base_guard_manager, base_example_value, example_value, guard_manager_enum -): + source: Union[DictGetItemSource, DictSubclassGetItemSource], + base_guard_manager: DictGuardManager, + base_example_value: Any, + example_value: Any, + guard_manager_enum: GuardManagerType, +) -> GuardManager: base_source_name = source.base.name() if isinstance(source.index, ConstDictKeySource): index = source.index.index @@ -760,7 +889,7 @@ def getitem_on_dict_manager( ) -def match_on_id_for_tensor(guard): +def match_on_id_for_tensor(guard: Guard) -> bool: source = guard.originating_source # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads # to a new tensor every time and therefore id differs. @@ -787,7 +916,7 @@ class GuardManagerType(enum.Enum): @functools.cache -def code_framelocals_names_reversed_cached(code: types.CodeType): +def code_framelocals_names_reversed_cached(code: types.CodeType) -> list[str]: return list(reversed(code_framelocals_names(code))) @@ -795,16 +924,16 @@ class GuardBuilder(GuardBuilderBase): def __init__( self, f_code: types.CodeType, - id_ref: Callable[[Any, str], str], + id_ref: Callable[[object, str], int], source_ref: Callable[[Source], str], - lookup_weakrefs: Callable[[object], ReferenceType[object]], + lookup_weakrefs: Callable[[object], Optional[weakref.ref[object]]], local_scope: dict[str, object], global_scope: dict[str, object], guard_manager: GuardManagerWrapper, check_fn_manager: CheckFunctionManager, serialization_mode: Optional[str] = None, - runtime_global_scope: Optional[dict[str, Any]] = None, - ): + runtime_global_scope: Optional[dict[str, object]] = None, + ) -> None: self.f_code = f_code self.id_ref = id_ref self.source_ref = source_ref @@ -835,7 +964,7 @@ def __init__( # Collect the guard managers and debug info to insert no tensor aliasing # guards. self.no_tensor_aliasing_names: list[str] = [] - self.no_tensor_aliasing_guard_managers: list[GuardManagerWrapper] = [] + self.no_tensor_aliasing_guard_managers: list[GuardManager] = [] self.check_fn_manager: CheckFunctionManager = check_fn_manager @@ -844,6 +973,7 @@ def __init__( # to access the same object - self._module["param"] is same as # self.param. self.key_order_guarded_dict_ids = set() + assert self.check_fn_manager.output_graph is not None for source in self.check_fn_manager.output_graph.guard_on_key_order: self.key_order_guarded_dict_ids.add(id(self.get(source.name()))) @@ -853,17 +983,20 @@ def __init__( self.id_matched_objs: dict[str, ReferenceType[object]] = {} # Save the guard managers to avoid repeatedly traversing sources. - self._cached_guard_managers: dict[ - str, torch._C._dynamo.guards.GuardManager - ] = {} + self._cached_guard_managers: dict[str, GuardManager] = {} self._cached_duplicate_input_guards: set[tuple[str, str]] = set() self.object_aliasing_guard_codes: list[tuple[str, str]] = [] self.serialization_mode = serialization_mode self.guard_nn_modules = config.guard_nn_modules and justknobs_check( "pytorch/compiler:guard_nn_modules" ) + self.already_guarded_not_present_in_generic_dict: OrderedSet[ + tuple[str, str] + ] = OrderedSet() - def guard_on_dict_keys_and_ignore_order(self, example_value, guard): + def guard_on_dict_keys_and_ignore_order( + self, example_value: dict[Any, Any], guard: Guard + ) -> None: dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): raise NotImplementedError( @@ -891,7 +1024,7 @@ def guard_on_dict_keys_and_ignore_order(self, example_value, guard): guard_manager_enum=guard_manager_enum, ) - def guard_on_dict_keys_and_order(self, value, guard): + def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> None: # Add key managers for the DictGuardManager. Then add either an # ID_MATCH or EQUALS_MATCH guard on the key. dict_mgr = self.get_guard_manager(guard) @@ -930,7 +1063,7 @@ def guard_on_dict_keys_and_order(self, value, guard): ) @staticmethod - def _get_generic_dict_manager_example_value(example_value): + def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any]: # due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115, # reported in https://github.com/python/cpython/issues/125608, # fixed by https://github.com/python/cpython/pull/125611), we cannot take @@ -949,14 +1082,14 @@ def _get_generic_dict_manager_example_value(example_value): def getattr_on_nn_module( self, - source, - base_guard_manager, - base_example_value, - example_value, - base_source_name, - source_name, - guard_manager_enum, - ): + source: AttrSource, + base_guard_manager: GuardManager, + base_example_value: Any, + example_value: Any, + base_source_name: str, + source_name: str, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: """ This tries to avoid calling the expensive nn module custom getattr method by checking if the attribute is accessible via __dict__. For attributes that @@ -975,8 +1108,13 @@ def getattr_on_nn_module( """ def getitem_on_dict_mgr( - mgr, key, source_name, base_example_value, example_value, guard_manager_enum - ): + mgr: GuardManager, + key: Any, + source_name: str, + base_example_value: Any, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: if isinstance(mgr, DictGuardManager): # Case where the user code relies on key order, e.g., # named_parameters @@ -1086,6 +1224,7 @@ def getitem_on_dict_mgr( ) if l2_key: + assert l2_source_name is not None and l2_guard_manager_enum is not None return getitem_on_dict_mgr( mgr=l1_mgr, key=l2_key, @@ -1096,14 +1235,20 @@ def getitem_on_dict_mgr( ) return l1_mgr - def requires_key_order_guarding(self, source): + def requires_key_order_guarding(self, source: Source) -> bool: source_name = source.name() if source_name == "": return False obj_id = id(self.get(source_name)) return obj_id in self.key_order_guarded_dict_ids - def get_guard_manager_type(self, source, example_value): + def get_guard_manager_type( + self, + source: Source, + example_value: Optional[ + Union[KeysView[Any], set[Any], frozenset[Any], dict[Any, Any]] + ], + ) -> GuardManagerType: guard_manager_enum = GuardManagerType.GUARD_MANAGER if self.requires_key_order_guarding(source): # Fix this if condition @@ -1119,10 +1264,10 @@ def get_guard_manager_type(self, source, example_value): guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER return guard_manager_enum - def manager_guards_on_keys(self, mgr_enum): + def manager_guards_on_keys(self, mgr_enum: GuardManagerType) -> bool: return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER - def get_global_guard_manager(self): + def get_global_guard_manager(self) -> GuardManager: return self.guard_manager.root.globals_dict_manager( f_globals=self.runtime_global_scope, source="G", @@ -1130,7 +1275,7 @@ def get_global_guard_manager(self): guard_manager_enum=GuardManagerType.GUARD_MANAGER, ) - def get_guard_manager_from_source(self, source): + def get_guard_manager_from_source(self, source: Source) -> GuardManager: root_guard_manager = self.guard_manager.root example_value = None @@ -1209,6 +1354,20 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, TypeDictSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.type_dict_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, TypeMROSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.type_mro_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype( source, ( @@ -1254,12 +1413,13 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, (AttrSource, UnspecializedParamBufferSource)): assert base_guard_manager # to make mypy happy - + assert isinstance(source, AttrSource) if ( isinstance(base_example_value, torch.nn.Module) and get_custom_getattr(base_example_value) is unpatched_nn_module_getattr ): + assert base_source_name out = self.getattr_on_nn_module( source, base_guard_manager, @@ -1279,6 +1439,7 @@ def get_guard_manager_from_source(self, source): elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): assert base_guard_manager # to make mypy happy assert isinstance(base_example_value, (dict, collections.OrderedDict)) + assert isinstance(source, (DictGetItemSource, DictSubclassGetItemSource)) if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) out = getitem_on_dict_manager( @@ -1495,6 +1656,20 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, CodeSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.code_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, ClosureSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.closure_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" @@ -1503,16 +1678,16 @@ def get_guard_manager_from_source(self, source): self._cached_guard_managers[source.name()] = out return out - def get_guard_manager(self, guard: Guard): + def get_guard_manager(self, guard: Guard) -> GuardManager: return self.get_guard_manager_from_source(guard.originating_source) def add_python_lambda_leaf_guard_to_root( self, - code_parts, - verbose_code_parts, - closure_vars=None, - is_epilogue=True, - ): + code_parts: list[str], + verbose_code_parts: list[str], + closure_vars: Optional[dict[str, object]] = None, + is_epilogue: bool = True, + ) -> None: if closure_vars is None: closure_vars = _get_closure_vars() # Adds a lambda leaf guard to the root guard manager. It wraps the @@ -1567,8 +1742,16 @@ def arg_ref(self, guard: Union[str, Guard]) -> str: return name - def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): - attr_source = AttrSource(guard.originating_source, attr_name) + def _guard_on_attribute( + self, + guard: Guard, + attr_name: str, + guard_fn: Callable[[GuardBuilderBase, Guard], Any], + ) -> None: + if attr_name == "__code__": + attr_source = CodeSource(guard.originating_source) + else: + attr_source = AttrSource(guard.originating_source, attr_name) # type: ignore[assignment] # Copy the stack info new_guard = Guard( attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack @@ -1576,10 +1759,13 @@ def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): new_guard.create(self) # Note: the order of the guards in this file matters since we sort guards on the same object by lineno - def HASATTR(self, guard: Guard): + def HASATTR(self, guard: Guard) -> None: source = guard.originating_source if isinstance(source, NNModuleSource): source = source.base + if isinstance(source, CodeSource): + # No need to guard that a function has a __code__ attribute + return assert isinstance(source, AttrSource), f"invalid source {guard.name}" base_source = source.base base = base_source.name() @@ -1611,7 +1797,7 @@ def HASATTR(self, guard: Guard): and get_custom_getattr(base_example_value) is unpatched_nn_module_getattr ): - return self.getattr_on_nn_module( + self.getattr_on_nn_module( source, base_manager, base_example_value, @@ -1630,14 +1816,18 @@ def HASATTR(self, guard: Guard): else: base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) - def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: + def NOT_PRESENT_IN_GENERIC_DICT( + self, guard: Guard, attr: Optional[Any] = None + ) -> None: assert attr is not None ref = self.arg_ref(guard) val = self.get(guard.name) - assert isinstance(val, torch.nn.Module) base_manager = self.get_guard_manager(guard) + if (ref, attr) in self.already_guarded_not_present_in_generic_dict: + return + mod_dict_source = f"{guard.name}.__dict__" mod_generic_dict_manager = base_manager.get_generic_dict_manager( source=mod_dict_source, @@ -1649,6 +1839,7 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: mod_generic_dict_manager.add_dict_contains_guard( False, attr, get_verbose_code_parts(code, guard) ) + self.already_guarded_not_present_in_generic_dict.add((ref, attr)) def TYPE_MATCH(self, guard: Guard) -> None: # ___check_type_id is same as `id(type(x)) == y` @@ -1670,7 +1861,7 @@ def TYPE_MATCH(self, guard: Guard) -> None: obj_id, get_verbose_code_parts(code, guard) ) - def DICT_VERSION(self, guard: Guard): + def DICT_VERSION(self, guard: Guard) -> None: if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError( "DICT_VERSION guard cannot be serialized." @@ -1688,7 +1879,7 @@ def DICT_VERSION(self, guard: Guard): val, get_verbose_code_parts(code, guard) ) - def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): + def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None: dict_ref = self.arg_ref(guard) maybe_not = "not " if invert else "" @@ -1699,7 +1890,7 @@ def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): not invert, key, get_verbose_code_parts(code, guard) ) - def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool): + def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None: set_ref = self.arg_ref(guard) item = key contains = not invert # install_dict_contains_guard inverts "contains" @@ -1712,7 +1903,7 @@ def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool): contains, item, get_verbose_code_parts(code, guard) ) - def BOOL_MATCH(self, guard: Guard): + def BOOL_MATCH(self, guard: Guard) -> None: # checks val == True or val == False ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1729,7 +1920,7 @@ def BOOL_MATCH(self, guard: Guard): get_verbose_code_parts(code, guard) ) - def NONE_MATCH(self, guard: Guard): + def NONE_MATCH(self, guard: Guard) -> None: # checks `val is None` ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1741,12 +1932,12 @@ def NONE_MATCH(self, guard: Guard): get_verbose_code_parts(code, guard) ) - def ID_MATCH(self, guard: Guard): + def ID_MATCH(self, guard: Guard) -> None: if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.") return self.id_match_unchecked(guard) - def id_match_unchecked(self, guard: Guard): + def id_match_unchecked(self, guard: Guard) -> None: # ___check_obj_id is same as `id(x) == y` if isinstance(guard.originating_source, TypeSource): # optional optimization to produce cleaner/faster guard code @@ -1776,7 +1967,7 @@ def id_match_unchecked(self, guard: Guard): if weak_id is not None: self.id_matched_objs[local_name] = weak_id - def NOT_NONE_MATCH(self, guard: Guard, value=None): + def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) assert isinstance(val, torch.Tensor) @@ -1787,7 +1978,7 @@ def NOT_NONE_MATCH(self, guard: Guard, value=None): get_verbose_code_parts(code, guard) ) - def DISPATCH_KEY_SET_MATCH(self, guard: Guard): + def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) assert isinstance(val, torch._C.DispatchKeySet) @@ -1797,28 +1988,30 @@ def DISPATCH_KEY_SET_MATCH(self, guard: Guard): val, get_verbose_code_parts(code_parts, guard) ) - def NAME_MATCH(self, guard: Guard): - self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) + def NAME_MATCH(self, guard: Guard) -> None: + self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) # type: ignore[arg-type] - def DUAL_LEVEL(self, guard: Guard): + def DUAL_LEVEL(self, guard: Guard) -> None: # Invalidate dual level if current dual level is different than the one # in the fx graph + assert self.check_fn_manager.output_graph is not None dual_level = self.check_fn_manager.output_graph.dual_level code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] - self._set_guard_export_info(guard, [code]) + self._set_guard_export_info(guard, code) # TODO(anijain2305) - Consider this moving this guard to C++ forward_ad = torch.autograd.forward_ad - def fn(x): + def fn(x: Any) -> bool: return forward_ad._current_level == dual_level self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) - def FUNCTORCH_STACK_MATCH(self, guard: Guard): + def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None: # Invalidate functorch code if current level is different than # the one when FX graph was generated + assert self.check_fn_manager.output_graph is not None cis = self.check_fn_manager.output_graph.functorch_layers states = [ci.get_state() for ci in cis] code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] @@ -1827,20 +2020,22 @@ def FUNCTORCH_STACK_MATCH(self, guard: Guard): # TODO(anijain2305) - Consider this moving this guard to C++ compare_fn = torch._functorch.pyfunctorch.compare_functorch_state - def fn(x): + def fn(x: Any) -> bool: return compare_fn(states) self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) - def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard): + def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard) -> None: get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks are_inline_hooks = ( torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable ) - def hooks_ids_fn(hooks): + def hooks_ids_fn( + hooks: tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]], + ) -> Optional[tuple[int, ...]]: if not are_inline_hooks(hooks): return None @@ -1854,27 +2049,27 @@ def hooks_ids_fn(hooks): ] self._set_guard_export_info(guard, code) - def fn(x): + def fn(x: Any) -> bool: return guard_hooks_ids == hooks_ids_fn(get_hooks()) self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) - def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): + def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None: value = self.get(guard.name) original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) if hasattr(value, "__metadata_guard__"): verify_guard_fn_signature(value) - def metadata_checker(x): + def metadata_checker(x: Any) -> bool: return value.__metadata_guard__( original_metadata, x.__tensor_flatten__()[1] ) else: - def metadata_checker(x): + def metadata_checker(x: Any) -> bool: return x.__tensor_flatten__()[1] == original_metadata global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" @@ -1882,7 +2077,7 @@ def metadata_checker(x): metadata_checker, get_verbose_code_parts(global_name, guard) ) - def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): + def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) if np: @@ -1990,7 +2185,7 @@ def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): self._set_guard_export_info(guard, code) return - def CONSTANT_MATCH(self, guard: Guard): + def CONSTANT_MATCH(self, guard: Guard) -> None: val = self.get(guard.name) if istype(val, bool): self.BOOL_MATCH(guard) @@ -2001,7 +2196,7 @@ def CONSTANT_MATCH(self, guard: Guard): else: self.EQUALS_MATCH(guard) - def NN_MODULE(self, guard: Guard): + def NN_MODULE(self, guard: Guard) -> None: # don't support this in serialization because it uses unsupported ID_MATCH if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError( @@ -2013,7 +2208,7 @@ def NN_MODULE(self, guard: Guard): assert istype(val.training, bool) if not self.guard_nn_modules: # If guard_nn_modules is true, we will guard on the right set of guards - self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type] else: exc.unimplemented_v2( gb_type="Attempted to guard on uninitialized nn.Module", @@ -2025,7 +2220,7 @@ def NN_MODULE(self, guard: Guard): ], ) - def FUNCTION_MATCH(self, guard: Guard): + def FUNCTION_MATCH(self, guard: Guard) -> None: """things like torch.add and user defined functions""" # don't support this in serialization because it uses unsupported ID_MATCH if self.serialization_mode == "save": @@ -2034,7 +2229,7 @@ def FUNCTION_MATCH(self, guard: Guard): ) return self.ID_MATCH(guard) - def CLOSURE_MATCH(self, guard: Guard): + def CLOSURE_MATCH(self, guard: Guard) -> None: """matches a closure by __code__ id.""" # don't support this in serialization because it uses unsupported FUNCTION_MATCH if self.serialization_mode == "save": @@ -2044,12 +2239,12 @@ def CLOSURE_MATCH(self, guard: Guard): val = self.get(guard.name) # Strictly only want user-defined functions if type(val) == types.FunctionType and hasattr(val, "__code__"): - self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) - self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) + self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type] + self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type] else: self.FUNCTION_MATCH(guard) - def BUILTIN_MATCH(self, guard: Guard): + def BUILTIN_MATCH(self, guard: Guard) -> None: if self.serialization_mode == "save": # Record which builtin variables are used for pruning later. if isinstance(guard.originating_source, DictGetItemSource): @@ -2060,7 +2255,7 @@ def BUILTIN_MATCH(self, guard: Guard): return self.ID_MATCH(guard) - def SEQUENCE_LENGTH(self, guard): + def SEQUENCE_LENGTH(self, guard: Guard) -> None: # This guard is used to check length of PySequence objects like list, # tuple, collections.deque etc ref = self.arg_ref(guard) @@ -2086,7 +2281,7 @@ def SEQUENCE_LENGTH(self, guard): len(value), get_verbose_code_parts(code, guard) ) - def TUPLE_ITERATOR_LEN(self, guard): + def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) @@ -2102,7 +2297,7 @@ def TUPLE_ITERATOR_LEN(self, guard): tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) ) - def RANGE_ITERATOR_MATCH(self, guard): + def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) @@ -2121,7 +2316,7 @@ def RANGE_ITERATOR_MATCH(self, guard): ) # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards - def DUPLICATE_INPUT(self, guard, source_b): + def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: if self.serialization_mode == "save": if name := get_local_source_name(source_b): self.check_fn_manager.additional_used_local_vars.add(name) @@ -2161,7 +2356,7 @@ def DUPLICATE_INPUT(self, guard, source_b): get_verbose_code_parts(code, guard), ) - def WEAKREF_ALIVE(self, guard): + def WEAKREF_ALIVE(self, guard: Guard) -> None: if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError( "WEAKREF_ALIVE guard cannot be serialized." @@ -2173,7 +2368,7 @@ def WEAKREF_ALIVE(self, guard): get_verbose_code_parts(code, guard) ) - def MAPPING_KEYS_CHECK(self, guard): + def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: """Guard on the key order of types.MappingProxyType object""" ref = self.arg_ref(guard) value = self.get(guard.name) @@ -2183,7 +2378,7 @@ def MAPPING_KEYS_CHECK(self, guard): self._set_guard_export_info(guard, code) self.get_guard_manager(guard).add_mapping_keys_guard(value, code) - def DICT_KEYS_MATCH(self, guard): + def DICT_KEYS_MATCH(self, guard: Guard) -> None: """Insert guard to check that the keys of a dict are same""" ref = self.arg_ref(guard) value = self.get(guard.name) @@ -2208,29 +2403,30 @@ def DICT_KEYS_MATCH(self, guard): else: self.guard_on_dict_keys_and_ignore_order(value, guard) - def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): + def EMPTY_NN_MODULE_HOOKS_DICT(self, guard: Guard) -> None: """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" if config.skip_nnmodule_hook_guards: # This is unsafe if you add/remove a hook on nn module variable return self.SEQUENCE_LENGTH(guard) - def GRAD_MODE(self, guard: Guard): + def GRAD_MODE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def DETERMINISTIC_ALGORITHMS(self, guard: Guard): + def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def TORCH_FUNCTION_STATE(self, guard: Guard): + def TORCH_FUNCTION_STATE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def FSDP_TRAINING_STATE(self, guard: Guard): + def FSDP_TRAINING_STATE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def DEFAULT_DEVICE(self, guard: Guard): + def DEFAULT_DEVICE(self, guard: Guard) -> None: """Guard on CURRENT_DEVICE per torch.utils._device""" assert guard.source is GuardSource.GLOBAL + assert self.check_fn_manager.output_graph is not None code = [ f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}" ] @@ -2240,9 +2436,10 @@ def DEFAULT_DEVICE(self, guard: Guard): get_verbose_code_parts(code, guard) ) - def SHAPE_ENV(self, guard: Guard): + def SHAPE_ENV(self, guard: Guard) -> None: assert guard.name == "" output_graph = self.check_fn_manager.output_graph + assert output_graph is not None if self.serialization_mode == "load": assert self.check_fn_manager.shape_code_parts is not None shape_code_parts = self.check_fn_manager.shape_code_parts @@ -2259,7 +2456,7 @@ def SHAPE_ENV(self, guard: Guard): fs = output_graph.tracked_fakes input_contexts = [a.symbolic_context for a in fs] - def get_sources(t_id, dim): + def get_sources(t_id: int, dim: int) -> list[Source]: # Looks up base sources mapped to a tensor id and uses them to create # sources for the corresponding tensor dimension. return [ @@ -2267,6 +2464,7 @@ def get_sources(t_id, dim): for source in output_graph.tracked_fakes_id_to_source[t_id] ] + assert output_graph.shape_env is not None if output_graph.export_constraints: names: dict[str, tuple[int, int]] = {} source_pairs: list[tuple[Source, Source]] = [] @@ -2275,7 +2473,7 @@ def get_sources(t_id, dim): ] = [] phantom_symbols: dict[str, Symbol] = {} relaxed_sources: set[Source] = set() - for constraint in output_graph.export_constraints: + for constraint in output_graph.export_constraints: # type: ignore[attr-defined] if constraint.t_id in output_graph.tracked_fakes_id_to_source: torch.export.dynamic_shapes._process_equalities( constraint, @@ -2299,15 +2497,15 @@ def get_sources(t_id, dim): else: equalities_inputs = None - def _get_code_parts(langs): + def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: return output_graph.shape_env.produce_guards_verbose( - [a.fake for a in fs], + [a.fake for a in fs], # type: ignore[misc] [a.source for a in fs], - input_contexts=input_contexts, + input_contexts=input_contexts, # type: ignore[arg-type] equalities_inputs=equalities_inputs, source_ref=self.source_ref, # Export keeps static. - ignore_static=(not self.check_fn_manager.output_graph.export), + ignore_static=(not output_graph.export), langs=langs, ) @@ -2315,7 +2513,7 @@ def _get_code_parts(langs): try: # For exporting we need the python code parts python_code_parts, verbose_code_parts, cpp_code_parts = ( - _get_code_parts(("python", "verbose_python", "cpp")) + _get_code_parts(("python", "verbose_python", "cpp")) # type: ignore[assignment] ) python_fallback = False except OverflowError: @@ -2332,7 +2530,7 @@ def _get_code_parts(langs): # When exporting, we may work with the shape constraints some more in # postprocessing, so don't freeze yet - if not self.check_fn_manager.output_graph.export: + if not output_graph.export: output_graph.shape_env.freeze() if self.serialization_mode == "save": @@ -2476,7 +2674,7 @@ def _get_code_parts(langs): closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, ) - def TENSOR_MATCH(self, guard: Guard, value=None): + def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): return # For tensors that are part of the Dynamo extracted Fx graph module, an @@ -2529,6 +2727,7 @@ def TENSOR_MATCH(self, guard: Guard, value=None): # The list of tensor fields and calls we care about can be found in `terms` below. # TODO(voz): We are missing storage offset in all our tensor guards? code: list[str] = [] + assert self.check_fn_manager.output_graph is not None if self.check_fn_manager.output_graph.export: self.TYPE_MATCH(guard) terms = [ @@ -2580,7 +2779,12 @@ def TENSOR_MATCH(self, guard: Guard, value=None): verbose_code_parts = get_verbose_code_parts( get_tensor_guard_code_part( - value, tensor_name, size, stride, pytype, dispatch_keys + value, + tensor_name, + size, + stride, + pytype, + dispatch_keys, # type: ignore[arg-type] ), guard, ) @@ -2656,8 +2860,12 @@ def TENSOR_MATCH(self, guard: Guard, value=None): # A util that in the case of export, adds data onto guards def _set_guard_export_info( - self, guard, code_list, provided_guarded_object=None, provided_func_name=None - ): + self, + guard: Guard, + code_list: list[str], + provided_guarded_object: Optional[Any] = None, + provided_func_name: Optional[str] = None, + ) -> None: # WARNING: It is important that cur_frame/caller do NOT stay in # the current frame, because they will keep things live longer # than they should. See TestMisc.test_release_module_memory @@ -2735,7 +2943,7 @@ class ExprCounter(ast.NodeVisitor): def __init__(self, config: PyExprCSEPass.Config) -> None: self._config = config - def visit(self, node: ast.AST) -> Any: + def visit(self, node: ast.AST) -> None: if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): self._config.expr_count[_ast_unparse(node)] += 1 super().visit(node) @@ -2803,7 +3011,7 @@ def replace(self, expr: str) -> tuple[list[str], str]: return replacer.preface, _ast_unparse(new_node) -def must_add_nn_module_guards(guard): +def must_add_nn_module_guards(guard: Guard) -> bool: # For config.guard_nn_modules=False, we can skip all the guards that # originate from inside of nn module except for a few categories. return ( @@ -2818,11 +3026,11 @@ def must_add_nn_module_guards(guard): class DeletedGuardManagerWrapper(GuardManagerWrapper): - def __init__(self, reason): + def __init__(self, reason: str) -> None: super().__init__() self.invalidation_reason = reason - def populate_diff_guard_manager(self): + def populate_diff_guard_manager(self) -> None: self.diff_guard_root = None @@ -2837,24 +3045,35 @@ class ShapeCodeParts: @dataclasses.dataclass class GuardsState: - output_graph: OutputGraphGuardsState + output_graph: OutputGraph shape_code_parts: Optional[ShapeCodeParts] +class _Missing: + pass + + class GuardsStatePickler(pickle.Pickler): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.fake_mode = torch._subclasses.FakeTensorMode() self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() @classmethod - def _unpickle_module(cls, state): + def _unpickle_module(cls, state: Any) -> torch.nn.Module: mod = torch.nn.Module() mod.__setstate__(state) return mod @classmethod - def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw, grad): + def _unpickle_tensor( + cls, + meta_tensor: torch.Tensor, + device: torch.device, + pytype: type, + dispatch_keys_raw: int, + grad: torch.Tensor, + ) -> torch.Tensor: fake_mode = torch._subclasses.FakeTensorMode() tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() ret = tensor_converter.from_meta_and_device( @@ -2869,15 +3088,21 @@ def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw, grad): @classmethod def _unpickle_traceable_wrapper_subclass( - cls, meta_tensor, device, pytype, dispatch_keys_raw, ctx, inner_data - ): + cls, + meta_tensor: torch.Tensor, + device: torch.device, + pytype: type, + dispatch_keys_raw: int, + ctx: Any, + inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]], + ) -> torch.Tensor: # Unpickle the inner tensor components. These could also be subclass instances. inner_tensors = {} for attr, unpickle_func, unpickle_func_args in inner_data: inner_tensors[attr] = unpickle_func(*unpickle_func_args) outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride() - out = type(meta_tensor).__tensor_unflatten__( + out = type(meta_tensor).__tensor_unflatten__( # type: ignore[attr-defined] inner_tensors, ctx, outer_size, outer_stride ) out.pytype = pytype @@ -2885,22 +3110,32 @@ def _unpickle_traceable_wrapper_subclass( return out @classmethod - def _unpickle_python_module(cls, alias: str): + def _unpickle_python_module(cls, alias: str) -> types.ModuleType: return importlib.import_module(alias) @classmethod - def _unpickle_dispatch_key_set(cls, raw_repr: int): + def _unpickle_dispatch_key_set(cls, raw_repr: int) -> torch._C.DispatchKeySet: return torch._C.DispatchKeySet.from_raw_repr(raw_repr) @classmethod - def _unpickle_functorch_interpreter(cls, json: bytes): + def _unpickle_functorch_interpreter( + cls, json: bytes + ) -> torch._C._functorch.CInterpreter: return torch._C._functorch.CInterpreter.deserialize(json) @classmethod - def _unpickle_mapping_proxy(cls, d): + def _unpickle_mapping_proxy( + cls, d: dict[Any, Any] + ) -> types.MappingProxyType[Any, Any]: return types.MappingProxyType(d) - def reducer_override(self, obj): + @classmethod + def _unpickle_c_op(cls, name: str) -> Any: + return getattr(torch.ops._C, name) + + def reducer_override( + self, obj: Any + ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]: import sympy if isinstance(obj, torch.Tensor) and obj.device.type != "meta": @@ -2964,6 +3199,27 @@ def reducer_override(self, obj): elif isinstance(obj, types.MappingProxyType): return type(self)._unpickle_mapping_proxy, (obj.copy(),) + elif isinstance( + obj, torch._ops.OpOverloadPacket + ) and obj._qualified_op_name.startswith("_C::"): + return type(self)._unpickle_c_op, (obj.__name__,) + + elif ( + obj.__class__.__module__ == "builtins" + and obj.__class__.__name__ == "PyCapsule" + ): + # Skipping PyCapsule since there isn't much to be guarded about them. + return _Missing, () + + elif isinstance(obj, types.CodeType): + # We only do ID_MATCH on code objects which is already banned from guards serialization. + return _Missing, () + + elif inspect.isfunction(obj) and (obj.__code__.co_flags & inspect.CO_NESTED): + # Skipping nested function since CLOSURE_MATCH is banned from guards serialization. + assert obj.__qualname__ != obj.__name__ + return _Missing, () + if type(obj).__qualname__ != type(obj).__name__: raise torch._dynamo.exc.PackageError( f"Type {type(obj)} for object {obj} cannot be saved " @@ -2992,9 +3248,9 @@ def pickle_guards_state(state: GuardsState) -> bytes: class CheckFunctionManager: def __init__( self, - f_code, - output_graph=None, - cache_entry=None, + f_code: types.CodeType, + output_graph: Optional[OutputGraph] = None, + cache_entry: Optional[CacheEntry] = None, guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, guard_filter_fn: Optional[ Callable[[list[GuardFilterEntry]], list[bool]] @@ -3037,7 +3293,7 @@ def __init__( ): _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs]) - def guard_filter_fn(guards): + def guard_filter_fn(guards: list[GuardFilterEntry]) -> list[bool]: ret = [] for keep, g in zip(_guard_filter_fn(guards), guards): if not keep: @@ -3057,6 +3313,7 @@ def guard_filter_fn(guards): return ret sorted_guards = sorted(guards or (), key=Guard.sort_key) + assert output_graph is not None builder, guard_manager = self.build_guards( sorted_guards, existing_diff_guard_sources, @@ -3067,7 +3324,7 @@ def guard_filter_fn(guards): if guard_filter_fn: - def make_guard_filter_entry(guard): + def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: MISSING = object() name = strip_local_scope(guard.name) if name == "": @@ -3087,15 +3344,15 @@ def make_guard_filter_entry(guard): is_global = get_global_source_name(guard.originating_source) is not None guard_fn = guard.create_fn if isinstance(guard_fn, functools.partial): - guard_fn = guard.create_fn.func + guard_fn = guard.create_fn.func # type: ignore[attr-defined] return GuardFilterEntry( name=name, has_value=has_value, value=value, guard_type=guard_fn.__name__, - derived_guard_types=tuple(guard.guard_types) - if guard.guard_types - else (), + derived_guard_types=( + tuple(guard.guard_types) if guard.guard_types else () + ), is_global=is_global, orig_guard=guard, ) @@ -3141,7 +3398,7 @@ def make_guard_filter_entry(guard): if not output_graph.export and self.guards_serialization_mode != "load": if not self.guard_manager.check(output_graph.local_scope): reasons = get_guard_fail_reason_helper( - self.guard_manager, # type: ignore[arg-type] + self.guard_manager, output_graph.local_scope, CompileContext.current_compile_id(), ) @@ -3174,12 +3431,13 @@ def make_guard_filter_entry(guard): CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) self.guards_state: Optional[bytes] = None + assert self.output_graph is not None builtins_dict_name = self.output_graph.name_of_builtins_dict_key_in_fglobals if self.guards_serialization_mode == "save": used_global_vars = set() used_local_vars = set() - def prune_variable(source): + def prune_variable(source: Source) -> None: if name := get_global_source_name(source): assert isinstance(name, str) # Leave out the builtins dict key, as we will special handle @@ -3204,10 +3462,10 @@ def prune_variable(source): for source in self.output_graph.guard_on_key_order: prune_variable(source) - def normalize_create_fn(x): + def normalize_create_fn(x: Any) -> Any: if isinstance(x, functools.partial): - def _ref(x): + def _ref(x: Any) -> Any: if isinstance(x, (TensorWeakRef, weakref.ref)): return x() return x @@ -3227,7 +3485,7 @@ def _ref(x): k: v for k, v in output_graph_guards_state.global_scope[ builtins_dict_name - ].items() + ].items() # type: ignore[attr-defined] if k in self.used_builtin_vars } output_graph_guards_state = dataclasses.replace( @@ -3255,7 +3513,7 @@ def _ref(x): ), ) guards_state = GuardsState( - output_graph=output_graph_guards_state, + output_graph=output_graph_guards_state, # type: ignore[arg-type] shape_code_parts=self.shape_code_parts, ) self.guards_state = pickle_guards_state(guards_state) @@ -3278,18 +3536,18 @@ def _ref(x): def build_guards( self, - sorted_guards, - existing_diff_guard_sources, - f_code, - output_graph, - serialization_mode=None, - ): + sorted_guards: list[Guard], + existing_diff_guard_sources: OrderedSet[str], + f_code: types.CodeType, + output_graph: OutputGraph, + serialization_mode: Optional[str] = None, + ) -> tuple[GuardBuilder, GuardManagerWrapper]: guard_manager = GuardManagerWrapper() guard_manager.diff_guard_sources = existing_diff_guard_sources w_builder = None - def source_ref(source): + def source_ref(source: Source) -> str: guard_source = source.guard_source() if guard_source is GuardSource.CONSTANT: # No need to track constants @@ -3313,10 +3571,10 @@ def source_ref(source): ) # Break retain cycle. See test_release_scope_memory - def cleanup_builder(weak_b): + def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None: b = weak_b() if b: - b.scope = None + b.scope = None # type: ignore[assignment] # Break retain cycle. See test_release_input_memory w_builder = weakref.ref(builder, cleanup_builder) @@ -3340,7 +3598,12 @@ def cleanup_builder(weak_b): guard.create(builder) return builder, guard_manager - def compile_check_fn(self, builder, guards_out, guard_fail_fn): + def compile_check_fn( + self, + builder: GuardBuilder, + guards_out: list[Guard], + guard_fail_fn: Optional[Callable[[GuardFail], None]], + ) -> None: # see parallel handling of ".0" / "___implicit0" in _eval_frame.c largs = builder.argnames largs += ["**___kwargs_ignored"] @@ -3351,6 +3614,7 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): verbose_code_parts = [] structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] + assert self.torch_function_mode_stack is not None torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( self.torch_function_mode_stack ) @@ -3374,7 +3638,9 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): # Clear references to torch_function modes held in the list self.torch_function_mode_stack = None - def add_code_part(code_part, guard, log_only=False): + def add_code_part( + code_part: str, guard: Optional[Guard], log_only: bool = False + ) -> None: verbose_code_part = get_verbose_code_part(code_part, guard) guards_log.debug("%s", verbose_code_part) @@ -3544,7 +3810,7 @@ def add_code_part(code_part, guard, log_only=False): self.guard_manager.extra_state = None self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names - def invalidate(self, obj_str): + def invalidate(self, obj_str: str) -> None: # Some tests reveal that CheckFunctionManager has no attribute # guard_manager, but this case should not be of any concern. # This case doesn't seem easy to repro. @@ -3561,7 +3827,7 @@ def invalidate(self, obj_str): extra_state.invalidate(cache_entry, deleted_guard_manager) self.guard_manager = deleted_guard_manager - def id_ref(self, obj, obj_str): + def id_ref(self, obj: object, obj_str: str) -> int: """add a weakref, return the id""" try: if id(obj) not in self._weakrefs: @@ -3576,14 +3842,14 @@ def id_ref(self, obj, obj_str): pass # cannot weakref bool object return id(obj) - def lookup_weakrefs(self, obj): + def lookup_weakrefs(self, obj: object) -> Optional[weakref.ref[object]]: """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" if id(obj) in self._weakrefs: return self._weakrefs[id(obj)] return None -def build_guard_function(code_parts, closure_args) -> tuple[str, str]: +def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, str]: from torch._inductor.utils import IndentedBuffer csepass = PyExprCSEPass() @@ -3592,6 +3858,7 @@ def build_guard_function(code_parts, closure_args) -> tuple[str, str]: def replace(expr: str) -> tuple[list[str], str]: return csepass.replace(expr) + except RecursionError: # If we hit recursion limits during CSE analysis, fall back to a no-op replace function # This can happen with extremely complex guard expressions @@ -3626,19 +3893,21 @@ def replace(expr: str) -> tuple[list[str], str]: return guard_body.getvalue(), make_guard_fn.getvalue() -def is_recompiles_enabled(): +def is_recompiles_enabled() -> bool: return torch._logging._internal.log_state.is_artifact_enabled("recompiles") -def is_recompiles_verbose_enabled(): +def is_recompiles_verbose_enabled() -> bool: return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") # this will only be used if cpp guards are disabled -def make_torch_function_mode_stack_guard(initial_stack): +def make_torch_function_mode_stack_guard( + initial_stack: list[torch.overrides.TorchFunctionMode], +) -> Callable[[], bool]: types = [type(x) for x in initial_stack] - def check_torch_function_mode_stack(): + def check_torch_function_mode_stack() -> bool: cur_stack = get_torch_function_mode_stack() if len(cur_stack) != len(types): @@ -3653,10 +3922,16 @@ def check_torch_function_mode_stack(): return check_torch_function_mode_stack -def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): +Scope = TypeAliasType("Scope", dict[str, object]) + + +def recompilation_reason_for_no_tensor_aliasing_guard( + guard_manager: GuardManagerWrapper, scope: Scope +) -> list[str]: + assert guard_manager.global_scope is not None global_scope = dict(guard_manager.global_scope) ids_to_source = collections.defaultdict(list) - for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] + for tensor_source in guard_manager.no_tensor_aliasing_sources: global_scope["__compile_source__"] = tensor_source tensor_id = id(eval(tensor_source, global_scope, scope)) ids_to_source[tensor_id].append(tensor_source) @@ -3683,7 +3958,7 @@ def strip_local_scope(s: str) -> str: def get_guard_fail_reason_helper( - guard_manager: GuardFn, + guard_manager: GuardManagerWrapper, f_locals: dict[str, object], compile_id: Optional[CompileId], ) -> str: @@ -3692,6 +3967,8 @@ def get_guard_fail_reason_helper( Updates `guard_failures` with the generated reason. Only the first failed check of guard_manager is reported. """ + assert guard_manager.global_scope is not None + assert guard_manager.closure_vars is not None scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} scope.update(guard_manager.closure_vars) reasons: list[str] = [] @@ -3699,7 +3976,7 @@ def get_guard_fail_reason_helper( no_tensor_aliasing_check_failed = False verbose_code_parts: list[str] = [] - guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] + guard_debug_info = guard_manager.check_verbose(f_locals) # For test_export_with_map_cond, the check_verbose fail even without the # C++ guard manager. We need to fix the issue to remove the comment. # assert not guard_debug_info.result @@ -3750,7 +4027,7 @@ def get_guard_fail_reason_helper( def get_guard_fail_reason( - guard_manager: GuardFn, + guard_manager: GuardManagerWrapper, code: types.CodeType, f_locals: dict[str, object], compile_id: CompileId, @@ -3774,7 +4051,7 @@ def get_guard_fail_reason( def get_and_maybe_log_recompilation_reasons( - cache_entry, frame: DynamoFrameType + cache_entry: Optional[CacheEntry], frame: DynamoFrameType ) -> list[str]: """ Return the list of guard failure reasons using cache_entry. @@ -3833,18 +4110,20 @@ def get_and_maybe_log_recompilation_reasons( return reasons -def update_diff_guard_managers_for_existing_cache_entries(cache_entry): +def update_diff_guard_managers_for_existing_cache_entries( + cache_entry: Optional[CacheEntry], +) -> OrderedSet[str]: first_cache_entry = cache_entry # On the first pass, go through the cache entries and accumulate the diff # guard sources. Different guard managers can fail with different sources. # So, we collect all of them first. - acc_diff_guard_sources = set() + acc_diff_guard_sources: OrderedSet[str] = OrderedSet() while cache_entry is not None: acc_diff_guard_sources.update( cache_entry.guard_manager.collect_diff_guard_sources() ) - cache_entry = cache_entry.next + cache_entry = cache_entry.next # type: ignore[assignment] # On the second pass, set the diff_guard_sources for each cache line to the # accumulated value. And the re-populate the diff guard manager. @@ -3852,7 +4131,7 @@ def update_diff_guard_managers_for_existing_cache_entries(cache_entry): while cache_entry is not None: cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources cache_entry.guard_manager.populate_diff_guard_manager() - cache_entry = cache_entry.next + cache_entry = cache_entry.next # type: ignore[assignment] # return the accumulated sources to set up the new cache line. return acc_diff_guard_sources @@ -3864,7 +4143,7 @@ def guard_error_hook( f_locals: dict[str, object], index: int, last: bool, -): +) -> None: print( f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" ) @@ -3884,7 +4163,7 @@ def guard_error_hook( set_guard_error_hook(guard_error_hook) -def unique(seq): +def unique(seq: Sequence[T]) -> Generator[T, None, None]: seen = set() for x in seq: if x not in seen: @@ -3892,7 +4171,9 @@ def unique(seq): seen.add(x) -def make_dupe_guard(obj_source, dupe_source): +def make_dupe_guard( + obj_source: Source, dupe_source: Source +) -> Optional[functools.partial[Any]]: # Note - we may end up in a situation where we invoke something like # def fn(x, y) # with fn(x, x) @@ -3926,7 +4207,7 @@ def make_dupe_guard(obj_source, dupe_source): return None -def install_guard(*guards, skip=0): +def install_guard(*guards: Guard, skip: int = 0) -> None: """ Add dynamo guards to the current tracing context. @@ -3942,4 +4223,7 @@ def install_guard(*guards, skip=0): add = TracingContext.get().guards_context.dynamo_guards.add for guard in guards: assert isinstance(guard, Guard) + + if is_from_skip_guard_source(guard.originating_source): + continue add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 355705efbdeb..caa7b6fef530 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Core graph building functionality for PyTorch's Dynamo system. This module contains the essential components for constructing and managing FX graphs during compilation: @@ -33,8 +31,11 @@ import sys import traceback import weakref +from collections.abc import Generator, Sequence from dataclasses import dataclass, field as dc_field +from types import CodeType from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +from typing_extensions import ParamSpec, TypeVar import sympy @@ -56,6 +57,7 @@ ) from torch._subclasses.fake_tensor import FakeTensor from torch._utils_internal import signpost_event +from torch.export.dynamic_shapes import _ConstraintTarget from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -65,6 +67,7 @@ ShapeEnv, Specialization, ) +from torch.fx.node import Target from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet @@ -153,9 +156,9 @@ if TYPE_CHECKING: + from torch._dynamo.package import CompilePackage from torch._dynamo.symbolic_convert import InstructionTranslatorBase - log = logging.getLogger(__name__) graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") @@ -187,31 +190,31 @@ class MutationInfo: class VariableTrackerCache: - def __init__(self): - self.cache = {} + def __init__(self) -> None: + self.cache: dict[VariableTrackerCacheKey, VariableTracker] = {} - def lookup(self, value, source): + def lookup(self, value: Any, source: Source) -> Optional[VariableTracker]: key = VariableTrackerCacheKey(id(value), source) if key not in self.cache: return None return self.cache[key] - def add(self, value, source, vt): + def add(self, value: Any, source: Source, vt: VariableTracker) -> None: key = VariableTrackerCacheKey(id(value), source) self.cache[key] = vt - def clone(self): + def clone(self) -> "VariableTrackerCache": # Needed for copy and restore graph state new_cache = VariableTrackerCache() new_cache.cache.update(self.cache) return new_cache - def clear(self): + def clear(self) -> None: self.cache.clear() @functools.cache -def _step_logger(): +def _step_logger() -> Any: return torchdynamo_logging.get_step_logger(log) @@ -222,16 +225,16 @@ class GraphCompileReason: reason: str user_stack: list[traceback.FrameSummary] - # Indicates if this was a graph compile reason due to graph break. + # Indicates if this was a graph break reason due to graph break. graph_break: bool = True - def __post_init__(self): + def __post_init__(self) -> None: if self.graph_break: graph_break_reasons.append(self) -def _get_gen_rand_values_fn(random_calls): - def _gen_rand_values(): +def _get_gen_rand_values_fn(random_calls: Any) -> Callable[[], list[Any]]: + def _gen_rand_values() -> list[Any]: return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] return _gen_rand_values @@ -248,16 +251,18 @@ def __init__(self, nn_modules: dict[str, torch.nn.Module]): def __repr__(self) -> str: return "FakeRootModule(...)" - def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]): + def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]) -> None: for k, v in nn_modules.items(): setattr(self, k, v) class WrapperBackend: - def __init__(self, backend: CompilerFn): + def __init__(self, backend: CompilerFn) -> None: self.backend: CompilerFn = backend - def __call__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]): + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> CompiledFn: self.restore = checkpoint_params(gm) self.gm = gm copy_gm = copy.deepcopy(self.gm) @@ -320,15 +325,15 @@ class OutputGraphGuardsState: _aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None @property - def shape_env(self): + def shape_env(self) -> ShapeEnv: raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}") @property - def guards(self): + def guards(self) -> Optional[torch._guards.GuardsSet]: return self._guards @property - def aotautograd_guards(self): + def aotautograd_guards(self) -> Optional[list[torch._guards.GuardEnvExpr]]: return self._aotautograd_guards @@ -345,7 +350,7 @@ class StackLocalsMetadata: locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list) -def get_builtins_dict(global_scope): +def get_builtins_dict(global_scope: Scope) -> dict[str, Any]: # f_globals["__builtins__"] can be a dict or a module. This is an # implementation detail - # https://docs.python.org/3/library/builtins.html. @@ -382,16 +387,16 @@ def __init__( self, code_options: dict[str, Any], compiler_fn: Optional[CompilerFn], - root_tx, + root_tx: "InstructionTranslatorBase", export: bool, - export_constraints, - frame_state, + export_constraints: Sequence[_ConstraintTarget], + frame_state: Any, local_scope: Scope, global_scope: Scope, - f_code, - torch_function_mode_stack, - package, - ): + f_code: CodeType, + torch_function_mode_stack: list[torch.overrides.TorchFunctionMode], + package: Optional["CompilePackage"], + ) -> None: super().__init__( local_scope, global_scope, @@ -410,7 +415,7 @@ def __init__( # de-duplicate graph inputs by source and reuse the tracker self.input_source_to_var: dict[Source, VariableTracker] = {} self.export = export - self.export_constraints = export_constraints + self.export_constraints = export_constraints # type: ignore[assignment] self.frame_state = frame_state self.cleanup_hooks: list[Callable[[], Any]] = [] # compile_id is an id number for the current torch.compile @@ -575,7 +580,7 @@ def __init__( self.maybe_install_saved_tensors_hooks_subgraphs() ) - def mark_bytecode_tracing_start(self): + def mark_bytecode_tracing_start(self) -> None: self.compiler_trace_stack.enter_context( dynamo_timed( "bytecode_tracing", @@ -583,20 +588,22 @@ def mark_bytecode_tracing_start(self): ) ) - def mark_bytecode_tracing_stop(self): + def mark_bytecode_tracing_stop(self) -> None: self.compiler_trace_stack.close() - def install_builtins_dict_in_fglobals(self): + def install_builtins_dict_in_fglobals(self) -> str: f_builtins = get_builtins_dict(self.global_scope) return self.install_global("__builtins_dict__", f_builtins) - def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): + def add_backward_state_hook( + self, hook: VariableTracker, prefix: str = "hook" + ) -> tuple[str, torch.fx.Proxy]: name = f"{prefix}{len(self.backward_state)}" assert name not in self.backward_state self.backward_state[name] = hook return name, self.get_backward_state_proxy() - def get_backward_state_proxy(self): + def get_backward_state_proxy(self) -> torch.fx.Proxy: if self.backward_state_proxy is None: if self.export: unimplemented_v2( @@ -617,7 +624,7 @@ def get_backward_state_proxy(self): return self.backward_state_proxy # This gets its own helper function so guards DEBUG logs are more informative - def init_ambient_guards(self): + def init_ambient_guards(self) -> None: # Register a SHAPE_ENV guard to make sure we setup shape guards # that show up in ShapeEnv self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) @@ -675,7 +682,7 @@ def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]: assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0" return [pack_subgraph_name, unpack_subgraph_name] - def dump_guards_state(self): + def dump_guards_state(self) -> OutputGraphGuardsState: return OutputGraphGuardsState( local_scope=self.local_scope, global_scope=self.global_scope, @@ -693,7 +700,9 @@ def dump_guards_state(self): _aotautograd_guards=self.aotautograd_guards, ) - def synthetic_graph_input(self, fn, args): + def synthetic_graph_input( + self, fn: Callable[..., Any], args: tuple[Any, ...] + ) -> VariableTracker: """ call fn(*args) before the graph runs and turn the result into a fake input. """ @@ -719,45 +728,45 @@ def synthetic_graph_input(self, fn, args): ) return result - def add_cleanup_hook(self, fn: Callable[[], Any]): + def add_cleanup_hook(self, fn: Callable[[], Any]) -> None: self.cleanup_hooks.append(fn) - def call_cleanup_hooks(self): + def call_cleanup_hooks(self) -> None: for hook in reversed(self.cleanup_hooks): hook() self.cleanup_hooks.clear() @property - def root_tracer(self): + def root_tracer(self) -> "SubgraphTracer": return self.tracers[0] @property - def current_tracer(self): + def current_tracer(self) -> "SubgraphTracer": return self.tracers[-1] - def is_root_tracer(self): + def is_root_tracer(self) -> bool: # Helper to tell if we are inside the higher order operator tracing. return len(self.tracers) == 1 @property - def graph(self): + def graph(self) -> torch.fx.Graph: return self.current_tracer.graph # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. @graph.setter - def graph(self, value): + def graph(self, value: torch.fx.Graph) -> None: self.current_tracer.graph = value @property - def input_name_to_proxy(self): + def input_name_to_proxy(self) -> dict[str, fx.Proxy]: return self.current_tracer.input_name_to_proxy @property - def real_value_cache(self): + def real_value_cache(self) -> dict[fx.Node, torch.Tensor]: return self.current_tracer.real_value_cache @property - def bound_symbols(self): + def bound_symbols(self) -> dict[sympy.Symbol, Union[torch.fx.Proxy, "LazyProxy"]]: return self.current_tracer.bound_symbols # If you are here, and you're looking for create_graph_input, @@ -766,17 +775,19 @@ def bound_symbols(self): # - self.root_tracer.create_graph_input # See NOTE [HigherOrderOperator tracing design] for more context. - def create_proxy(self, *args, **kwargs): + def create_proxy(self, *args: Any, **kwargs: Any) -> torch.fx.Proxy: return self.current_tracer.create_proxy(*args, **kwargs) - def create_node(self, *args, **kwargs): + def create_node(self, *args: Any, **kwargs: Any) -> torch.fx.Node: return self.current_tracer.create_node(*args, **kwargs) - def remove_node(self, *args, **kwargs): + def remove_node(self, *args: Any, **kwargs: Any) -> None: return self.current_tracer.remove_node(*args, **kwargs) @contextlib.contextmanager - def subtracer(self, source_target, prior_tracer): + def subtracer( + self, source_target: Optional[Target], prior_tracer: "SubgraphTracer" + ) -> Generator[fx.Tracer, None, None]: new_scope_ctx = enter_new_scope() try: if prior_tracer: @@ -800,16 +811,18 @@ def subtracer(self, source_target, prior_tracer): self.tracers.pop() @property - def output(self): + def output(self) -> "OutputGraph": return self @property - def fake_mode(self): + def fake_mode(self) -> torch._subclasses.FakeTensorMode: + assert self.tracing_context.fake_mode is not None return self.tracing_context.fake_mode @property - def shape_env(self): + def shape_env(self) -> ShapeEnv: assert self.tracing_context.fake_mode is not None + assert self.tracing_context.fake_mode.shape_env is not None return self.tracing_context.fake_mode.shape_env @property @@ -821,10 +834,12 @@ def nn_modules(self) -> dict[str, Any]: return self.tracing_context.module_context.nn_modules @property - def aotautograd_guards(self): + def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]: return self.tracing_context.guards_context.aotautograd_guards - def save_global_state(self, out=None): + def save_global_state( + self, out: Optional[dict[str, tuple[Callable[..., Any], bool]]] = None + ) -> None: """ Saves to out if it is provided. Else saves to the tracing context's global_state. """ @@ -860,23 +875,23 @@ def save_global_state(self, out=None): torch.is_autocast_cache_enabled(), ) - def push_tx(self, tx): + def push_tx(self, tx: "InstructionTranslatorBase") -> None: self._current_tx.append(tx) - def pop_tx(self): + def pop_tx(self) -> "InstructionTranslatorBase": return self._current_tx.pop() @property - def current_tx(self): + def current_tx(self) -> "InstructionTranslatorBase": return self.root_tx if not self._current_tx else self._current_tx[-1] - def count_calls(self): + def count_calls(self) -> int: return count_calls(self.graph) - def is_empty_graph(self): + def is_empty_graph(self) -> bool: return len(list(self.graph.nodes)) == 0 - def get_submodule(self, keys): + def get_submodule(self, keys: str) -> Union[torch.nn.Module, Any]: assert keys obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules for k in keys.split("."): @@ -886,7 +901,7 @@ def get_submodule(self, keys): obj = getattr(obj, k) return obj - def new_var(self, name="tmp"): + def new_var(self, name: str = "tmp") -> str: existing = set(self.code_options["co_varnames"]) # In common case, this will be O(1) while True: @@ -895,13 +910,13 @@ def new_var(self, name="tmp"): self.code_options["co_varnames"] += (var,) return var - def update_co_names(self, name): + def update_co_names(self, name: str) -> None: """Ensure self.code_options.co_names contains name""" if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) @staticmethod - def module_key_name(*names): + def module_key_name(*names: Any) -> str: # create a new unique name name = "_".join(map(str, names)) # Strip the guard lookup L/G access @@ -930,9 +945,9 @@ def register_static_attr_and_return_proxy( def register_attr_or_module( self, target: Union[torch.nn.Module, torch.Tensor, Any], - *names, - **options, - ): + *names: Any, + **options: Any, + ) -> VariableTracker: if is_dynamic_nn_module(target, self.export): # Instead of returning UnspecializedNNModuleVariable, call # VariableTracker.build so that it is tracked for mutation. @@ -959,12 +974,13 @@ def register_attr_or_module( # are registered as get_attr nodes in the root graph. tracer = self.root_tracer - def wrap_name(module_key): + def wrap_name(module_key: str) -> VariableTracker: assert self.param_name_to_source is not None self.param_name_to_source[module_key] = source # Check if the attr has already been registered. This can happen # when two different sources point to the same tensor. + assert self.root_tx is not None if target in self.root_tx.output.side_effects: return self.root_tx.output.side_effects[target] @@ -986,8 +1002,8 @@ def wrap_name(module_key): # different sources pointing to the same tensor object. vt = self.root_tx.output.side_effects.track_object_existing(target, vt) - assert "tensor_dict" not in vt.proxy.node.meta - vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target) + assert "tensor_dict" not in vt.as_proxy().node.meta + vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target) return vt @@ -997,7 +1013,7 @@ def wrap_name(module_key): if source: install_guard(source.make_guard(GuardBuilder.NN_MODULE)) - def wrap_name(module_key): + def wrap_name(module_key: str) -> VariableTracker: return NNModuleVariable(type(target), module_key, target, **options) else: @@ -1005,7 +1021,7 @@ def wrap_name(module_key): # from higher order ops. NNModuleVariable tracker can't be # sourceless, so let's return a unspecializedNNModule variable # tracker. - def wrap_name(module_key): + def wrap_name(module_key: str) -> VariableTracker: return variables.UnspecializedNNModuleVariable(target, **options) elif isinstance(target, (torch.SymInt, torch.SymFloat)): @@ -1016,7 +1032,7 @@ def wrap_name(module_key): # own storage # alas, this is like this for now - def wrap_name(module_key): + def wrap_name(module_key: str) -> VariableTracker: return SymNodeVariable.create( self, self.create_proxy("get_attr", module_key, (), {}), @@ -1027,7 +1043,7 @@ def wrap_name(module_key): # HACKY CODE REGION END else: - def wrap_name(module_key): + def wrap_name(module_key: str) -> VariableTracker: self.output.update_co_names(module_key) self.global_scope[module_key] = target return VariableTracker.build( @@ -1046,7 +1062,7 @@ def wrap_name(module_key): self.nn_modules[name] = target if isinstance(target, torch.nn.Module): - def register_leaf_name(leaf_name): + def register_leaf_name(leaf_name: str) -> None: assert self.param_name_to_source is not None new_source = ParamBufferSource(source, leaf_name) new_name = f"{name}.{leaf_name}" @@ -1067,7 +1083,9 @@ def register_leaf_name(leaf_name): return wrap_name(name) - def handle_aliases_for_stolen_lists(self, tx): + def handle_aliases_for_stolen_lists( + self, tx: "InstructionTranslatorBase" + ) -> tuple[list[Instruction], dict[Source, Source]]: # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive maybe_gm = self.local_scope.get("self") stolen_list_names = get_locals_to_steal(maybe_gm) @@ -1153,7 +1171,9 @@ def handle_aliases_for_stolen_lists(self, tx): # other parts of Dynamo like guards. return alias_insts, overridden_sources - def _get_stack_values_to_restore(self, tx, stack_pops): + def _get_stack_values_to_restore( + self, tx: "InstructionTranslatorBase", stack_pops: int + ) -> tuple[list[VariableTracker], list[str], StackLocalsMetadata]: """ Gets the stack + locals values belonging to tx that need to be restored. @@ -1242,9 +1262,9 @@ def compile_subgraph( self, tx: "InstructionTranslatorBase", reason: GraphCompileReason, - partial_convert=False, - stack_pops=0, - ): + partial_convert: bool = False, + stack_pops: int = 0, + ) -> list[StackLocalsMetadata]: """ Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens: - Call the compiled subgraph @@ -1465,7 +1485,12 @@ def compile_subgraph( return all_stack_locals_metas - def codegen_suffix(self, tx, stack_values, cg): + def codegen_suffix( + self, + tx: "InstructionTranslatorBase", + stack_values: list[VariableTracker], + cg: PyCodegen, + ) -> None: # NOTE: `codegen_save_tempvars` must run first to update `source` fields # for variables with `AttributeMutationNew`, as they don't implement # `reconstruct` themselves. @@ -1474,6 +1499,7 @@ def codegen_suffix(self, tx, stack_values, cg): assert not self.export for name, val in self.backward_state.items(): cg(val) + assert self.backward_state_var is not None cg.append_output(cg.create_load(self.backward_state_var)) cg.store_attr(name) self.side_effects.codegen_hooks(cg) @@ -1489,7 +1515,7 @@ def codegen_suffix(self, tx, stack_values, cg): cg.restore_stack(stack_values, value_from_source=not tx.export) self.side_effects.codegen_update_mutated(cg) - def cleanup_graph(self): + def cleanup_graph(self) -> None: """ Remove "creation_timestamp" from node meta @@ -1519,8 +1545,8 @@ def cleanup_graph(self): self.graph.erase_node(node1) self.graph.erase_node(node2) - def get_graph_sizes_structured(self): - ret = {} + def get_graph_sizes_structured(self) -> dict[str, list[Union[int, str]]]: + ret: dict[str, list[Union[int, str]]] = {} for node in self.graph.nodes: example_value = node.meta.get("example_value", None) if isinstance(example_value, torch._subclasses.FakeTensor): @@ -1528,7 +1554,7 @@ def get_graph_sizes_structured(self): ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] return ret - def get_graph_sizes(self, name: str): + def get_graph_sizes(self, name: str) -> str: graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" graph_sizes_str += f"===== {name} =====\n" for node in self.graph.nodes: @@ -1554,7 +1580,7 @@ def get_graph_sizes(self, name: str): return graph_sizes_str @contextlib.contextmanager - def restore_global_state(self): + def restore_global_state(self) -> Any: """ Momentarily restores the global state to what it was prior to tracing the current output """ @@ -1571,7 +1597,7 @@ def restore_global_state(self): GlobalContextCheckpointState(current_global_state) ) - def run_compiler_collective(self): + def run_compiler_collective(self) -> None: tx = self.root_tx assert tx is not None if (ds := tx.distributed_state) is not None and ds.all_states is None: @@ -1595,7 +1621,7 @@ def run_compiler_collective(self): ), dynamo_timed("compiler_collective", log_pt2_compile_event=True), ): - all_states = [None] * compile_pg.size() + all_states: list[Any] = [None] * compile_pg.size() dist.all_gather_object(all_states, ds.local_state, group=compile_pg) ds.all_states = all_states # Clear speculation log, because are tracing may diverge due to @@ -1603,7 +1629,12 @@ def run_compiler_collective(self): tx.speculation_log.clear() raise exc.CompileCollectiveRestartAnalysis - def compile_and_call_fx_graph(self, tx, rv, root): + def compile_and_call_fx_graph( + self, + tx: "InstructionTranslatorBase", + rv: list[VariableTracker], + root: FakeRootModule, + ) -> list[Instruction]: """ Generate code from self.graph and return the Instruction()s to call that generated code. @@ -1630,7 +1661,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): {}, ) sub_gms = self.dedup_pass() - root.add_nn_modules(sub_gms) + root.add_nn_modules(sub_gms) # type: ignore[arg-type] self.current_tracer._maybe_preserve_original_meta(tx, output_node) if not config.do_not_emit_runtime_asserts: @@ -1773,8 +1804,8 @@ def compile_and_call_fx_graph(self, tx, rv, root): ) ) - @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") - def specialized_dispatch(*args, **kwargs): + @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") # type: ignore[misc] + def specialized_dispatch(*args: Any, **kwargs: Any) -> Any: for check_fn, specialization in specialization_guards: if check_fn(args): if specialization in specialization_cache: @@ -1903,16 +1934,16 @@ def _call_user_compiler( return compiled_fn - def dedup_pass(self): + def dedup_pass(self) -> dict[str, torch.fx.GraphModule]: if torch._dynamo.config.use_graph_deduplication: return apply_graph_deduplication(self) else: return {} - def install_subgraph(self, name, sub_gm): + def install_subgraph(self, name: str, sub_gm: torch.fx.GraphModule) -> str: next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True) - sub_gm.__name__ = next_name - sub_gm.torchdynamo_force_dynamic = False + sub_gm.__name__ = next_name # type: ignore[assignment] + sub_gm.torchdynamo_force_dynamic = False # type: ignore[assignment] # This graph module is not present in the user space, so it can't be # accessed by a source. Set source=None. self.register_attr_or_module(sub_gm, next_name, source=None) @@ -1941,7 +1972,7 @@ def remove_unused_graphargs(self) -> None: assert self.should_exit # Miniature DCE pass, but only for obviously trivial operations - def is_static_true(b_node: fx.node.Argument): + def is_static_true(b_node: fx.node.Argument) -> bool: if b_node is True: return True if not isinstance(b_node, fx.Node): @@ -1960,7 +1991,7 @@ def is_static_true(b_node: fx.node.Argument): # doesn't have unbacked inputs, since it's all in the ShapeEnv return False - def is_symnode_arg(a: fx.node.Argument): + def is_symnode_arg(a: fx.node.Argument) -> bool: from torch.fx.experimental.sym_node import SymTypes if isinstance(a, (int, float, bool)): @@ -1972,7 +2003,7 @@ def is_symnode_arg(a: fx.node.Argument): # NB: We assume that you cannot do mutations on int/float/bool, # because they are immutable types, and therefore is always safe to # DCE. - def is_symnode_compute_node(node): + def is_symnode_compute_node(node: fx.Node) -> bool: from torch.fx.experimental.sym_node import SymTypes if node.op != "call_function": @@ -2006,7 +2037,7 @@ def is_symnode_compute_node(node): ): self.remove_node(node) - def placeholder_binds_symbol(node): + def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]: arg = node.meta["grapharg"] example = arg.example if isinstance(example, torch.SymInt) and isinstance( @@ -2015,7 +2046,7 @@ def placeholder_binds_symbol(node): return example.node.expr return None - def remove_unused(node): + def remove_unused(node: fx.Node) -> None: log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) # I'm not really sure why you need to delete these from the # node since the node is going to get removed @@ -2025,7 +2056,9 @@ def remove_unused(node): used_symbols: set[sympy.Symbol] = set() - def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): + def update_used_symbols( + used_symbols: set[sympy.Symbol], fake: Union[torch.SymInt, torch.Tensor] + ) -> None: used_symbols |= free_symbols(fake) recheck_placeholders = [] @@ -2119,7 +2152,7 @@ def add_output_instructions(self, prefix: list[Instruction]) -> None: self.output_instructions.extend(prefix) self.should_exit = True - def install_global_unsafe(self, name, value) -> None: + def install_global_unsafe(self, name: str, value: Any) -> None: """ WARNING: prefer the safer `install_global_by_id/install_global`. torch.compile instances should be independent of each other; @@ -2131,7 +2164,7 @@ def install_global_unsafe(self, name, value) -> None: self.installed_globals.add(name) self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) - def install_global_by_id(self, prefix, value) -> str: + def install_global_by_id(self, prefix: str, value: Any) -> str: """ Installs a global if it hasn't been installed already. This is determined by (prefix, id(value)) pair. @@ -2146,7 +2179,7 @@ def install_global_by_id(self, prefix, value) -> str: self.install_global_unsafe(name, value) return name - def install_global(self, prefix, value) -> str: + def install_global(self, prefix: str, value: Any) -> str: """ Installs a global, generating a unique name for it. @@ -2160,7 +2193,7 @@ def install_global(self, prefix, value) -> str: def cleanup(self) -> None: # There is a reference cycle between tracer and OutputGraph, causing # some of the tensor objects to be held alive for longer than necessary. - self.root_tx = None + self.root_tx = None # type: ignore[assignment] self.nn_modules.clear() self.param_name_to_source = None @@ -2183,7 +2216,7 @@ def add_graph_finalizer( ) -> None: self.register_finalizer_fns.append(register_finalizer) - def example_value_from_input_node(self, node: torch.fx.Node): + def example_value_from_input_node(self, node: torch.fx.Node) -> Any: """Extract the non-fake example tensor""" if node.op == "placeholder": return node.meta["grapharg"].example @@ -2200,16 +2233,18 @@ def example_value_from_input_node(self, node: torch.fx.Node): ) -def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): +def check_pt2_compliant_op( + output_graph: OutputGraph, kind: str, target: Any, args: Any, kwargs: Any +) -> None: if kind != "call_function": return - def encountered_compliant_op(target): + def encountered_compliant_op(target: torch._ops.OpOverload) -> None: if target.namespace in {"prim", "prims", "aten"}: return output_graph.compliant_custom_ops.add(target) - def encountered_non_compliant_op(target, msg): + def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> None: output_graph.non_compliant_ops.add(target) if config.only_allow_pt2_compliant_ops: unimplemented_v2( @@ -2275,15 +2310,24 @@ def encountered_non_compliant_op(target, msg): _compile_id_counter = itertools.count() +P = ParamSpec("P") +R = TypeVar("R") + class LazyProxy: - def __init__(self, tracer, fn, *args, **kwargs): + def __init__( + self, + tracer: "SubgraphTracer", + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: self.tracer = tracer self.fn = fn self.args = args self.kwargs = kwargs - def __call__(self): + def __call__(self) -> Any: return self.fn(*self.args, **self.kwargs) @@ -2295,7 +2339,13 @@ class SubgraphTracer(fx.Tracer): compiling and executing the graph. """ - def __init__(self, output_graph, parent=None, is_export=False, source_target=None): + def __init__( + self, + output_graph: "OutputGraph", + parent: Optional["SubgraphTracer"] = None, + is_export: bool = False, + source_target: Optional[Target] = None, + ) -> None: super().__init__() self.output_graph = weakref.proxy(output_graph) self.graph = torch.fx.Graph() @@ -2325,7 +2375,7 @@ def __init__(self, output_graph, parent=None, is_export=False, source_target=Non # need to keep track of what free variables were lifted so we can # rewrite the HigherOrderOperator call using the traced body_fn. # Dicts maintain the order of args for the HigherOrderOperator call. - self.lifted_freevars = {} + self.lifted_freevars: dict[fx.Proxy, fx.Proxy] = {} # map basic symbols (unbacked and unbacked) to their bound proxies. # There are only two cases where bound_symbols will be recorded: @@ -2354,15 +2404,15 @@ def __init__(self, output_graph, parent=None, is_export=False, source_target=Non self.debug_level: int = parent.debug_level + 1 if parent is not None else 0 self._cur_code = None - self._orig_gm_meta = None - self._orig_gm_lineno_map = None - self._orig_gm_firstlineno = None + self._orig_gm_meta: Optional[list[Any]] = None + self._orig_gm_lineno_map: Optional[dict[int, Optional[int]]] = None + self._orig_gm_firstlineno: Optional[int] = None # Each SubgraphTracer is associated with a source target, which indicates # which operator this subgraph is attached to. We compute a source_fn_stack # based on the source target. For the root tracer, it's set to []. # This is useful for debugging and transforming the exported graph. if self.parent is None: - self.source_fn_stack = [] + self.source_fn_stack: list[Any] = [] else: self.source_fn_stack = self.parent.source_fn_stack + [ (self.graph._target_to_str(source_target), source_target) @@ -2379,7 +2429,9 @@ def __init__(self, output_graph, parent=None, is_export=False, source_target=Non ) # preserve original meta if it is available - def _maybe_preserve_original_meta(self, tx, node): + def _maybe_preserve_original_meta( + self, tx: "InstructionTranslatorBase", node: fx.Node + ) -> None: if ( self._orig_gm_meta and self._orig_gm_lineno_map @@ -2401,14 +2453,14 @@ def _maybe_preserve_original_meta(self, tx, node): def create_proxy( self, - kind, - target, - args, - kwargs, - name=None, - type_expr=None, - proxy_factory_fn=None, - ): + kind: str, + target: Any, + args: Any, + kwargs: Any, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Optional[Callable[[fx.Node], fx.Proxy]] = None, + ) -> fx.Proxy: # NOTE: [Nested SubgraphTracer and free_variable handling] # -------------------------------------------------------- # Read NOTE [HigherOrderOperator tracing design] first. @@ -2452,7 +2504,13 @@ def create_proxy( args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) rv = super().create_proxy( - kind, target, args, kwargs, name, type_expr, proxy_factory_fn + kind, + target, + args, + kwargs, + name, + type_expr, + proxy_factory_fn, # type: ignore[arg-type] ) # append stack trace to fx node @@ -2473,7 +2531,7 @@ def create_proxy( tx_code = tx.f_code header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) - def get_trace_call_log_str(): + def get_trace_call_log_str() -> str: line = get_instruction_source_311(tx_code, cur_inst).rstrip() return f"TRACE FX call {rv.node.name} from {header}\n{line}" @@ -2583,8 +2641,14 @@ def get_trace_call_log_str(): return rv def create_node( - self, op, target, args=None, kwargs=None, name=None, type_expr=None - ): + self, + op: str, + target: Target, + args: Any = None, + kwargs: Any = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> fx.Node: check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) if self.parent is not None: flat_args = pytree.arg_tree_leaves(*args, **kwargs) @@ -2602,7 +2666,7 @@ def create_node( # Note: we did not override erase_node since # we call self.graph.erase_node elsewhere - def remove_node(self, node): + def remove_node(self, node: fx.Node) -> None: if len(node.users) > 0: user_graph_nodes: list[torch.fx.Node] = [] for user in node.users.keys(): @@ -2625,8 +2689,13 @@ def remove_node(self, node): # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets # fixed. def create_graph_input( - self, name, type_expr, example_value, before=False, source=None - ): + self, + name: str, + type_expr: Any, + example_value: Any, + before: bool = False, + source: Optional[Source] = None, + ) -> fx.Proxy: if isinstance(example_value, torch.Tensor): self._input_versions_at_beginning.append(example_value._version) log.debug( @@ -2652,6 +2721,7 @@ def create_graph_input( # So we are a bit more strict about what sources can become inputs # in export if self.is_export and self.parent is None: + assert source is not None if not is_from_local_source(source, only_allow_input=True): self.output_graph.source_to_user_stacks.setdefault(source, []).append( TracingContext.extract_stack() @@ -2734,7 +2804,9 @@ def create_graph_input( return proxy # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details - def lift_tracked_freevar_to_input(self, proxy): + def lift_tracked_freevar_to_input( + self, proxy: fx.Proxy + ) -> Union[LazyProxy, fx.Proxy]: # You're doing something wrong if we are the root SubgraphTracer because # Dynamo adds tensors to graph inputs before creating a proxy for them. assert self.parent is not None, ( @@ -2774,7 +2846,7 @@ def lift_tracked_freevar_to_input(self, proxy): self.lifted_freevars[proxy] = new_proxy return new_proxy - def maybe_lift_tracked_freevar_to_input(self, arg): + def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any: """ If arg is a free variable, then lift it to be an input. Returns the new lifted arg (if arg was a freevar), else the @@ -2806,8 +2878,8 @@ def maybe_lift_tracked_freevar_to_input(self, arg): # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies # for symbols that're not going to be used. def track_unbacked_symbols( - self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy] - ): + self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy] + ) -> None: # When binding the symbols in an exmaple_value, we bind the symbols # to the proxy's associated Tracer instead of current tracer. # This is because: @@ -2824,7 +2896,7 @@ def track_unbacked_symbols( tracer = e_proxy.tracer assert isinstance(tracer, SubgraphTracer) - def need_bind(s) -> bool: + def need_bind(s: Any) -> bool: from torch.fx.experimental.symbolic_shapes import is_symbolic return ( @@ -2834,7 +2906,9 @@ def need_bind(s) -> bool: and s.node.expr not in self.bound_symbols ) - def _proxy_with_example_value(example_value, *args, **kwargs): + def _proxy_with_example_value( + example_value: Any, *args: Any, **kwargs: Any + ) -> fx.Proxy: proxy = tracer.create_proxy(*args, **kwargs) set_example_value(proxy.node, example_value) return proxy @@ -2906,7 +2980,7 @@ def _proxy_with_example_value(example_value, *args, **kwargs): # See Note [Auto lift basic free symbols when create_graph_input] def _lift_basic_symbols( self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source] - ): + ) -> None: # The before arg is for inserting symints in the sizes/strides of a tensor # before the tensor. This ordering ensures that when we look at the tensor's # symbols, they're already lifted/tracked. E.g. this assumption is used @@ -2930,7 +3004,7 @@ def _lift_symbols_in_symint( self.parent._lift_basic_symbols(s, source) for s0 in self_to_be_bound: parent_proxy = self.parent.bound_symbols[s0] - example_val = parent_proxy.node.meta["example_value"] + example_val = parent_proxy.node.meta["example_value"] # type: ignore[union-attr] assert isinstance(example_val, torch.SymInt) ph = self.create_graph_input( str(s0), @@ -2945,7 +3019,7 @@ def _lift_symbols_in_symint( source.name() if source is not None else "subgraph inputs", self.debug_level, ) - self.lifted_freevars[parent_proxy] = ph + self.lifted_freevars[parent_proxy] = ph # type: ignore[index] # For root_tracer: else: assert len(self_to_be_bound) == 1, ( @@ -3055,7 +3129,7 @@ def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]: # Sort the symbols so that we can have a deterministic lifting order return sorted(to_be_bound, key=lambda s: s.name) - def has_input_mutation(self): + def has_input_mutation(self) -> MutationInfo: input_versions_at_beginning = self._input_versions_at_beginning input_nodes = [] @@ -3084,7 +3158,7 @@ def has_input_mutation(self): return MutationInfo(False, "") - def has_aliasing(self): + def has_aliasing(self) -> AliasingInfo: from torch._higher_order_ops.utils import _collect_fake_inputs input_storages: dict[StorageWeakRef, torch.fx.Node] = dict() diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index b15dc0b2fdf6..311a702dfa38 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -112,7 +112,17 @@ class InlinedSource: @dataclasses.dataclass -class _DynamoCodeCacheEntry: +class DynamoCaptureOutput: + """ + Core information generated from Dynamo for fullgraph=True. + """ + + guarded_codes: list[_GuardedCodeCacheEntry] + backend_ids: list[_BackendId] + + +@dataclasses.dataclass +class _DynamoCodeCacheEntry(DynamoCaptureOutput): """ Contains the serializable information associated with a single code object in dynamo. To restore an execution of compiled code, we will need the following @@ -135,9 +145,7 @@ class _DynamoCodeCacheEntry: python_code: SerializedCode python_module: str function_names: list[_FunctionId] - guarded_codes: list[_GuardedCodeCacheEntry] import_sources: dict[str, str] - backend_ids: list[_BackendId] code_source: Optional[str] install_to_global: bool has_compile_id: bool = False diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 9bdec2df05c2..5e12e0dc36a8 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -173,6 +173,7 @@ class CodeState: _INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None _CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None +_LOGGED_DYNAMIC_ALLOWLIST: bool = False @dataclasses.dataclass(frozen=True) @@ -521,9 +522,9 @@ def process_automatic_dynamic( def get_cache_key() -> Optional[str]: # TODO: info versions of these logs that log only once - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: warn_once( - "dynamo_pgo force disabled by torch._inductor.config.force_disable_caches" + "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches" ) return None @@ -566,7 +567,7 @@ def code_state_path(cache_key: str) -> Optional[str]: def should_use_remote_dynamo_pgo_cache() -> bool: - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: return False if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: @@ -616,6 +617,7 @@ def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]: def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: + global _LOGGED_DYNAMIC_ALLOWLIST code_id = CodeId.make(f_code) frame_state = get_code_state()[code_id] frame_whitelist = ",".join(_collect_dynamic_sources(frame_state)) @@ -624,6 +626,15 @@ def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: CompileEventLogger.pt2_compile( name, recompile_dynamic_whitelist=frame_whitelist ) + if not _LOGGED_DYNAMIC_ALLOWLIST: + torch._utils_internal.add_mlhub_insight( + category="dynamic_shapes_analysis", + insight="Dynamic shapes detected", + insight_description="PGO detected a recompilation due to dynamic shapes. \ + Please follow the instruction from the action link to reduce shape recompilations.", + ) + # add mlhub insight only once per job + _LOGGED_DYNAMIC_ALLOWLIST = True def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index 1548dc798903..d3544fa354fa 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -93,7 +93,11 @@ def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-unt if sentinel is _SENTINEL_MISSING: iterable = fn_or_iterable if hasattr(iterable, "__iter__"): - return iterable.__iter__() + iterator = iterable.__iter__() + if hasattr(iterator, "__next__"): + return iterator + else: + raise TypeError(f"'{type(iterator)}' object is not iterable") if hasattr(iterable, "__getitem__"): # Needs to be a new function to avoid iter becoming a generator def sequence_protocol(iterable): # type: ignore[no-untyped-def] diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 0126e738eb1e..745df38496ff 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -5,8 +5,9 @@ from __future__ import annotations import itertools +import operator import sys -from typing import Callable, overload, TYPE_CHECKING, TypeVar +from typing import Callable, Optional, overload, TYPE_CHECKING, TypeVar from typing_extensions import TypeAlias from ..decorators import substitute_in_graph @@ -17,9 +18,11 @@ __all__ = [ + "accumulate", "chain", "chain_from_iterable", "compress", + "cycle", "dropwhile", "islice", "tee", @@ -41,6 +44,35 @@ def chain(*iterables: Iterable[_T]) -> Iterator[_T]: yield from iterable +# Reference: https://docs.python.org/3/library/itertools.html#itertools.accumulate +@substitute_in_graph(itertools.accumulate, is_embedded_type=True) # type: ignore[arg-type] +def accumulate( + iterable: Iterable[_T], + func: Optional[Callable[[_T, _T], _T]] = None, + *, + initial: Optional[_T] = None, +) -> Iterator[_T]: + # call iter outside of the generator to match cypthon behavior + iterator = iter(iterable) + if func is None: + func = operator.add + + def _accumulate(iterator: Iterator[_T]) -> Iterator[_T]: + total = initial + if total is None: + try: + total = next(iterator) + except StopIteration: + return + + yield total + for element in iterator: + total = func(total, element) + yield total + + return _accumulate(iterator) + + @substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: # previous version of this code was: @@ -59,6 +91,24 @@ def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]: return (datum for datum, selector in zip(data, selectors) if selector) +# Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle +@substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type] +def cycle(iterable: Iterable[_T]) -> Iterator[_T]: + iterator = iter(iterable) + + def _cycle(iterator: Iterator[_T]) -> Iterator[_T]: + saved = [] + for element in iterable: + yield element + saved.append(element) + + while saved: + for element in saved: + yield element + + return _cycle(iterator) + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile @substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type] def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index b131160db25e..5d01217fdbb6 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -15,8 +15,9 @@ import dataclasses from dataclasses import field +from io import BufferedReader, BufferedWriter from types import CellType, CodeType, ModuleType -from typing import Any, IO +from typing import Any, IO, Union from typing_extensions import Self from torch.utils._import_utils import import_dill @@ -51,12 +52,12 @@ class ExecutionRecord: builtins: dict[str, Any] = field(default_factory=dict) code_options: dict[str, Any] = field(default_factory=dict) - def dump(self, f: IO[str]) -> None: + def dump(self, f: Union[IO[str], BufferedWriter]) -> None: assert dill is not None, "replay_record requires `pip install dill`" dill.dump(self, f) @classmethod - def load(cls, f: IO[bytes]) -> Self: + def load(cls, f: Union[IO[bytes], BufferedReader]) -> Self: assert dill is not None, "replay_record requires `pip install dill`" return dill.load(f) diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 71f552a83b4a..136d2af1a608 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -34,6 +34,24 @@ from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack +from torch.utils._triton import has_triton + + +if has_triton(): + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction +else: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + class Heuristics: # type: ignore[no-redef] + pass + + import torch import torch.fx as fx import torch.nn as nn @@ -58,6 +76,7 @@ ) from torch._dynamo.utils import clone_inputs, counters, same from torch._environment import is_fbcode +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.output_code import OutputCode from torch._library.fake_class_registry import FakeScriptObject @@ -302,6 +321,16 @@ def generate_compiler_repro_string( """ ).strip() + triton_imports = "" + + if len(kernel_side_table.id_to_kernel) > 0: + triton_imports = textwrap.dedent( + """ +import triton +import triton.language as tl + """ + ).strip() + model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -312,6 +341,7 @@ def generate_compiler_repro_string( from math import inf import torch._inductor.inductor_prims {distributed_imports} +{triton_imports} {generate_config_string(stable_output=stable_output)} @@ -330,6 +360,53 @@ def generate_compiler_repro_string( model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() + kernel_side_table_prefix = ( + "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" + ) + # Track which grid entry corresponds to the best config + for id in kernel_side_table.id_to_kernel: + kernel = kernel_side_table.get_kernel(id) + + if isinstance(kernel, Autotuner): + if isinstance(kernel.fn, Heuristics): + model_str += "ERROR: Repro will not work as intended, " + model_str += ( + "triton.runtime.autotuner.Heuristics is not currently supported\n" + ) + break + + config_strs = [] + for kernel_config in kernel.configs: + config_strs.append(f"""triton.Config( + {str(kernel_config.kwargs)}, + num_warps={kernel_config.num_warps}, + num_stages={kernel_config.num_stages}, + )""") + + config_str = ",".join(config_strs) + model_str += textwrap.dedent(f""" + @triton.autotune( + configs=[ + {config_str} + ], + key=[] + ) + """).strip() + + model_str += "\n@triton.jit\n" + src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src + fn_name = ( + kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name + ) + fn_name = fn_name.split(".")[-1] + + model_str += src_code + model_str += "\n" + model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" + + if len(kernel_side_table.constant_args) > 0: + model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" + model_str += NNModuleToString.convert(gm) writer = InputWriter(save_dir, stable_hash=stable_hash) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 03b0c3d1c7e0..6897ddd9b24c 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -266,6 +266,38 @@ def name(self) -> str: return f"object.__getattribute__({self.base.name()}, {self.member!r})" +# Represents obj.__dict__ where obj is a type object +@dataclasses.dataclass(frozen=True) +class TypeDictSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs("__dict__")) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + # type(ob).__dict__ can return a proxy of the dict. But in the C++ + # guard accessor, we are use type->tp_dict which is a dict. So, + # forcefully pass a dict object to ensure that the GuardManager + # registers that its working on a dict object. + return f"dict({self.base.name()}.__dict__)" + + +# Represents obj.__mro__ where object is type object +@dataclasses.dataclass(frozen=True) +class TypeMROSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs("__mro__")) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + return f"{self.base.name()}.__mro__" + + @dataclasses.dataclass(frozen=True) class LocalCellSource(Source): """ @@ -285,6 +317,34 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # local cell object should never be used for guards. +# Represents obj.__code__ where object is type object +@dataclasses.dataclass(frozen=True) +class CodeSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs("__code__")) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + return f"{self.base.name()}.__code__" + + +# Represents obj.__closure__ where object is type object +@dataclasses.dataclass(frozen=True) +class ClosureSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs("__closure__")) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + return f"{self.base.name()}.__closure__" + + # Represents tensor.grad source. It could be represented by AttrSource as well. # But, we could access grad field on tensor directly in C++ without going # through the Python bytecodes. Therefore, we use a separate source for grad @@ -342,6 +402,18 @@ def is_ephemeral(self) -> bool: return True +@dataclasses.dataclass(frozen=True) +class SkipGuardSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen") -> None: + self.base.reconstruct(codegen) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + return self.base.name() + + class TensorProperty(enum.Enum): SIZE = 0 STRIDE = 1 @@ -1006,6 +1078,14 @@ def is_from_nonlocal_source(source: Source) -> bool: ) +def is_from_closure_source(source: Source) -> bool: + if isinstance(source, ClosureSource): + return True + if isinstance(source, ChainedSource): + return is_from_closure_source(source.base) + return False + + def is_from_source(source: Source, target: Source) -> bool: if isinstance(source, ChainedSource): return is_from_source(source.base, target) @@ -1083,3 +1163,14 @@ def is_from_defaults(source: Source) -> bool: if isinstance(source, ChainedSource): return is_from_defaults(source.base) return False + + +@functools.lru_cache +def is_from_skip_guard_source(source: Source) -> bool: + if isinstance(source, SkipGuardSource): + return True + + if isinstance(source, ChainedSource): + return is_from_skip_guard_source(source.base) + + return False diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 129af950aa1a..8e5a1ef80393 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -108,6 +108,7 @@ GlobalWeakRefSource, LocalCellSource, LocalSource, + SkipGuardSource, Source, ) from .trace_rules import is_builtin_constant, is_forbidden @@ -145,6 +146,7 @@ from .variables.lazy import LazyVariableTracker from .variables.lists import ( BaseListVariable, + IteratorVariable, ListIteratorVariable, ListVariable, SliceVariable, @@ -442,6 +444,15 @@ def impl(self: "InstructionTranslator", inst: Instruction): return impl +def is_stdlib(mod): + if sys.version_info < (3, 10): + # For < 3.10, no easy way to identify a stdlib module name. + return False + if not isinstance(mod, types.ModuleType): + return False + return mod.__name__.split(".")[0] in sys.stdlib_module_names + + def _detect_and_normalize_assert_statement( self: "InstructionTranslatorBase", truth_fn: typing.Callable[[object], bool], @@ -4099,6 +4110,12 @@ def get_globals_source_and_value(self, name): # Dont use lazy vt because we will do a setattr afterwards fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment] + + if is_stdlib(fglobals_value): + # Users don't inplace mutate a stdlib attribute (like inspect, + # collections), skip guards that originate from the stdlib modules. + global_source = SkipGuardSource(global_source) # type: ignore[assignment] + return fglobals_value, fglobals_vt, global_source def _load_global(self, inst): @@ -4221,7 +4238,7 @@ def SEND(self, inst): assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] - if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or ( + if isinstance(tos, (IteratorVariable, LocalGeneratorObjectVariable)) or ( isinstance(tos, UserDefinedObjectVariable) and isinstance(tos.value, collections.abc.Iterator) ): diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 85e44b7c7e48..f0f1dab4f9c8 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -204,9 +204,9 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None: graph = OutputGraph( code_options={}, compiler_fn=None, - root_tx=None, + root_tx=None, # type: ignore[arg-type] export=False, - export_constraints=None, + export_constraints=[], frame_state={"_id": 0}, # TODO: shouldn't this be f_locals/f_globals from frame? local_scope=locals(), diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index a3beb561f186..56b5e508f058 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2963,6 +2963,7 @@ "torch.xpu.random.seed_all", "torch.xpu.random.seed", "torch.xpu.set_stream", + "torch.xpu.stream", "torch.xpu.synchronize", ], TorchInGraphFunctionVariable, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index fc42934d98d0..c6707fe12fbd 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utility functions and classes used throughout the TorchDynamo system. @@ -62,7 +60,7 @@ TypeVar, Union, ) -from typing_extensions import Literal, TypeAlias, TypeGuard, TypeIs +from typing_extensions import Literal, ParamSpec, TypeAlias, TypeGuard, TypeIs import torch import torch._functorch.config @@ -106,6 +104,14 @@ ValuesView, ) + from torch._dynamo.replay_record import ExecutionRecord + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + InstructionTranslatorBase, + ) + from torch._dynamo.variables.base import VariableTracker + from torch._prims_common import DeviceLikeType + try: import numpy as np @@ -145,6 +151,7 @@ T = TypeVar("T") +_P = ParamSpec("_P") unpatched_nn_module_getattr = torch.nn.Module.__getattr__ unpatched_nn_module_call = torch.nn.Module.__call__ @@ -184,43 +191,43 @@ class ReinplaceCounters: # Track sizes of known not re-inplaced tensors (exclude dynamic shapes). @classmethod - def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int): + def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int) -> None: if bytes != 0: cls._values[f"missed_bytes_{trigger.name}"] += bytes # Track number of not re-inplaced tensors. @classmethod - def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int): + def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int) -> None: if count != 0: cls._values[f"missed_tensors_{trigger}"] += count @classmethod - def clear(cls): + def clear(cls) -> None: cls._values.clear() @classmethod - def get_total_missed(cls): + def get_total_missed(cls) -> int: sum = 0 for trigger in ReInplaceTrigger: sum += cls._values.get(f"missed_tensors_{trigger}", 0) return sum @classmethod - def get_total_missed_bytes(cls): + def get_total_missed_bytes(cls) -> int: sum = 0 for trigger in ReInplaceTrigger: sum += cls._values.get(f"missed_bytes_{trigger.name}", 0) return sum @classmethod - def log(cls): + def log(cls) -> None: # if not empty log. if cls._values: signpost_event("inductor", "reinplace_counters", cls._values) def tabulate( - rows: Union[list[tuple[str, object]], list[list[object]]], + rows: Union[list[tuple[str, Any]], list[list[Any]]], headers: Union[tuple[str, ...], list[str]], ) -> str: try: @@ -385,7 +392,7 @@ def log_instant_event( metadata: dict[str, Any], time_ns: Optional[int] = None, log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM, - ): + ) -> None: if time_ns is None: time_ns = time.time_ns() chromium_log = get_chromium_event_logger() @@ -407,7 +414,7 @@ def add_data( log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object, - ): + ) -> None: """ Centralized API for adding data to various events Log an event to a toplevel "dynamo" event or metrics context @@ -450,7 +457,7 @@ def add_data( @staticmethod def add_toplevel( log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object - ): + ) -> None: """ Syntactic sugar for logging to the toplevel event """ @@ -464,7 +471,7 @@ def add_toplevel( @staticmethod def increment( event_name: str, log_level: CompileEventLogLevel, key: str, value: int - ): + ) -> None: """ Increments an existing field, or adds it """ @@ -497,7 +504,7 @@ def increment_toplevel( key: str, value: int = 1, log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, - ): + ) -> None: """ Increments a value on the toplevel metric. By default, logs to metric. """ @@ -512,7 +519,7 @@ def increment_toplevel( @staticmethod def add_to_set( event_name: str, log_level: CompileEventLogLevel, key: str, value: Any - ): + ) -> None: """ Add metadata to a set of values with key . Creates a set if it doesn't exist. """ @@ -545,7 +552,7 @@ def add_to_set_toplevel( key: str, value: Any, log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, - ): + ) -> None: """ Same as add to set, just does it automatically to the toplevel event instead of having to explicitly name it. Defaults to COMPILATION_METRIC log level. @@ -561,7 +568,7 @@ def add_to_set_toplevel( # Helper functions that are syntactic sugar @staticmethod - def chromium(event_name: str, **metadata: object): + def chromium(event_name: str, **metadata: object) -> None: """ Add to in chromium. Each key/value of metadata will appear in the chromium trace. should be the name of a timed event span passed to `dynamo_timed`. @@ -571,7 +578,7 @@ def chromium(event_name: str, **metadata: object): ) @staticmethod - def pt2_compile(event_name: str, **metadata: object): + def pt2_compile(event_name: str, **metadata: object) -> None: """ Add to in chromium and PT2 Compile Events. Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes @@ -584,7 +591,7 @@ def pt2_compile(event_name: str, **metadata: object): ) @staticmethod - def compilation_metric(overwrite: bool = False, **metadata: object): + def compilation_metric(overwrite: bool = False, **metadata: object) -> None: """ Add to the CompilationMetrics context. Also logs to PT2 Compile Events and chromium. @@ -598,7 +605,7 @@ def compilation_metric(overwrite: bool = False, **metadata: object): @staticmethod def instant( event_name: str, metadata: dict[str, Any], time_ns: Optional[int] = None - ): + ) -> None: """ Log an instant event to chromium logs with name at time . The `args` field in Perfetto will point to metadata. should be a value obtained from time.time_ns(). @@ -608,7 +615,7 @@ def instant( ) @staticmethod - def try_add_pt2_compile(event_name: str, **metadata: object): + def try_add_pt2_compile(event_name: str, **metadata: object) -> None: """ Adds to an existing pt2_compile event, but silently returns if the event doesn't exist or ChromiumEventLogger is not initialized. @@ -620,7 +627,7 @@ def try_add_pt2_compile(event_name: str, **metadata: object): chromium_log.try_add_event_data(event_name, **metadata) @staticmethod - def try_(method_fn, *args, **kwargs): + def try_(method_fn: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs) -> None: """ Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set """ @@ -791,7 +798,9 @@ def compile_times( ) -> tuple[list[str], list[object]]: ... -def compile_times(repr="str", aggregate: bool = False): +def compile_times( # type: ignore[misc] + repr: str = "str", aggregate: bool = False +) -> Union[str, None, tuple[list[str], list[str]]]: """ Get metrics about torchdynamo frontend/backend compilation times. @@ -805,7 +814,7 @@ def compile_times(repr="str", aggregate: bool = False): per metric. """ - def fmt_fn(values, item_fn=lambda x: x): + def fmt_fn(values: list[float], item_fn: Callable[[float], str] = str) -> str: if aggregate: return item_fn(sum(values)) return ", ".join(map(item_fn, values)) @@ -852,8 +861,8 @@ def __init__(self, maxsize: int = 4096) -> None: self.maxsize = maxsize self.reset() - def reset(self): - self.set = OrderedDict() + def reset(self) -> None: + self.set: OrderedDict[Any, Any] = OrderedDict() def add(self, key: Union[str, tuple[object, object]]) -> bool: if key in self.set: @@ -870,7 +879,7 @@ def add(self, key: Union[str, tuple[object, object]]) -> bool: graph_break_dup_warning_checker = DuplicateWarningChecker() -def setup_compile_debug(): +def setup_compile_debug() -> contextlib.ExitStack: compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" if compile_debug: @@ -883,7 +892,7 @@ def reset_graph_break_dup_checker() -> None: graph_break_dup_warning_checker.reset() -def add_file_handler(): +def add_file_handler() -> contextlib.ExitStack: log_path = os.path.join(get_debug_dir(), "torchdynamo") os.makedirs(log_path, exist_ok=True) @@ -896,7 +905,7 @@ def add_file_handler(): return exitstack -def setup_log_file(): +def setup_log_file() -> contextlib.ExitStack: exitstack = contextlib.ExitStack() if config.log_file_name is not None: log_file_handler = logging.FileHandler(config.log_file_name) @@ -908,12 +917,12 @@ def setup_log_file(): return exitstack -def gen_record_file_name(exc, code) -> str: +def gen_record_file_name(exc: Exception, code: CodeType) -> str: return f"{get_debug_dir()}/error_recordings/\ {code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" -def write_record_to_file(filename: str, exec_record) -> None: +def write_record_to_file(filename: str, exec_record: ExecutionRecord) -> None: try: if os.path.exists(filename): log.warning( @@ -939,7 +948,7 @@ def identity(x: T) -> T: return x -def hashable(x): +def hashable(x: Any) -> bool: try: hash(x) return True @@ -950,39 +959,39 @@ def hashable(x): return False -def nothing(*args, **kwargs): +def nothing(*args: Any, **kwargs: Any) -> None: pass class ExactWeakKeyDictionary: """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" - def __init__(self): - self.values = {} - self.refs = {} + def __init__(self) -> None: + self.values: dict[int, Any] = {} + self.refs: dict[int, weakref.ReferenceType[Any]] = {} - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: return self.values[id(key)] - def get(self, key, default=None): + def get(self, key: Any, default: Any = None) -> Any: return self.values.get(id(key), default) - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: return id(key) in self.values - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: idx = id(key) if idx not in self.refs: self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) self.values[idx] = value - def _remove_id(self, idx): + def _remove_id(self, idx: int) -> None: if idx in self.values: del self.values[idx] if idx in self.refs: del self.refs[idx] - def clear(self): + def clear(self) -> None: self.refs.clear() self.values.clear() @@ -1001,7 +1010,7 @@ def istype( def istype(obj: object, allowed_types: Iterable[type]) -> bool: ... -def istype(obj, allowed_types): +def istype(obj: object, allowed_types: Any) -> bool: """isinstance() without subclasses""" if isinstance(allowed_types, (tuple, list, set)): return type(obj) in allowed_types @@ -1021,7 +1030,7 @@ def istype(obj, allowed_types): ) -def is_typing(value): +def is_typing(value: Any) -> bool: # _Final catches most of typing classes: # - Any # - Callable @@ -1035,7 +1044,7 @@ def is_typing(value): return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] -def is_numpy_int_type(value): +def is_numpy_int_type(value: Any) -> bool: if not np: return False @@ -1054,7 +1063,7 @@ def is_numpy_int_type(value): ) -def is_numpy_float_type(value): +def is_numpy_float_type(value: Any) -> bool: if not np: return False @@ -1166,11 +1175,11 @@ def is_wrapper_or_member_descriptor( ) -def unwrap_if_wrapper(fn): +def unwrap_if_wrapper(fn: Any) -> Any: return unwrap_with_attr_name_if_wrapper(fn)[0] -def unwrap_with_attr_name_if_wrapper(fn): +def unwrap_with_attr_name_if_wrapper(fn: Any) -> tuple[Any, Optional[str]]: # TODO(anijain2305) - Investigate if we can get rid of this function # unpack @torch._dynamo.optimize()(fn) wrapped function if is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): @@ -1181,14 +1190,14 @@ def unwrap_with_attr_name_if_wrapper(fn): return fn, attr_name -def is_numpy_ndarray(value): +def is_numpy_ndarray(value: Any) -> TypeGuard[np.ndarray]: # type: ignore[type-arg] if not np: return False return istype(value, np.ndarray) -def istensor(obj): +def istensor(obj: Any) -> bool: """Check of obj is a tensor""" tensor_list: tuple[type, ...] = ( torch.Tensor, @@ -1199,27 +1208,27 @@ def istensor(obj): return istype(obj, tensor_list) -def is_lazy_module(mod): +def is_lazy_module(mod: Any) -> bool: return isinstance(mod, LazyModuleMixin) @functools.lru_cache(4096) -def print_once(*args): +def print_once(*args: Any) -> None: print(*args) -def make_cell(val=None): +def make_cell(val: Any = None) -> types.CellType: """Some black magic to create a cell object that usually only exists in a closure""" x = val - def f(): + def f() -> Any: return x assert f.__closure__ is not None and len(f.__closure__) == 1 return f.__closure__[0] -def proxy_args_kwargs(args, kwargs): +def proxy_args_kwargs(args: Any, kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: try: proxy_args = tuple(arg.as_proxy() for arg in args) proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} @@ -1279,6 +1288,7 @@ class CompilationMetrics: compliant_custom_ops: Optional[set[str]] = None restart_reasons: Optional[set[str]] = None dynamo_time_before_restart_s: Optional[float] = None + stack_trace: Optional[list[str]] = None # Sometimes, we will finish analyzing a frame but conclude we don't want # to install any guarded code. True means we actually decided to install # a compiled frame @@ -1350,7 +1360,7 @@ class CompilationMetrics: recompile_user_contexts: Optional[set[str]] = None @classmethod - def create(cls, metrics: dict[str, Any]): + def create(cls, metrics: dict[str, Any]) -> CompilationMetrics: """ Factory method to create a CompilationMetrics from a dict of fields. Includes the logic to add legacy fields and any pre-processing, e.g., @@ -1475,15 +1485,15 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: fail_user_frame_filename=c.fail_user_frame_filename, fail_user_frame_lineno=c.fail_user_frame_lineno, # Sets aren't JSON serializable - non_compliant_ops=list(c.non_compliant_ops) - if c.non_compliant_ops is not None - else None, - compliant_custom_ops=list(c.compliant_custom_ops) - if c.compliant_custom_ops is not None - else None, - restart_reasons=list(c.restart_reasons) - if c.restart_reasons is not None - else None, + non_compliant_ops=( + list(c.non_compliant_ops) if c.non_compliant_ops is not None else None + ), + compliant_custom_ops=( + list(c.compliant_custom_ops) if c.compliant_custom_ops is not None else None + ), + restart_reasons=( + list(c.restart_reasons) if c.restart_reasons is not None else None + ), dynamo_time_before_restart_s=c.dynamo_time_before_restart_s, has_guarded_code=c.has_guarded_code, dynamo_config=c.dynamo_config, @@ -1533,7 +1543,7 @@ def _scrubbed_inductor_config_for_logging() -> Optional[str]: # TypeSafeSerializer for json.dumps() # Skips complex types as values in config dict class TypeSafeSerializer(json.JSONEncoder): - def default(self, o): + def default(self, o: Any) -> Any: try: return super().default(o) except Exception: @@ -1574,7 +1584,7 @@ def record_compilation_metrics( metrics: dict[str, Any], exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], -): +) -> None: if torch._inductor.utils.should_use_remote_fx_graph_cache(): try: from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION @@ -1696,7 +1706,7 @@ def get_outermost_event(self) -> Optional[str]: stack = self.get_stack() return stack[0] if stack else None - def get_pt2_compile_substack(self): + def get_pt2_compile_substack(self) -> list[str]: """ A smaller subset of the main stack that gets used to log PT2 Compile Events internally. @@ -1712,7 +1722,7 @@ def get_event_data(self) -> dict[str, Any]: self.tls.event_data = {} return self.tls.event_data - def __init__(self): + def __init__(self) -> None: self.tls = threading.local() from . import config @@ -1727,7 +1737,7 @@ def __init__(self): # TODO: log to init/id tlparse after I add support for it log.info("ChromiumEventLogger initialized with id %s", self.id_) - def try_add_event_data(self, event_name: str, **kwargs) -> None: + def try_add_event_data(self, event_name: str, **kwargs: Any) -> None: """ Same as add_event_data, but will silently not log if the event isn't in the stack. """ @@ -1738,7 +1748,7 @@ def try_add_event_data(self, event_name: str, **kwargs) -> None: def add_event_data( self, event_name: str, - **kwargs, + **kwargs: Any, ) -> None: """ Adds additional metadata info to an in-progress event @@ -1755,7 +1765,7 @@ def add_event_data( event_data[event_name] = {} event_data[event_name].update(kwargs) - def increment(self, event_name: str, key: str, value: int): + def increment(self, event_name: str, key: str, value: int) -> None: """ Increment an integer event data field by the given amount """ @@ -1778,7 +1788,7 @@ def add_to_set( event_name: str, key: str, value: Any, - ): + ) -> None: """ Add a value to a set within a event_name's metadata if it exists """ @@ -1874,7 +1884,7 @@ def log_event_end( event_metadata, ) - def pop_stack(stack): + def pop_stack(stack: list[str]) -> None: while event_name != stack[-1]: # If the event isn't the most recent one to end, pop # off the stack until it is. @@ -2035,14 +2045,14 @@ class CleanupHook: scope: dict[str, Any] name: str - def __call__(self, *args): + def __call__(self, *args: Any) -> None: # Make sure we're not shutting down if CleanupManager is not None: CleanupManager.count -= 1 del self.scope[self.name] @staticmethod - def create(scope, name, val): + def create(scope: dict[str, Any], name: str, val: Any) -> CleanupHook: assert name not in scope CleanupManager.count += 1 scope[name] = val @@ -2053,7 +2063,7 @@ class CleanupManager(ExactWeakKeyDictionary): count = 0 instance: ClassVar[CleanupManager] - def _remove_id(self, idx): + def _remove_id(self, idx: int) -> None: for hook in self.values[idx]: hook() super()._remove_id(idx) @@ -2062,7 +2072,7 @@ def _remove_id(self, idx): CleanupManager.instance = CleanupManager() -def clone_tensor(x): +def clone_tensor(x: torch.Tensor) -> torch.Tensor: """Clone the tensor and its gradient""" y = x.clone().requires_grad_(x.requires_grad) if x.is_leaf and x.grad is not None: @@ -2070,14 +2080,16 @@ def clone_tensor(x): return y -def clone_input(x, *, dtype=None): +def clone_input( + x: torch.Tensor, *, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: """copy while preserving strides""" # TODO: this is questionable if is_fake(x): # this func fails on fake tensors in __torch_dispatch__ return x - def torch_clone(x): + def torch_clone(x: torch.Tensor) -> torch.Tensor: y = torch.clone(x) if x.is_leaf: y.requires_grad_(x.requires_grad) @@ -2154,7 +2166,7 @@ def clone_inputs( def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ... -def clone_inputs(example_inputs): +def clone_inputs(example_inputs: Any) -> Any: res: Union[dict[str, Any], list[Any]] if type(example_inputs) is dict: res = dict(example_inputs) @@ -2173,7 +2185,7 @@ def clone_inputs(example_inputs): return res -def skip_frame_if_in_functorch_mode(val: torch.Tensor): +def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None: try: val.data_ptr() # will throw for functorch tensors except RuntimeError as e: @@ -2187,7 +2199,7 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor): @contextmanager -def preserve_rng_state(): +def preserve_rng_state() -> Generator[None, None, None]: disable_functorch = torch._C._DisableFuncTorch disable_current_modes = torch.utils._python_dispatch._disable_current_modes with disable_current_modes(), disable_functorch(): @@ -2205,8 +2217,8 @@ def preserve_rng_state(): def is_jit_model( - model0, -): + model0: Any, +) -> bool: return isinstance( model0, ( @@ -2218,7 +2230,7 @@ def is_jit_model( ) -def torchscript(model, example_inputs, verbose=False): +def torchscript(model: Any, example_inputs: Any, verbose: bool = False) -> Any: if is_jit_model(model): # already done? return model @@ -2243,12 +2255,12 @@ def getfile(obj: Any) -> Optional[str]: return None -def is_namedtuple(obj): +def is_namedtuple(obj: Any) -> bool: """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" return is_namedtuple_cls(type(obj)) -def is_namedtuple_cls(cls): +def is_namedtuple_cls(cls: Any) -> bool: """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple""" try: if issubclass(cls, tuple): @@ -2279,7 +2291,7 @@ def is_namedtuple_cls(cls): @functools.lru_cache(1) -def namedtuple_fields(cls) -> tuple[str, ...]: +def namedtuple_fields(cls: type) -> tuple[str, ...]: """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" if cls is slice: return ("start", "stop", "step") @@ -2295,16 +2307,16 @@ class Marker: # frustrating ones e.g. torch.return_types.max assert cls.__module__ == "torch.return_types" - obj = cls(map(Marker, range(cls.n_fields))) + obj = cls(map(Marker, range(cls.n_fields))) # type: ignore[attr-defined] fields: dict[str, int] = {} for name in dir(obj): if name[0] != "_" and isinstance(getattr(obj, name), Marker): fields[name] = getattr(obj, name).index - assert len(fields) == cls.n_fields + assert len(fields) == cls.n_fields # type: ignore[attr-defined] return tuple(sorted(fields, key=fields.get)) # type: ignore[arg-type] -def checkpoint_params(gm): +def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]: with torch.no_grad(): rng_state = torch.clone(torch.random.get_rng_state()) if torch.cuda.is_available(): @@ -2314,7 +2326,7 @@ def checkpoint_params(gm): for param in itertools.chain(gm.parameters(), gm.buffers()) ] - def restore(): + def restore() -> None: with torch.no_grad(): torch.random.set_rng_state(rng_state) if torch.cuda.is_available(): @@ -2326,7 +2338,7 @@ def restore(): return restore -def timed(model, example_inputs, times=1): +def timed(model: Any, example_inputs: Any, times: int = 1) -> tuple[Any, float]: if torch.cuda.is_available(): synchronize = torch.cuda.synchronize else: @@ -2343,12 +2355,12 @@ def timed(model, example_inputs, times=1): return result, t1 - t0 # type: ignore[possibly-undefined] -def check_is_cuda(gm, example_inputs): +def check_is_cuda(gm: torch.fx.GraphModule, example_inputs: Any) -> bool: return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) @lru_cache(32) -def rot_n_helper(n): +def rot_n_helper(n: int) -> Callable[..., Any]: assert n > 1 vars = [f"v{i}" for i in range(n)] rotated = reversed(vars[-1:] + vars[:-1]) @@ -2392,7 +2404,7 @@ def rot_n_helper(n): """ -def is_safe_constant(v): +def is_safe_constant(v: Any) -> bool: if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) return isinstance( @@ -2411,7 +2423,7 @@ def is_safe_constant(v): @functools.cache -def common_constants(): +def common_constants() -> set[int]: return { # We zero-one specialize shapes, so specialize these constants # too @@ -2426,7 +2438,7 @@ def is_torch_sym(value: Any) -> TypeGuard[Union[torch.SymBool, torch.SymInt]]: ) -def is_int_specialization_case(value, source): +def is_int_specialization_case(value: Any, source: Any) -> bool: from .source import is_from_defaults return not TracingContext.get().force_unspec_int_unbacked_size_like and ( @@ -2457,7 +2469,7 @@ def is_int_specialization_case(value, source): ) -def specialize_symnode(arg): +def specialize_symnode(arg: Any) -> Any: from .variables import ConstantVariable, LazyVariableTracker, SymNodeVariable # Guard and specialize @@ -2482,7 +2494,7 @@ def specialize_symnode(arg): return arg -def guard_if_dyn(arg): +def guard_if_dyn(arg: Any) -> Any: from .variables import ConstantVariable arg = specialize_symnode(arg) @@ -2493,11 +2505,11 @@ def guard_if_dyn(arg): return arg -def check_constant_args(args, kwargs): +def check_constant_args(args: Any, kwargs: Any) -> bool: return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) -def check_unspec_python_args(args, kwargs): +def check_unspec_python_args(args: Any, kwargs: Any) -> bool: from .variables.constant import ConstantVariable from .variables.tensor import UnspecializedPythonVariable @@ -2510,7 +2522,7 @@ def check_unspec_python_args(args, kwargs): return unspec_count > 0 -def check_unspec_or_constant_args(args, kwargs): +def check_unspec_or_constant_args(args: Any, kwargs: Any) -> bool: # A fused version of: # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs) from .variables.tensor import UnspecializedPythonVariable @@ -2521,7 +2533,7 @@ def check_unspec_or_constant_args(args, kwargs): return True -def check_numpy_ndarray_args(args, kwargs): +def check_numpy_ndarray_args(args: Any, kwargs: Any) -> bool: from .variables.tensor import NumpyNdarrayVariable return any( @@ -2557,13 +2569,13 @@ def check_numpy_ndarray_args(args, kwargs): str_methods = {method for method in str.__dict__.values() if callable(method)} -def builtin_dict_keys(d): +def builtin_dict_keys(d: dict[Any, Any]) -> KeysView[Any]: # Avoids overridden keys method of the dictionary assert isinstance(d, dict) return dict.keys(d) -def get_items_from_dict(obj): +def get_items_from_dict(obj: dict[Any, Any]) -> Any: # Get items without calling the user defined __getitem__ or keys method. assert isinstance(obj, dict) if istype(obj, (dict, OrderedDict)): @@ -2574,29 +2586,29 @@ def get_items_from_dict(obj): return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] -def nn_module_new(cls): +def nn_module_new(cls: Any) -> Any: obj = object_new(cls) torch.nn.Module.__init__(obj) return obj -def product(it): +def product(it: Iterable[Any]) -> Any: return functools.reduce(operator.mul, it, 1) -def tuple_iterator_getitem(it, index): +def tuple_iterator_getitem(it: Any, index: int) -> Any: _, (obj,), start = it.__reduce__() return obj[start + index] -def dataclass_fields(cls): +def dataclass_fields(cls: Any) -> Any: return torch._dynamo.disable(dataclasses.fields)(cls) iter_next = next -def normalize_range_iter(range_iter) -> tuple[int, int, int]: +def normalize_range_iter(range_iter: Any) -> tuple[int, int, int]: _, (range_obj,), maybe_idx = range_iter.__reduce__() # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been # already incremented by the current index. @@ -2606,14 +2618,14 @@ def normalize_range_iter(range_iter) -> tuple[int, int, int]: return (start, stop, step) -def to_subclass(t, cls): +def to_subclass(t: Any, cls: type) -> Any: return t.as_subclass(cls) dict_getitem = dict.__getitem__ -def dict_keys_getitem(d, n): +def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: # Call dict(d) to prevent calling overridden __iter__/keys dict_class = dict if isinstance(d, OrderedDict): @@ -2621,12 +2633,12 @@ def dict_keys_getitem(d, n): return next(itertools.islice(dict_class.keys(d), n, n + 1)) -def set_getitem(s, n): +def set_getitem(s: set[Any], n: int) -> Any: # Set ordering might not be stable return list(s)[n] -def enum_repr(value, local): +def enum_repr(value: Any, local: bool) -> str: # enum class can override __str__ method. Use __class__ and name attribute # to extract the class name and key name. name = value.__class__.__name__ @@ -2636,7 +2648,7 @@ def enum_repr(value, local): return local_name -def set_example_value(node, example_value): +def set_example_value(node: torch.fx.Node, example_value: Any) -> None: # NB: example_value is a bit of a misnomer, because this is always a fake # tensor of some sort. Furthermore, these example values serve as the # runtime state of Dynamo tracing, which means if metadata mutation @@ -2656,7 +2668,7 @@ def set_example_value(node, example_value): node.meta["unbacked_bindings"] = symbol_to_path -def _get_fake_tensor(vt): +def _get_fake_tensor(vt: VariableTracker) -> Any: fake_tensor = vt.as_proxy().node.meta.get("example_value") if not is_fake(fake_tensor): from . import graph_break_hints @@ -2676,7 +2688,7 @@ def slice_length(s: slice, seq_len: int) -> int: return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step) -def raise_args_mismatch(tx, name): +def raise_args_mismatch(tx: InstructionTranslatorBase, name: str) -> None: from torch._dynamo.exc import raise_observed_exception from torch._dynamo.variables import ConstantVariable @@ -2687,13 +2699,13 @@ def raise_args_mismatch(tx, name): ) -def iter_contains(items, search, tx, check_tensor_identity=False): - from .variables import ( - BuiltinVariable, - ConstantVariable, - TensorVariable, - VariableTracker, - ) +def iter_contains( + items: Any, + search: Any, + tx: InstructionTranslator, + check_tensor_identity: bool = False, +) -> Any: + from .variables import BuiltinVariable, ConstantVariable, TensorVariable if search.is_python_constant(): found_const = any( @@ -2735,11 +2747,11 @@ def key_is_id( return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) -def key_to_id(value): +def key_to_id(value: Any) -> list[Any]: return [id(k) if key_is_id(k) else k for k in value.keys()] -def const_repr(x, *, local) -> str: +def const_repr(x: Any, *, local: Any) -> str: from .trace_rules import is_builtin_callable if isinstance(x, (list, tuple)): @@ -2760,7 +2772,7 @@ def const_repr(x, *, local) -> str: return x.__name__ elif isinstance(x, type): - def fullname(o): + def fullname(o: Any) -> str: klass = o.__class__ module = klass.__module__ if module == "builtins": @@ -2772,7 +2784,7 @@ def fullname(o): return f"{x!r}" -def dict_keys_repr(const_keys, *, local) -> str: +def dict_keys_repr(const_keys: Any, *, local: Any) -> str: keys_str = ",".join(const_repr(s, local=local) for s in const_keys) return "[" + keys_str + "]" @@ -2783,7 +2795,7 @@ def dict_keys_repr(const_keys, *, local) -> str: from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 -def get_safe_global_name(tx, root, obj): +def get_safe_global_name(tx: InstructionTranslatorBase, root: str, obj: Any) -> str: # The global_mangled_class_name should be different for different # invocations of torch.compile. Otherwise, we can run into a situation # where multiple torch.compile invocations reuse the same global name, @@ -2793,14 +2805,16 @@ def get_safe_global_name(tx, root, obj): return f"{root}_{id(obj)}_c{tx.output.compile_id}" -def is_in(item: Any, *containers) -> bool: +def is_in(item: str, *containers: Any) -> bool: for container in containers: if item in container: return True return False -def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str: +def get_unique_name_wrt( + prefix: str, *containers: Any, requires_suffix: bool = False +) -> str: """ Return a name that starts with `prefix` and is not in any of the `containers` (e.g., map, set). @@ -2816,7 +2830,7 @@ def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str: raise AssertionError("unreachable") -def wrap_fake_exception(fn): +def wrap_fake_exception(fn: Callable[[], Any]) -> Any: try: return fn() except UnsupportedFakeTensorException as e: @@ -2833,12 +2847,14 @@ def wrap_fake_exception(fn): ) -def deepcopy_to_fake_tensor(obj, fake_mode): +def deepcopy_to_fake_tensor( + obj: Any, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode +) -> Any: with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): return wrap_fake_exception(lambda: copy.deepcopy(obj)) -def rmse(ref, res): +def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor: """ Calculate root mean squared error """ @@ -2846,19 +2862,19 @@ def rmse(ref, res): def same( - ref, - res, - fp64_ref=None, - cos_similarity=False, - tol=1e-4, - equal_nan=False, - exact_dtype=True, - relax_numpy_equality=False, - ignore_non_fp=False, - log_error=log.error, - use_larger_multiplier_for_smaller_tensor=False, + ref: Any, + res: Any, + fp64_ref: Any = None, + cos_similarity: bool = False, + tol: float = 1e-4, + equal_nan: bool = False, + exact_dtype: bool = True, + relax_numpy_equality: bool = False, + ignore_non_fp: bool = False, + log_error: Callable[..., None] = log.error, + use_larger_multiplier_for_smaller_tensor: bool = False, force_max_multiplier: bool = False, -): +) -> bool: """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref @@ -2939,7 +2955,7 @@ def same( assert not isinstance(ref, torch._subclasses.FakeTensor) assert not isinstance(res, torch._subclasses.FakeTensor) - def to_tensor(t): + def to_tensor(t: Any) -> Any: return t if isinstance(t, torch.Tensor) else torch.tensor(t) ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) @@ -2978,7 +2994,7 @@ def to_tensor(t): score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) if score < 0.99: log.warning("Similarity score=%s", score.detach().cpu().item()) - return score >= 0.99 + return bool(score >= 0.99) else: if not exact_dtype: ref = ref.to(res.dtype) @@ -3018,7 +3034,7 @@ def to_tensor(t): res_error = rmse(fp64_ref, res).item() - def get_multiplier(): + def get_multiplier() -> float: # In some particular cases, we expect high difference in results. # At the moment one of this cases is inductor freezing bfloat16 convolution const folding. # In case of it the res_error is at least one order of magnitude higher. @@ -3149,13 +3165,13 @@ def get_multiplier(): raise RuntimeError(f"unsupported type: {type(ref).__name__}") -def format_func_info(code): +def format_func_info(code: CodeType) -> str: short_filename = code.co_filename.split("/")[-1] return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" @contextlib.contextmanager -def disable_cache_limit(): +def disable_cache_limit() -> Generator[None, None, None]: prior = config.recompile_limit config.recompile_limit = sys.maxsize prior_acc_limit = config.accumulated_recompile_limit @@ -3184,7 +3200,7 @@ def disable_cache_limit(): # return same dir unless user changes config between calls @functools.cache -def _get_debug_dir(root_dir): +def _get_debug_dir(root_dir: str) -> str: dir_name = ( "run_" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") @@ -3195,12 +3211,12 @@ def _get_debug_dir(root_dir): return os.path.join(root_dir, dir_name) -def get_debug_dir(): +def get_debug_dir() -> str: debug_root = config.debug_dir_root return _get_debug_dir(debug_root) -def extract_fake_example_value(node, required=True): +def extract_fake_example_value(node: torch.fx.Node, required: bool = True) -> Any: if "example_value" in node.meta and is_fake(node.meta["example_value"]): return node.meta["example_value"] elif required: @@ -3218,13 +3234,15 @@ def extract_fake_example_value(node, required=True): return None -def ensure_graph_fake(e, tx): +def ensure_graph_fake(e: Any, tx: InstructionTranslatorBase) -> Any: assert maybe_get_fake_mode(e) is tx.fake_mode return e -def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): - def visit(n: torch.fx.Node): +def get_fake_values_from_nodes( + tx: InstructionTranslatorBase, nodes: Any, allow_non_graph_fake: bool +) -> Any: + def visit(n: torch.fx.Node) -> Any: if n.op == "call_function" and "example_value" not in n.meta: # fake tensor validity is checked inside get_fake_value using # ensure_graph_fake @@ -3232,7 +3250,7 @@ def visit(n: torch.fx.Node): elif n.op == "get_attr" and "example_value" not in n.meta: assert n.target in tx.output.nn_modules - gm = tx.output.nn_modules[n.target] + gm = tx.output.nn_modules[n.target] # type: ignore[index] assert isinstance(gm, torch.fx.GraphModule) return gm @@ -3244,7 +3262,11 @@ def visit(n: torch.fx.Node): return torch.fx.node.map_arg(nodes, visit) -def get_fake_value(node, tx, allow_non_graph_fake=False): +def get_fake_value( + node: torch.fx.Node, + tx: InstructionTranslatorBase, + allow_non_graph_fake: bool = False, +) -> Any: """ Run the computation represented by `node` using fake tensors and return the result. @@ -3293,7 +3315,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) if op == "call_module": - nnmodule = tx.output.nn_modules[node.target] + nnmodule = tx.output.nn_modules[node.target] # type: ignore[index] if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): # In the case of a lazy module, we want to run @@ -3310,9 +3332,11 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): ): # We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo. args = tuple( - float(arg) - if isinstance(arg, torch.SymFloat) and arg.node.hint is not None - else arg + ( + float(arg) + if isinstance(arg, torch.SymFloat) and arg.node.hint is not None + else arg + ) for arg in args ) @@ -3379,7 +3403,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): elif isinstance( cause, torch._subclasses.fake_tensor.UnsupportedOperatorException ): - op = cause.func + op = cause.func # type: ignore[assignment] import_suggestion = "" if isinstance(op, torch._ops.OpOverload): maybe_pystub = torch._C._dispatch_pystub( @@ -3443,12 +3467,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): _current_node = threading.local() -def get_current_node(): +def get_current_node() -> Optional[torch.fx.Node]: return getattr(_current_node, "value", None) @contextmanager -def set_current_node(node): +def set_current_node(node: torch.fx.Node) -> Generator[None, None, None]: old = get_current_node() _current_node.value = node try: @@ -3457,7 +3481,9 @@ def set_current_node(node): _current_node.value = old -def run_node(tracer, node, args, kwargs, nnmodule): +def run_node( + tracer: Any, node: torch.fx.Node, args: Any, kwargs: Any, nnmodule: Any +) -> Any: """ Runs a given node, with the given args and kwargs. @@ -3476,7 +3502,7 @@ def run_node(tracer, node, args, kwargs, nnmodule): with set_current_node(node): - def make_error_message(e): + def make_error_message(e: Any) -> str: return ( f"Dynamo failed to run FX node with fake tensors: {op} {node.target}(*{args}, **{kwargs}): got " + repr(e) @@ -3486,9 +3512,9 @@ def make_error_message(e): try: if op == "call_function": - return node.target(*args, **kwargs) + return node.target(*args, **kwargs) # type: ignore[operator] elif op == "call_method": - if not hasattr(args[0], node.target): + if not hasattr(args[0], node.target): # type: ignore[arg-type] from .exc import unimplemented_v2 unimplemented_v2( @@ -3497,7 +3523,7 @@ def make_error_message(e): explanation=make_error_message("attribute not defined"), hints=[], ) - return getattr(args[0], node.target)(*args[1:], **kwargs) + return getattr(args[0], node.target)(*args[1:], **kwargs) # type: ignore[arg-type] elif op == "call_module": assert nnmodule is not None return nnmodule(*args, **kwargs) @@ -3534,7 +3560,7 @@ def make_error_message(e): raise AssertionError(op) -def get_real_value(node, tracer): +def get_real_value(node: torch.fx.Node, tracer: Any) -> Any: """ Run the actual computation represented by `node` and return the result. This will execute any dependent nodes in the graph as well. @@ -3573,10 +3599,10 @@ def get_real_value(node, tracer): return real_value -def assert_no_fake_params_or_buffers(gm): +def assert_no_fake_params_or_buffers(gm: torch.fx.GraphModule) -> None: from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake - def stack_or_hint(t): + def stack_or_hint(t: Any) -> str: if FakeTensorConfig.debug: import traceback @@ -3594,21 +3620,21 @@ def stack_or_hint(t): ) -def fqn(obj: Any): +def fqn(obj: Any) -> str: """ Returns the fully qualified name of the object. """ return f"{obj.__module__}.{obj.__qualname__}" -def ifdynstaticdefault(count1, count2): +def ifdynstaticdefault(count1: Any, count2: Any) -> Any: if torch._dynamo.config.assume_static_by_default: return count1 else: return count2 -def import_submodule(mod: types.ModuleType): +def import_submodule(mod: types.ModuleType) -> None: """ Ensure all the files in a given submodule are imported """ @@ -3617,17 +3643,17 @@ def import_submodule(mod: types.ModuleType): importlib.import_module(f"{mod.__name__}.{filename[:-3]}") -def object_has_getattribute(value: Any): +def object_has_getattribute(value: Any) -> bool: return class_has_getattribute(type(value)) -def object_setattr_ignore_descriptor(obj, name, value): +def object_setattr_ignore_descriptor(obj: Any, name: str, value: Any) -> None: # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1286-L1335 d = object.__getattribute__(obj, "__dict__") d[name] = value -def class_has_getattribute(cls: type): +def class_has_getattribute(cls: type) -> bool: try: if isinstance( inspect.getattr_static(cls, "__getattribute__"), @@ -3639,7 +3665,9 @@ def class_has_getattribute(cls: type): return False -def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): +def get_custom_getattr( + value: Any, ignore_nn_module_getattr: bool = False +) -> Optional[Any]: try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: @@ -3656,7 +3684,7 @@ class TensorStaticReason(enum.Enum): NN_MODULE_PROPERTY = 5 -def tensor_static_reason_to_message(reason: TensorStaticReason): +def tensor_static_reason_to_message(reason: TensorStaticReason) -> str: if reason == TensorStaticReason.PARAMETER: return "mark_dynamic on parameter, parameters are always static today." if reason == TensorStaticReason.NOT_TENSOR: @@ -3700,8 +3728,8 @@ def tensor_always_has_static_shape( return False, None -def lazy_format_graph_tabular(fn_name, gm): - def inner(): +def lazy_format_graph_tabular(fn_name: str, gm: torch.fx.GraphModule) -> Any: + def inner() -> str: try: from tabulate import tabulate # TODO: Check that this is installed except ImportError: @@ -3721,7 +3749,9 @@ def inner(): return LazyString(inner) -def format_bytecode(prefix, name, filename, line_no, code): +def format_bytecode( + prefix: str, name: str, filename: str, line_no: int, code: Any +) -> str: return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" @@ -3736,20 +3766,21 @@ def format_bytecode(prefix, name, filename, line_no, code): all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names -def nn_module_has_global_hooks(): +def nn_module_has_global_hooks() -> bool: # This is limited to backward hooks for now because NNModuleVariable # supports fwd hooks underneath. - return len(torch.nn.modules.module._global_backward_hooks) or len( - torch.nn.modules.module._global_backward_pre_hooks + return bool( + len(torch.nn.modules.module._global_backward_hooks) + or len(torch.nn.modules.module._global_backward_pre_hooks) ) def nn_module_get_all_hooks( - mod, - check_forward_hooks=False, - check_backward_hooks=False, - check_state_dict_hooks=False, -): + mod: torch.nn.Module, + check_forward_hooks: bool = False, + check_backward_hooks: bool = False, + check_state_dict_hooks: bool = False, +) -> list[Any]: """ Sometimes its useful to differentiate between types of hooks such as forward/backward/pre hooks executed during module.__call__, and state_dict hooks which are executed separately. @@ -3778,11 +3809,11 @@ def nn_module_get_all_hooks( def nnmodule_has_hooks( - mod, - check_forward_hooks=False, - check_backward_hooks=False, - check_state_dict_hooks=False, -): + mod: torch.nn.Module, + check_forward_hooks: bool = False, + check_backward_hooks: bool = False, + check_state_dict_hooks: bool = False, +) -> bool: """ Helper function to check if a module has any hooks attached to it. """ @@ -3795,7 +3826,7 @@ def nnmodule_has_hooks( return bool(hooks) -def to_numpy_helper(value): +def to_numpy_helper(value: Any) -> Any: """Convert tensor and tnp.ndarray to numpy.ndarray.""" if is_fake(value): return value @@ -3809,7 +3840,7 @@ def to_numpy_helper(value): return value -def numpy_to_tensor(value): +def numpy_to_tensor(value: Any) -> Any: """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" assert np is not None if isinstance(value, np.ndarray): @@ -3823,19 +3854,19 @@ def numpy_to_tensor(value): class numpy_to_tensor_wrapper: - def __init__(self, f): + def __init__(self, f: Any) -> None: self.f = f self.__name__ = "wrapped_" + self.f.__name__ def __repr__(self) -> str: return f">" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: out = self.f(*args, **kwargs) return numpy_to_tensor(out) -def numpy_attr_wrapper(obj, name): +def numpy_attr_wrapper(obj: Any, name: str) -> Any: if isinstance(obj, tnp.ndarray): out = getattr(obj, name) return numpy_to_tensor(out) @@ -3847,14 +3878,14 @@ def numpy_attr_wrapper(obj, name): class numpy_method_wrapper: """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" - def __init__(self, method: str): + def __init__(self, method: str) -> None: self.method = method self.__name__ = "wrapped_" + self.method def __repr__(self) -> str: return f">" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: obj = args[0] if isinstance(obj, torch.Tensor): obj = tnp.ndarray(obj) @@ -3866,14 +3897,14 @@ def __call__(self, *args, **kwargs): class numpy_operator_wrapper: """Implements dunder methods for tnp.ndarray via functions from the operator library""" - def __init__(self, op: Callable[..., Any]): + def __init__(self, op: Callable[..., Any]) -> None: self.op = op self.__name__ = f"wrapped_{op.__name__}" def __repr__(self) -> str: return f">" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: assert not kwargs args = ( @@ -3883,7 +3914,7 @@ def __call__(self, *args, **kwargs): return numpy_to_tensor(out) -def defake(x): +def defake(x: Any) -> Any: if not isinstance(x, FakeTensor): return x size: torch._prims_common.ShapeType @@ -3915,24 +3946,26 @@ def defake(x): return y -def _disable_side_effect_safety_checks_for_current_subtracer(fn, *args, **kwargs): +def _disable_side_effect_safety_checks_for_current_subtracer( + fn: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs +) -> Any: return fn(*args, **kwargs) -def is_utils_checkpoint(obj): +def is_utils_checkpoint(obj: Any) -> bool: # Lazy import to avoid circular dependencies import torch.utils.checkpoint return obj is torch.utils.checkpoint.checkpoint -def is_invoke_subgraph(obj): +def is_invoke_subgraph(obj: Any) -> bool: from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_placeholder return obj is invoke_subgraph_placeholder -def build_invoke_subgraph_variable(**options): +def build_invoke_subgraph_variable(**options: Any) -> Any: from .variables.higher_order_ops import TorchHigherOrderOperatorVariable return TorchHigherOrderOperatorVariable.make( @@ -3941,7 +3974,7 @@ def build_invoke_subgraph_variable(**options): ) -def build_checkpoint_variable(**options): +def build_checkpoint_variable(**options: Any) -> Any: import torch._higher_order_ops.wrap as higher_order_ops from .variables.higher_order_ops import TorchHigherOrderOperatorVariable @@ -3960,7 +3993,7 @@ def build_checkpoint_variable(**options): ) -def is_compile_supported(device_type): +def is_compile_supported(device_type: DeviceLikeType) -> Any: from .eval_frame import is_dynamo_supported type = torch.device(device_type).type @@ -4026,12 +4059,12 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: lines = segment.split("\n") # get character index given byte offset - def normalize(lineno, offset): + def normalize(lineno: int, offset: int) -> int: return _fix_offset(lines[lineno], offset) # Gets the next valid character index in `lines`, if # the current location is not valid. Handles empty lines. - def next_valid_char(lineno, col): + def next_valid_char(lineno: int, col: int) -> tuple[int, int]: while lineno < len(lines) and col >= len(lines[lineno]): col = 0 lineno += 1 @@ -4039,14 +4072,14 @@ def next_valid_char(lineno, col): return lineno, col # Get the next valid character index in `lines`. - def increment(lineno, col): + def increment(lineno: int, col: int) -> tuple[int, int]: col += 1 lineno, col = next_valid_char(lineno, col) assert lineno < len(lines) and col < len(lines[lineno]) return lineno, col # Get the next valid character at least on the next line - def nextline(lineno, col): + def nextline(lineno: int, col: int) -> tuple[int, int]: col = 0 lineno += 1 lineno, col = next_valid_char(lineno, col) @@ -4063,6 +4096,7 @@ def nextline(lineno, col): # -2 since end_lineno is 1-indexed and because we added an extra # bracket to `segment` when calling ast.parse cur_lineno = cast(int, expr.left.end_lineno) - 2 + assert expr.left.end_col_offset is not None cur_col = normalize(cur_lineno, expr.left.end_col_offset) cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) @@ -4095,12 +4129,14 @@ def nextline(lineno, col): # subscript^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '[' after value) left_lineno = cast(int, expr.value.end_lineno) - 2 + assert expr.value.end_col_offset is not None left_col = normalize(left_lineno, expr.value.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "[": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) right_lineno = cast(int, expr.end_lineno) - 2 + assert expr.end_col_offset is not None right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) elif isinstance(expr, ast.Call): @@ -4109,12 +4145,14 @@ def nextline(lineno, col): # call^^^^^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '(' after func) left_lineno = cast(int, expr.func.end_lineno) - 2 + assert expr.func.end_col_offset is not None left_col = normalize(left_lineno, expr.func.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "(": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) right_lineno = cast(int, expr.end_lineno) - 2 + assert expr.end_col_offset is not None right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) @@ -4253,14 +4291,14 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s return result -def get_static_address_type(t): +def get_static_address_type(t: Any) -> Any: if isinstance(t, torch.Tensor): return getattr(t, "_dynamo_static_input_type", None) return None -def is_rng_state_getter_or_setter(value): +def is_rng_state_getter_or_setter(value: Any) -> bool: getters = ( # The following two functions are not identical, so don't remove anyone! torch._C.Generator.get_state, @@ -4277,7 +4315,7 @@ def is_rng_state_getter_or_setter(value): return value in (*setters, *getters) -def is_tensor_base_attr_getter(value): +def is_tensor_base_attr_getter(value: Any) -> bool: return ( isinstance(value, types.MethodWrapperType) and value.__name__ == "__get__" @@ -4285,7 +4323,7 @@ def is_tensor_base_attr_getter(value): ) -def is_tensor_getset_descriptor(name): +def is_tensor_getset_descriptor(name: str) -> bool: try: attr = inspect.getattr_static(torch.Tensor, name) return type(attr) is types.GetSetDescriptorType @@ -4293,11 +4331,11 @@ def is_tensor_getset_descriptor(name): return False -def is_torch_function_object(value): +def is_torch_function_object(value: Any) -> bool: return hasattr(value, "__torch_function__") -def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool: +def has_torch_function(vt: VariableTracker) -> bool: # This emulates # https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/disable_torch_function.cpp#L315-L323 from torch._dynamo.variables import UserDefinedObjectVariable @@ -4327,7 +4365,9 @@ def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool # see note [Tensor Fakification and Symbol Caching] -def to_fake_tensor(t, fake_mode): +def to_fake_tensor( + t: torch.Tensor, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode +) -> Any: symbolic_context = None source = None if tracing_context := torch._guards.TracingContext.try_get(): @@ -4341,7 +4381,7 @@ def to_fake_tensor(t, fake_mode): # NB: this works for both classes and instances -def is_frozen_dataclass(value): +def is_frozen_dataclass(value: Any) -> bool: return ( not object_has_getattribute(value) and not class_has_getattribute(value) @@ -4352,7 +4392,7 @@ def is_frozen_dataclass(value): ) -def get_first_attr(obj, *attrs): +def get_first_attr(obj: Any, *attrs: str) -> Any: """ Return the first available attribute or throw an exception if none is present. """ @@ -4364,13 +4404,15 @@ def get_first_attr(obj, *attrs): @contextlib.contextmanager -def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True): +def maybe_enable_compiled_autograd( + should_enable: bool, fullgraph: bool = True, dynamic: bool = True +) -> Generator[Any, None, None]: if not should_enable: yield else: - def compiler_fn(gm): - def inner_compiler(gm_, example_inputs_): + def compiler_fn(gm: Any) -> Any: + def inner_compiler(gm_: Any, example_inputs_: Any) -> Any: torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 return torch._inductor.compile(gm_, example_inputs_) @@ -4382,7 +4424,7 @@ def inner_compiler(gm_, example_inputs_): yield ctx -def invalid_removeable_handle(): +def invalid_removeable_handle() -> RemovableHandle: # need a subclass so weakref works class Invalid(dict): # type: ignore[type-arg] pass @@ -4394,7 +4436,7 @@ class Invalid(dict): # type: ignore[type-arg] # Attribute changes to the original object/proxy will be reflected in the other. # This is useful for cases where we want a keep-alive reference to a module without increasing # its reference count. -def nn_module_proxy(mod): +def nn_module_proxy(mod: Any) -> Any: if not isinstance(mod, torch.nn.Module): return mod if isinstance(mod, torch.fx.GraphModule): @@ -4406,17 +4448,21 @@ def nn_module_proxy(mod): class GmWrapper(torch.nn.Module): - def __init__(self, gm, unflatten_fn): + def __init__( + self, gm: torch.fx.GraphModule, unflatten_fn: Callable[[list[Any]], Any] + ) -> None: super().__init__() self.gm = gm self.unflatten_fn = unflatten_fn - def forward(self, *args): + def forward(self, *args: Any) -> Any: args: list[Any] = list(args) return self.gm(*self.unflatten_fn(args)) -def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): +def flatten_graph_inputs( + gm: torch.fx.GraphModule, inputs: Any, compile_gm: Callable[[Any, Any], Any] +) -> Callable[..., Any]: """ Mutate inputs so that they are flat and wrap gm such that it accepts those inputs. This is needed for graphs that take @@ -4435,10 +4481,10 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): assert isinstance(inputs[0], list) boxed_inputs_count = len(inputs[0]) - def flatten_fn(args): + def flatten_fn(args: Any) -> Any: return args[0] + list(args[1:]) - def unflatten_fn(flat_args): + def unflatten_fn(flat_args: Any) -> Any: return (flat_args[:boxed_inputs_count], *flat_args[boxed_inputs_count:]) compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flatten_fn(inputs)) @@ -4450,7 +4496,7 @@ def unflatten_fn(flat_args): # note this doesn't check the spec, assuming it is the same flatten_fn = pytree.arg_tree_leaves - def wrapper(*args): + def wrapper(*args: Any) -> Any: flat_args = flatten_fn(args) # flat_args is a new list, so we need to clear references from the old list @@ -4463,18 +4509,18 @@ def wrapper(*args): return wrapper -def get_locals_to_steal(maybe_gm): +def get_locals_to_steal(maybe_gm: Any) -> list[Any]: if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"): return [] return maybe_gm.meta.get("locals_to_steal", []) -def set_locals_to_steal(gm, locals_to_steal): +def set_locals_to_steal(gm: torch.fx.GraphModule, locals_to_steal: list[Any]) -> None: gm.meta["locals_to_steal"] = locals_to_steal class Lit: - def __init__(self, s): + def __init__(self, s: str) -> None: self.s = s def __repr__(self) -> str: @@ -4484,7 +4530,7 @@ def __repr__(self) -> str: warn_once_cache: set[str] = set() -def warn_once(msg, stacklevel=1): +def warn_once(msg: str, stacklevel: int = 1) -> None: # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. # https://github.com/pytorch/pytorch/issues/128427. # warn_once is a workaround: if the msg has been warned on before, then we will not @@ -4496,14 +4542,14 @@ def warn_once(msg, stacklevel=1): warnings.warn(msg, stacklevel=stacklevel + 1) -def strip_color_from_string(text): +def strip_color_from_string(text: str) -> str: # This regular expression matches ANSI escape codes ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") return ansi_escape.sub("", text) @contextlib.contextmanager -def _disable_saved_tensors_hooks_during_tracing(): +def _disable_saved_tensors_hooks_during_tracing() -> Generator[None, None, None]: # See NOTE: [Deferring tensor pack/unpack hooks until runtime] try: prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True) @@ -4512,22 +4558,22 @@ def _disable_saved_tensors_hooks_during_tracing(): torch._C._autograd._saved_tensors_hooks_set_tracing(prior) -def is_parameter_freezing(): +def is_parameter_freezing() -> bool: return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(): +def get_torch_function_mode_stack() -> list[Any]: return [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] -def get_torch_function_mode_stack_at(ind): +def get_torch_function_mode_stack_at(ind: int) -> Any: assert ind < _len_torch_function_stack() and ind >= 0 return torch._C._get_function_stack_at(ind) -def set_torch_function_mode_stack(stack): +def set_torch_function_mode_stack(stack: list[Any]) -> None: for _ in range(_len_torch_function_stack()): _pop_torch_function_stack() @@ -4535,17 +4581,17 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) -def clear_torch_function_mode_stack(): +def clear_torch_function_mode_stack() -> None: for _ in range(_len_torch_function_stack()): _pop_torch_function_stack() # call from C dynamo in order to inspect values in pdb -def _breakpoint_for_c_dynamo(*args): +def _breakpoint_for_c_dynamo(*args: Any) -> None: breakpoint() -def verify_guard_fn_signature(value): +def verify_guard_fn_signature(value: Any) -> None: fn = value.__metadata_guard__ sig = inspect.signature(fn) if len(sig.parameters) != 2: @@ -4562,7 +4608,7 @@ def verify_guard_fn_signature(value): ) -def does_not_override_dict_iter_methods(user_cls): +def does_not_override_dict_iter_methods(user_cls: Any) -> bool: return ( user_cls.items in (dict.items, OrderedDict.items) and user_cls.values in (dict.values, OrderedDict.values) @@ -4575,23 +4621,23 @@ def does_not_override_dict_iter_methods(user_cls): # __torch_function__ calls triggered on tensor properties in the pre graph # bytecode. @torch._disable_dynamo -def call_size(x, i): +def call_size(x: Any, i: int) -> int: return x.size(i) @torch._disable_dynamo -def call_stride(x, i): +def call_stride(x: Any, i: int) -> int: return x.stride(i) @torch._disable_dynamo -def call_storage_offset(x): +def call_storage_offset(x: Any) -> int: return x.storage_offset() # Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. # To avoid ref cycles, it's important that no tensors are present here, so leave those out. -def _extract_tensor_dict(t): +def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]: KEYS_TO_COPY = [ "_dynamo_static_input_type", "tag", @@ -4610,13 +4656,13 @@ def _extract_tensor_dict(t): user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {} -def get_user_object_from_id(obj_id): +def get_user_object_from_id(obj_id: int) -> Any: obj = user_obj_id_to_weakref[obj_id]() assert obj is not None, "User object is no longer alive" return obj -def store_user_object_weakref(obj): +def store_user_object_weakref(obj: object) -> None: obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) @@ -4649,7 +4695,7 @@ def value(cls) -> int: @classmethod @contextmanager - def record(cls): + def record(cls) -> Generator[None, None, None]: try: if config.record_compile_time_instruction_count: cls.start() @@ -4659,7 +4705,7 @@ def record(cls): cls.end() -def set_feature_use(feature: str, usage: bool): +def set_feature_use(feature: str, usage: bool) -> None: """ Records whether we are using a feature Generally a feature is a JK. @@ -4677,7 +4723,7 @@ def set_feature_use(feature: str, usage: bool): ) -def get_optimize_ddp_mode(): +def get_optimize_ddp_mode() -> str: optimize_ddp = config.optimize_ddp if isinstance(optimize_ddp, bool): mode = "ddp_optimizer" if optimize_ddp else "no_optimization" @@ -4771,22 +4817,3 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() - - -class CreateNestedFnCache: - cache: dict[str, types.FunctionType] = {} - - @classmethod - def get(cls, key): - return cls.cache.get(key, None) - - @classmethod - def set(cls, key, value): - cls.cache[key] = value - - @classmethod - def clear(cls): - cls.cache.clear() - - -create_nested_fn_cache: CreateNestedFnCache = CreateNestedFnCache() diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index d7e29ae66669..08d62ce607f0 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -80,7 +80,6 @@ ) from .iter import ( CountIteratorVariable, - CycleIteratorVariable, FilterVariable, IteratorVariable, ItertoolsVariable, @@ -169,7 +168,6 @@ "CreateTMADescriptorExperimentalVariable", "CreateTMADescriptorStableVariable", "CUDADeviceVariable", - "CycleIteratorVariable", "DataPtrVariable", "DefaultDictVariable", "DeletedVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f9d8e273068f..d4aac8041452 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -104,6 +104,7 @@ GetItemSource, GradSource, is_constant_source, + is_from_closure_source, is_from_global_source, is_from_nonlocal_source, is_from_optimizer_source, @@ -1332,9 +1333,16 @@ def build_key_value(i, k, v): and not is_traceable_wrapper_subclass_type(value) ): return TensorSubclassVariable(value, source=self.source) - # This is a userdefined class, so install an ID_MATCH even if its a - # global variable. - self.install_guards(GuardBuilder.ID_MATCH) + + if not is_from_closure_source(self.source): + # For closure source, the variable comes from LOAD_SUPER_ATTR, + # which calls self.__class__. This is internal Cpython + # implementation, and it is rare for the user to modify + # self.__class__ manually. + # For other cases, this is a userdefined class, so install an + # ID_MATCH even if its a global variable. + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedClassVariable( value, source=self.source, @@ -3239,7 +3247,6 @@ def _automatic_dynamic( ) if static_shapes and not is_dynamic_source(name): - record_automatic_dynamic(tx, name, e) return StatefulSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * e.dim(), dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index edb1169cb193..dc3929c9cce4 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -120,6 +120,7 @@ def is_hashable(x): variables.TypingVariable, variables.FunctoolsPartialVariable, variables.WeakRefVariable, + variables.TorchHigherOrderOperatorVariable, ), ) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 1cf015d10769..4bdcecf3b3c2 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -56,13 +56,19 @@ Unsupported, ) from ..guards import GuardBuilder, install_guard -from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource +from ..source import ( + AttrSource, + ClosureSource, + ConstantSource, + DefaultsSource, + GetItemSource, + SkipGuardSource, +) from ..utils import ( check_constant_args, check_unspec_or_constant_args, cmp_name_to_op_mapping, counters, - create_nested_fn_cache, identity, is_function, is_wrapper_or_member_descriptor, @@ -270,11 +276,6 @@ def _create_nested_fn( ): from types import FunctionType - # Add caching for the actual IDs of user functions so that we can use them in the ID_MATCH guard. - cache_key = str(id(code)) + str(id(closure)) + str(id(f_globals)) - if create_nested_fn_cache.get(cache_key): - return create_nested_fn_cache.get(cache_key) - func = FunctionType(code, f_globals, name, defaults, closure) func.__kwdefaults__ = kwdefaults @@ -286,7 +287,7 @@ def _create_nested_fn( # TypeError: __annotations__ must be set to a dict object assert annotations is None or isinstance(annotations, dict) func.__annotations__ = annotations - create_nested_fn_cache.set(cache_key, func) + return func @@ -303,6 +304,13 @@ def _create_nested_fn( def fn_var_getattr(tx, fn, source, name): source = source and AttrSource(source, name) + + if source and name == "__annotations__": + # We get a large number of silly guards from annotations from inspect + # module. Changing annotations is rare, and it impacting the extracted + # graph is even rarer. So skip guards. + source = SkipGuardSource(source) + try: subobj = inspect.getattr_static(fn, name) except AttributeError: @@ -416,6 +424,13 @@ def has_self(self): def get_globals(self): return self.fn.__globals__ + def get_source(self): + source = self.source + + if source and isinstance(self, variables.UserMethodVariable): + source = self.source_fn + return source + def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to @@ -428,7 +443,9 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: if not isinstance(fn, FunctionType): raise TypeError("Only supports regular Python functions.") root_tx = parent.output.root_tx - result = bind_args_cached(fn, root_tx, self.source, args, kwargs) + + source = self.get_source() + result = bind_args_cached(fn, root_tx, source, args, kwargs) init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () @@ -441,10 +458,8 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: if cell in side_effects: cell_var = side_effects[cell] - elif self.source: - closure_cell = GetItemSource( - AttrSource(self.source, "__closure__"), idx - ) + elif source: + closure_cell = GetItemSource(ClosureSource(source), idx) closure_cell_contents = AttrSource(closure_cell, "cell_contents") try: contents_var = VariableTracker.build( @@ -474,7 +489,8 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) - return fn_var_getattr(tx, self.fn, self.source, name) + source = self.get_source() + return fn_var_getattr(tx, self.fn, source, name) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -1046,9 +1062,24 @@ def _build_inline_tracer(self, tx, args, kwargs): class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" - def __init__(self, fn, obj, **kwargs) -> None: + def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj + self.source_fn = source_fn + # Note on source and source_fn + # Be careful with `source` when delegating to UserFunctionVariable + # (base-class) methods. In this __init__, `source` is a *bound method* + # object, but the base class expects the underlying *function* object. + # One way is to simplly use `__func__` to unwrap it. + # + # For recursive dict-tag optimizations, it can be faster to fetch the + # function directly from `cls.__dict__`; that’s why we pass on + # `source_fn`. Whenever it is possible to access the function from + # cls.__dict__, we pass that on to `source_fn`. Because bind_args + # operates on the unbound function, most guards should target + # `source_fn` rather than the original `source`. + if source_fn is None and kwargs.get("source") is not None: + self.source_fn = AttrSource(kwargs.get("source"), "__func__") def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" @@ -1124,11 +1155,13 @@ def inspect_parameter_names(self): return super().inspect_parameter_names()[1:] def var_getattr(self, tx: "InstructionTranslator", name: str): - source = self.source and AttrSource(self.source, name) if name == "__self__": return self.obj if name == "__func__": - return VariableTracker.build(tx, self.fn, source) + # We might have a better way to access the function object, this + # information is stored in self.source_fn, use that to construct the + # variable tracker. + return VariableTracker.build(tx, self.fn, self.source_fn) return super().var_getattr(tx, name) @@ -1433,13 +1466,7 @@ def as_python_constant(self): @classmethod def create_with_source(cls, value, source): - if inspect.getattr_static(value, "_torchdynamo_orig_callable", False): - install_guard( - AttrSource(source, "_torchdynamo_orig_callable").make_guard( - GuardBuilder.FUNCTION_MATCH - ) - ) - elif not is_wrapper_or_member_descriptor(value): + if not is_wrapper_or_member_descriptor(value): # These descriptors are not guaranteed to return the same object on # attribute lookup. They are unlikely to be changed, so we can skip # guarding them. diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 1ee9bbb323df..d3334424c5f4 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -26,6 +26,7 @@ import logging import types import warnings +from collections.abc import Sequence from typing import Optional, TYPE_CHECKING import torch._C @@ -76,8 +77,19 @@ def graph_break_as_hard_error(*args, **kwargs): try: return fn(*args, **kwargs) except (Unsupported, ObservedException) as e: - msg = " Scroll up to find out what causes the graph break." - raise UncapturedHigherOrderOpError(reason + msg) from e + import sys + + if isinstance(e, Unsupported): + exc = UncapturedHigherOrderOpError( + f"{reason} Got {e.msg}", e.real_stack + ) + else: + msg = e.msg if hasattr(e, "msg") else type(e) + real_stack = e.real_stack if hasattr(e, "real_stack") else None + exc = UncapturedHigherOrderOpError( + f"{reason} Got {msg}", real_stack + ) + raise exc.with_traceback(sys.exc_info()[2]) from None return graph_break_as_hard_error @@ -899,7 +911,20 @@ def make(value, source=None, **kwargs): def call_function( self, tx: "InstructionTranslator", - args: list[VariableTracker], + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + + return self._call_function(tx, args, kwargs) + + def _call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: unimplemented(f"HigherOrderOperator {self.value.__name__}") @@ -913,7 +938,7 @@ class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable Wraps torch._functorch.autograd_function.custom_function_call """ - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -924,7 +949,7 @@ def call_function( torch._dynamo.variables.UserDefinedObjectVariable( self.value, source=self.source ), - source=AttrSource(AttrSource(self.source, "__call__"), "__func__"), + source=AttrSource(self.source, "__call__"), ).call_function(tx, args, kwargs) @@ -935,7 +960,7 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="Cond doesn't work unless it is captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -1119,7 +1144,7 @@ def __init__(self, hop, source, script_obj_var, method_name) -> None: self.script_obj_var = script_obj_var self.method_name = method_name - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1172,7 +1197,7 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="while_loop doesn't work unless it is captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1422,7 +1447,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="associative_scan must be captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1627,7 +1652,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="scan must be captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1828,7 +1853,7 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="map doesn't work unless it is captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1924,7 +1949,7 @@ def call_function( class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2074,7 +2099,7 @@ def create_wrapped_node( return proxy_args, {}, example_value, body_r, treespec, body_gmod, body_name - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2262,7 +2287,7 @@ class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" ) -> "VariableTracker": _check_supported_callable_arg(tx, args[0], "body_fn") @@ -2332,7 +2357,7 @@ def call_function( class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2370,7 +2395,7 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2431,7 +2456,7 @@ def call_function( class CheckpointHigherOrderVariable(WrapHigherOrderVariable): - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -2507,7 +2532,7 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable): def __init__(self, hop, source) -> None: super().__init__(hop, source) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -2594,7 +2619,7 @@ def call_function( class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable): - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2617,7 +2642,7 @@ def call_function( class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): - def call_function( + def _call_function( self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" ) -> "VariableTracker": from .builder import wrap_fx_proxy @@ -2654,7 +2679,7 @@ def to_proxy(self, tx, arg): else: return arg.as_proxy() - def call_function( + def _call_function( self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" ) -> "VariableTracker": from .builder import wrap_fx_proxy @@ -2686,7 +2711,7 @@ class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): here in the call to dynamo from compiled autograd. """ - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2777,17 +2802,12 @@ def create_scalar(): return proxy_args - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - from .torch_function import can_dispatch_torch_function, dispatch_torch_function - - if can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - from .builder import wrap_fx_proxy ( @@ -3264,7 +3284,7 @@ class BaseHOPVariable(WrapHigherOrderVariable): def python_type(self): return type(self.value) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -3361,7 +3381,7 @@ def install_subgraph_in_output_graph( @raise_hard_error_if_graph_break( reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", ) - def call_function( + def _call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -3406,7 +3426,6 @@ def call_function( _hop_name_to_variable_class = { "cond": CondHigherOrderVariable, "while_loop": WhileLoopHigherOrderVariable, - "map": MapHigherOrderVariable, "map_impl": MapHigherOrderVariable, "executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable, "out_dtype": OutDtypeHigherOrderVariable, diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 362f45884e22..75c6712609e9 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -16,9 +16,8 @@ """ import itertools -import operator import sys -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction @@ -60,86 +59,25 @@ def call_function( ) -> "VariableTracker": # See also: module `torch._dynamo.polyfills.itertools` - if ( - self.value is itertools.product - and not kwargs - and all(arg.has_unpack_var_sequence(tx) for arg in args) - ): - seqs = [arg.unpack_var_sequence(tx) for arg in args] - items = [ - variables.TupleVariable(list(item)) for item in itertools.product(*seqs) - ] - return variables.ListIteratorVariable( - items, mutation_type=ValueMutationNew() - ) - elif self.value is itertools.accumulate: - from .builtin import BuiltinVariable - - if any(key not in ["initial", "func"] for key in kwargs.keys()): + if self.value is itertools.product: + if any(kw != "repeat" for kw in kwargs.keys()): unimplemented_v2( - gb_type="Unsupported kwargs for itertools.accumulate", + gb_type="Unsupported kwargs for itertools.product", context=f"call_function {self} {args} {kwargs}", - explanation=f"Expected kwargs: 'initial', 'func', but got " - f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}", + explanation=f"Expected kwargs: 'repeat', but got " + f"{','.join(set(kwargs.keys()) - {'repeat'})}", hints=[*graph_break_hints.USER_ERROR], ) - if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): - seq = args[0].unpack_var_sequence(tx) - - if "func" in kwargs and len(args) == 1: - func = kwargs["func"].call_function - elif len(args) == 2: - func = args[1].call_function - elif len(args) == 1: - # Default to operator.add - func = BuiltinVariable(operator.add).call_function - else: - unimplemented_v2( - gb_type="Unsupported `func` in itertools.accumulate", - context=f"call_function {self} {args} {kwargs}", - explanation="Dynamo does not know how to get the " - "function to use for itertools.accumulate. " - "itertools.accumulate expects the `func` as the second " - "argument or as a keyword argument.", - hints=[*graph_break_hints.USER_ERROR], - ) + if "repeat" in kwargs.keys(): + r = kwargs["repeat"].as_python_constant() else: - unimplemented_v2( - gb_type="Unsupported arguments for itertools.accumulate", - context=f"call_function {self} {args} {kwargs}", - explanation="Dynamo does not know how to trace " - f"itertools.accumulate with args: {args} and kwargs: {kwargs}. " - "itertools.accumulate expects an iterable, an optional " - "binary function for accumulation, and an optional initial " - "value to set the starting state.", - hints=[ - "Make sure the arguments to itertools.accumulate are correct.", - *graph_break_hints.SUPPORTABLE, - ], - ) - - items = [] - acc = kwargs.get("initial") - if acc is not None: - items.append(acc) - for item in seq: - if acc is None: - acc = item - else: - try: - acc = func(tx, [acc, item], {}) - except Exception as e: - unimplemented_v2( - gb_type="Unexpected failure during itertools.accumulate() iteration", - context=f"call_function {self} {args} {kwargs}", - explanation="Unexpected failure in invoking function during accumulate. " - f"Failed running func {func}({item}{acc})", - hints=[*graph_break_hints.DIFFICULT], - from_exc=e, - ) - items.append(acc) - + r = 1 + seqs = [arg.force_unpack_var_sequence(tx) for arg in args] + items = [ + variables.TupleVariable(list(item)) + for item in itertools.product(*seqs, repeat=r) + ] return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() ) @@ -252,9 +190,23 @@ def keyfunc(x): return variables.CountIteratorVariable( *args, mutation_type=ValueMutationNew() ) - elif self.value is itertools.cycle: - return variables.CycleIteratorVariable( - *args, mutation_type=ValueMutationNew() + elif ( + self.value is itertools.permutations + and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant())) + and not kwargs + ): + if len(args) == 2: + r = args[1].as_python_constant() + else: + r = None + items = [ + variables.TupleVariable(list(item)) + for item in itertools.permutations( + args[0].force_unpack_var_sequence(tx), r + ) + ] + return variables.ListIteratorVariable( + items, mutation_type=ValueMutationNew() ) else: return super().call_function(tx, args, kwargs) @@ -380,54 +332,6 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.extend_output(create_call_function(2, False)) -class CycleIteratorVariable(IteratorVariable): - def __init__( - self, - iterator: IteratorVariable, - saved: Optional[list[VariableTracker]] = None, - saved_index: int = 0, - item: Optional[VariableTracker] = None, - **kwargs, - ) -> None: - if saved is None: - saved = [] - super().__init__(**kwargs) - self.iterator = iterator - self.saved = saved - self.saved_index = saved_index - self.item = item - - def next_variable(self, tx): - assert self.is_mutable() - - if self.iterator is not None: - try: - new_item = self.iterator.next_variable(tx) - if len(self.saved) > MAX_ITERATOR_LIMIT: - unimplemented_v2( - gb_type="input iterator to itertools.cycle has too many items", - context=f"next({self})", - explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}", - hints=[], - ) - tx.output.side_effects.mutation(self) - self.saved.append(new_item) - self.item = new_item - if self.item is None: - return self.next_variable(tx) - return self.item - except ObservedUserStopIteration: - handle_observed_exception(tx) - self.iterator = None - return self.next_variable(tx) - elif len(self.saved) > 0: - tx.output.side_effects.mutation(self) - self.saved_index = (self.saved_index + 1) % len(self.saved) - return self.item - else: - raise_observed_exception(StopIteration, tx) - - class ZipVariable(IteratorVariable): """ Represents zip(*iterables) @@ -475,6 +379,10 @@ def unpack_var_sequence(self, tx) -> list["VariableTracker"]: def next_variable(self, tx): assert self.is_mutable() + + if len(self.iterables) == 0: + raise_observed_exception(StopIteration, tx) + old_index = self.index args = [] @@ -638,7 +546,10 @@ def _next(): while True: item = _next() self.index += 1 - res = self.fn.call_function(tx, [item], {}) + if isinstance(self.fn, ConstantVariable) and self.fn.value is None: + res = item + else: + res = self.fn.call_function(tx, [item], {}) pred_res = variables.UserFunctionVariable( polyfills.predicate ).call_function(tx, [res], {}) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 923021e63294..f75f5b180c72 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -42,6 +42,7 @@ AttrSource, GenericAttrSource, GetItemSource, + TypeMROSource, TypeSource, WeakRefCallSource, ) @@ -134,9 +135,7 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): # Equivalent of something like type(L['self']).__mro__[1].attr_name if type_to_use_source: source = AttrSource( - GetItemSource( - AttrSource(type_to_use_source, "__mro__"), index - ), + GetItemSource(TypeMROSource(type_to_use_source), index), name, ) return resolved_getattr, source @@ -247,14 +246,14 @@ def call_method( # different from type(self) with polymorphism. cls_source = None if self.objvar.source: - cls_source = AttrSource(self.objvar.source, "__class__") + cls_source = TypeSource(self.objvar.source) cls_variable = VariableTracker.build( tx, self.objvar.value_type, cls_source ) - return variables.UserMethodVariable( - inner_fn.__func__, cls_variable, source=source - ).call_function(tx, args, kwargs) + return variables.UserFunctionVariable( + inner_fn.__func__, source=AttrSource(source, "__func__") + ).call_function(tx, [cls_variable, *args], kwargs) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( inner_fn, source=source diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index e6d1e669dad5..10ad8c4a1286 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -909,7 +909,11 @@ def set_nn_module_stack_source(self, source): @functools.cache def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} + supported = { + torch.nn.Module.__setattr__, + torch.nn.Module.__init__, + torch.nn.Module.__delattr__, + } return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -989,7 +993,7 @@ def call_function( fn = self.value_type.forward if self.source: - source = AttrSource(AttrSource(self.source, "__class__"), name) + source = self.get_source_by_walking_mro(name) else: source = None @@ -1017,7 +1021,7 @@ def call_method( if name in ["_call_impl", "_wrapped_call_impl"]: fn = getattr(self.value_type, name) if self.source: - source = AttrSource(AttrSource(self.source, "__class__"), name) + source = self.get_source_by_walking_mro(name) else: source = None @@ -1032,9 +1036,7 @@ def call_method( method = None if isinstance(method, staticmethod): - source = AttrSource( - AttrSource(AttrSource(self.source, "__class__"), name), "__func__" - ) + source = AttrSource(self.get_source_by_walking_mro(name), "__func__") return tx.inline_user_function_return( variables.UserFunctionVariable(method.__func__, source=source), args, @@ -1093,9 +1095,10 @@ def call_method( # Handle submodules self.is_state_mutated = True - if method is torch.nn.Module.__setattr__ and isinstance( - args[1], variables.DeletedVariable - ): + if ( + method is torch.nn.Module.__setattr__ + and isinstance(args[1], variables.DeletedVariable) + ) or method is torch.nn.Module.__delattr__: # Trace through __delattr__ to track mutations on the module # members like `_modules``. return tx.inline_user_function_return( diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index bc2852c38d86..a120ab488ed9 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -25,7 +25,8 @@ import torch -from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported +from .. import graph_break_hints +from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported from .base import VariableTracker from .user_defined import UserDefinedObjectVariable @@ -75,14 +76,24 @@ def var_getattr(self, tx, name: str) -> VariableTracker: method = getattr(self.value, name, None) if method is None: - unimplemented( - f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?" + unimplemented_v2( + gb_type="FakeScriptObject missing method implementation", + context=f"value={self.value}, method={name}", + explanation=f"TorchScript object {self.value} doesn't define the method {name}.", + hints=[ + f"Ensure the method {name} is implemented in {self.value}.", + *graph_break_hints.USER_ERROR, + ], ) if not callable(method): - unimplemented( - "Only method calls on TorchScript objects can be supported safely." - " Please use method calls instead of attribute access." + unimplemented_v2( + gb_type="Attempted to access non-callable attribute of TorchScript object", + context=f"value={self.value}, method={name}", + explanation="Attribute accesses of TorchScript objects to non-callable attributes are not supported.", + hints=[ + "Use method calls instead of attribute access.", + ], ) return TorchHigherOrderOperatorVariable.make( @@ -100,4 +111,14 @@ def var_getattr(self, tx, name: str) -> VariableTracker: "Dynamo cannot safely trace script object due to graph break." ) def call_method(self, tx, name, args, kwargs): - unimplemented(f"call method {name} on script object is not safe.") + unimplemented_v2( + gb_type="Weird method call on TorchScript object", + context=f"value={self.value}, method={name}", + explanation=( + f"This particular method call ({name}) is not supported (e.g. calling `__setattr__`). " + "Most method calls to TorchScript objects should be supported." + ), + hints=[ + "Avoid calling this method.", + ], + ) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 7ee8c48b0ffb..4458468d8118 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -59,7 +59,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable -from .functions import UserMethodVariable +from .functions import UserFunctionVariable, UserMethodVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -389,7 +389,7 @@ def _flatten_vts(vts): output = [] while vts: - vt = vts.pop() + vt = vts.popleft() if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): vt.realize() @@ -397,8 +397,10 @@ def _flatten_vts(vts): if vt.is_realized(): if isinstance(vt, ListVariable): vts.extend(vt.items) + continue elif isinstance(vt, ConstDictVariable): vts.extend(vt.items.values()) + continue output.append(vt) @@ -618,8 +620,8 @@ def var_getattr(self, tx: "InstructionTranslator", name): elif isinstance(attr, property): getter_source = AttrSource(attr_source, "fget") getter = attr.fget - getter_var = UserMethodVariable(getter, self, source=getter_source) - return getter_var.call_function(tx, [], {}) + getter_var = UserFunctionVariable(getter, source=getter_source) + return getter_var.call_function(tx, [self], {}) elif isinstance(attr, classmethod): return UserMethodVariable( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index df78c793a81d..084a1e2149d0 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -60,8 +60,11 @@ AttrSource, CallFunctionNoArgsSource, DataclassFieldsSource, + DictGetItemSource, GetItemSource, RandomValueSource, + TypeDictSource, + TypeMROSource, TypeSource, UnspecializedParamBufferSource, ) @@ -135,6 +138,14 @@ def is_forbidden_context_manager(ctx): return ctx in f_ctxs +def is_cython_function(obj): + return ( + callable(obj) + and hasattr(type(obj), "__name__") + and type(obj).__name__ == "cython_function_or_method" + ) + + class UserDefinedVariable(VariableTracker): value: object @@ -242,6 +253,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke elif name == "__dict__": options = {"source": source} return variables.GetAttrVariable(self, name, **options) + elif name == "__mro__": + attr_source = self.source and TypeMROSource(self.source) + return VariableTracker.build(tx, self.value.__mro__, attr_source) # Special handling of collections.OrderedDict.fromkeys() # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with @@ -284,10 +298,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke func = obj.__get__(None, self.value) return VariableTracker.build(tx, func, source) elif source: - # __mro__ is a member in < 3.12, an attribute in >= 3.12 - if inspect.ismemberdescriptor(obj) or ( - sys.version_info >= (3, 12) and name == "__mro__" - ): + if inspect.ismemberdescriptor(obj): return VariableTracker.build(tx, obj.__get__(self.value), source) if ConstantVariable.is_literal(obj): @@ -998,19 +1009,18 @@ def call_method( # check for methods implemented in C++ if isinstance(method, types.FunctionType): - source = ( - None - if self.source is None - else AttrSource(AttrSource(self.source, "__class__"), name) - ) + source = self.source + source_fn = None + if source: + source_fn = self.get_source_by_walking_mro(name) # TODO(jansel): add a guard to check for monkey patching? from ..mutation_guard import unpatched_nn_module_init if method is torch.nn.Module.__init__: method = unpatched_nn_module_init - return UserMethodVariable(method, self, source=source).call_function( - tx, args, kwargs - ) + return UserMethodVariable( + method, self, source_fn=source_fn, source=source + ).call_function(tx, args, kwargs) if method is list.__len__ and self.source and not (args or kwargs): install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) @@ -1224,12 +1234,40 @@ def get_source_by_walking_mro(self, name): for idx, klass in enumerate(type(self.value).__mro__): if name in klass.__dict__: - mro_source = AttrSource(self.cls_source, "__mro__") - klass_source = GetItemSource(mro_source, idx) - dict_source = AttrSource(klass_source, "__dict__") - # TODO(anijain2305) - This is a mapping proxy object. Ideally we - # should use DictGetItemSource here. - return GetItemSource(dict_source, name) + if idx != 0: + mro_source = TypeMROSource(self.cls_source) + klass_source = GetItemSource(mro_source, idx) + else: + klass_source = self.cls_source + dict_source = TypeDictSource(klass_source) + out_source = DictGetItemSource(dict_source, name) + + for absent_idx in range(1, idx): + # Insert a guard that the name is not present in the mro hierarchy + mro_source = TypeMROSource(self.cls_source) + klass_source = GetItemSource(mro_source, absent_idx) + dict_source = TypeDictSource(klass_source) + install_guard( + dict_source.make_guard( + functools.partial( + GuardBuilder.DICT_CONTAINS, key=name, invert=True + ) + ) + ) + # Insert a guard that the name is not present in the object __dict__ + if ( + self.source + and hasattr(self.value, "__dict__") + and name not in self.value.__dict__ + ): + install_guard( + self.source.make_guard( + functools.partial( + GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name + ) + ) + ) + return out_source unimplemented_v2( gb_type="could not find name in object's mro", @@ -1339,15 +1377,28 @@ def var_getattr(self, tx: "InstructionTranslator", name): if subobj is torch.nn.Module.__init__: subobj = unpatched_nn_module_init + subobj_from_class = inspect.getattr_static( + self.value.__class__, name, NO_SUCH_SUBOBJ + ) + is_accessible_from_type_mro = ( + subobj_from_class is subobj + and self.cls_source is not None + and self.source is not None + ) + if isinstance(subobj, property): if self.source: # Read the class attribute to reach the property - source = AttrSource(AttrSource(self.source, "__class__"), name) + source = self.get_source_by_walking_mro(name) # Get the getter function source = AttrSource(source, "fget") - return variables.UserMethodVariable( - subobj.fget, self, source=source - ).call_function(tx, [], {}) + + # Avoid using UserMethodVariable here because there is no way to + # access the method object here. Direct inline by creating the + # UserFunctionVariable. + return variables.UserFunctionVariable( + subobj.fget, source=source + ).call_function(tx, [self], {}) elif isinstance(subobj, _collections._tuplegetter): # namedtuple fields are represented by _tuplegetter, and here we # emulate its `__get__`, which is implemented in C. @@ -1360,11 +1411,25 @@ def var_getattr(self, tx: "InstructionTranslator", name): # Safe because `staticmethod.__get__` basically won't trigger user # code and just returns the underlying `__func__`: # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100 + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a staticmethod object, so access the __func__ + # attribute to get to the actual function. + source = AttrSource(self.get_source_by_walking_mro(name), "__func__") func = subobj.__get__(self.value) return VariableTracker.build(tx, func, source) elif isinstance(subobj, classmethod): + source_fn = None + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a classmethod object, so access the __func__ + # attribute to get to the actual function. + source_fn = AttrSource(self.get_source_by_walking_mro(name), "__func__") return variables.UserMethodVariable( - subobj.__func__, self.var_getattr(tx, "__class__"), source=source + subobj.__func__, + self.var_getattr(tx, "__class__"), + source_fn=source_fn, + source=source, ) elif isinstance(subobj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static({}, "fromkeys") @@ -1454,7 +1519,12 @@ def var_getattr(self, tx: "InstructionTranslator", name): func = subobj if inspect.ismethod(dynamic_subobj): - return variables.UserMethodVariable(func, self, source=source) + source_fn = None + if is_accessible_from_type_mro: + source_fn = self.get_source_by_walking_mro(name) + return variables.UserMethodVariable( + func, self, source_fn=source_fn, source=source + ) elif inspect.isfunction(dynamic_subobj): if is_utils_checkpoint(func): return build_checkpoint_variable(source=source) @@ -1485,10 +1555,17 @@ def var_getattr(self, tx: "InstructionTranslator", name): source = self._wrap_source(source) if subobj is not NO_SUCH_SUBOBJ: - if is_wrapper_or_member_descriptor(subobj): + if ( + is_wrapper_or_member_descriptor(subobj) + or torch._C._dynamo.utils.is_instancemethod(subobj) + or is_cython_function(subobj) + ): options = {"source": source} return variables.GetAttrVariable(self, name, **options) if source: + if is_accessible_from_type_mro: + source = self.get_source_by_walking_mro(name) + return variables.LazyVariableTracker.create(subobj, source) else: # Check if the subobj is accessible from the class itself. If the class source is known, we can create a diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index b1195cf42128..ef49c4f035a5 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import contextlib -from typing import Optional +from typing import Any, Optional import torch from torch.fx.graph_module import GraphModule @@ -9,14 +9,16 @@ _EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" -def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) -> None: +def _node_metadata_hook( + node: torch.fx.Node, metadata: Optional[dict[str, Any]] = None +) -> None: """ Hook for adding the appropriate metadata to nodes that are created during a pass using graph.create_node. An example of how to use it: ``` with _set_node_metadata_hook(gm, - functools.partial(_node_metadata_hook, stack_trace="file") + functools.partial(_node_metadata_hook, metadata={"stack_trace": "file"}) ): pass(gm) ``` @@ -44,7 +46,6 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) fake_res = node.target(*fake_args) node.meta["val"] = fake_res - node.meta["stack_trace"] = stack_trace node.meta["nn_module_stack"] = arg_meta.get( "nn_module_stack", { @@ -60,6 +61,12 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) f"{node.target.__class__.__name__}.{node.target.__name__}", ) + # Hook specified metadata takes precedence over all previously set + # metadata, so this goes last + if metadata is not None: + for k, v in metadata.items(): + node.meta[k] = v + @contextlib.contextmanager def _set_node_metadata_hook(gm: torch.fx.GraphModule, f): diff --git a/torch/_export/passes/insert_custom_op_guards.py b/torch/_export/passes/insert_custom_op_guards.py index 4deecdf81822..bfea7b08c924 100644 --- a/torch/_export/passes/insert_custom_op_guards.py +++ b/torch/_export/passes/insert_custom_op_guards.py @@ -20,7 +20,8 @@ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> _set_node_metadata_hook( gm, functools.partial( - _node_metadata_hook, stack_trace=node.meta.get("stack_trace") + _node_metadata_hook, + metadata={"stack_trace": node.meta.get("stack_trace")}, ), ), gm.graph.inserting_before(node), diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index 50472c02375c..5eb5512cde63 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<31664e4faa0eacd6f538ffed163078e190d9d2b98d762dd45b68eb1b7b12f0d1>> +// checksum<> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -330,19 +330,14 @@ struct ExportedProgram { 60: SchemaVersion schema_version; 70: list verifiers; 80: string torch_version; -} - -struct Program { - 200: map methods; + 90: map tensor_paths; + 100: map constant_paths; } struct Model { 10: string name; - 20: map tensorPaths; - 40: Program program; - 50: map delegates; - 60: map deviceAllocationMap; - 70: map constantPaths; + 80: ExportedProgram program; + 90: map variants; } struct AOTInductorModelPickleData { diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 933d30310b72..dba719a60155 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 9) +SCHEMA_VERSION = (8, 10) TREESPEC_VERSION = 1 @@ -436,35 +436,35 @@ class ExportedProgram: verifiers: Annotated[list[str], 70] = field(default_factory=list) torch_version: Annotated[str, 80] = "<=2.4" + # key is the FQN of tensor in exported program + # value is the archive path of tensor payloads + # e.g. "L__self__linear.weight" : "/data/tensor/weight_1" + tensor_paths: Annotated[dict[str, str], 90] = field(default_factory=dict) + + # key is the FQN of constant in exported program (constant tensor or torchbind objs) + # value is the archive path of serialized constants + constant_paths: Annotated[dict[str, str], 100] = field(default_factory=dict) + ######################################################################### # Container types for inference tasks, not being used directly for export. ######################################################################### -@dataclass -class Program: - methods: Annotated[dict[str, ExportedProgram], 200] - - # This is the top-level model definition that be will serialized into the package @dataclass class Model: # unique identifier of the model in the package, e.g. local, remote, merge name: Annotated[str, 10] - # key is the FQN of tensor in exported program - # value is the archive path of tensor payloads - # e.g. "L__self__linear.weight" : "/data/tensor/L__self__linear.weight" - tensorPaths: Annotated[dict[str, str], 20] - # program exported from torch.export() - program: Annotated[Program, 40] - # Backend-specialized Lowered GraphModule - # e.g. "aotinductor-a100" : ExportedProgram_with_AOTInductor_delegate - delegates: Annotated[dict[str, Program], 50] - deviceAllocationMap: Annotated[dict[str, str], 60] - # key is the FQN of constant in exported program (constant tensor or torchbind objs) - # value is the archive path of serialized constants - constantPaths: Annotated[dict[str, str], 70] + + # the main program exported from torch.export() + program: Annotated[ExportedProgram, 80] + + # a collection of ExportedPrograms that are related to the same model + # They can be used for different purposes, e.g. + # - different methods such as "encode" and "decode" for the same model + # - different delegates such as "aoti_sm80" and "aoti_sm90" + variants: Annotated[dict[str, ExportedProgram], 90] # diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 9167a6820ef4..bb087048a30c 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> +# checksum<> AOTInductorModelPickleData: kind: struct fields: @@ -131,6 +131,12 @@ ExportedProgram: torch_version: type: str default: <=2.4 + tensor_paths: + type: Dict[str, str] + default: '{}' + constant_paths: + type: Dict[str, str] + default: '{}' ExternKernelNode: kind: struct fields: @@ -298,16 +304,10 @@ Model: fields: name: type: str - tensorPaths: - type: Dict[str, str] program: - type: Program - delegates: - type: Dict[str, Program] - deviceAllocationMap: - type: Dict[str, str] - constantPaths: - type: Dict[str, str] + type: ExportedProgram + variants: + type: Dict[str, ExportedProgram] ModuleCallEntry: kind: struct fields: @@ -388,11 +388,6 @@ OutputTokenSpec: fields: arg: type: TokenArgument -Program: - kind: struct - fields: - methods: - type: Dict[str, ExportedProgram] RangeConstraint: kind: struct fields: @@ -534,5 +529,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 9 +- 10 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index ccc963397530..29b9766ae18a 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -448,6 +448,7 @@ class ForwardRef {{ ptr_ = std::make_unique(*other.ptr_); return *this; }} + ~ForwardRef(); const T& operator*() const {{ return *ptr_; }} @@ -519,6 +520,7 @@ class F64 {{ template ForwardRef::ForwardRef(ForwardRef&&) = default; template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +template ForwardRef::~ForwardRef() = default; }} // namespace _export }} // namespace torch """ diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 6c31a4a9ed8f..16f7ebcb7676 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -662,7 +662,10 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: gm, functools.partial( _node_metadata_hook, - stack_trace=node.meta.get("stack_trace"), + metadata={ + "stack_trace": node.meta.get("stack_trace"), + "nn_module_stack": node.meta.get("nn_module_stack"), + }, ), ), ): @@ -690,7 +693,10 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): "in insert_deferred_runtime_asserts" ) with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + gm, + functools.partial( + _node_metadata_hook, metadata={"stack_trace": stack_trace} + ), ): shape_env = _get_shape_env_from_gm(gm) if shape_env: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 37bcebe0518e..248c3a0ae673 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -95,7 +95,7 @@ class FXGraphCacheMiss(BypassAOTAutogradCache): def should_use_remote_autograd_cache(): - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: return False if config.enable_remote_autograd_cache is not None: return config.enable_remote_autograd_cache @@ -116,7 +116,7 @@ def should_use_remote_autograd_cache(): def should_use_local_autograd_cache(): - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: return False return config.enable_autograd_cache @@ -302,6 +302,42 @@ class AOTAutogradCacheDetails(FxGraphHashDetails): a safe and stable cache key for AOTAutograd. """ + def get_triton_source_codes_from_gm( + self, + gm: torch.fx.GraphModule, + ): + triton_kernels = [] + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if isinstance(node.target, torch._ops.OpOverloadPacket): + attrs = node.target._dir + for attr in attrs: + if custom_op := getattr(node.target, attr, None): + kernels = torch._library.triton.get_triton_kernels_for_op( + custom_op._name + ) + triton_kernels.extend(kernels) + elif isinstance(node.target, torch._ops.OpOverload): + kernels = torch._library.triton.get_triton_kernels_for_op( + node.target._name + ) + triton_kernels.extend(kernels) + + triton_kernel_source_codes = [] + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + for kernel in triton_kernels: + source_codes = user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + triton_kernel_source_codes.append(source_codes) + + return triton_kernel_source_codes + def __init__( self, gm: torch.fx.GraphModule, @@ -319,6 +355,7 @@ def __init__( [], [], ) + self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) if hasattr(gm, "saved_tensors_hooks_pack_0"): diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 27cf699091ee..a1c6e795bfec 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -516,6 +516,48 @@ class InvokeSubgraphHopGraphs: new_num_saved_nodes: Optional[int] = None +def prepare_for_partitioner(mod, num_primals, num_fw_outputs): + # min-cut partitioner requires the placeholders to have primals and + # tangents string in the node.name. The signature of the joint graph is + # (*primals, *tangents) + + # We also have to update the output signature which is right now + # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the + # partitioner to work. + new_graph = torch.fx.Graph() + env = {} + + primals_counter = itertools.count(0) + tangents_counter = itertools.count(0) + + for idx, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if idx < num_primals: + env[node] = new_graph.placeholder(f"primals_{next(primals_counter)}") + else: + env[node] = new_graph.placeholder(f"tangents_{next(tangents_counter)}") + env[node].meta = copy.copy(node.meta) + elif node.op == "output": + # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) + # The reason for having the reversed signature in the first + # place is to simplify step 3. + old_outputs = node.args[0] + new_outputs = ( + *old_outputs[-num_fw_outputs:], + *old_outputs[:-num_fw_outputs], + ) + new_outputs = [env[n] if n else None for n in new_outputs] + new_graph.output(tuple(new_outputs)) + else: + env[node] = new_graph.node_copy(node, lambda n: env[n]) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + + out = torch.fx.GraphModule(mod, new_graph) + return out + + def run_joint_graph_passes_on_hops( joint_gm: torch.fx.GraphModule, joint_inputs: Any, @@ -553,51 +595,6 @@ def num_outputs(mod): def num_inputs(mod): return len(mod.graph.find_nodes(op="placeholder")) - def prepare_for_partitioner(mod, num_primals, num_fw_outputs): - # min-cut partitioner requires the placeholders to have primals and - # tangents string in the node.name. The signature of the joint graph is - # (*primals, *tangents) - - # We also have to update the output signature which is right now - # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the - # partitioner to work. - new_graph = torch.fx.Graph() - env = {} - - primals_counter = itertools.count(0) - tangents_counter = itertools.count(0) - - for idx, node in enumerate(mod.graph.nodes): - if node.op == "placeholder": - if idx < num_primals: - env[node] = new_graph.placeholder( - f"primals_{next(primals_counter)}" - ) - else: - env[node] = new_graph.placeholder( - f"tangents_{next(tangents_counter)}" - ) - env[node].meta = copy.copy(node.meta) - elif node.op == "output": - # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) - # The reason for having the reversed signature in the first - # place is to simplify step 3. - old_outputs = node.args[0] - new_outputs = ( - *old_outputs[-num_fw_outputs:], - *old_outputs[:-num_fw_outputs], - ) - new_outputs = [env[n] if n else None for n in new_outputs] - new_graph.output(tuple(new_outputs)) - else: - env[node] = new_graph.node_copy(node, lambda n: env[n]) - env[node].meta = copy.copy(node.meta) - - new_graph.lint() - - out = torch.fx.GraphModule(mod, new_graph) - return out - new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( lambda: InvokeSubgraphHopGraphs() ) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index b30a6437cb3b..cecfda2bcf1c 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -637,10 +637,15 @@ def _dup_fake_script_obj(fake_flat_args): if fw_metadata.num_intermediate_bases > 0: assert not req_subclass_dispatch, f"""\ -torch.compile is currently being used with tensor subclass inputs: -{",".join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs -that alias one another, which is currently unsupported in the subclass use case. If you run into this, -please file a github issue""" +torch.compile is currently being used with tensor subclass inputs. +We are attempting to a compile a graph with two graph outputs +that alias one another, specifically output indices: + + {[i for i, x in enumerate(fw_metadata.output_info) if x.output_type == OutputType.alias_of_intermediate]} + +ANY output aliasing (even for regular tensors) is currently unsupported if +there are any subclass outputs. If you run into this, please file a github +issue""" if aot_config.is_export: # aot_export: ban input metadata mutations for now to keep shared code paths simpler. @@ -1131,6 +1136,15 @@ def forward(*runtime_args: tuple[Any]): return forward +def boxed_nop_preserve_node_meta(fx_g, example_inputs): + def run(args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + def aot_export_joint_with_descriptors( stack: contextlib.ExitStack, mod: nn.Module, @@ -1140,6 +1154,8 @@ def aot_export_joint_with_descriptors( decompositions: Optional[dict] = None, keep_inference_input_mutations=False, ignore_shape_env=False, + fw_compiler: Optional[AOTDispatchCompiler] = boxed_nop_preserve_node_meta, + bw_compiler: Optional[AOTDispatchCompiler] = boxed_nop_preserve_node_meta, ) -> JointWithDescriptors: """ This API captures the joint graph for an nn.Module. However, unlike @@ -1201,8 +1217,6 @@ def aot_export_joint_with_descriptors( of the inputs to determine if inputs are parameters and their FQNs. """ - from torch._dynamo.backends.debugging import boxed_nop - ( functional_call, _params_buffers_flat, @@ -1219,8 +1233,8 @@ def aot_export_joint_with_descriptors( mod, args, kwargs, - boxed_nop, - boxed_nop, + fw_compiler, + bw_compiler, default_partition, decompositions, keep_inference_input_mutations, diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 2833a2b1631a..5bf2dee3e1d7 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -281,6 +281,17 @@ def remote_autograd_cache_default() -> Optional[bool]: # real tensor outputs. generate_fake_kernels_from_real_mismatches = False +# When there are device mismatches in FakeTensor device propagation, +# prefer a specific device type over others. This is particularly useful +# in full compiled mode where intermediate tensors with device mismatches +# represent only logical differences during compilation - these intermediate +# tensors will never physically materialize in the binary execution, so the +# device mismatch is not a real runtime concern. Enabling this allows the +# compiler to proceed with compilation by choosing the preferred device type +# for consistency. For example, set to "mtia" to prefer MTIA devices over +# CPU, or "cuda" to prefer CUDA devices over CPU. +fake_tensor_prefer_device_type: Optional[str] = None + # CUDAGraph save run_with_rng functionalization. # TODO: turn on by default graphsafe_rng_functionalization = True diff --git a/torch/_guards.py b/torch/_guards.py index fa6f9cc1e7bd..dd2ba4774792 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -267,7 +267,7 @@ class Guard: guard_types: Optional[list[str]] = None code_list: Optional[list[str]] = None obj_weakref: Optional[object] = None - guarded_class_weakref: Optional[type] = None + guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None stack: Optional[CapturedTraceback] = None user_stack: Optional[traceback.StackSummary] = None @@ -380,7 +380,7 @@ def is_local(self) -> bool: def set_export_info( self, guard_type: str, - guarded_class: Optional[type], + guarded_class: Optional[weakref.ReferenceType[Any]], code_list: list[str], obj_weakref: object, ) -> None: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 648d41b0b95a..10f6ca9f386c 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Optional, Union import torch -import torch._subclasses.functional_tensor import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._C._functorch import ( @@ -19,6 +18,7 @@ from torch._higher_order_ops.utils import ( _maybe_run_with_interpreter, _set_compilation_env, + create_bw_fn, materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, @@ -36,8 +36,6 @@ ) from torch.utils._python_dispatch import _get_current_dispatch_mode -from .utils import clone_outputs_aliasing_inputs - log = logging.getLogger(__name__) @@ -201,60 +199,6 @@ def _cond_op_wrapper(*args, **kwargs): ) -def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: - """ - For a fn that accepts flat inputs and returns flat outputs: - fw_out = fn(*args), - this function returns: - grad_args = bw_fn(*args_and_grad_output) - with the following invariants: - 1. args + fw_out has an 1-1 correspondence to args_and_grad_output - 2. grad_args has an 1-1 corresponsence to args - 3. for tensor arg whose requires_grad is False, its corresponding grad in - grad_args will be a zero tensor with the same shape. - """ - - from torch._functorch.aot_autograd import AOTConfig, create_joint - from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad - - dummy_aot_config = AOTConfig( - fw_compiler=None, # type: ignore[arg-type] - bw_compiler=None, # type: ignore[arg-type] - partition_fn=None, # type: ignore[arg-type] - decompositions={}, - num_params_buffers=0, - aot_id=0, - keep_inference_input_mutations=False, - ) - n_primals = len(args) - - bw_fn = create_joint( - prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config - ) - - def flat_fn(*args_and_grad_outs): - primals = args_and_grad_outs[:n_primals] - tangents = args_and_grad_outs[n_primals:] - grad_args = bw_fn(primals, tangents)[1] - assert len(args) == len(grad_args) - # In order to keep HOPs functional where the backward graph, - # would have outputs that are aliasing inputs. - # For example in cases where the backward of the function is simply - # passing the upstream gradients through. - maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) - - return [ - ( - torch.zeros_like(arg) - if isinstance(arg, torch.Tensor) and grad is None - else maybe_clone(grad) - ) - for grad, arg in zip(grad_args, primals) - ] - - return flat_fn - - def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): assert isinstance(operands, (list, tuple)), ( f"Cond operands must be a list or tuple of tensors and SymInts {operands}" diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 9f73df7ef478..332bde7e464f 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -13,7 +13,6 @@ from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, - make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) @@ -22,10 +21,11 @@ _from_fun, _stack_pytree, _unstack_pytree, - clone_outputs_aliasing_inputs, - prepare_fw_with_masks, + create_bw_fn, + materialize_as_graph, save_tensors_and_symints_for_backward, saved_tensors_and_symints, + split_into_chunks, ) @@ -40,77 +40,6 @@ def __call__(self, *args, **kwargs): map_impl = MapImpl() -def create_fw_bw_graph(f, num_mapped_args, *args): - mapped_xs = args[:num_mapped_args] - pos_args = args[num_mapped_args:] - - # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py - - with suspend_functionalization(), disable_functional_mode(): - with disable_proxy_modes_tracing(): - unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) - example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] - - example_pos_args = [ - _from_fun(arg) if isinstance(arg, torch.Tensor) else arg - for arg in pos_args - ] - example_flat_out = pytree.tree_map( - _from_fun, f(*example_xs, *example_pos_args) - ) - if any( - not isinstance(out, torch.Tensor) - for out in example_flat_out - if out is not None - ): - raise RuntimeError( - "Expect outputs of map only contains tensors or None. " - f"Got types {[type(out) for out in example_flat_out]}." - ) - example_grad = [_from_fun(out) for out in example_flat_out] - - fw_graph = make_fx(f)(*example_xs, *example_pos_args) - - from torch._functorch.aot_autograd import AOTConfig, create_joint - - dummy_aot_config = AOTConfig( - fw_compiler=None, # type: ignore[arg-type] - bw_compiler=None, # type: ignore[arg-type] - partition_fn=None, # type: ignore[arg-type] - decompositions={}, - num_params_buffers=0, - aot_id=0, - keep_inference_input_mutations=False, - ) - - def joint_f(*example_args): - joint_mapped_args = example_args[:joint_num_mapped] - args = example_args[joint_num_mapped:] - - mapped_input = joint_mapped_args[:num_mapped_args] - mapped_grads = joint_mapped_args[num_mapped_args:] - - joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config) - _, grads = joint( - list(mapped_input) + list(args), - [ - grad - for grad in mapped_grads - if grad is not None and grad.requires_grad - ], - ) - - # In order to keep map functional for backward graph, - # we clone outputs that are aliasing inputs - maybe_clone = clone_outputs_aliasing_inputs(example_args) - - return pytree.tree_map(maybe_clone, grads) - - joint_num_mapped = len(example_grad) + len(example_xs) - joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) - return fw_graph, joint_graph - - def map( f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree], xs: Union[pytree.PyTree, torch.Tensor], @@ -193,36 +122,88 @@ def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs): class MapAutogradOp(torch.autograd.Function): @staticmethod - def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): - save_tensors_and_symints_for_backward(ctx, flat_args) - ctx._joint_graph = joint_graph + def forward(ctx, f, num_mapped_args, *flat_args): + ctx._f = f ctx._num_mapped_args = num_mapped_args + ctx._num_pos_args = len(flat_args) - num_mapped_args + + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + save_tensors_and_symints_for_backward(ctx, flat_args) with torch._C._AutoDispatchBelowAutograd(): return ( - *map_impl( - fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] - ), + *map_impl(f, flat_args[:num_mapped_args], flat_args[num_mapped_args:]), ) @staticmethod def backward(ctx, *flat_grads): fw_args = saved_tensors_and_symints(ctx) - fw_mapped_args = fw_args[: ctx._num_mapped_args] - pos_args = fw_args[ctx._num_mapped_args :] - - grads = map_impl( - ctx._joint_graph, - fw_mapped_args + flat_grads, - pos_args, + num_mapped_args = ctx._num_mapped_args + num_pos_args = ctx._num_pos_args + num_grads = len(flat_grads) + + fw_mapped_args, pos_args = split_into_chunks( + fw_args, + [ + num_mapped_args, + num_pos_args, + ], ) - return None, None, None, *grads + + bw_f = create_bw_fn(ctx._f, fw_args) + + # Create a wrapper around thefor the bw_f + def bw_f_wrapper(*args): + # Dissect args and re-order them for the ``ctx._bw_f`` + # args provided to the wrapper are composed of [*fw_mapped_args, *flat_grads, *pos_args] + # The content of ``bw_f_tangents`` are the upstream gradients, i.e. flat_grads + # The content of ``bw_f_primals`` are the fw_args, i.e., [*fw_mapped_args, *pos_args] + # The bw_f requires *bw_f_primals, *bw_f_tangents + fw_m_args, bw_f_tangents, pos_args = split_into_chunks( + args, [num_mapped_args, num_grads, num_pos_args] + ) + bw_f_primals = *fw_m_args, *pos_args + return bw_f(*bw_f_primals, *bw_f_tangents) + + def construct_args_single_step_bw(): + unwrapped_mapped_xs = pytree.tree_map(_from_fun, fw_mapped_args) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + unwrapped_grads = pytree.tree_map(_from_fun, flat_grads) + example_grads = _unstack_pytree(unwrapped_grads)[0] + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + return *example_xs, *example_grads, *example_pos_args + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + args_single_step_bw = construct_args_single_step_bw() + + # TODO: we need to materialize the bw graphs because dynamo is unable to + # trace through the joint function when torch.compile torch.autograd.grad. + fn_bw_gm = materialize_as_graph( + bw_f_wrapper, + args_single_step_bw, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + force_enable_grad=True, + ) + + grads = map_impl(fn_bw_gm, fw_mapped_args + flat_grads, pos_args) + + return None, None, *grads def trace_map(proxy_mode, func_overload, f, xs, pos_args): - example_input = _unstack_pytree(xs)[0] - body_graph = f + with disable_proxy_modes_tracing(): + example_input = _unstack_pytree(xs)[0] + + body_graph = f - body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) + body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_") @@ -249,8 +230,7 @@ def map_dense(f, xs, pos_args): @map_impl.py_autograd_impl def map_autograd(f, xs, pos_args): num_mapped_args = len(xs) - fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) - flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(f, num_mapped_args, *xs, *pos_args) return flat_out diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index 40533532fbb5..38c07e37bdb8 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -111,8 +111,8 @@ def is_int_mm(op, output_dtype, args): and len(args) == 2 and args[0].dtype == torch.int8 and args[1].dtype == torch.int8 - and args[0].is_cuda - and args[1].is_cuda + and (args[0].is_cuda or args[0].is_xpu) + and (args[1].is_cuda or args[1].is_xpu) ) diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 3cd5bf9ec4e2..4e636b396b38 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -1,22 +1,22 @@ # mypy: allow-untyped-defs import functools import itertools -from collections.abc import Sequence from typing import Any, Callable, Optional import torch import torch._prims_common as utils import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._higher_order_ops.cond import create_bw_fn from torch._higher_order_ops.utils import ( _maybe_compile_and_run_fn, check_meta_consistency, + create_bw_fn, first_slice_copy, materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, saved_tensors_and_symints, + split_into_chunks, unique_graph_id, validate_subgraph_args_types, ) @@ -95,14 +95,6 @@ def first_slice_copy_with_grad(li: list[Any]) -> list[Any]: return slc -def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: - it = iter(iterable) - assert sum(chunk_sizes) == len(iterable), ( - "the sum of all chunks needs to match the length of the iterable." - ) - return [list(itertools.islice(it, size)) for size in chunk_sizes] - - def call_operator(operator, *args): return pytree.tree_leaves(operator(*args)) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 34a9c5915254..4dd2bd145a90 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -461,11 +461,16 @@ def get_signature_value(idx: int, arg: Any) -> str: elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() ttir_module = src.make_ir(options, codegen_fns, context) - else: + elif make_ir_sig_params == 4: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() ttir_module = src.make_ir(options, codegen_fns, module_map, context) + else: + codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] + codegen_fns = backend.get_codegen_implementation(*codegen_args) + module_map = backend.get_module_map() + ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 25ef972864d5..ab0fc4e654c6 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import contextlib import functools +from collections.abc import Sequence from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import Any, Callable, Optional, overload, TypeVar, Union @@ -722,6 +723,69 @@ def saved_tensors_and_symints(ctx): return tuple(args) +def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: + assert sum(chunk_sizes) == len(iterable), ( + "the sum of all chunks needs to match the length of the iterable." + ) + elements = [] + idx = 0 + for size in chunk_sizes: + elements.append(iterable[idx : idx + size]) + idx += size + return elements + + +def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: + """ + For a fn that accepts flat inputs and returns flat outputs: + fw_out = fn(*args), + this function returns: + grad_args = bw_fn(*args_and_grad_output) + with the following invariants: + 1. args + fw_out has an 1-1 correspondence to args_and_grad_output + 2. grad_args has an 1-1 corresponsence to args + 3. for tensor arg whose requires_grad is False, its corresponding grad in + grad_args will be a zero tensor with the same shape. + """ + + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + n_primals = len(args) + + bw_fn = create_joint( + prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config + ) + + def flat_fn(*args_and_grad_outs): + primals = args_and_grad_outs[:n_primals] + tangents = args_and_grad_outs[n_primals:] + grad_args = bw_fn(primals, tangents)[1] + assert len(args) == len(grad_args) + + maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) + + return [ + ( + torch.zeros_like(arg) + if isinstance(arg, torch.Tensor) and grad is None + else maybe_clone(grad) + ) + for grad, arg in zip(grad_args, primals) + ] + + return flat_fn + + def get_dummy_aot_autograd_config(): from torch._functorch.aot_autograd import AOTConfig diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 0a12356de670..b23838306923 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -49,6 +49,7 @@ ) from torch._inductor.utils import clear_on_fresh_cache from torch._inductor.virtualized import V +from torch._utils_internal import log_triton_builds from torch.hub import _Faketqdm, tqdm from torch.utils._ordered_set import OrderedSet from torch.utils._triton import has_triton_package @@ -479,22 +480,29 @@ def get_result() -> CachingAutotuner: log_waitcounter=True, waitcounter_name_override="compile_triton", ): - start_ns = time_ns() - _set_triton_ptxas_path() - kernel = load_kernel() - kernel.set_compile_info(compile_id, is_backward) - kernel.precompile( - warm_cache_only=False, - static_triton_bundle_key=CompiledTritonKernels.key(source_code), - ) - elapsed_us = (time_ns() - start_ns) // 1000 - get_metrics_context().add_top_n( - "triton_kernel_compile_times_us", kernel_name, elapsed_us - ) - info = kernel.autotune_cache_info or {} - info["compile_time_us"] = elapsed_us - _add_triton_kernel_info(kernel_name, info) - return kernel + fail = None + try: + start_ns = time_ns() + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.set_compile_info(compile_id, is_backward) + kernel.precompile( + warm_cache_only=False, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + elapsed_us = (time_ns() - start_ns) // 1000 + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + return kernel + except Exception as e: + fail = str(e) + raise + finally: + log_triton_builds(fail=fail) def multi_kernel(self, *args, **kwargs) -> Any: from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c936fbe92c67..dfaabd1ef594 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -764,7 +764,7 @@ def update_workspace_size(self) -> None: return self.ensure_dll_loaded() unique_input_count = len( - {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + dict.fromkeys(meta.name for meta in self.input_tensor_meta) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) diff --git a/torch/_inductor/await_utils.py b/torch/_inductor/await_utils.py new file mode 100644 index 000000000000..a549674d5cd7 --- /dev/null +++ b/torch/_inductor/await_utils.py @@ -0,0 +1,176 @@ +import asyncio +import sys +import weakref +from asyncio import AbstractEventLoop, Future +from collections.abc import Awaitable, Coroutine, Generator, Iterator +from contextlib import contextmanager, ExitStack +from contextvars import Context +from typing import Any, Callable, Optional, Protocol, TypeVar + +from torch.utils._ordered_set import OrderedSet + + +T = TypeVar("T") +TCoro = Generator[Any, None, T] + +if sys.version_info >= (3, 11): + + class TaskFactory(Protocol): + def __call__( + self, + __loop: AbstractEventLoop, + __factory: Coroutine[None, None, object] | Generator[None, None, object], + __context: Context | None = None, + /, + ) -> asyncio.futures.Future[object]: ... + + TaskFactoryType = TaskFactory +else: + TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type] + + +def await_sync(awaitable: Awaitable[T]) -> T: + with get_loop() as loop: + return loop.run_until_complete(awaitable) + + +@contextmanager +def get_loop( + always_create_new_loop: bool = False, +) -> Iterator[AbstractEventLoop]: + try: + loop = asyncio.get_event_loop() + except RuntimeError as re: + if "There is no current event loop in thread" in str(re): + with _new_loop() as loop: + yield loop + return + else: + raise + + @contextmanager + def _restore_loop( + loop: asyncio.AbstractEventLoop, + ) -> Iterator[None]: + try: + yield + finally: + asyncio.set_event_loop(loop) + + @contextmanager + def _restore_running_loop() -> Iterator[None]: + loop_from_events = asyncio.events._get_running_loop() + asyncio.events._set_running_loop(None) + try: + yield + finally: + asyncio.events._set_running_loop(loop_from_events) + + with ExitStack() as stack: + if loop.is_running(): + stack.enter_context(_restore_running_loop()) + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type] + elif loop.is_closed(): + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + elif always_create_new_loop: + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + yield loop + + +@contextmanager +def _new_loop( + task_factory: Optional[TaskFactoryType] = None, +) -> Iterator[asyncio.AbstractEventLoop]: + loop = asyncio.new_event_loop() + tasks = _patch_loop(loop) + + if task_factory: + # pyre-ignore[6] + loop.set_task_factory(task_factory) # type: ignore[arg-type] + + asyncio.set_event_loop(loop) + try: + yield loop + finally: + try: + _cancel_all_tasks(loop, tasks) + finally: + asyncio.set_event_loop(None) + loop.close() + + +def _cancel_all_tasks( + loop: AbstractEventLoop, + tasks: OrderedSet[Future], # type: ignore[type-arg] +) -> None: + to_cancel = [task for task in tasks if not task.done()] + + if not to_cancel: + return + + # pyre-fixme[1001]: Awaitable assigned to `task` is never awaited. + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg] + tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg] + + task_factories: list[Optional[TaskFactoryType]] = [None] + + def _set_task_factory(factory: Optional[TaskFactoryType]) -> None: + task_factories[0] = factory + + def _get_task_factory() -> Optional[TaskFactoryType]: + return task_factories[0] + + def _safe_task_factory( + loop: AbstractEventLoop, + coro: TCoro, # type: ignore[type-arg] + *, + context: Context | None = None, + ) -> asyncio.Future: # type: ignore[valid-type, type-arg] + task_factory = task_factories[0] + if task_factory is None: + if sys.version_info >= (3, 11): + task = asyncio.Task(coro, loop=loop, context=context) + else: + task = asyncio.Task(coro, loop=loop) + # pyre-ignore[16]: `Task` has no attribute `_source_traceback`. + if task._source_traceback: # type: ignore[attr-defined] + del task._source_traceback[ # type: ignore[attr-defined] + -1 + ] # pragma: no cover # type: ignore[attr-defined] + else: + if sys.version_info >= (3, 11): + task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment] + else: + task = task_factory(loop, coro) # type: ignore[arg-type] + # `Union[Task[Any], Future[Any]]`. + tasks.add(task) + return task + + # pyre-ignore[6] + loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type] + # pyre-ignore[8] + loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment] + # pyre-ignore[8] + loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment] + + return tasks # type: ignore[return-value] diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 2ea260a3e956..aacb62c7a123 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -9,6 +9,7 @@ from . import config from .codecache import write_text +from .kernel_inputs import KernelInputs # noqa: TC001 from .metrics import get_metric_table, is_metric_table_enabled from .runtime.hints import DeviceProperties, ReductionHint from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse @@ -20,6 +21,7 @@ ROCmConfigHeuristic, XPUConfigHeuristic, ) +from .template_registry import get_template_heuristic from .virtualized import V @@ -71,58 +73,6 @@ def get_config_heuristics( else: return BaseConfigHeuristic() - # GEMM configs - def get_base_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - if config.max_autotune_gemm_search_space != "EXHAUSTIVE": - return mm_heuristics.get_mm_configs() - else: - return mm_heuristics.get_exhaustive_mm_configs() - - def get_extra_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_extra_mm_configs() - - def get_int8_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_int8_mm_configs() - - def get_mixed_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_mixed_mm_configs() - - def get_persistent_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_persistent_mm_configs() - - def get_scaled_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_scaled_mm_configs() - - def get_scaled_persistent_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_scaled_persistent_mm_configs() - - def get_mm_plus_mm_configs( - self, device_type: Optional[str] = "cuda" - ) -> partial[Generator[TritonConfig, None, None]]: - mm_heuristics = self.get_config_heuristics(device_type) - return mm_heuristics.get_mm_plus_mm_configs() - # Conv configs def get_conv_configs( self, device_type: Optional[str] = "cuda" @@ -131,6 +81,7 @@ def get_conv_configs( return conv_heuristics.get_conv_configs() # Flex attention configs + # TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism def get_flex_attention_fwd_configs( self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" ) -> list[Any]: @@ -149,6 +100,37 @@ def get_flex_decode_configs( flex_heuristics = self.get_config_heuristics(device_type) return flex_heuristics.get_flex_decode_configs(head_dim, dtype) + def get_mm_configs( + self, + kernel_inputs: KernelInputs, + layout: Any, + template_name: str, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get generator of template parameters for MM templates using template-specific heuristics. + + Args: + kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices + layout: Output layout + template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma") + op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm") + + Yields: + Template parameter dictionaries ready for maybe_append_choice + """ + input_tensors = kernel_inputs.nodes() + if len(input_tensors) < 2: + raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}") + + # Extract device_type from kernel_inputs + device_type = kernel_inputs.device_type + assert device_type is not None, "get_mm_configs requires a valid device type" + # Get the appropriate template-specific heuristic + heuristic = get_template_heuristic(template_name, device_type, op_name) + + yield from heuristic.get_template_configs(kernel_inputs, layout, op_name) + def triton_kernel_kwargs( self, kernel_cls: type[TritonKernel], @@ -214,18 +196,6 @@ def should_use_persistent_reduction( features.reduction_numel, threshold ) # type: ignore[arg-types] - @staticmethod - def want_no_x_dim(features: SIMDKernelFeatures) -> bool: - """ - Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. - So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. - Strangely this is faster than a [1, RBLOCK] block in some cases. - """ - return ( - features.get_reduction_hint() == ReductionHint.INNER - and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) - ) - @staticmethod def reduction_split_factor( device: torch.device, diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index e80a7ef9755d..65317648a02e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -30,6 +30,7 @@ from datetime import timedelta from functools import lru_cache, partial from pathlib import Path +from tempfile import _TemporaryFileWrapper from time import time, time_ns from types import ModuleType from typing import ( @@ -75,6 +76,7 @@ get_ld_and_objcopy, get_name_and_dir_from_output_file_path, normalize_path_separator, + run_asm_build_object, ) from torch._inductor.cpu_vec_isa import pick_vec_isa from torch._inductor.custom_graph_pass import ( @@ -358,6 +360,36 @@ def get_hash( raise AssertionError(f"Unknown hash type {hash_type}") +class WritableTempFile: + """ + Avoid "Permission denied error" on Windows: + with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: + # Not writable on Windows: + # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile + + Example: + with WritableTempFile("w", suffix=".gv") as temp_file: + tree.to_dotfile(temp_file.name) + """ + + def __init__( + self, mode: str = "w", *, encoding: Any = None, suffix: Any = None + ) -> None: + self.mode = mode + self.encoding = encoding + self.suffix = suffix + + def __enter__(self) -> _TemporaryFileWrapper[Any]: + self.temp_file = tempfile.NamedTemporaryFile( + self.mode, encoding=self.encoding, suffix=self.suffix, delete=False + ) + return self.temp_file + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.temp_file.close() + os.unlink(self.temp_file.name) + + def write( content: Union[str, bytes], extension: str, @@ -1621,36 +1653,6 @@ def get_keys(cls) -> KeysView[str]: return cls.cache.keys() -class WritableTempFile: - """ - Avoid "Permission denied error" on Windows: - with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: - # Not writable on Windows: - # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile - - Example: - with WritableTempFile("w", suffix=".gv") as temp_file: - tree.to_dotfile(temp_file.name) - """ - - def __init__( - self, mode: str = "w", *, encoding: Any = None, suffix: Any = None - ) -> None: - self.mode = mode - self.encoding = encoding - self.suffix = suffix - - def __enter__(self) -> Any: - self.temp_file = tempfile.NamedTemporaryFile( - self.mode, encoding=self.encoding, suffix=self.suffix, delete=False - ) - return self.temp_file - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.temp_file.close() - os.unlink(self.temp_file.name) - - class AotCodeCompiler: """ Compile AOT Inductor generated code. @@ -1709,12 +1711,6 @@ def compile( wrapper_code = "\n".join((wrapper_code, kernel_code)) kernel_code = "" - from .utils import aoti_model_name_from_config - - model_class_name = "" - if config.aot_inductor.compile_standalone: - model_class_name = aoti_model_name_from_config() - wrapper_key, wrapper_path = write( wrapper_code, "wrapper.cpp", @@ -1747,6 +1743,8 @@ def compile( "model.h", ) ) as f: + # model_name_for_generated_files is guaranteed to be non-empty when compile_standalone + model_class_name = config.aot_inductor.model_name_for_generated_files class_name = f"AOTInductorModel{model_class_name}" header_code = f.read() @@ -1761,7 +1759,7 @@ def compile( header_code, "h", specified_dir=specified_output_path, - key=f"{model_class_name}", + key=model_class_name, ) # Log the AOTInductor wrapper and kernel code, if needed. @@ -1862,8 +1860,9 @@ def _compile_consts(consts: bytes, platform: str) -> str: use_asm_build = False is_large_consts = len(consts) > 1024 + is_zero_size_consts = len(consts) == 0 - def format_consts_to_asm( + def format_consts_to_gnu_asm( consts: bytes, align_bytes: int, symbol_prefix: str, @@ -1885,7 +1884,7 @@ def format_consts_to_asm( consts_asm += f"\t.space {len(consts) - 8}\n" consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" - return consts_asm, "S" + return consts_asm, "weights.S" # Use c++ to convert consts to object file can support more compilers, such as msvc and icx. def format_consts_to_cpp( @@ -1910,21 +1909,73 @@ def format_consts_to_cpp( const_cpp += "\t\n" const_cpp += "};\t\n" const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" - return const_cpp, "cpp" + return const_cpp, "weights.cpp" + + def get_zero_consts_asm_code( + align_bytes: int, + symbol_prefix: str, + ) -> tuple[str, str]: + """ + This function handles zero-sized constants because the C++ standard prohibits zero-length arrays: + https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c + + On Windows (MSVC): + The compiler reports error C2466 for zero-sized arrays: + https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466 + Solution: Use assembly compilation to handle this case. + + Why not use Win32 assembly for all paths? + ml64 only supports alignment up to 16 bytes, which isn't optimal for performance. + + Cross-platform implementation: + Linux: Added '-pedantic' to disable zero-sized arrays in C++ compiler + Windows: MSVC naturally rejects zero-sized arrays by default + """ + if _IS_WINDOWS: + # Windows ml64 is max support align to 16, but it is no effect to zero size data. + asm_code = """ +option casemap:none +.data +?_binary_constants_bin_start@@3PAEA: +align 16 +?_binary_constants_bin_end@@3PAEA: +align 16 +public ?_binary_constants_bin_start@@3PAEA +public ?_binary_constants_bin_end@@3PAEA +end +""" + asm_ext = "asm" + else: + asm_code = f"\t.section\t{section_attr}\n" + asm_code += f"\t.balign {align_bytes}\n" + asm_code += ( + f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" + ) + asm_code += f"{symbol_prefix}_binary_constants_bin_start:\n" + asm_code += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" + asm_code += f"{symbol_prefix}_binary_constants_bin_end:\n" + asm_ext = "S" + return asm_code, asm_ext if use_asm_build: - consts_code, code_ext = format_consts_to_asm( + consts_code, code_ext = format_consts_to_gnu_asm( consts, ALIGN_BYTES, symbol_prefix, is_large_consts ) else: - consts_code, code_ext = format_consts_to_cpp( - consts, ALIGN_BYTES, symbol_prefix - ) + if is_zero_size_consts: + consts_code, code_ext = get_zero_consts_asm_code( + ALIGN_BYTES, symbol_prefix + ) + else: + consts_code, code_ext = format_consts_to_cpp( + consts, ALIGN_BYTES, symbol_prefix + ) _, consts_s = write( consts_code, code_ext, specified_dir=str(specified_sub_dir), + key=config.aot_inductor.model_name_for_generated_files, ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( @@ -1940,14 +1991,21 @@ def format_consts_to_cpp( BuildOption=object_build_options, ) consts_o = object_builder.get_target_file_path() - object_builder.build() + if use_asm_build is False and is_zero_size_consts: + run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent)) + else: + object_builder.build() if is_large_consts and use_asm_build: with open(consts_o, "r+b") as f: f.seek(0) hdr = f.read(1024) # Search for magic number and write the actual data over it - start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + start_idx = ( + hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + if sys.byteorder == "little" + else hdr.find(b"\x12\x34\x56\x78\x99\xab\xcd\xef") + ) assert start_idx != -1 f.seek(start_idx) pos = 0 @@ -2218,7 +2276,13 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: asm_files = [] if not _IS_WINDOWS: ld, objcopy = get_ld_and_objcopy(use_relative_path) + kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) for kernel_name, value in CudaKernelParamCache.cache.items(): + if kernel_name not in kernels: + # It is possible that CudaKernelParamCache contains more Triton kernels + # than what the current graph uses + continue + if asm_file := value["asm"]: asm_files.append(asm_file) @@ -2488,7 +2552,7 @@ def _get_cpp_prefix_header(device: str) -> Optional[str]: def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: """Given a device type (and optionally whether we're in AOT Inductor mode), returns the path to the cpp_wrapper header file to be precompiled.""" - base_device = device.split(":")[0] + base_device = device.split(":", maxsplit=1)[0] is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu" return ( "torch/csrc/inductor/" diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index dad5a281e10a..40ebbed13ddd 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -253,6 +253,9 @@ def get_stride(self) -> list[sympy.Expr]: def get_name(self) -> str: return self.outer_name + def get_is_pinned(self) -> bool: + return False + def get_inputs_that_alias_output(self) -> list[str]: return [] @@ -359,8 +362,8 @@ def cpp_device_ptr(self) -> str: def tma_descriptor_helpers(self) -> str: raise NotImplementedError - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: # optionally return (scratch definition, arg name) raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 12584284631b..1ee9d033d4f9 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -216,13 +216,17 @@ def reduction_combine( reduction_type, var, next_value, + helper_val=None, index: Optional[sympy.Symbol] = None, src_dtype=None, ): is_bool = src_dtype == torch.bool if reduction_type == "sum": - conjunction = "|" if is_bool else "+" - return f"{var} {conjunction} {next_value}" + if helper_val: + return f"cascade_sum_combine({next_value}, &{helper_val})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" if reduction_type == "prod": return f"{var} * {next_value}" if reduction_type == "xor_sum": @@ -362,6 +366,31 @@ def replace_acc_name(buffer: IndentedBuffer, name: str, new_name: str): buffer._lines[i] = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line) +def replace_cascade_sum_with_add(buffer: IndentedBuffer): + """ + Replaces `acc = cascade_sum_combine(value, ...)` with `acc = acc + value;` + """ + + pattern = r"(.*?)\s*=\s*cascade_sum_combine\(([^,]+),.*?\);" + for i, line in enumerate(buffer._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + content = line.line if isinstance(line, DeferredLine) else line + match = re.search(pattern, content) + if match: + acc, value = match.groups() + new_content = re.sub(pattern, f"{acc} = {acc} + {value};", content) + if isinstance(line, DeferredLine): + line.line = new_content + else: + buffer._lines[i] = new_content + + @functools.lru_cache def stride_at(index: sympy.Expr, var: sympy.Symbol): if not index.has(var): @@ -1874,6 +1903,15 @@ def index_expr(expr, dtype): class CppKernel(Kernel): + """ + Base class for C++ kernel code generation in PyTorch Inductor. + This class is responsible for generating C++ code from the intermediate representation. + + Args: + args: Kernel arguments used for code generation + num_threads: Number of threads for parallel execution + """ + overrides = CppOverrides # type: ignore[assignment] sexpr = cexpr newvar_prefix = "auto " @@ -1911,6 +1949,9 @@ def __init__(self, args, num_threads): self.welford_helper_cse = CSE( self.newvar_prefix, self.suffix, name_prefix="welford_helper" ) + self.cascade_helper_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="cascade_helper" + ) self.preloads = IndentedBuffer() self.poststores = IndentedBuffer() self.num_threads = num_threads # num_threads the kernel specialized for @@ -2126,6 +2167,123 @@ def finalize_reduction_prefix(self, size: Optional[int] = None): for gen_fn in self.reduction_prefix_generators: self.reduction_prefix.splice(gen_fn(size)) + def need_use_acc_helper(self, reduction_type, dtype, use_scalar): + # Check if we need accumulate helper for the reduction operation. + # using accumulate helper generates the necessary code to improve precision for + # sum and welford + # Note: using helper has non-negligible impact on performance + + # keep the original behavior for welford_reduce + # acc helper is not used for scalar welford_reduce + if reduction_type == "welford_reduce": + return not use_scalar + + # TODO add supports for more data types when needed + if reduction_type == "sum" and dtype == torch.float: + assert self.call_ranges is not None + reduction_size = functools.reduce( + operator.mul, self.call_ranges[self.reduction_depth :] + ) + if config.cpp.dynamic_threads: + # If dynamic threads, to be conservative, + # use reduction_size as the range size + rt_size = reduction_size + else: + rt_size = CeilDiv(reduction_size, parallel_num_threads()) + + # chunk size to balance accuracy and performance + chunk_size = 2**20 + + # use acc helper If cannot get size_hint + try: + rt_size_hint = V.graph.sizevars.size_hint(rt_size) + except Exception: + return True + + if rt_size_hint > chunk_size: + # use helper if the reduction size is too large + V.graph.sizevars.check_lt(chunk_size, rt_size) + return True + else: + V.graph.sizevars.check_leq(rt_size, chunk_size) + return False + + def _acc_helper_init( + self, + reduction_type, + helper_val, + helper_range, + dtype, + num_threads=None, + use_scalar=False, + ): + num_range_thread = ( + CeilDiv(helper_range, num_threads) if num_threads else helper_range + ) + num_range_thread_expr = cexpr_index(num_range_thread) + assert reduction_type in ["welford_reduce", "sum"] + chunk_size = 4096 if reduction_type == "welford_reduce" else 2**20 + num_chunks = CeilDiv(num_range_thread, chunk_size) + helper_type = ( + "WelfordHelper" + if reduction_type == "welford_reduce" + else "CascadeSumHelper" + ) + if use_scalar: + h_type = DTYPE_TO_CPP[dtype] + else: + h_type = ( + self._get_vec_type(dtype) + if hasattr(self, "_get_vec_type") + else DTYPE_TO_CPP[dtype] + ) + helper_init_line = ( + f"{helper_type}<{h_type}, {chunk_size}> {helper_val}" + f"(" + f"{num_range_thread_expr}" + f");" + ) + if reduction_type == "sum": + return helper_init_line + if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1: + # When the number of chunks <= 1, there is no need to use cascade summation to improve + # reduction accuracy. We can initialize a static WelfordHelper to improve performance. + return f"static {helper_init_line}" + else: + return helper_init_line + + def _use_acc_helper( + self, reduction_type, acc, helper_val, helper_range, dtype, use_scalar=False + ): + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + self.non_parallel_reduction_prefix.writeline( + self._acc_helper_init( + reduction_type, helper_val, helper_range, dtype, None, use_scalar + ) + ) + self.local_reduction_init.writeline( + self._acc_helper_init( + reduction_type, helper_val, helper_range, dtype, num_threads, use_scalar + ) + ) + result = acc if use_scalar else f"{acc}_vec" + if reduction_type == "welford_reduce": + self.non_parallel_reduction_suffix.writeline( + f"{result} = welford_combine({result}, &{helper_val});" + ) + self.local_reduction_stores.writeline( + f"{result}_local = welford_combine({result}_local, &{helper_val});" + ) + else: + self.non_parallel_reduction_suffix.writeline( + f"{result} = cascade_sum_final(&{helper_val});" + ) + self.local_reduction_stores.writeline( + f"{result}_local = cascade_sum_final(&{helper_val});" + ) + def reduction(self, dtype, src_dtype, reduction_type, value): argmax_or_argmin = reduction_type in ("argmax", "argmin") reduction_key = src_dtype, reduction_type, value @@ -2144,13 +2302,36 @@ def reduction(self, dtype, src_dtype, reduction_type, value): acc, acc_type, reduction_type, init_dtype, reduction_init ) ) - assert self.reduction_depth is not None - index = self.itervars[self.reduction_depth] - for i in range(self.reduction_depth + 1, len(self.itervars)): - index = index * self.ranges[i] + self.itervars[i] - self.stores.writeline( - f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" - ) + + if self.need_use_acc_helper(reduction_type, dtype, True): + # use cascade_helper for vec kernel + reduction_size = functools.reduce( + operator.mul, self.ranges[self.reduction_depth :] + ) + helper_val = self.cascade_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + # rename the helper variable to distinguish it from vectorized version + scalar_helper_val = f"scalar_{helper_val}" + self._use_acc_helper( + reduction_type, + acc, + scalar_helper_val, + reduction_size, + dtype, + use_scalar=True, + ) + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, scalar_helper_val)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index=index)};" + ) self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) result = reduction_project(reduction_type, acc) @@ -2787,6 +2968,22 @@ def store(self, name, index, value, mode=None): raise NotImplementedError(f"store mode={mode}") def reduction(self, dtype, src_dtype, reduction_type, value): + """ + Perform vectorized reduction operation. + + This method handles vectorized reduction for different reduction types. + It manages special cases for low-precision floating point types and + employs precision improvement techniques for certain reduction operations. + + Args: + dtype: The output data type for the reduction result + src_dtype: The source data type of the input value + reduction_type: Type of reduction operation (sum, min, max, etc.) + value: The input value to reduce + + Returns: + The result of the reduction operation + """ # Note: For argmax and argmin on bool type, we always convert bool to float. # Fix issue: https://github.com/pytorch/pytorch/issues/143568 assert reduction_type in VECTORIZABLE_RTYPES @@ -2812,6 +3009,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) assert isinstance(acc, CppCSEVariable) acc_vec = f"{acc}_vec" + masked_acc = f"masked_{acc}" masked_acc_vec = f"masked_{acc_vec}" self.reduction_var_names += [f"{acc}", acc_vec, masked_acc_vec] self.is_reduction = True @@ -2829,7 +3027,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.reduction_init_vec, ) ) - if reduction_type == "welford_reduce": + + use_acc_helper = self.need_use_acc_helper(reduction_type, dtype, False) + if use_acc_helper: # use masked acc_vec for tail vec kernel self.reduction_prefix_generators.append( self._gen_reduction_prefix( @@ -2841,16 +3041,21 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) ) - # use welford_helper for vec kernel + # use welford_helper/cascade_helper for vec kernel assert self.reduction_depth is not None reduction_size = functools.reduce( operator.mul, self.ranges[self.reduction_depth :] ) - welford_helper_val = self.welford_helper_cse.generate( - self.compute, f"reduction {reduction_key}", write=False - ) - masked_welford_helper_val = f"masked_{welford_helper_val}" - welford_helper_vec_range = ( + if reduction_type == "welford_reduce": + helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + else: + helper_val = self.cascade_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + masked_helper_val = f"masked_{helper_val}" + helper_vec_range = ( ( FloorDiv(reduction_size, self.ranges[self.tiling_idx]) * FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) @@ -2860,7 +3065,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): if FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) else sympy.Integer(0) ) - masked_welford_helper_vec_range = ( + masked_helper_vec_range = ( ( FloorDiv(reduction_size, self.ranges[self.tiling_idx]) if self.tiling_idx >= self.reduction_depth @@ -2869,24 +3074,41 @@ def reduction(self, dtype, src_dtype, reduction_type, value): if self.ranges[self.tiling_idx] % self.tiling_factor else sympy.Integer(0) ) - self._use_welford_helper( - acc_vec, welford_helper_val, welford_helper_vec_range, dtype + # scalar helper for scalar sum is also needed when vec kernel is included + # Note: is it different from welford reduction as welford reduction of scalar version + # does not need helper, and the helper needs the information of reduction size to initialize + if reduction_type == "sum": + scalar_helper_val = f"scalar_{helper_val}" + self._use_acc_helper( + reduction_type, + acc, + scalar_helper_val, + reduction_size, + dtype, + use_scalar=True, + ) + self._use_acc_helper( + reduction_type, acc, helper_val, helper_vec_range, dtype ) - self._use_welford_helper( - masked_acc_vec, - masked_welford_helper_val, - masked_welford_helper_vec_range, + self._use_acc_helper( + reduction_type, + masked_acc, + masked_helper_val, + masked_helper_vec_range, dtype, ) # use masked acc_vec for tail vec kernel acc_vec_ = masked_acc_vec if self.tail_size else acc_vec - welford_helper_val_ = ( - masked_welford_helper_val if self.tail_size else welford_helper_val - ) - self.stores.writeline( - f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, welford_helper_val_)};" - ) + helper_val_ = masked_helper_val if self.tail_size else helper_val + if reduction_type == "sum": + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};" + ) + else: + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};" + ) else: assert self.reduction_depth is not None index = self.itervars[self.reduction_depth] @@ -2917,7 +3139,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): reduction_combine_fn=reduction_combine, reduction_init_fn=reduction_init, ) - if reduction_type == "welford_reduce": + if use_acc_helper: # use masked acc_vec for tail vec kernel self._gen_parallel_reduction_buffers( masked_acc_vec, @@ -2964,7 +3186,11 @@ def reduction(self, dtype, src_dtype, reduction_type, value): vec_dtype = torch.float if is_bool else dtype vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" - next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + result_vec = f"{acc_vec}" + if use_acc_helper: + assert reduction_type == "sum" + result_vec = f"{acc_vec} + {masked_acc_vec}" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {result_vec})" self.reduction_suffix.writeline( f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" @@ -2977,6 +3203,12 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.reduction_suffix.writeline( f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" ) + elif use_acc_helper: + assert reduction_type == "sum" + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {tmpvar} + {masked_tmpvar};" + ) result = reduction_project(reduction_type, tmpvar) self.reduction_cse.reduction_cache[reduction_key] = result @@ -2986,11 +3218,10 @@ def store_reduction(self, name, index, value): index = self.rename_indexing(index) var = self.args.output(name) out_dtype = V.graph.get_dtype(name) - dtype = ( - (out_dtype if out_dtype == torch.double else torch.float) - if out_dtype.is_floating_point - else torch.int64 - ) + if out_dtype.is_floating_point and out_dtype != torch.double: + dtype = torch.float + else: + dtype = out_dtype out_num_vectors = V.kernel._get_num_vectors(out_dtype) src_num_vectors = V.kernel._get_num_vectors(dtype) code = IndentedBuffer() @@ -3102,59 +3333,12 @@ def reduction_acc_type_vec(self, reduction_type, dtype): return f"{self._get_mask_type()}" return vec_type - def _welford_helper_init( - self, welford_helper_val, welford_helper_vec_range, dtype, num_threads=None - ): - vec_num_range_thread = ( - CeilDiv(welford_helper_vec_range, num_threads) - if num_threads - else welford_helper_vec_range - ) - vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) - chunk_size = 4096 - num_chunks = CeilDiv(vec_num_range_thread, chunk_size) - welford_helper_init_line = ( - f"WelfordHelper<{self._get_vec_type(dtype)}, {chunk_size}> {welford_helper_val}" - f"(" - f"{vec_num_range_thread_expr}" - f");" - ) - if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1: - # When the number of chunks <= 1, there is no need to use cascade summation to improve - # reduction accuracy. We can initialize a static WelfordHelper to improve performance. - return f"static {welford_helper_init_line}" - else: - return welford_helper_init_line - - def _use_welford_helper( - self, acc_vec, welford_helper_val, welford_helper_vec_range, dtype - ): - num_threads = ( - "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() - ) - self.non_parallel_reduction_prefix.writeline( - self._welford_helper_init( - welford_helper_val, welford_helper_vec_range, dtype - ) - ) - self.local_reduction_init.writeline( - self._welford_helper_init( - welford_helper_val, welford_helper_vec_range, dtype, num_threads - ) - ) - self.non_parallel_reduction_suffix.writeline( - f"{acc_vec} = welford_combine({acc_vec}, &{welford_helper_val});" - ) - self.local_reduction_stores.writeline( - f"{acc_vec}_local = welford_combine({acc_vec}_local, &{welford_helper_val});" - ) - def reduction_combine_vec( self, reduction_type, var, next_value, - welford_helper_val=None, + helper_val=None, index: Optional[sympy.Symbol] = None, horizontal_reduction: Optional[bool] = None, src_dtype: Optional[torch.dtype] = torch.float32, @@ -3179,11 +3363,17 @@ def reduction_combine_vec( else f"at::vec::minimum({var}, {next_value})" ) elif reduction_type == "sum": - if self.tail_size: - return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + if helper_val: + if self.tail_size: + return f"cascade_sum_combine({next_value}, {cexpr_index(self.tail_size)}, &{helper_val})" + else: + return f"cascade_sum_combine({next_value}, &{helper_val})" else: - conjunction = "|" if is_bool else "+" - return f"{var} {conjunction} {next_value}" + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" elif reduction_type == "prod": if self.tail_size: return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" @@ -3195,13 +3385,11 @@ def reduction_combine_vec( else: return f"{var} ^ {next_value}" elif reduction_type == "welford_reduce": - if welford_helper_val: + if helper_val: if self.tail_size: - return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{welford_helper_val})" + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{helper_val})" else: - return ( - f"welford_combine({var}, {next_value}, &{welford_helper_val})" - ) + return f"welford_combine({var}, {next_value}, &{helper_val})" else: if self.tail_size: return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" @@ -4330,9 +4518,12 @@ def gen_body(self, code: Optional[BracesBuffer] = None): def aggregate_reduction_buffers( self, inner_loop_reduction_outer_not: bool, outer_loop: Optional["LoopLevel"] ): - # CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. - # Here, we decide how to aggregate them together and place new reduction buffers - # under CppKernelProxy. + """ + CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. + Here, we decide how to aggregate them together and place new reduction buffers + under CppKernelProxy. + """ + def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): assert len(self.kernels) >= 2 main_loop_kernel = self.kernels[0] @@ -4376,6 +4567,9 @@ def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): replace_acc_name( tail_loop_kernel.reduction_suffix, name, new_name ) + # If tail loop kernel is a scalar kernel, use direct sum instead of cascade_sum_combine + # as the reduction vars are extended: tmp_acc -> tmp_acc_arr[]. + replace_cascade_sum_with_add(tail_loop_kernel.stores) suffix_buf.splice( move_code_under_inner_loop( tail_loop_kernel.reduction_suffix, diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 7dd5fdc288ac..794a971adf08 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -22,13 +22,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, cpp_builder, ir -from ..utils import ( - _align, - aoti_model_name_from_config, - DeferredLineBase, - LineContext, - normalize_name, -) +from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import get_device_op_overrides, IndentedBuffer, Kernel @@ -64,11 +58,15 @@ def __init__(self): self.device = "cpu" # must be initialized prior to calling super().__init__() self.included_devices: OrderedSet[str] = OrderedSet() - self.model_class_name_suffix = "" - if config.aot_inductor.compile_standalone: - self.model_class_name_suffix = aoti_model_name_from_config() + self.model_class_name_suffix = ( + config.aot_inductor.model_name_for_generated_files + if config.aot_inductor.compile_standalone + else "" + ) self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}" + super().__init__() + self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " self.ending = ";" @@ -518,6 +516,8 @@ def gen_check(handle_kind, idx, name, tensor): def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: + self.codegen_additional_funcs() + if V.graph.const_module: self.header.splice(V.graph.const_module.wrapper_code.header) @@ -674,6 +674,9 @@ def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" ) + def codegen_additional_funcs(self): + pass + def codegen_model_kernels(self): self.prefix.writeline("namespace {") @@ -1275,7 +1278,17 @@ def generate_c_shim_extern_kernel_alloc( extern_kernel.get_kernel_name(), args, device ) - if not is_inplace: + if extern_kernel.python_kernel_name in ( + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.wait_tensor.default", + ): + # all_reduce_ is an inplace op and its returned tensor is not used anywhere. + # wait_tensor returns its input without any modification and the returned tensor is not used anywhere. + # In both cases, we can immediately delete the returned AtenTensorHandle to reduce its lifetime. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object({output_handle_name}));" + ) + elif not is_inplace: self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): @@ -1562,10 +1575,11 @@ def make_buffer_allocation(self, buffer): buffer.get_size(), buffer.get_stride(), V.graph.get_allocation_size(buffer), + buffer.get_is_pinned(), ) def make_allocation( - self, name, device, dtype, shape, stride, allocation_shape=None + self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False ): if allocation_shape is None: allocation_shape = shape @@ -1617,8 +1631,9 @@ def make_allocation( ] self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + pinned_str = "_pinned" if is_pinned else "" self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" ) if allocation_size != size: @@ -1636,7 +1651,9 @@ def make_allocation( return f"RAIIAtenTensorHandle {name}({handle_name});" - def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + def codegen_alloc_from_pool( + self, name, offset, dtype, shape, stride + ) -> tuple[str, list[str]]: size = self.codegen_shape_tuple(shape) stride = self.codegen_shape_tuple(stride) tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" @@ -1653,11 +1670,14 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: ), f"&{tmp_name}", ] - self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" - ) - return f"RAIIAtenTensorHandle({tmp_name})" + # We return the lines instead of writing here because writing here is bug prune. + # If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle + # as well, otherwise you get memory leaks + allocations_to_write = [ + f"AtenTensorHandle {tmp_name};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));", + ] + return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write def codegen_reinterpret_view( self, diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index eb3390cbc39c..fd145ece606d 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -565,10 +565,18 @@ def make_buffer_allocation(self, buffer): buffer.get_size(), buffer.get_stride(), buffer if self.can_stack_allocate_buffer(buffer) else None, + buffer.get_is_pinned(), ) def make_allocation( - self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + self, + name, + device, + dtype, + shape, + stride, + buffer_if_can_stack_allocate=None, + is_pinned=False, ): orig_stride = stride device_str = self.codegen_device(device) @@ -615,8 +623,9 @@ def make_allocation( ] self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + pinned_str = "_pinned" if is_pinned else "" self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" ) return f"RAIIAtenTensorHandle {name}({name}_handle);" diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 430511ce4ebf..6bbbab859900 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -211,12 +211,17 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): ] arg_types = [arg_type_loookup[name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args] + scratch_spaces = { + name: params[name] + for name in ["global_scratch", "profile_scratch"] + if params.get(name, None) is not None + } call_args_str = wrapper.generate_args_decl( prefix, call_args, arg_types, arg_signatures, - workspace_size=params.get("global_scratch") or 0, + scratch_spaces=scratch_spaces, ) prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") launch_kernel_args = [ @@ -454,7 +459,7 @@ def generate_args_decl( arg_types, arg_signatures, is_triton_kernel=True, - workspace_size=0, + scratch_spaces: Optional[dict[str, int]] = None, ): """ Generates any declarations of args to pass into a kernel call, and then returns the arg names. @@ -572,22 +577,26 @@ def process_args(arg, arg_type, arg_signature=None): ): process_args(arg, arg_type, arg_signature) - if ( - is_triton_kernel - and ( - global_scratch := self.device_codegen.cpp_global_scratch( - next(self.arg_var_id), - workspace=TritonScratchWorkspace( - size=workspace_size, - generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), - ), + for scratch_name, workspace_size in (scratch_spaces or {}).items(): + if ( + is_triton_kernel + and ( + scratch := self.device_codegen.cpp_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=( + lambda: self.codegen_dtype(torch.uint8) + ), + ), + prefix=scratch_name, + ) ) - ) - is not None - ): - global_scratch_def, global_scratch_var = global_scratch - code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) - new_args.append(f"&{global_scratch_var}") + is not None + ): + scratch_def, scratch_var = scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def]) + new_args.append(f"&{scratch_var}") return ", ".join(new_args) diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index b953927f52be..aea4470f1c96 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -9,7 +9,7 @@ from ..virtualized import V from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_gpu import CppWrapperGpu -from .wrapper import PythonWrapperCodegen +from .wrapper import KernelCallLine, PythonWrapperCodegen class CppWrapperMps(CppWrapperGpu): @@ -47,14 +47,12 @@ def _generate_kernel_call_helper( """ Generates MPS kernel call code. It should look something like: ``` - auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel"); - auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get()); - mps_lib_0_func->runCommandBlock([&] { - mps_lib_0_func->startEncoding(); - aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0); - aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1); + get_mps_lib_0()->runCommandBlock([&] { + get_mps_lib_0()->startEncoding(); + aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0); + aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 1, arg0_1); ... - mps_lib_0_func->dispatch(9); + get_mps_lib_0()->dispatch(9); }); ``` """ @@ -81,11 +79,11 @@ def _generate_kernel_call_helper( for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): if isinstance(arg_type, torch.dtype): new_args.append( - f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});" + f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});" ) elif arg_type in (int, sympy.core.symbol.Symbol): new_args.append( - f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});" + f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});" ) else: raise NotImplementedError( @@ -96,9 +94,11 @@ def _generate_kernel_call_helper( if threads is None: raise NotImplementedError("No threads or group_size provided") elif group_size is None: - new_args.append(f"{kernel_name}->dispatch({threads});\n") + new_args.append(f"get_{kernel_name}()->dispatch({threads});\n") else: - new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n") + new_args.append( + f"get_{kernel_name}()->dispatch({threads}, {group_size});\n" + ) # debug printer related logic for cpp kernel type. debug_printer_manager = V.graph.wrapper_code.debug_printer @@ -113,20 +113,11 @@ def _generate_kernel_call_helper( self.write_mps_kernel_call(kernel_name, new_args) def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: - # Only add handle definition if the kernel is not already used - lib_name = name[: -len("_func")] - if name not in self._used_kernel_names: - self._used_kernel_names.add(name) - - self.writeline( - f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");' - ) - self.writeline( - f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());" - ) - - self.writeline(f"{name}->runCommandBlock([&] {{") - self.writeline(f" {name}->startEncoding();") + # Initialization of the kernel function and kernel function handle + # variables have already been done at the beginning, which was + # codegen-ed in `codegen_mps_func_init` + self.writeline(f"get_{name}()->runCommandBlock([&] {{") + self.writeline(f" get_{name}()->startEncoding();") for call_arg in call_args: self.writeline(f" {call_arg}") self.writeline("});") @@ -138,3 +129,52 @@ def get_device_include_path(device: str) -> str: "#include \n" "#include " ) + + def codegen_additional_funcs(self) -> None: + """ + We want to codegen the mps kernel function variable initializations + ahead of time. This is so that if we reuse kernels within subgraphs, we + don't need to worry about the scope in which we're initializing the + variables. Instead we will just initialize the variables all at the top + level. + + The kernel function variable initializations should look something like: + ``` + const std::shared_ptr get_mps_lib_0() { + static const auto func = mps_lib_0.getKernelFunction("generated_kernel"); + return func; + } + AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { + static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get()); + return handle; + } + ``` + """ + + for line in self.lines: + if not isinstance(line, KernelCallLine): + continue + if line.device.type != "mps": + continue + + # Only add handle definition once + if line.kernel_name not in self._used_kernel_names: + self._used_kernel_names.add(line.kernel_name) + + self.prefix.writeline( + f"const std::shared_ptr get_{line.kernel_name}() {{" + ) + self.prefix.writeline( + f' static const auto func = {line.kernel_name}.getKernelFunction("generated_kernel");' + ) + self.prefix.writeline(" return func;") + self.prefix.writeline("}") + + self.prefix.writeline( + f"AOTIMetalKernelFunctionHandle get_{line.kernel_name}_handle() {{" + ) + self.prefix.writeline( + f" static const auto handle = AOTIMetalKernelFunctionHandle(get_{line.kernel_name}().get());" + ) + self.prefix.writeline(" return handle;") + self.prefix.writeline("}") diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 224f0d2a423d..0a9c6b0ca4e5 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -177,6 +177,9 @@ def get_ld(node) -> Union[Expr, int]: def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: return [*self.get_layout_args(), *self.size_args] + def get_offset_args(self) -> list[Expr]: + return [node.get_layout().offset for node in self.named_nodes.values()] + @staticmethod def find_ld_idx(node: IRNode) -> int: strides = node.get_stride() @@ -264,6 +267,7 @@ def def_kernel( In this case, the `input_reorder` would be [2, 0, 1]. additional_size_args: Additional size arguments for epilogue inputs """ + # NB: name order matters here, it's used to match up offsets names = [x.strip() for x in names_str.strip().split(",")] if len(inputs) + len(outputs) != len(names): raise RuntimeError( @@ -285,6 +289,7 @@ def def_kernel( free_symbols: OrderedSet[Expr] = OrderedSet() for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): if node is not None: + # NB: named nodes must be populated in the order of names self.named_nodes[name] = node self.args.output_buffers[node.get_name()] = name @@ -306,14 +311,17 @@ def def_kernel( size_vars.extend(str(s) for s in free_symbols) self.size_args.extend(free_symbols) size_args = [f"const int {s}" for s in size_vars] - + offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()] runtime_arg_decls = ",".join( [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] ) if runtime_arg_decls: runtime_arg_decls += ", " - signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + signature = ( + f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\ + {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + ) self.signature = signature return signature @@ -346,10 +354,13 @@ def call_kernel( _, call_args, _, arg_types = self.args.python_argdefs() dynamic_shape_args = self.get_dynamic_shape_args() + offset_args = self.get_offset_args() call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + call_args.extend(offset_args) # type: ignore[arg-type] for arg in self.runtime_arg_values: - call_args.append(arg) - arg_types.extend("int" for _ in dynamic_shape_args) + call_args.append(str(arg)) + arg_types.extend("const int" for _ in dynamic_shape_args) + arg_types.extend("const int" for _ in offset_args) for arg in self.runtime_arg_info: arg_types.append(arg.ty) # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar @@ -425,15 +436,6 @@ def max_valid_index(self, node: IRNode, default=-1): max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] return max_valid_offset - def offset(self, node: IRNode) -> str: - """ - Generates code which represents offset of a given node. - """ - - if node is None: - return "0" - return str(node.get_layout().offset) # type: ignore[union-attr] - def ptr(self, node: IRNode) -> str: """ Generates code which represents pointer of a given node. @@ -444,8 +446,7 @@ def ptr(self, node: IRNode) -> str: arg_name = self.arg_name(node) if arg_name is None: return "nullptr" - offset = self.offset(node) - return arg_name if offset == "0" else f"{arg_name} + {offset}" + return f"{arg_name} + {arg_name}_offset" def size( self, diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index cc03ccbdda86..4aa0aeb46e07 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -43,7 +43,7 @@ class ArgInfo: class CUDATemplate(KernelTemplate): index_counter = itertools.count() # dict of cache key to (code, size_args) - code_cache: dict[str, tuple[str, tuple[int, ...]]] = {} + code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {} cache_clear = staticmethod(code_cache.clear) def __init__( @@ -113,8 +113,12 @@ def generate_code_and_args( key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) if key is not None and key in self.code_cache: - code, size_args = self.code_cache[key] - extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + code, size_args, offset_args = self.code_cache[key] + extra_args = tuple( + list(size_args) + + list(offset_args) + + list(self.get_runtime_arg_values(**kwargs)) + ) return code, extra_args kernel_name = str(Placeholder.KERNEL_NAME) @@ -148,12 +152,15 @@ def generate_code_and_args( ) V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args()) if key is not None: - self.code_cache[key] = code, size_args + self.code_cache[key] = code, size_args, offset_args # extra args has runtime params, which shouldn't be cached - extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + extra_args = tuple( + list(size_args) + list(offset_args) + self.get_runtime_arg_values(**kwargs) + ) return code, extra_args diff --git a/torch/distributed/numa/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py similarity index 100% rename from torch/distributed/numa/__init__.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py index e42a13534e6f..605b93dff592 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -255,7 +255,8 @@ def render_stride(x: int) -> str: return f"{{{', '.join([render_stride(x) for x in stride])}}}" elif issubclass(arg_ty, ctypes.c_void_p): - return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {arg_renames.new_name(node.get_name())}" + name = arg_renames.new_name(node.get_name()) + return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) ({name} + {name}_offset)" elif ( arg_ty in _CUTLASS_C_DTYPES ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 0ba067742294..147515e0decf 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -4,7 +4,6 @@ import torch -from ...utils import triton_version_uses_attrs_dict from ..common import ( DeviceOpOverrides, register_device_op_overrides, @@ -333,34 +332,33 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "CUdeviceptr" - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: - if triton_version_uses_attrs_dict(): - var_name = f"global_scratch_{idx}" - if workspace.size > 0: - size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" - stride_array = f"int64_t {var_name}_stride[] = {{1}};" - device_type = "cached_torch_device_type_cuda" - device_idx = "device_idx_" - - return ( - [ - f"{size_array}", - f"{stride_array}", - f"AtenTensorHandle {var_name}_handle;", - ( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " - f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" - ), - f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", - f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", - ], - var_name, - ) - else: - return [f"CUdeviceptr {var_name} = 0;"], var_name - return None + prefix = f"{prefix}_" if prefix else "" + var_name = f"{prefix}scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 06bf2653e02a..0e11bc100002 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -12,6 +12,7 @@ import torch.utils._pytree as pytree from torch._inductor.autotune_process import TensorMeta from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.select_algorithm import create_inputs_key @@ -593,11 +594,14 @@ def _add_cutlass_gemm_choices( ) if len(ops) == 0: - input_layouts = [node.get_layout() for node in input_nodes] - input_strides = [node.get_stride() for node in input_nodes] - output_layout = layout - warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 - log.warning(warning_msg) + log.info( + "No suitable Cutlass GEMM configs found, fallbacks used " + "( len(ops)=%d, output_layout=%s, input_layouts=%s, input_strides=%s )", + len(ops), + layout, + [node.get_layout() for node in input_nodes], + [node.get_stride() for node in input_nodes], + ) log.debug( "Added %d Cutlass gemm configs.", len(ops), @@ -919,6 +923,14 @@ def filter_op( ) return None + # only use stream k for static shape + if op.tile_scheduler.name == "StreamK": + static_shape = PythonWrapperCodegen.statically_known_list_of_ints_or_none( + tuple(X.get_size()) + tuple(W.get_size()) + ) + if not static_shape: + return None + # Update op. op = copy.deepcopy(op) @@ -1308,7 +1320,7 @@ def test_call_statement( f"(({arg_type}){arg_name}_data.get())" for arg_type, arg_name in zip(arg_types, arg_names) ] - return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 def _render_evt( self, diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 8efec7eeca9f..12d7500975e5 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -10,6 +10,7 @@ import sympy import torch +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from .. import config @@ -142,6 +143,17 @@ class Allocation(AllocationTreeNode): allocated: bool = False pool: Optional[AllocationPool] = None offset: Optional[sympy.Expr] = None + earliest_available: Optional[float] = None + + def __post_init__(self) -> None: + has_unbacked_sym = False + for s in self.node.get_layout().size: + if free_unbacked_symbols(s): + has_unbacked_sym = True + break + + if has_unbacked_sym: + self.earliest_available = self.get_live_ranges().begin @property def device(self): @@ -186,6 +198,9 @@ def __repr__(self): f"offset={self.offset})" ) + def get_earliest_available(self): + return self.earliest_available + @dataclasses.dataclass class Empty(AllocationTreeNode): @@ -377,14 +392,26 @@ class AllocationPool: names_to_del: list[str] = dataclasses.field(default_factory=list) creation_cache: dict[str, str] = dataclasses.field(default_factory=dict) + def __post_init__(self) -> None: + for block in self.root.allocations: + if isinstance(block, Allocation): + self.update_restrict_live_range(block) + def allocate(self, block: Allocation, is_last: bool): - if self.restrict_live_range and not self.restrict_live_range.contains( - block.live_range + if ( + self.restrict_live_range is not None + and not self.restrict_live_range.contains(block.live_range) ): return False + block_earliest_available = block.get_earliest_available() + pool_begin = self.root.get_live_ranges().begin + if block_earliest_available and block_earliest_available > pool_begin: + return False + is_last = self.can_expand and is_last if self.root.allocate(block, is_last): + self.update_restrict_live_range(block) return True if is_last: @@ -392,9 +419,22 @@ def allocate(self, block: Allocation, is_last: bool): return False + def update_restrict_live_range(self, block: Allocation): + if block_earliest_available := block.get_earliest_available(): + if self.restrict_live_range is None: + self.restrict_live_range = LiveRange( + block_earliest_available, float("inf") + ) + else: + self.restrict_live_range = LiveRange( + min(self.restrict_live_range.begin, block_earliest_available), + self.restrict_live_range.end, + ) + def allocate_at_end(self, block): block.mark_allocated() self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + self.update_restrict_live_range(block) return True def finalize(self, name): @@ -408,7 +448,6 @@ def codegen_create(self, wrapper, code: IndentedBuffer): nbytes = self.root.get_symbolic_size() for block in self.root.allocations: if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): - # optimization: fuse first allocation and pool creation node = block.node code.writeline( wrapper.make_allocation( @@ -419,7 +458,6 @@ def codegen_create(self, wrapper, code: IndentedBuffer): stride=tuple(node.get_stride()), ) ) - self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name return else: code.writeline( @@ -577,7 +615,10 @@ def codegen(self, code: IndentedBuffer): pool.codegen_create(self.wrapper, code) pool.names_to_del.extend(self.group.names) - alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool( + self.wrapper + ) + code.writelines(allocation_lines_to_write) if alloc_from_pool in pool.creation_cache: code.writeline( self.wrapper.make_tensor_alias( diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 624f1b710705..8b59db126f05 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -535,7 +535,7 @@ def _new_idxvar( var_def = "threadgroup " if is_threadgroup else "" var_def += f"{dtype} {var_name}" if elem_count: - var_def += f"[{elem_count}]" + var_def += f"[{self.sexpr(elem_count)}]" if default_value is not None: assert not is_threadgroup, "Thread group var can not have default value" var_def += f" = {default_value}" @@ -585,9 +585,21 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: if reduction_idx: reduction_idx += " + " reduction_idx += f"{rd.name} * {acc_buf_size}" - acc_buf_size *= rd.numel - acc_buf_size = min(acc_buf_size, self.max_threadgroup_size) - shmem_buf_size = ceildiv(acc_buf_size, self.simd_group_size) + + if isinstance(rd.numel, sympy.Integer): + acc_buf_size *= rd.numel + else: + acc_buf_size *= sympy.Symbol( + f"{rd.prefix}numel", integer=True, positive=True + ) + + acc_buf_size = sympy.Min(acc_buf_size, self.max_threadgroup_size) + acc_buf_size_str = self.sexpr(acc_buf_size) + shmem_buf_size = ( + ceildiv(acc_buf_size, self.simd_group_size) + if isinstance(acc_buf_size, sympy.Integer) + else self.simd_group_size + ) if reduction_type == "any": acc = self._new_idxvar(dtype) @@ -622,9 +634,10 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: acc_dtype, default_value=default_val, is_threadgroup=False ) self.compute.splice(f"{val} {reduction_op}= {value};") + return self.cse.generate( self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})", dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], ) if reduction_type in ["max", "min"]: @@ -644,40 +657,43 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: ) return self.cse.generate( self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})", dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], ) if reduction_type in ["argmin", "argmax"]: - acc_buf = self._new_idxvar(src_dtype, acc_buf_size) - acc_thread_var = f"{acc_buf}[{reduction_idx}]" + data_acc_buf = self._new_idxvar(src_dtype, shmem_buf_size) + idx_acc_buf = self._new_idxvar(dtype, shmem_buf_size) src_metal_type = DTYPE_TO_METAL[src_dtype] + cast_value = f"static_cast<{src_metal_type}>({value})" if not self.multistage_reduction_entry: - self.compute.splice( - f"{acc_thread_var} = static_cast<{src_metal_type}>({value});" + val = cast_value # type: ignore[assignment] + idx_val = f"static_cast<{DTYPE_TO_METAL[dtype]}>({reduction_idx})" + else: + lim_fn = "lowest" if reduction_type.endswith("max") else "max" + limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()" + val = self._new_idxvar( + src_dtype, default_value=limit_val, is_threadgroup=False ) - return self.cse.generate( - self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", - dtype=dtype, + idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment] + idx_var = next( + t for t in self.range_tree_nodes.values() if t.is_reduction ) - lim_fn = "lowest" if reduction_type.endswith("max") else "max" - self.indexing_code.writeline( - f"{acc_thread_var} = ::metal::numeric_limits<{src_metal_type}>::{lim_fn}();" - ) - idx_var = next(t for t in self.range_tree_nodes.values() if t.is_reduction) - idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) - cmp_op = ">" if reduction_type == "argmax" else "<" - idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]" - self.indexing_code.splice(f"{idx_thread_var} = -1;") - self.compute.splice(f""" - if ({value} {cmp_op} {acc_thread_var}) {{ - {acc_thread_var} = {value}; - {idx_thread_var} = {idx_var.name}; - }} - """) + cmp_op = ">" if reduction_type == "argmax" else "<" + nan_suffix = ( + f" || ::metal::isnan({value}) " + if src_dtype.is_floating_point + else "" + ) + self.compute.splice(f""" + if ({value} {cmp_op} {val}{nan_suffix}) {{ + {val} = {value}; + {idx_val} = {idx_var.name}; + }} + """) return self.cse.generate( self.stores, - f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]", + f"c10::metal::threadgroup_{reduction_type}({data_acc_buf}, {idx_acc_buf}, " + f"{val}, {idx_val}, {reduction_idx}, {acc_buf_size_str})", dtype=dtype, ) if reduction_type == "welford_reduce": @@ -686,7 +702,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") wf_res = self.cse.generate( self.compute, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", dtype=torch.float32, ) return _unwrap_helper(wf_res) @@ -717,7 +733,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.compute.writeline(f"{acc_thread_var} = {inp_value};") wf_res = self.cse.generate( self.stores if self.multistage_reduction_entry else self.compute, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", dtype=torch.float32, ) return _unwrap_helper(wf_res) @@ -727,28 +743,51 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: index_expr = self.rename_indexing(entry.expr) index_str = self.sexpr(index_expr) # type: ignore[misc] - if not entry.is_reduction or entry.root.numel <= self.max_threadgroup_size: + if not entry.is_reduction or ( + isinstance(entry.root.numel, sympy.Integer) + and entry.root.numel <= self.max_threadgroup_size + ): self.indexing_code.writeline( f"{self.index_dtype} {entry.name} = {index_str};" ) return + + acc_size = ( + entry.root.numel + if isinstance(entry.root.numel, sympy.Integer) + else sympy.Symbol(f"{entry.root.prefix}numel", integer=True, positive=True) + ) + self.multistage_reduction_entry.append(entry) # When reducing the tensor whose size exceeds max threadgroup size # loop over extra indices per reduction thread and perform part of the operation # using values in the shared memory - loop_size = ( - entry.root.numel + self.max_threadgroup_size - 1 - ) // self.max_threadgroup_size + + # Use floats so that it doesn't do integer division + loop_size = (acc_size + float(self.max_threadgroup_size - 1)) // float( + self.max_threadgroup_size + ) + loop_size_str = self.sexpr(loop_size) + self.body.writeline( - f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size}; ++{entry.name}_cnt) {{" + f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size_str}; ++{entry.name}_cnt) {{" ) with self.body.indent(): - self.body.writeline( - f"{self.index_dtype} {entry.name} = {loop_size} * {index_str} + {entry.name}_cnt;" - ) + if isinstance(acc_size, sympy.Symbol): + self.body.writeline( + f"{self.index_dtype} {entry.name} = {self.max_threadgroup_size} * {entry.name}_cnt + {index_str};" + ) + else: + self.body.writeline( + f"{self.index_dtype} {entry.name} = {loop_size_str} * {index_str} + {entry.name}_cnt;" + ) + # Check that reduction is performed only within tensor boundary - if loop_size * self.max_threadgroup_size != entry.root.numel: - self.body.writeline(f"if ({entry.name} >= {entry.root.numel}) break;") + if ( + isinstance(acc_size, sympy.Symbol) + or loop_size * self.max_threadgroup_size != acc_size + ): + self.body.writeline(f"if ({entry.name} >= {acc_size}) break;") def codegen_body(self) -> None: """ @@ -817,7 +856,13 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: total_reduction_size = math.prod( t.numel for t in self.range_trees if t.is_reduction ) - threadgroup_size = min(total_reduction_size, self.max_threadgroup_size) + # If using dynamic shapes, set the threadgroup size to be the + # max possible size + threadgroup_size = ( + min(total_reduction_size, self.max_threadgroup_size) + if isinstance(total_reduction_size, sympy.Integer) + else self.max_threadgroup_size + ) code.writeline( f"[[max_total_threads_per_threadgroup({threadgroup_size})]]" ) @@ -841,6 +886,14 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: code.writeline(f"constant {dtype_str}* {inner},") for outer, inner in self.args.sizevars.items(): code.writeline(f"constant long& {inner},") + + # Write dynamic values as inputs + for idx_var in idx_vars: + if isinstance(idx_var.numel, sympy.Integer): + pass + else: + code.writeline(f"constant long& {idx_var.prefix}numel,") + assert len(idx_vars) < 4, "Up to 3 index variables are supported" thread_pos_dtype = ( f"uint{len(idx_vars)}" if len(idx_vars) > 1 else "uint" @@ -875,7 +928,9 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: return code.getvalue() def call_kernel(self, name: str, node: Any = None) -> None: - """Codegen a call to this kernel""" + """ + Codegens a call to this kernel + """ wrapper = V.graph.wrapper_code # Make sure sizevars has been computed for v in self.args.sizevars.keys(): @@ -889,8 +944,22 @@ def call_kernel(self, name: str, node: Any = None) -> None: args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] args = [arg for arg in args if arg not in self.removed_buffers] args += [str(v) for v in self.args.sizevars.keys()] - arg_types = [arg_name_to_type[arg] for arg in args] + + # Add any dynamic ints as inputs + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, int)): + # Don't need to pass in integers as inputs + continue + elif isinstance(tree.numel, sympy.Symbol): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner + + if not tree.is_reduction or self.inside_reduction: + args.append(str(expr)) + arg_types.append(int) + expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr def format_threads(threads: list[str], kwarg: str) -> str: @@ -983,11 +1052,7 @@ def define_kernel( # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" - if V.graph.cpp_wrapper: - kernel_name = f"{mps_lib_name}_func" - else: - kernel_name = f"{mps_lib_name}.generated_kernel" - + kernel_name = f"{mps_lib_name}" wrapper.src_to_kernel[src_code] = kernel_name if V.graph.cpp_wrapper: diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index 4a08773433c3..df4982988aa1 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -96,7 +96,7 @@ def update_workspace_size(self) -> None: return self.ensure_dll_loaded() unique_input_count = len( - {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + dict.fromkeys(meta.name for meta in self.input_tensor_meta) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 7ac967bbe0b0..5b1350a9239e 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1991,7 +1991,7 @@ def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: @classmethod def create_tiling( cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] - ) -> dict[str, sympy.Expr]: + ) -> immutable_dict[str, sympy.Expr]: """ Create a tiling dict from pointwise and reduction splits. """ @@ -2006,7 +2006,7 @@ def create_partial_tiling( cls, tiling: Sequence[sympy.Expr], is_pointwise: bool, - ) -> dict[str, sympy.Expr]: + ) -> immutable_dict[str, sympy.Expr]: return cls.create_tiling( tiling if is_pointwise else [], tiling if not is_pointwise else [], @@ -2018,7 +2018,7 @@ def complete_partial_tiling( tiling: dict[str, sympy.Expr], numel: sympy.Expr, reduction_numel: sympy.Expr, - ) -> dict[str, sympy.Expr]: + ) -> immutable_dict[str, sympy.Expr]: """ Given a tiling for only pointwise or reduction dimensions, adds the missing one. """ @@ -2039,7 +2039,7 @@ def get_nd_tilings( node_schedule, pointwise_numel, reduction_numel, - ) -> list[dict[str, tuple[sympy.Expr]]]: + ) -> list[immutable_dict[str, sympy.Expr]]: """ Creates N-dimensional tiling candidates, attempting to simplify loads/stores by tiling the kernel into higher dimensions. @@ -2047,7 +2047,7 @@ def get_nd_tilings( Returns a list of tilings ranked by dimensionality. """ is_pointwise = reduction_numel == 1 - tilings = OrderedSet[dict[str, sympy.Expr]]() + tilings = OrderedSet[immutable_dict[str, sympy.Expr]]() for node in EnableReduction.filter(node_schedule): if not isinstance(node, scheduler.SchedulerNode): continue @@ -2312,7 +2312,7 @@ def process_node_vars( ) ) - tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = [] + tilings: list[tuple[CandidateTiling, immutable_dict[str, sympy.Expr]]] = [] for (pw_split, pw_score), (red_split, red_score) in score_split: candidate = CandidateTiling( cls.create_tiling(pw_split, red_split), diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 49e10d7c0512..8e0831e3726f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2001,14 +2001,12 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - if ( + return ( self.persistent_reduction and len(self.numels) == self.num_reduction_dims + 1 - ): - if self.fixed_config: - return self.fixed_config["XBLOCK"] == 1 - return V.choices.want_no_x_dim(self.features) - return False + and self.fixed_config + and self.fixed_config["XBLOCK"] == 1 + ) @property def assert_function(self) -> str: @@ -2671,6 +2669,18 @@ def guard_cooperative_store(self, name, buffer): buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):")) return buffer.indent() + def _combine_masks(self, *variables: Optional[CSEVariable]): + masks = None + for elem in variables: + if elem is None: + continue + if hasattr(elem, "mask_vars"): + if masks is None: + masks = elem.mask_vars + else: + masks = masks | elem.mask_vars + return masks + def bucketize( self, values: CSEVariable, @@ -2720,6 +2730,9 @@ def bucketize( dtype=indexing_dtype, # type: ignore[attr-defined] ) + masks = self._combine_masks(values, boundary_indices, sorter_indices) + result.mask_vars = masks # type: ignore[attr-defined] + return result def reduction_resize(self, value) -> str: @@ -3970,8 +3983,8 @@ def add_constexpr_arg(arg_name): optimize_mem = V.graph.is_inference or V.graph.is_backward inductor_meta = { - # Triton will not accept an OrderedSet for autotune_hints "grid_type": self._get_grid_type().__name__, + # Triton will not accept an OrderedSet for autotune_hints "autotune_hints": set(self.autotune_hints), # noqa: set_linter "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), "mutated_arg_names": mutated_args, @@ -4483,6 +4496,11 @@ def define_kernel(self, src_code, node_schedule, kernel): kernel_name = "_".join( ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] ) + if config.aot_inductor.model_name_for_generated_files: + # When AOTI compiles multiple submodules, we need to use the model name to + # distinguish kernel related symbols. + kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}" + # use the original src_code as the key wrapper.src_to_kernel[src_code] = kernel_name subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 7e509bf18ea1..9394c0e4a16d 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -50,6 +50,7 @@ get_benchmark_name, IndentedBuffer, is_codegen_graph_partition_subgraph, + is_using_cudagraph_partition, LineContext, sympy_product, sympy_str, @@ -963,9 +964,12 @@ def write_header(self) -> None: aot_config_comment = "" if context is not None and context.aot_graph_name is not None: aot_config_comment = f"# AOT ID: {context.aot_graph_name}" - aot_inductor_debug_utils = "" + inductor_debug_utils = "" if int(config.aot_inductor.debug_intermediate_value_printer) > 0: - aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + elif torch._inductor.config.test_configs.track_memory_lifecycle: + inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n" + self.imports.splice( f""" {aot_config_comment} @@ -983,7 +987,7 @@ def write_header(self) -> None: from torch import device, empty_strided from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels - {aot_inductor_debug_utils} + {inductor_debug_utils} """, strip=True, ) @@ -995,6 +999,7 @@ def write_header(self) -> None: assert_size_stride = torch._C._dynamo.guards.assert_size_stride assert_alignment = torch._C._dynamo.guards.assert_alignment empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia @@ -1193,7 +1198,14 @@ def write_prefix(self) -> None: self.write_args(graph_input_names) self.codegen_inputs() - self.codegen_input_size_and_nan_asserts() + + # avoid duplicating asserts for both partition functions and + # the call function when using cudagraph partition + if not ( + is_using_cudagraph_partition() + and (not is_codegen_graph_partition_subgraph(self)) + ): + self.codegen_input_size_and_nan_asserts() def codegen_input_size_and_nan_asserts(self) -> None: if config.size_asserts: @@ -1753,7 +1765,9 @@ def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: return self.codegen_python_shape_tuple(shape) - def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + def codegen_alloc_from_pool( + self, name, offset, dtype, shape, stride + ) -> tuple[str, list[str]]: return "alloc_from_pool({})".format( ", ".join( [ @@ -1764,7 +1778,7 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: self.codegen_python_shape_tuple(stride), ] ) - ) + ), [] def codegen_reinterpret_view( self, @@ -2550,10 +2564,16 @@ def _generate_kernel_call_helper( original_fxnode_name=None, ): device = device or V.graph.get_current_device_or_throw() - if not ( - triton or device.type not in ("cpu", "mps") - ): # TODO: Fix me, MPS does not expose streams now - self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + if not triton and device.type != "cuda": + if device.type == "cpu": + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + elif device.type == "mps": + # TODO: Fix me, MPS does not expose streams now + self.writeline( + self.wrap_kernel_call(f"{kernel_name}.generated_kernel", call_args) + ) + else: + raise RuntimeError(f"device {device.type} nyi") return call_args_str = self.prepare_triton_kernel_call(call_args) @@ -2763,12 +2783,21 @@ def make_buffer_allocation(self, buffer: BufferLike): shape = tuple(buffer.get_size()) allocation_shape = tuple(V.graph.get_allocation_size(buffer)) stride = tuple(buffer.get_stride()) + is_pinned = buffer.get_is_pinned() return self.make_allocation( - buffer.get_name(), device, dtype, shape, stride, allocation_shape + buffer.get_name(), device, dtype, shape, stride, allocation_shape, is_pinned ) + @cache_on_self + def write_memory_track_allocation_once(self): + import_str = """ + from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor + """ + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + def make_allocation( - self, name, device, dtype, shape, stride, allocation_shape=None + self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False ): if allocation_shape is None: allocation_shape = shape @@ -2778,7 +2807,23 @@ def make_allocation( allocation_shape ) codegen_stride_tuple = self.codegen_python_shape_tuple(stride) - if device.type in ("cpu", "cuda", "xpu", "mtia"): + if torch._inductor.config.test_configs.track_memory_lifecycle: + out = ( + f"{name} = tracked_empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"dtype={dtype}, " + f"device='{device.type}', " + f"name='{name}')" + ) + elif device.type == "cpu" and is_pinned: + out = ( + f"{name} = empty_strided_cpu_pinned(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"{dtype})" + ) + elif device.type in ("cpu", "cuda", "xpu", "mtia"): # optimized path for faster allocations, saving ~2us versus the stuff below out = ( f"{name} = empty_strided_{device.type}(" diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 6c70b1d175ac..1537d8267f0b 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -22,7 +22,7 @@ from torch._library.triton import wrap_triton from torch.fx import GraphModule from torch.utils import _pytree as pytree -from torch.utils._sympy.functions import FloorDiv +from torch.utils._sympy.functions import CeilDiv from .. import config, ir from ..utils import convert_shape_to_symint, convert_to_symint, LineContext @@ -581,8 +581,8 @@ def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: assert V.graph.sizevars.statically_known_equals(new_expr, expr), ( f"Unsound replacement: '{new_expr}' != '{expr}'" ) - - return FloorDiv(numerator, denominator) + # Undo the python division trick and replace with explicit CeilDiv + return -CeilDiv(-numerator, denominator) else: return sympy.floor(expr) diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 632cfd29f174..99502ca2dd97 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -58,8 +58,8 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "void *" - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: return None diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index b748f61f067b..e46909432f17 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -209,7 +209,9 @@ def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.Tensor inp.realize() V.graph.no_fuse_buffer_names.add(inp.get_name()) inp = ir.ExternKernel.require_contiguous(inp) - ir._AllReduceKernel.create_inplace( + # Because we are lowering as inplace c10d.all_reduce_, we should generate + # _AllReduce_Kernel instead of _AllReduceKernel. + ir._AllReduce_Kernel.create_inplace( c10d.all_reduce_.default, inp, # type: ignore[arg-type] reduce_op, diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 2fef157859d7..eaab9020f1e8 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1509,6 +1509,7 @@ def codegen_and_compile( compiled_module, "runner", None ) + node_runtimes = None if inductor_metrics_log.isEnabledFor(logging.INFO): num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() metrics.num_bytes_accessed += num_bytes @@ -1523,6 +1524,14 @@ def codegen_and_compile( }, ) + # Collect and dump op runtimes for TLParse + if config.log_tlparse: + _, _, node_runtimes = graph.count_bytes() + torch._inductor.debug.log_runtime_estimates(node_runtimes) + + # Collect and dump collective-op schedule for external diagnostics + torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes) + if ( cudagraphs and config.triton.cudagraph_skip_dynamic_graphs @@ -2043,6 +2052,34 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[ ) +def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, +) -> tuple[GraphModule, GraphModule]: + cuda_context = get_cuda_device_context(gm) + with cuda_context: + # We can skip the invoke_subgraph because the + # entire_partition_fn is called recursively for invoke_subgraph + # in partitioning. + _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) + + static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + + with dynamo_utils.dynamo_timed( + "min_cut_rematerialization_partition", log_pt2_compile_event=True + ): + return min_cut_rematerialization_partition( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + + def compile_fx( model_: GraphModule, example_inputs_: Sequence[InputType], @@ -2361,33 +2398,6 @@ def fw_compiler_base( OutputCode, inference_compiler ) - def partition_fn( - gm: GraphModule, - joint_inputs: Sequence[object], - **kwargs: object, - ) -> tuple[GraphModule, GraphModule]: - cuda_context = get_cuda_device_context(gm) - with cuda_context: - # We can skip the invoke_subgraph because the - # entire_partition_fn is called recursively for invoke_subgraph - # in partitioning. - _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) - - static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] - "static_lifetime_input_indices", None - ) - - with dynamo_utils.dynamo_timed( - "min_cut_rematerialization_partition", log_pt2_compile_event=True - ): - return min_cut_rematerialization_partition( - gm, - joint_inputs, - compiler="inductor", - static_lifetime_input_indices=static_lifetime_input_indices, - **kwargs, - ) - @compile_time_strobelight_meta(phase_name="backward") def bw_compiler( gm: GraphModule, example_inputs: Sequence[InputType] diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 0b670b268b37..7c05b01f45d7 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -145,10 +145,19 @@ def __init__( f"--write-fd={str(subproc_write_fd)}", f"--torch-key={torch_key_str}", ] - local = False + log_path = None + self.log_file = None + if config.worker_suppress_logging: + log_path = os.devnull log.info("Suppressing compile worker output due to config") - local = True + else: + log_path = config.torchinductor_worker_logpath + if not log_path: + log_path = config.get_worker_log_path() + + if log_path: + self.log_file = open(log_path, "w") self.process = subprocess.Popen( cmd, @@ -164,8 +173,8 @@ def __init__( "LD_LIBRARY_PATH": get_ld_library_path(), }, pass_fds=(subproc_read_fd, subproc_write_fd), - stdout=subprocess.DEVNULL if local else None, - stderr=subprocess.DEVNULL if local else None, + stdout=self.log_file, + stderr=self.log_file, ) self.write_lock = threading.Lock() self.read_thread = threading.Thread( @@ -262,6 +271,8 @@ def shutdown(self) -> None: _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) self.write_pipe.close() self.process.wait(300) + if self.log_file: + self.log_file.close() except OSError as e: log.warning("Ignored OSError in pool shutdown: %s", e) finally: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 63dccf5d5d8a..deebfa273ba1 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -81,6 +81,11 @@ def prologue_fusion_enabled() -> bool: # Whether to enable printing the source code for each future verbose_progress = False +# Configurable compile worker logging path for subproc_pool +worker_log_path = ( + "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None +) + # precompilation timeout precompilation_timeout_seconds: int = 60 * 60 @@ -91,6 +96,8 @@ def prologue_fusion_enabled() -> bool: default=True, ) +remote_gemm_autotune_cache: bool = False + # use remote fx aot graph codegen cache # False: Disables the cache # True: Enables the cache @@ -138,12 +145,8 @@ def prologue_fusion_enabled() -> bool: # None: Not set -- Off for OSS, JustKnobs based for internal bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() -# Force disabled all inductor level caching -- This will override any other caching flag -force_disable_caches: bool = Config( - justknob="pytorch/remote_cache:force_disable_caches", - env_name_force="TORCHINDUCTOR_FORCE_DISABLE_CACHES", - default=False, -) +# See torch.compiler.config.force_disable_caches +force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches") # Unsafe way to skip dynamic shape guards to get faster cache load unsafe_skip_cache_dynamic_shape_guards: bool = False @@ -428,8 +431,17 @@ def prologue_fusion_enabled() -> bool: # Modifies the number of autotuning choices displayed, set to None for all autotune_num_choices_displayed: Optional[int] = 10 +# Report the autotune choices and their benchmark results. Default is True. +max_autotune_report_choices_stats = ( + os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1" +) + # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph -graph_partition = False +graph_partition: bool = ( + os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0") + == "1" +) + # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations @@ -740,6 +752,12 @@ def decide_worker_start_method() -> str: default=True, ) +# Log per-operation runtime estimates for TLParse analysis. +log_tlparse: bool = Config( + env_name_force="LOG_TLPARSE", + default=False, +) + # Flags to turn on all_reduce fusion. These 2 flags should be automatically turned # on by DDP and should not be set by the users. _fuse_ddp_communication = False @@ -1002,6 +1020,24 @@ def decide_compile_threads() -> int: autotune_lookup_table: dict[str, dict[str, Any]] = {} +def get_worker_log_path() -> Optional[str]: + log_loc = None + if is_fbcode(): + mast_job_name = os.environ.get("MAST_HPC_JOB_NAME", None) + global_rank = os.environ.get("ROLE_RANK", "0") + + if mast_job_name is not None: + log_loc = f"/logs/dedicated_log_torch_compile_worker_rank{global_rank}" + + return log_loc + + +torchinductor_worker_logpath: str = Config( + env_name_force="TORCHINDUCTOR_WORKER_LOGPATH", + default="", +) + + # config specific to codegen/cpp.py class cpp: """ @@ -1464,12 +1500,12 @@ class aot_inductor: precompile_headers: bool = not is_fbcode() # Embed generated kernel binary files into model.so - embed_kernel_binary: bool = False + embed_kernel_binary: Optional[bool] = None # Generate kernel files that support multiple archs # For CUDA, this means generating fatbin files for kernels, and the fatbin files # contains PTX and SASS for the current architecture. - emit_multi_arch_kernel: bool = False + emit_multi_arch_kernel: Optional[bool] = None # If not None, the generated files with use this name in file stem. # If None, we will use a hash to name files. @@ -1860,6 +1896,12 @@ class test_configs: graphsafe_rng_func_ignores_fallback_random = False + track_memory_lifecycle: Optional[Literal["assert", "log"]] = None + + # If set to True, AOTI-generated CMakelists.txt will still use libtorch + # for unit testing + use_libtorch = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 47820d3d7725..c58849f9bf5a 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -28,7 +28,6 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir -from torch._inductor.utils import aoti_model_name_from_config from torch.torch_version import TorchVersion @@ -602,43 +601,73 @@ def _get_ffast_math_flags() -> list[str]: return flags +def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]: + """ + When we turn on generate debug symbol. + On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. + On Linux, it should create some debug sections in binary file. + """ + cflags: list[str] = [] + ldflags: list[str] = [] + + if _IS_WINDOWS: + cflags = ["ZI", "_DEBUG"] + ldflags = ["DEBUG", "ASSEMBLYDEBUG ", "OPT:REF", "OPT:ICF"] + else: + cflags.append("g") + + return cflags, ldflags + + def _get_optimization_cflags( cpp_compiler: str, min_optimize: bool = False -) -> list[str]: - if _IS_WINDOWS: - return ["O1" if min_optimize else "O2"] +) -> tuple[list[str], list[str]]: + cflags: list[str] = [] + ldflags: list[str] = [] + + b_debug_build = ( + config.aot_inductor.debug_compile + or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" + ) + wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level + + if b_debug_build: + cflags, ldflags = _get_inductor_debug_symbol_cflags() + if _IS_WINDOWS: + cflags += ["Od", "Ob0", "Oy-"] + else: + cflags.append("O0") else: - wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level - cflags = ( - ["O0", "g"] - if config.aot_inductor.debug_compile - else [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] - ) - cflags += _get_ffast_math_flags() - cflags.append("fno-finite-math-only") - if not config.cpp.enable_unsafe_math_opt_flag: - cflags.append("fno-unsafe-math-optimizations") - cflags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") - - if sys.platform != "darwin": - # on macos, unknown argument: '-fno-tree-loop-vectorize' - if _is_gcc(cpp_compiler): - cflags.append("fno-tree-loop-vectorize") - # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 - # `-march=native` is unrecognized option on M1 - if not config.is_fbcode(): - if platform.machine() == "ppc64le": - cflags.append("mcpu=native") - else: - cflags.append("march=native") - - if config.aot_inductor.enable_lto and _is_clang(cpp_compiler): - cflags.append("flto=thin") - - return cflags - - -def _get_shared_cflag(do_link: bool) -> list[str]: + if _IS_WINDOWS: + cflags = ["O1" if min_optimize else "O2"] + else: + cflags = [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] + + cflags += _get_ffast_math_flags() + cflags.append("fno-finite-math-only") + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + cflags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") + + if sys.platform != "darwin": + # on macos, unknown argument: '-fno-tree-loop-vectorize' + if _is_gcc(cpp_compiler): + cflags.append("fno-tree-loop-vectorize") + # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 + # `-march=native` is unrecognized option on M1 + if not config.is_fbcode(): + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + if config.aot_inductor.enable_lto and _is_clang(cpp_compiler): + cflags.append("flto=thin") + + return cflags, ldflags + + +def _get_shared_cflags(do_link: bool) -> list[str]: if _IS_WINDOWS: """ MSVC `/MD` using python `ucrtbase.dll` lib as runtime. @@ -668,9 +697,11 @@ def get_cpp_options( libraries: list[str] = [] passthrough_args: list[str] = [] + opt_cflags, opt_ldflags = _get_optimization_cflags(cpp_compiler, min_optimize) + cflags = ( - _get_shared_cflag(do_link) - + _get_optimization_cflags(cpp_compiler, min_optimize) + opt_cflags + + _get_shared_cflags(do_link) + _get_warning_all_cflag(warning_all) + _get_cpp_std_cflag() + _get_os_related_cpp_cflags(cpp_compiler) @@ -686,7 +717,7 @@ def get_cpp_options( definitions, include_dirs, cflags, - ldflags, + ldflags + opt_ldflags, libraries_dirs, libraries, passthrough_args, @@ -1545,7 +1576,9 @@ def __init__( self._aot_mode: bool = False self._name = name - self._target_name = aoti_model_name_from_config() + self._target_name = ( + config.aot_inductor.model_name_for_generated_files or "aoti_model" + ) # Code start here, initial self internal variables firstly. self._build_option = BuildOption @@ -1598,7 +1631,8 @@ def __init__( if isinstance(sources, str): sources = [sources] - if config.is_fbcode() and (not self._aot_mode or self._use_relative_path): + # Use relative paths only when requested (typically for remote builds) + if config.is_fbcode() and self._use_relative_path: # Will create another temp directory for building, so do NOT use the # absolute path. self._orig_source_paths = list(sources) @@ -1781,22 +1815,54 @@ def save_compile_cmd_to_cmake( project({self._target_name} LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) - # May need to point CMAKE_PREFIX_PATH to the right torch location - find_package(Torch REQUIRED) - - # Set a shared library target + # Set a library target add_library({self._target_name} {target_library_type}) - # Add macro definitions - target_compile_definitions({self._target_name} PRIVATE {definitions}) - - # Add compile flags - target_compile_options({self._target_name} PRIVATE {self._cflags_args}) - # Backend specific flags - target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) - """ ) + + if ( + not config.aot_inductor.compile_standalone + or config.test_configs.use_libtorch + ): + # When compile_standalone is True, the generated cpp project should + # not use Torch. But for unit testing purpose, we need to use Torch here. + contents += textwrap.dedent( + """ + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) + + """ + ) + # flags and macros here are mostly CPU specific. Not emitting them for GPU models + # will make the generated CMake file more portable and won't really hurt performance. + # NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may + # be still needed. + contents += textwrap.dedent( + f""" + # Add macro definitions + target_compile_definitions({self._target_name} PRIVATE {definitions}) + + # Add compile flags + target_compile_options({self._target_name} PRIVATE {self._cflags_args}) + + # Backend-specific flags + target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + + """ + ) + else: + # When compile_standalone is True, use TorchStandalone instead of Torch + contents += textwrap.dedent( + f""" + find_package(TorchStandalone REQUIRED) + # Set up include directories to find headers at the correct paths + target_include_directories({self._target_name} PRIVATE ${{TorchStandalone_INCLUDE_DIRS}}) + target_include_directories({self._target_name} PRIVATE ${{TorchStandalone_INCLUDE_DIRS}}/standalone) + + """ + ) + if device_type == "cuda" and torch.version.hip is None: from torch._inductor.codecache import _nvcc_arch_as_compile_option @@ -1804,7 +1870,11 @@ def save_compile_cmd_to_cmake( contents += textwrap.dedent( f""" enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 17) find_package(CUDAToolkit REQUIRED) + target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}}) + target_compile_definitions({self._target_name} PRIVATE USE_CUDA) + target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static) find_program(OBJCOPY_EXECUTABLE objcopy) if(NOT OBJCOPY_EXECUTABLE) @@ -1833,7 +1903,7 @@ def save_compile_cmd_to_cmake( add_custom_command( OUTPUT ${{FATBIN_FILE}} COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} - -gencode arch=compute_80,code=compute_80 + -gencode arch=compute_{current_arch},code=compute_{current_arch} -gencode arch=compute_{current_arch},code=sm_{current_arch} DEPENDS ${{PTX_FILE}} ) @@ -1882,12 +1952,20 @@ def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> Non """ ) f.write(contents) - f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") - f.write( - f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" - ) + if asm_files: + f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") + f.write( + f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" + ) def save_link_cmd_to_cmake(self, cmake_path: str) -> None: + if ( + config.aot_inductor.compile_standalone + and not config.test_configs.use_libtorch + ): + # When compile_standalone is True, do not link with libtorch + return + lflags = " ".join(self._build_option.get_ldflags()) libs = " ".join(self._build_option.get_libraries()) contents = textwrap.dedent( @@ -1905,3 +1983,33 @@ def save_link_cmd_to_cmake(self, cmake_path: str) -> None: ) with open(cmake_path, "a") as f: f.write(contents) + + +def run_asm_build_object(src: str, target: str, cwd: str) -> None: + def get_asm_compiler() -> str: + if _IS_WINDOWS: + ASM_CC = "ml64" + else: + ASM_CC = get_cpp_compiler() + # Intel compiler is not support to compile asm, switch to gcc. + if _is_intel_compiler(ASM_CC): + ASM_CC = "gcc" + return ASM_CC + + def get_command_line(asm_cc: str, src: str, target: str) -> str: + if _IS_WINDOWS: + # Format reference: + # https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170 + cmd = f"{asm_cc} {src} /c /Fo {target}" # codespell:ignore /Fo + else: + cmd = f"{asm_cc} -c {src} -o {target}" + + return cmd + + asm_cc = get_asm_compiler() + cmd = get_command_line( + asm_cc=asm_cc, + src=normalize_path_separator(src), + target=normalize_path_separator(target), + ) + run_compile_cmd(cmd, cwd=normalize_path_separator(cwd)) diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 2686d1d2ddde..7826c797d36b 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -10,6 +10,8 @@ from torch._inductor.utils import GraphPartitionMap, InputType from torch.utils._ordered_set import OrderedSet +from .utils import is_using_cudagraph_partition + if TYPE_CHECKING: from collections.abc import Sequence @@ -170,7 +172,8 @@ def check_multiple_devices_or_any_cpu_nodes( # meta tensors are supported since there is no compute device_node_mapping.pop(torch.device("meta"), None) - if torch._inductor.config.graph_partition: + # dynamo cudagraph does not support graph partition + if is_using_cudagraph_partition(): # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index de72edbe0175..71df3429bb01 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -22,7 +22,9 @@ from torch import fx as fx from torch._dynamo.repro.after_aot import save_graph_repro from torch._dynamo.utils import get_debug_dir +from torch._inductor import utils from torch._logging import getArtifactLogger +from torch._logging._internal import trace_structured from torch.fx.graph_module import GraphModule from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.fx.passes.tools_common import legalize_graph @@ -693,6 +695,57 @@ def log_ir_post_fusion(nodes: SchedulerNodeList) -> None: V.debug.ir_post_fusion(nodes) +def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None: + try: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_collective_schedule", + "encoding": "json", + }, + payload_fn=lambda: schedule, + ) + except Exception: + log.debug( + "Failed to log inductor_collective_schedule via structured logging", + exc_info=True, + ) + + +def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None: + schedule = [ + getattr(op, "python_kernel_name", None) + for node in nodes + if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel) + ] + + # Only log when there is at least one collective op + if schedule: + _dump_collective_schedule(schedule) + + +def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None: + """Log per-operation runtime estimates for TLParse.""" + + ops = [ + { + "name": getattr(s.node, "python_kernel_name", s.get_name()), + "type": "collective" if utils.is_collective(s.node) else "compute", + "estimated_runtime_ns": runtime_ns, + } + for s, runtime_ns in node_runtimes + ] + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_tlparse_runtime", + "encoding": "json", + }, + payload_fn=lambda: {"ops": ops}, + ) + + @dataclasses.dataclass class TensorMetadataHolder: tensor_metadata: TensorMetadata diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 2622ab6b95e4..d903d851ee87 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -581,10 +581,8 @@ def view_copy_dtype( def _get_shape_permutation_like( - self: torch.Tensor, layout: torch.layout + self: torch.Tensor, ) -> tuple[utils.ShapeType, utils.StrideType]: - assert layout == torch.strided - physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self) shape = [self.shape[l] for l in physical_layout] @@ -624,7 +622,8 @@ def full_like( return result.to(memory_format=memory_format) else: - shape, permutation = _get_shape_permutation_like(self, layout) + assert layout == torch.strided + shape, permutation = _get_shape_permutation_like(self) result = torch.full( shape, fill_value, @@ -644,29 +643,25 @@ def _rand_like( self: torch.Tensor, *, dtype: Optional[torch.dtype] = None, - layout: Optional[torch.layout] = None, device: Optional[torch.device] = None, memory_format: torch.memory_format = torch.preserve_format, **kwargs: Any, ) -> torch.Tensor: dtype = self.dtype if dtype is None else dtype - layout = self.layout if layout is None else layout device = self.device if device is None else device if memory_format != torch.preserve_format: return rand_fn( self.shape, dtype=dtype, - layout=layout, device=device, **kwargs, ).to(memory_format=memory_format) - shape, permutation = _get_shape_permutation_like(self, layout) + shape, permutation = _get_shape_permutation_like(self) result = rand_fn( shape, dtype=dtype, - layout=layout, device=device, **kwargs, ) diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index ac321c9974ae..a46663ed8f8c 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -92,6 +92,9 @@ def __init__(self, cmd: list[str], output: str) -> None: if isinstance(output, bytes): output = output.decode("utf-8") + self.cmd = cmd + self.output = output + super().__init__( textwrap.dedent( """ @@ -108,6 +111,9 @@ def __init__(self, cmd: list[str], output: str) -> None: .format(cmd=" ".join(cmd), output=output) ) + def __reduce__(self) -> tuple[type, tuple[list[str], str]]: + return (self.__class__, (self.cmd, self.output)) + class CUDACompileError(CppCompileError): pass diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 13543b486199..3bf1ff9dab86 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -93,6 +93,12 @@ def greedy_bucket_collective_by_mb( node_group_key: Callable[[torch.fx.Node], Any], filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, ) -> list[list[torch.fx.Node]]: + """ + Bucketing adjacent collectives with equal node_group_key. + We can not bucket non adjacent collectives, + as this will effectively change the order of collectives. + Reordering can lead to different order on different ranks. + """ g = gm.graph found_candidates = False for node in g.nodes: @@ -102,10 +108,12 @@ def greedy_bucket_collective_by_mb( if not found_candidates: return [] - nodes_groups: dict[Any, list[torch.fx.Node]] = defaultdict(list) nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict( OrderedSet ) + nodes_groups: list[list[torch.fx.Node]] = [] + cur_group: list[torch.fx.Node] = [] + cur_group_key = None for node in g.nodes: for n, successors in nodes_successors.items(): @@ -115,10 +123,19 @@ def greedy_bucket_collective_by_mb( if (filter_wait_node is None) or filter_wait_node(node): coll_node = node.args[0] group_key = node_group_key(coll_node) - nodes_groups[group_key].append(coll_node) + if group_key == cur_group_key: + cur_group.append(coll_node) + else: + if len(cur_group) > 1: + nodes_groups.append(cur_group) + cur_group = [coll_node] + cur_group_key = group_key + + if len(cur_group) > 1: + nodes_groups.append(cur_group) buckets: list[list[torch.fx.Node]] = [] - for nodes in nodes_groups.values(): + for nodes in nodes_groups: cur_bucket: list[torch.fx.Node] = [] cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet() cur_bucket_size_bytes: int = 0 @@ -128,15 +145,15 @@ def greedy_bucket_collective_by_mb( ) for node in nodes: if node in cur_bucket_successors: - # We can not bucket successors with the node + # We cannot bucket successors with the node continue assert "val" in node.meta n_val = node.meta["val"] out_size_bytes = n_val.numel() * n_val.element_size() - if ( - cur_bucket_size_bytes + out_size_bytes > bucket_size_bytes - and cur_bucket - ): + n_input_val = node.all_input_nodes[0].meta["val"] + in_size_bytes = n_input_val.numel() * n_input_val.element_size() + size_bytes = max(out_size_bytes, in_size_bytes) + if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket: # Current bucket is full, create new bucket if len(cur_bucket) > 1: buckets.append(cur_bucket) @@ -144,7 +161,7 @@ def greedy_bucket_collective_by_mb( cur_bucket_size_bytes = 0 cur_bucket_id += 1 cur_bucket_successors = OrderedSet() - cur_bucket_size_bytes += out_size_bytes + cur_bucket_size_bytes += size_bytes cur_bucket.append(node) cur_bucket_successors |= nodes_successors[node] if len(cur_bucket) > 1: @@ -163,7 +180,7 @@ def bucket_all_gather_by_mb( Args: gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers. - bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket + bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow to specify different sizes of the buckets at the start, as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx @@ -201,14 +218,14 @@ def bucket_reduce_scatter_by_mb( Args: gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters. - bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket + bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow to specify different sizes of the buckets. filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified, only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed. Returns: - list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes. + list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes. """ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: @@ -235,36 +252,20 @@ def reduce_scatter_merge_fn_to_trace( reduce_dtype: torch.dtype, # type: ignore[name-defined] device: torch.device, # type: ignore[name-defined] ) -> list[torch.Tensor]: # type: ignore[no-untyped-def] - rs_ins_flattened = [rs_in.view(-1) for rs_in in rs_ins] - - rs_ins_srcs = [ - rs_in_f.split([rs_in_f.numel() // group_size] * group_size) - for rs_in_f in rs_ins_flattened - ] + rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins] - foreach_copy_srcs = [] - for rank_idx in range(group_size): - for rs_in_idx in range(len(rs_ins)): - foreach_copy_srcs.append(rs_ins_srcs[rs_in_idx][rank_idx]) + new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins] + new_out_numels = [x.numel() // group_size for x in rs_ins] - new_rs_in = torch.cat(foreach_copy_srcs, dim=0) + new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() - wait_tensor = torch.ops.c10d_functional.wait_tensor( + new_rs_out = torch.ops.c10d_functional.wait_tensor( torch.ops._c10d_functional.reduce_scatter_tensor.default( new_rs_in, reduce_op, group_size, group_name ) ) - new_rs_out = wait_tensor - - new_outs = [] - new_rs_out_offset = 0 - for rs_in in rs_ins: - new_out_size = torch.Size((rs_in.shape[0] // group_size,) + rs_in.shape[1:]) # type: ignore[attr-defined] - new_out = new_rs_out.narrow(0, new_rs_out_offset, new_out_size.numel()).reshape( - new_out_size - ) - new_outs.append(new_out) - new_rs_out_offset += new_out_size.numel() + new_out_flat = new_rs_out.split(new_out_numels, 0) + new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)] return new_outs diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7133d77740bc..db273b06c8e6 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1760,17 +1760,44 @@ def __call__(self, graph: fx.Graph) -> None: movable_constructors = self.find_movable_constructors(graph, constructors) target_device = next(iter(target_devices)) - for node in movable_constructors: - if node in cpu_placeholders: - with graph.inserting_after(node): - gpu_node = graph.call_function( - torch.ops.prims.device_put.default, (node, target_device) + movable_cpu_placeholders = movable_constructors & cpu_placeholders + if movable_cpu_placeholders: + node = next(iter(reversed(movable_cpu_placeholders))) + last_node = node + unsqueezed_nodes = [] + for elem in movable_cpu_placeholders: + with graph.inserting_after(last_node): + unsqueezed_nodes.append( + graph.call_function(torch.ops.aten.unsqueeze.default, (elem, 0)) ) - node.replace_all_uses_with( - gpu_node, - lambda x: x != gpu_node - and x.target != torch.ops.aten.copy_.default, + last_node = unsqueezed_nodes[-1] + with graph.inserting_after(last_node): + cpu_concat = graph.call_function( + torch.ops.aten.cat.default, (unsqueezed_nodes,) + ) + last_node = cpu_concat + with graph.inserting_after(last_node): + gpu_concat = graph.call_function( + torch.ops.prims.device_put.default, + (cpu_concat, target_device, True), ) + last_node = gpu_concat + with graph.inserting_after(last_node): + gpu_split = graph.call_function( + torch.ops.aten.unbind.int, (gpu_concat,) + ) + last_node = gpu_split + for idx, node in enumerate(movable_cpu_placeholders): + with graph.inserting_after(last_node): + gpu_node = graph.call_function(operator.getitem, (gpu_split, idx)) + node.replace_all_uses_with( + gpu_node, + lambda x: x + not in [cpu_concat, gpu_concat, gpu_split, gpu_node] + + unsqueezed_nodes + and x.target != torch.ops.aten.copy_.default, + ) + last_node = gpu_node # noop elimination if there are other device_put for gpu_node to # target device. Alternatively, we could just move the other device_put @@ -1784,10 +1811,12 @@ def __call__(self, graph: fx.Graph) -> None: for noop in noop_device_puts: noop.replace_all_uses_with(gpu_node) graph.erase_node(noop) - else: - kwargs = node.kwargs.copy() - kwargs["device"] = target_device - node.kwargs = kwargs + + movable_constructors -= movable_cpu_placeholders + for node in movable_constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs def find_movable_constructors( self, graph: fx.Graph, constructors: list[fx.Node] diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2f0670c75725..31be050ab28d 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1205,7 +1205,9 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> error.operator_str(target, args, kwargs), ) - tag = get_layout_constraint_tag(target, with_default=False) + tag: Optional[torch._C.Tag] = get_layout_constraint_tag( + target, with_default=False + ) if ( tag is None and torch._library.utils.is_builtin(target) @@ -1222,8 +1224,10 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> # and identify them one by one. decided_constraint = require_contiguous # type: ignore[assignment] else: - tag = get_layout_constraint_tag(target, with_default=True) - decided_constraint = tag_to_layout_constraint(tag) + default_tag: torch._C.Tag = get_layout_constraint_tag( + target, with_default=True + ) + decided_constraint = tag_to_layout_constraint(default_tag) make_fallback(target, layout_constraint=decided_constraint) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 52cb8af69fdd..db62af361633 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6,6 +6,7 @@ import itertools import logging import operator +import os import textwrap import traceback from collections.abc import Container, Generator, Iterable, Iterator, Sequence @@ -156,6 +157,9 @@ indent = functools.partial(textwrap.indent, prefix=" ") aten = torch.ops.aten +autotune_warmup = int(os.getenv("TORCH_AUTOTUNE_WARMUP", 25)) +autotune_rep = int(os.getenv("TORCH_AUTOTUNE_REP", 100)) + """ [Note: Inductor IR] Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each @@ -510,6 +514,7 @@ def try_match_insignificant_strides( old_layout.size, new_stride, old_layout.offset, + old_layout.is_pinned, ) return TensorBox(ReinterpretView(data=storage, layout=new_layout)) @@ -2906,6 +2911,7 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: list(new_size), new_stride, old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -2952,6 +2958,7 @@ def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView: [old_layout.size[i] for i in dims], [old_layout.stride[i] for i in dims], old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3013,6 +3020,7 @@ def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: new_size, new_stride, old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3131,6 +3139,7 @@ def fake_reindex(index: Any) -> tuple[int, ...]: new_size, FlexibleLayout.contiguous_strides(new_size), old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3365,6 +3374,7 @@ def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView: old_layout.size, old_layout.stride, old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) return DtypeView(data=x, target_dtype=new_dtype) @@ -3472,6 +3482,7 @@ def create( # type: ignore[override] new_size, new_stride, old_layout.offset + old_layout.stride[dim] * start, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3568,6 +3579,13 @@ def storage_size(self) -> int: @ir_dataclass class Layout(OutputSpec): + """ + Layout base class + + Carries tensor meta-information including offset and + whether it is pinned. + """ + def __init__( self, device: torch.device, @@ -3575,6 +3593,7 @@ def __init__( size: Sequence[Expr], stride: Optional[Sequence[Expr]] = None, offset: Expr = Integer(0), + is_pinned: bool = False, ) -> None: if stride is None: stride = FlexibleLayout.contiguous_strides(size) @@ -3585,6 +3604,9 @@ def __init__( self.size = size self.stride = stride self.offset = offset + self.is_pinned = is_pinned + # is_pinned implies cpu + assert (not self.is_pinned) or (self.device.type == "cpu") def __str__(self) -> str: offset = "" @@ -3592,9 +3614,12 @@ def __str__(self) -> str: offset = f", offset={self.offset}" device_index_str = "" if self.device.index is None else f":{self.device.index}" + is_pinned_str = "" + if self.is_pinned: + is_pinned_str = f", is_pinned={self.is_pinned}" return ( f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, " - f"size={self.size}, stride={self.stride}{offset})" + f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})" ) __repr__ = __str__ @@ -3609,6 +3634,7 @@ def get_example(self) -> torch.Tensor: convert_shape_to_symint(self.stride), dtype=self.dtype, device=self.device, + pin_memory=self.is_pinned, ) def is_contiguous(self) -> bool: @@ -3707,10 +3733,8 @@ def _pad_strides( # do for dynamic shape. # # Skip padding the strides for dynamic shape for now. - if not all( - isinstance(s, (int, sympy.Integer)) - for s in itertools.chain(in_strides, size) - ): + # If outermost dim is dynamic, stride still can be fully static + if not all(isinstance(s, (int, sympy.Integer)) for s in in_strides): return in_strides stride_order = get_stride_order(in_strides) @@ -3725,11 +3749,11 @@ def _pad_strides( for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - - if stride > config.padding_stride_threshold and stride % align != 0: - stride = ceildiv(stride, align) * align - padded = True - new_strides[idx] = stride + if isinstance(stride, (int, sympy.Integer)): + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride if not padded: # Consider a tensor with shape [256, 1, 5, 5] @@ -3760,6 +3784,7 @@ def as_fixed(self) -> FixedLayout: self.size, self.stride, self.offset, + self.is_pinned, ) def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: @@ -3776,6 +3801,7 @@ def __eq__(self, other: object) -> bool: and self.size == other.size and self.stride == other.stride and self.offset == other.offset + and self.is_pinned == other.is_pinned ) def storage_size(self) -> Expr: @@ -3889,6 +3915,7 @@ def as_stride_order( self.size, new_stride, self.offset, + self.is_pinned, ) def as_exact_strides( @@ -3904,6 +3931,7 @@ def as_exact_strides( self.size, new_stride, self.offset, + self.is_pinned, ) def as_fill_order(self, order: Sequence[int]) -> FixedLayout: @@ -3916,6 +3944,7 @@ def as_fill_order(self, order: Sequence[int]) -> FixedLayout: self.size, new_stride, self.offset, + self.is_pinned, ) def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: @@ -3928,6 +3957,7 @@ def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: self.size, new_stride, self.offset, + self.is_pinned, ) def __init__( @@ -3936,12 +3966,13 @@ def __init__( dtype: torch.dtype, size: Sequence[Expr], stride_order: Optional[Sequence[Union[int, Integer]]] = None, + is_pinned: bool = False, ) -> None: if stride_order: strides = FlexibleLayout.fill_ordered(size, stride_order) else: strides = FlexibleLayout.contiguous_strides(size) - super().__init__(device, dtype, size, strides) + super().__init__(device, dtype, size, strides, is_pinned=is_pinned) class NonOwningLayout(Layout): @@ -4007,6 +4038,7 @@ def __init__( size=fixed.size, stride=fixed.stride, offset=fixed.offset, + is_pinned=fixed.is_pinned, ) self.comm_buffer_type = comm_buffer_type self.group_name = group_name @@ -4181,6 +4213,9 @@ def get_output_spec(self) -> OutputSpec: def get_storage_numel(self) -> int: return self.get_numel() + def get_is_pinned(self) -> bool: + return self.get_layout().is_pinned + def freeze_layout(self) -> None: if isinstance(self.layout, Layout) and not isinstance( self.layout, NonOwningLayout @@ -4877,9 +4912,13 @@ def __init__( def benchmark(self, *args: Any, out: torch.Tensor) -> float: algo = self.to_callable() + benchmark_configs = { + "warmup": autotune_warmup, + "rep": autotune_rep, + } if config.profile_bandwidth_with_do_bench_using_profiling: - return do_bench_using_profiling(lambda: algo(*args)) - return benchmarker.benchmark(algo, args, {"out": out}) + return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) + return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs) def call_name(self) -> str: raise NotImplementedError @@ -5148,6 +5187,9 @@ class ConcatKernel(NopKernel): @classmethod def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: + """ + Create the concat kernel from inputs + """ device = inputs[0].get_device() dtype = inputs[0].get_dtype() new_size = list(inputs[0].get_size()) @@ -5201,6 +5243,10 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: ): output_stride = make_channels_last_strides_for(new_size) + is_pinned = all( + is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs + ) + assert device is not None concat_kernel = ConcatKernel( name=None, @@ -5209,6 +5255,7 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: dtype=dtype, size=new_size, stride=output_stride, + is_pinned=is_pinned, ), inputs=[], ) @@ -5324,6 +5371,11 @@ def should_allocate(self) -> bool: @ir_dataclass(frozen=False) class ExternKernel(InputsKernel): + """ + A class that represents Kernels which are not directly lowered to Inductor + Loop Level IR, such as custom operators, or aten operators which we fallback to. + """ + constant_args: Sequence[Any] = () kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) output_view: Optional[ReinterpretView] = None @@ -5688,6 +5740,7 @@ def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: size=x.get_size(), stride=strides, offset=offset, + is_pinned=False, ), ) @@ -6120,6 +6173,17 @@ def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None: f"# buffer {name} (op: {op_name}) is assumed to be not aligned" ) + def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None: + """ + Track outputs of fallback operators if config.test_configs.track_memory_lifecycle + """ + if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper: + return + + wrapper.write_memory_track_allocation_once() + name = self.get_name() + wrapper.writeline(f"track_tensor({name}, '{name}')") + def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: """ get output sizes and strides, for template_codegen @@ -7005,11 +7069,27 @@ def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: developer_warning("DeviceCopy in input program") constant_args = (non_blocking,) + # Device Copy should keep the same layout as input + x = ExternKernel.require_contiguous(x) + stride = None + if x.get_size(): + # x.get_stride() may be unimplemented if x's size is empty + stride = x.get_stride() + is_destination_pinned = ( + is_gpu(x_device.type) and device.type == "cpu" and non_blocking + ) + is_source_pinned = ( + x_device.type == "cpu" and is_gpu(device.type) and non_blocking + ) + if is_source_pinned and is_storage_and_layout(x): + x.get_layout().is_pinned = True return DeviceCopy( - FlexibleLayout( - device=device, - dtype=x.get_dtype(), - size=x.get_size(), + FixedLayout( + device, + x.get_dtype(), + x.get_size(), + stride, + is_pinned=is_destination_pinned, ), [cls.realize_input(x)], constant_args, @@ -7572,16 +7652,24 @@ def is_number(t: torch.JitType) -> bool: if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) self.codegen_alignment_asserts(wrapper) + self.codegen_memory_tracking(wrapper) self.codegen_unbacked_symbol_defs(wrapper) @staticmethod def tensor_to_layout(output: torch.Tensor) -> FixedLayout: + is_pinned = False + try: + is_pinned = output.is_pinned() + except RuntimeError: + # dispatch not implemented + pass return FixedLayout( output.device, output.dtype, convert_shape_to_inductor(output.size()), convert_shape_to_inductor(output.stride()), + is_pinned=is_pinned, ) @classmethod @@ -7713,6 +7801,31 @@ def __init__( ) +class MemoryCheckKernel(FallbackKernel): + """ + Custom kernel for memory checking that generates direct function calls + + TODO - the custom op was erroring with str inputs. should be able to custom op directly. + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Override codegen to write direct function call""" + # Extract our arguments from nontensor_args + wrapper.write_memory_track_allocation_once() + alive_list, dead_list, is_final_step = self.constant_args + + alive_repr = repr(alive_list) + dead_repr = repr(dead_list) + if is_final_step: + wrapper.writeline( + "# note: dont currently distinguish between buffers returned and dealloc'd in last step" + ) + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})" + else: + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})" + wrapper.writeline(call) + + @ir_dataclass class MultiOutputLayout(OutputSpec): device: torch.device @@ -7957,6 +8070,7 @@ def realize(self) -> Optional[str]: device=device, dtype=self.data.get_dtype(), size=self.data.get_size(), + is_pinned=False, ), data=self.data, ) @@ -8137,6 +8251,7 @@ def create_output( size=output.get_size(), stride=output.get_stride(), offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, ), invoke_subgraph, # type: ignore[has-type] [(list, ind)], @@ -8266,6 +8381,7 @@ def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]: size=[_maybe_expr(sz) for sz in merged_output.size()], stride=[_maybe_expr(sz) for sz in merged_output.stride()], offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, ), conditional, [(list, i)], @@ -8493,6 +8609,7 @@ def _guard_list_equals( size=output.get_size(), stride=output.get_stride(), offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, ), while_loop, [(list, idx)], diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index be782bac3a82..7375deff9a5f 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -6,6 +6,7 @@ from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from .. import ir, lowering as L +from ..kernel_inputs import MMKernelInputs from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, @@ -26,8 +27,6 @@ addmm_epilogue, is_batch_stride_largest, mm_args, - mm_config_kwargs, - mm_options, ) @@ -40,13 +39,6 @@ def bmm_grid(b, m, n, meta, *, cdiv): return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) -def _is_large_block_for_cpu(m, n, k): - # Thresholds are experimentally determined to reduce Triton CPU compile times - if m > 128 or n > 128 or k > 128: - return True - return m * n > 2**12 - - bmm_template = TritonTemplate( name="bmm", grid=bmm_grid, @@ -175,9 +167,14 @@ def may_require_contiguous(t, meta_t): meta_mat2 = V.graph.current_node.args[1] mat2 = may_require_contiguous(mat2, meta_mat2) + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2 = mm_args( mat1, mat2, layout=layout, out_dtype=out_dtype ) + name = "bmm" + + # Create MMKernelInputs for BMM at the top + kernel_inputs = MMKernelInputs([mat1, mat2]) # below is for getting an overview logging info of inductor mms batch_size = mat1.get_size()[0] # Extract batch dimension @@ -195,31 +192,27 @@ def may_require_contiguous(t, meta_t): if out_dtype: assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA" - aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype) + aten_func = aten_bmm_dtype.bind( + kernel_inputs.nodes(), layout, out_dtype=out_dtype + ) else: - aten_func = aten_bmm.bind((mat1, mat2), layout) + aten_func = aten_bmm.bind(kernel_inputs.nodes(), layout) # options to tune from choices = [aten_func] if use_aten_gemm_kernels() else [] - device_type = ir.get_device_type(mat1) - bmm_configs = V.choices.get_base_mm_configs(device_type) - - dtype = mat1.get_dtype() if use_triton_template(layout): # TODO: add out_dtype support for Triton Template assert out_dtype is None, "out_dtype is not supported for Triton" - for config in bmm_configs( - m, - n, - k, - **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, bmm_template.name, name ): bmm_template.maybe_append_choice( choices, - input_nodes=(mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, - **mm_options(config, m, n, k, layout), + **kwargs, ) _, is_nonzero = _is_static_problem(layout) batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout) @@ -227,11 +220,13 @@ def may_require_contiguous(t, meta_t): batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k) - and _use_cutlass_for_op("bmm") + and _use_cutlass_for_op(name) ): from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate - CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type] + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, kernel_inputs.nodes() + ) # type: ignore[arg-type] if use_cpp_bmm_template(layout, mat1, mat2): from ..codegen.cpp_bmm_template import CppBmmTemplate @@ -239,19 +234,23 @@ def may_require_contiguous(t, meta_t): CppBmmTemplate.add_choices( choices, layout, - [mat1, mat2], + kernel_inputs.nodes(), ) if use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) - return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) @L.register_lowering(aten.baddbmm) def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + # Create MMKernelInputs for BadDBMM at the top + kernel_inputs = MMKernelInputs([inp, mat1, mat2]) + # below is for getting an overview logging info of inductor mms batch_size = mat1.get_size()[0] counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1 @@ -266,29 +265,26 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): inp.get_dtype(), layout, ) - + name = "baddbmm" # options to tune from choices = ( - [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + [aten_baddbmm.bind(kernel_inputs.nodes(), layout, alpha=alpha, beta=beta)] if use_aten_gemm_kernels() else [] ) - device_type = ir.get_device_type(mat1) - bmm_configs = V.choices.get_base_mm_configs(device_type) - if use_triton_template(layout): - for config in bmm_configs( - m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, bmm_template.name, name ): bmm_template.maybe_append_choice( choices, - input_nodes=(inp, mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, - **mm_options(config, m, n, k, layout), + **kwargs, prefix_args=1, epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), ) - return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index ba1dc4aa2c24..6b9e9a1a32e7 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -29,7 +29,6 @@ use_triton_template, ) from ..virtualized import V -from .mm_common import mm_config_kwargs if TYPE_CHECKING: @@ -61,13 +60,6 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv): ) -def _is_large_block_for_cpu(m, n, k): - # Thresholds are experimentally determined to reduce Triton CPU compile times - if m > 256 or n > 256 or k > 256: - return True - return m * n * k > 2**17 - - LOOP_BODY_2D = """ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W @@ -603,7 +595,6 @@ def channels_last_conv(): sympy_product([x.get_size()[0], *x.get_size()[2:]]), out_chan, in_chan, - **mm_config_kwargs(device_type, _is_large_block_for_cpu), ): if ndim == 2: conv2d_template.maybe_append_choice( diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index 8ee50753439e..6cc197a35b9c 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -3,6 +3,7 @@ import math from collections.abc import Sequence +from pathlib import Path from typing import Any, Optional, Union import sympy @@ -323,267 +324,13 @@ def next_power_of_two(n): return 2 ** math.ceil(math.log2(n)) -# ---- Common Template Strings ---- -compute_forward_block_mn = r""" -@triton.jit -def forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +_TEMPLATE_DIR = Path(__file__).parent / "templates" -): - # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through - {{gen_defines() | indent_except_first(1)}} - - # -- load k -- - # NB reversed order to since K is transposed - {%- if USE_TMA %} - k = tl.load_tensor_descriptor( - desc_k, - [kv_start + kv_offset, 0], - ) - {%- else %} - k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) - {%- endif %} - - if USE_TMA: - k = tl.trans(k) - # -- compute qk --- - qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. - if not PRESCALE_QK: - qk *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, - # which is larger than the actual number of elements. To avoid access memory out of bound, - # we need to mask out the elements that are out of Q_LEN & KV_LEN. - m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) - n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) - - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - out="qk" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=1, - output_name="mask_mod_output", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - ) | indent_except_first(2) }} - - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) - # apply mask for partially unmasked blocks - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - # -- compute scaling constant --- - m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) - if not ROWS_GUARANTEED_SAFE: - masked_out_rows = (m_ij == float("-inf")) - m_ij_masked = tl.where(masked_out_rows, 0, m_ij) - else: - m_ij_masked = m_ij - - alpha = tl.math.exp2(m_i - m_ij_masked) - p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) - - # NB: l_i update is pulled up here since it's a bit faster - # NB: For headdim=256, it's faster to move it back down to after m_i = - # m_ij - l_i = l_i * alpha + tl.sum(p, 1) - # # -- scale and update acc -- - acc = acc * alpha[:, None] - {%- if USE_TMA %} - v = tl.load_tensor_descriptor( - desc_v, - [kv_start + kv_offset, 0], - ) - {%- else %} - v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) - {%- endif %} - acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) - - # -- update m_i - m_i = m_ij - - return acc, l_i, m_i - -""" - -compute_forward_inner = r""" -@triton.jit -def forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets used as inputs to score_mod & mask_mod - # of size [BLOCK_M, BLOCK_N] or scalar. - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - # blocksparse data - kv_indices, kv_num_blocks, - # start kv and end kv block - block_n_start, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through - {{gen_defines() | indent_except_first(1)}} - - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - RCP_LN2: tl.constexpr = 1.44269504 - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - kv_offset = 0 - - # loop over k, v and update accumulator until block_n_end - for start_n in range(block_n_start, block_n_end): - # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. - if IS_DIVISIBLE: - acc, l_i, m_i = forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - else: - # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, - # it's on par or slightly faster than only applying to the last block in fwd. - # However, we choose different strategy for bwd, where we only apply mod & mask - # to the last block because it's faster a lot. - acc, l_i, m_i = forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - - - - offset = get_offset_for_next_block( - start_n, kv_indices, kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS - ) - offs_n = offs_n + offset - kv_offset += offset - if not USE_TMA: - K_block_ptr = tl.advance(K_block_ptr, (0, offset)) - V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) +def load_template(name: str) -> str: + """Load a template file and return its content.""" + with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f: + return f.read() - return acc, l_i, m_i - -""" - -# Inner Triton functions shared by flex_attention & split-k decoding kernels. -compute_next_offset_func = r""" -@triton.jit -def get_offset_for_next_block( - loop_iter, col_indices, total_blocks, - SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, - BLOCKS_ARE_CONTIGUOUS: tl.constexpr -): - if BLOCKS_ARE_CONTIGUOUS: - return BLOCK - cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE - cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") - next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) - needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 - jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK - offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK - return offset -""" - -get_bounded_indices_func = r""" -@triton.jit -def get_bounded_indices(indices, max_len=None): - return indices % max_len if max_len is not None else indices -""" - - -load_checked_block = r""" -@triton.jit -def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): - if IS_DIVISIBLE and SAFE_HEAD_DIM: - return tl.load(block_ptr) - elif IS_DIVISIBLE and not SAFE_HEAD_DIM: - return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") - elif not IS_DIVISIBLE and SAFE_HEAD_DIM: - return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") - else: - return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") -""" - -load_checked_2d = r""" -@triton.jit -def load_checked_2d( - ptr, - offs_m, - offs_n, - stride_m, - stride_n, - IS_DIVISIBLE_M: tl.constexpr, - IS_DIVISIBLE_N: tl.constexpr, - M_LEN: tl.constexpr, - N_DIM: tl.constexpr, -): - # Calculate final pointer if strides are provided - if stride_m is not None and stride_n is not None: - ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n - - # Handle all masking cases - if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) - elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) - elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) - else: # Both divisible - return tl.load(ptr) -""" +# Template strings have been moved to templates/common.py.jinja diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 0553fd06755d..a3e441d033b3 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -22,17 +22,12 @@ ) from .common import ( build_subgraph_buffer, - compute_forward_block_mn, - compute_forward_inner, - compute_next_offset_func, create_indices_fake, create_num_blocks_fake_generator, create_placeholder, - get_bounded_indices_func, get_fwd_subgraph_outputs, infer_dense_strides, - load_checked_2d, - load_checked_block, + load_template, maybe_realize, set_head_dim_values, SubgraphResults, @@ -67,267 +62,12 @@ def get_float32_precision(): return "'tf32'" -compute_flex_attention = r""" -{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} - # Sub notation for this kernel: - # - # Q: Query, K: Key, V: Value - # M: Number of queries, N: Number of keys/values, D: Model dimension - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # - # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. - # - # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad - # - # (Modifiable) Performance tuning options - # BLOCK_M: The thread block size across the seqlen dim of Q. - # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. - - # The below are kernel options that can be applied for certain score_mods, - # or involve a numerics vs. perf tradeoff - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has - # about 20% more numerical error, but slightly faster. - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are - # contiguous? If so, we don't need to do an indirect jump for every block - - tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) - tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) - - # Define strides of inputs - stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - - ZQ = {{size("Q", 0)}} - HQ = {{size("Q", 1)}} - Q_LEN = {{size("Q", 2)}} - ZKV = {{size("K", 0)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - q_start = tl.program_id(0) - off_zq = tl.program_id(1) - off_hq = tl.program_id(2) - - # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. - # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. - off_zkv = off_zq % ZKV - off_hkv = off_hq // GQA_SHARED_HEADS - off_g = off_hq % GQA_SHARED_HEADS - - q_offset = off_zq * stride_qz + off_hq * stride_qh - k_offset = off_zkv * stride_kz + off_hkv * stride_kh - v_offset = off_zkv * stride_vz + off_hkv * stride_vh - - Q = Q + q_offset - K = K + k_offset - V = V + v_offset - - # Setting up the TMA descriptors for Q, K, V - desc_q = None - desc_k = None - desc_v = None - {%- if USE_TMA %} - desc_q = tl.make_tensor_descriptor( - base=Q, - shape=[Q_LEN, QK_HEAD_DIM], - strides=[stride_qm, 1], - block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], - ) - - desc_k = tl.make_tensor_descriptor( - base=K, - shape=[KV_LEN, QK_HEAD_DIM], - strides=[stride_kn, 1], - block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], - ) - - desc_v = tl.make_tensor_descriptor( - base=V, - shape=[KV_LEN, V_HEAD_DIM], - strides=[stride_vn, 1], - block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], - ) - {%- endif %} - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_zq % SPARSE_Z - sparse_idx_hq = off_hq % SPARSE_HQ - - SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - - stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} - stride_kv_idx_h = {{stride("KV_IDX", 1)}} - stride_kv_idx_m = {{stride("KV_IDX", 2)}} - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - - offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) - - # KV_IDX and KV_NUM_BLKS are always contiguous. - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq - sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE - sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 - K_block_ptr = None - V_block_ptr = None - Q_block_ptr = None - - if not USE_TMA: - Q_block_ptr = tl.make_block_ptr( - base=Q , - shape=(Q_LEN, QK_HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(q_start * BLOCK_M, 0), - block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - - {%- if USE_TMA %} - q = tl.load_tensor_descriptor( - desc_q, - [(q_start * BLOCK_M).to(tl.int32), 0], - ) - {%- else %} - q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) - {%- endif %} - - # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We don't know anything "special" about these blocks, so we need to apply - # both score_mod and mask_mod to it - kv_indices = KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - - if not USE_TMA: - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(QK_HEAD_DIM, KV_LEN), - strides=(stride_kk, stride_kn), - offsets=(0, kv_start), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(kv_start, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - - offs_n = kv_start + tl.arange(0, BLOCK_N) - - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - acc, l_i, m_i, - off_zq, off_hq, offs_m[:, None], offs_n[None, :], - kv_start, - kv_indices, kv_num_blocks, - 0, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We know these blocks are guaranteed to be "full", so we don't need to - # apply mask_mod to them - only score_mod - if HAS_FULL_BLOCKS: - # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. - kv_indices = FULL_KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - if not USE_TMA: - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(QK_HEAD_DIM, KV_LEN), - strides=(stride_kk, stride_kn), - offsets=(0, kv_start), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(kv_start, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = kv_start + tl.arange(0, BLOCK_N) - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - acc, l_i, m_i, - off_zq, off_hq, offs_m[:, None], offs_n[None, :], - kv_start, - kv_indices, kv_num_blocks, - 0, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - - # [Note] Handle fully masked out rows: - # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. - # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step - l_i = tl.where(l_i == 0.0, 1, l_i) - - acc = acc / l_i[:, None] - idx_zq = tl.program_id(1) - idx_hq = tl.program_id(2) - idx_m = offs_m[:, None] - idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] - - mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - - {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - - if OUTPUT_LOGSUMEXP: - off_hz = off_zq * HQ + off_hq - l_ptrs = LSE + off_hz * Q_LEN + offs_m - lse = m_i + tl.math.log2(l_i) - if IS_DIVISIBLE: - tl.store(l_ptrs, lse) - else: - tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) - """ - - flex_attention_template = TritonTemplate( name="flex_attention", grid=flex_attention_grid, - source=compute_flex_attention - + compute_forward_inner - + compute_next_offset_func - + compute_forward_block_mn - + load_checked_block - + get_bounded_indices_func, + source=load_template("flex_attention") + + load_template("utilities") + + load_template("common"), ) @@ -361,7 +101,6 @@ def flex_attention( score_mod_other_buffers, mask_mod_other_buffers, ) - # below is cuda path if device is not cpu # tl.dot does not support embedding size less than 16 small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) @@ -685,693 +424,7 @@ def flex_attention_backward_grid( flex_attention_backward_template = TritonTemplate( name="flex_attention_backward", grid=flex_attention_backward_grid, - source=r""" -{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} - # Sub notation for this kernel: - # - # Q: Query, K: Key, V: Value - # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) - # DELTA: Precomputed sum(OUT*DO, axis=-1) - # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value - # DK: Derivative of Key, is the written to via the store_output call due to some limitations with - # inductor codegen - # M: Number of queries, N: Number of keys/values - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # (Modifiable) Performance tuning options - # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. - # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. - # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. - # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. - # - # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. - # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. - # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. - # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. - - # The below are kernel options that can be applied for certain score_mods, - # or involve a numerics vs. perf tradeoff - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has - # about 20% more numerical error, but slightly faster. - - # Define strides of inputs - stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} - stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} - - stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} - stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} - - ZQ = {{size("Q", 0)}} - HQ = {{size("Q", 1)}} - HKV = {{size("K", 1)}} - Q_LEN = {{size("Q", 2)}} - ZKV = {{size("K", 0)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - pid = tl.program_id(0) - NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) - NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) - - off_zq = tl.program_id(1) # q batch idx - off_hkv = tl.program_id(2) # kv head idx - off_zkv = off_zq % ZKV # kv batch idx - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_zq % SPARSE_Z - - k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) - v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) - # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] - # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) - - # offset K, V, DV pointers for batch/kv-head - K += k_adj - V += v_adj - DV += dv_adj - - RCP_LN2 = 1.44269504 - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - if pid >= NUM_KV_BLOCKS: - off_pid = pid - NUM_KV_BLOCKS - # THIS BLOCK DOES DQ - SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) - SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) - off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS - start_m2_block = off_pid % NUM_Q_BLOCKS - off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE - stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} - stride_kv_idx_h = {{stride("KV_IDX", 1)}} - stride_kv_idx_m = {{stride("KV_IDX", 2)}} - - sparse_idx_hq2 = off_hq2 % SPARSE_HQ - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 - - sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask - sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 - - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. - q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) - do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) - dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) - off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) - - Q2 = Q + q_adj2 - DO2 = DO + do_adj2 - # TODO: This does not work if DQ is not the same layout as Q (for example, - # if Q is broadcasted) - DQ2 = DQ + dq_adj2 - LSE2 = LSE + off_chz2 - DELTA2 = DELTA + off_chz2 - - # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) - dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) - - start_m2 = start_m2_block * BLOCK_M2 - offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - - # load Q and do: they stay in SRAM throughout the inner loop. - q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) - do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - if IS_DIVISIBLE: - Di = tl.load(DELTA2 + offs_m2) - lse = tl.load(LSE2 + offs_m2) - else: - Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) - lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) - lse = tl.where(lse == -float("inf"), 0.0, lse) - lse = lse[:, None] - - # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # KV_IDX and KV_NUM_BLKS are always contiguous. - kv_indices = KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - - offs_n2 = kv_start + tl.arange(0, BLOCK_N2) - dq = bwd_dq_inner( - {{gen_argdefs()}}, - K, V, - dq, q, do, Di, lse, - off_zq, off_hq2, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - if HAS_FULL_BLOCKS: - # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. - kv_indices = FULL_KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - - offs_n2 = kv_start + tl.arange(0, BLOCK_N2) - dq = bwd_dq_inner( - {{gen_argdefs()}}, - K, V, - dq, q, do, Di, lse, - off_zq, off_hq2, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - # Write back dQ. - dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd - dq *= SM_SCALE - if IS_DIVISIBLE and SAFE_HEAD_DIM: - tl.store(dq_ptrs, dq) - else: - tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) - else: - # THIS BLOCK DOES DK & DV - SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) - SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) - - pid_mask = pid // SPARSE_KV_MULTIPLE - - stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} - stride_q_idx_h = {{stride("Q_IDX", 1)}} - stride_q_idx_n = {{stride("Q_IDX", 2)}} - - - dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) - - start_n1 = pid * BLOCK_N1 - offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - - # load K and V: they stay in SRAM throughout the inner loop. - k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) - v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) - - if PRESCALE_QK: - k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - for off_g in range(0, GQA_SHARED_HEADS): - off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g - - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. - q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) - do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) - dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) - off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) - - Q1 = Q + q_adj1 - DO1 = DO + do_adj1 - # TODO: This does not work if DQ is not the same layout as Q (for example, - # if Q is broadcasted) - LSE1 = LSE + off_chz1 - DELTA1 = DELTA + off_chz1 - - sparse_idx_hq1 = off_hq1 % SPARSE_HQ - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 - - sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask - sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 - - # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Q_IDX and Q_NUM_BLKS are always contiguous. - q_indices = Q_IDX + sparse_q_idx_offset - q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading - sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) - - offs_m1 = q_start + tl.arange(0, BLOCK_M1) - dk, dv = bwd_dkdv_inner( - {{gen_argdefs()}}, - Q1, DO1, DELTA1, LSE1, - dk, dv, k, v, - off_zq, off_hq1, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - - if HAS_FULL_BLOCKS: - # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. - q_indices = FULL_Q_IDX + sparse_q_idx_offset - q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading - sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) - - offs_m1 = q_start + tl.arange(0, BLOCK_M1) - dk, dv = bwd_dkdv_inner( - {{gen_argdefs()}}, - Q1, DO1, DELTA1, LSE1, - dk, dv, k, v, - off_zq, off_hq1, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - # Write back dV and dK. - dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd - - index_n = offs_n1[:, None] - index_k = offs_k[None, :] - index_v = offs_v[None, :] - - if IS_DIVISIBLE and SAFE_HEAD_DIM: - tl.store(dv_ptrs, dv) - else: - tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) - - dk *= SM_SCALE - - if SAFE_HEAD_DIM: - mask = index_n < KV_LEN - else: - mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) - - # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] - # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} - -@triton.jit -def bwd_dq_inner( - {{gen_argdefs()}}, - K, V, # pointers - dq, q, do, Di, lse, - off_z, off_hq, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - {{gen_defines() | indent_except_first(1) }} - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) - RCP_LN2: tl.constexpr = 1.44269504 - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} - - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd - vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - - hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) - if not IS_DIVISIBLE: - if hi >= 1: - for start_n in range(0, hi - 1): - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - - # Increment pointers. - offset = get_offset_for_next_block( - start_n, kv_indices, sparse_kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS - ) - - kT_ptrs += offset * stride_kn - vT_ptrs += offset * stride_vn - - offs_n2 += offset - - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - else: - for start_n in range(0, hi): - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - - # Increment pointers. - offset = get_offset_for_next_block( - start_n, kv_indices, sparse_kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS - ) - - kT_ptrs += offset * stride_kn - vT_ptrs += offset * stride_vn - - offs_n2 += offset - - return dq - - -@triton.jit -def bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, -): - {{gen_defines() | indent_except_first(1)}} - - # NB reversed order to since K is transposed - kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) - qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) - if not PRESCALE_QK: - qk *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - pre_mod_scores = qk - n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) - # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim - # that the M reads out of bounds prior to the last loop - m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) - - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_hq", - m="m", - n="n", - out="qk" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=2, - output_name="mask_mod_output", - score="qk", - b="off_z", - h="off_hq", - m="m", - n="n", - ) | indent_except_first(2) }} - - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) - # apply mask for partial masked block - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - p = tl.math.exp2(post_mod_scores - lse) - # Compute dP and dS. - # NB reversed order to since V is transposed - vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) - - dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) - ds = p * (dp - Di[:, None]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(1) }} - if CHECK_BLOCK_BOUNDARY: - grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) - - # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ - if WRITE_DQ: - scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN - {{ modification( - subgraph_number=3, - output_name=None, - mask="scatter_mask", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(2) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = grad_scores - - if not IS_FULL_BLOCKS: - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for partially unmasked block - ds = tl.where(mask_mod_output, ds, 0.0) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = ds.to(MATMUL_PRECISION) - # Compute dQ. - dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) - - return dq - - -@triton.jit -def bwd_dkdv_inner( - {{gen_argdefs()}}, - Q, DO, DELTA, LSE, # pointers - dk, dv, k, v, - off_z, off_hq, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - {{gen_defines() | indent_except_first(1) }} - SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) - RCP_LN2: tl.constexpr = 1.44269504 - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} - - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd - do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) - - if not IS_DIVISIBLE: - if hi >= 1: - for start_m in range(0, hi - 1): - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - # Increment pointers. - offset = get_offset_for_next_block( - start_m, q_indices, sparse_q_num_blocks, - SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS - ) - - qT_ptrs += offset * stride_qm - do_ptrs += offset * stride_dom - - offs_m1 += offset - - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - else: - for start_m in range(0, hi): - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - # Increment pointers. - offset = get_offset_for_next_block( - start_m, q_indices, sparse_q_num_blocks, - SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS - ) - - qT_ptrs += offset * stride_qm - do_ptrs += offset * stride_dom - - offs_m1 += offset - - return dk, dv - - -@triton.jit -def bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, -): - {{gen_defines() | indent_except_first(1) }} - - # NB reversed order since Q is transposed - qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) - # Load LSE before computing qk to reduce pipeline stall. - if IS_DIVISIBLE: - lse = tl.load(LSE + offs_m1) - else: - lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) - lse = tl.where(lse == -float("inf"), 0.0, lse) - qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) - if not PRESCALE_QK: - qkT *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) - # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim - # that the n reads out of bounds prior to the last loop - n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) - - pre_mod_scores = qkT - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qkT", - b="off_z", - h="off_hq", - m="m", - n="n", - out="qkT" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=2, - output_name="mask_mod_output", - score="qkT", - b="off_z", - h="off_hq", - m="m", - n="n", - ) | indent_except_first(2) }} - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for fully masked block - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - pT = tl.math.exp2(post_mod_scores - lse[None, :]) - do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) - # Compute dV. - ppT = pT - dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) - if IS_DIVISIBLE: - Di = tl.load(DELTA + offs_m1) - else: - Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) - dsT = pT * (dpT - Di[None, :]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="dsT" - ) | indent_except_first(1) }} - - # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ - if not WRITE_DQ: - idx_b = off_z - idx_h = off_hq - idx_m = m - idx_n = n - scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN - {{ modification( - subgraph_number=3, - output_name=None, - mask="scatter_mask", - score="pre_mod_scores", - b="idx_b", - h="idx_h", - m="idx_m", - n="idx_n", - grad_score_mod="dsT" - ) | indent_except_first(2) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - if CHECK_BLOCK_BOUNDARY: - grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) - - dsT = grad_scores - if not IS_FULL_BLOCKS: - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for partially unmasked block - dsT = tl.where(mask_mod_output, dsT, 0.0) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) - - return dk, dv - """ - + compute_next_offset_func - + get_bounded_indices_func - + load_checked_2d, + source=load_template("flex_backwards") + load_template("utilities"), ) @@ -1535,10 +588,12 @@ def flex_attention_backward(*args, **kwargs): for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) - if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: - kernel_options.setdefault("IS_DIVISIBLE", False) - else: + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) fwd_placeholder_inps = [ create_placeholder(name, dtype, device) diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 83c6b59cec96..361729d44b99 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -18,15 +18,10 @@ TritonTemplate, ) from .common import ( - compute_forward_block_mn, - compute_forward_inner, - compute_next_offset_func, create_indices_fake, create_num_blocks_fake_generator, - get_bounded_indices_func, get_fwd_subgraph_outputs, - load_checked_2d, - load_checked_block, + load_template, maybe_realize, set_head_dim_values, ) @@ -90,266 +85,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me flex_decoding_template = TritonTemplate( name="flex_decoding", grid=flex_decoding_grid, - source=r""" - {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} - # Sub notation for this kernel: - # Q: Query, K: Key, V: Value - # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split - # M: Number of queries, N: Number of keys/values - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits - # (Modifiable) Config options: - # SPLIT_KV: number of blocks K & V are split into - # TILE_KV: length of each local KV split - # BLOCK_M: block size that Q is padded along seqlen dim. - # BLOCK_N: block size of K & V along N dimension. - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # - # change of base out of the loop - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. - # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. - - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. - # - # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # - # - # Output: ACC output accumulated across local KV split. - - tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) - - # Define Q Strides - stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} - stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} - - - Z = {{size("Q", 0)}} - ZKV = {{size("K", 0)}} - HKV = {{size("Q", 1)}} - G: tl.constexpr = GQA_SHARED_HEADS - HQ = HKV * G - Q_LEN = {{size("Q", 3)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - # Make sure each split is a multiple of BLOCK_N - TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) - TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N - TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) - - off_z = tl.program_id(0) // HKV - off_zkv = off_z % ZKV - off_hkv = tl.program_id(0) % HKV - off_t = tl.program_id(1) - - q_offset = off_z * stride_qz + off_hkv * stride_qh - k_offset = off_zkv * stride_kz + off_hkv * stride_kh - v_offset = off_zkv * stride_vz + off_hkv * stride_vh - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_z % SPARSE_Z - sparse_idx_h = off_hkv % SPARSE_HQ - - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - - # initialize offsets - tl.device_assert(BLOCK_M % G == 0) - BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G - off_g = tl.arange(0, G) # [G] - offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] - offs_hq = offs_g + off_hkv * G - off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] - offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] - offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) - - # Get HZ offsets for KV_NUM_BLKS and KV_IDX - stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} - sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h - stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} - sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h - - # Calculate KV blocks that belong this CTA. - block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block - block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N - - q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] - - if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) - elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) - elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) - else: - q = tl.load(Q + q_offset + q_range) - - q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) - - - # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Apply both score_mod and mask_mod - - # find first kv block we are loading and the number of blocks we are loading - # Offset the kv_indices tensor by the correct batch and head - kv_indices = KV_IDX + sparse_idx_hz_offset - kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - # first kv block we're loading - - # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(off_n, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = tl.arange(0, BLOCK_N) + off_n - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, - # accumulatd values - acc, l_i, m_i, - #offsets - off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], - None, - #block sparse data - kv_indices, kv_num_blocks, - block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - - # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We know these blocks are guaranteed to be "full", so we don't need to - # apply mask_mod to them - only score_mod - if HAS_FULL_BLOCKS: - kv_indices = FULL_KV_IDX + sparse_idx_hz_offset - kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) - # Assign full block in a reverse order for off_t. Prioritize the last CTA. - block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE - block_n_end = block_n_start + TILE_KV_MULTIPLE - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - - # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(off_n, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = tl.arange(0, BLOCK_N) + off_n - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, - # accumulatd values - acc, l_i, m_i, - #offsets - off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], - None, - #block sparse data - kv_indices, kv_num_blocks, - block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - m_offset = off_t * stride_mt + off_z * stride_mz - l_offset = off_t * stride_lt + off_z * stride_lz - - M_block_ptr = tl.make_block_ptr( - base=M + m_offset, - shape=(G, Q_LEN), # (G, M) - strides=(stride_mh, stride_mm), - offsets=(off_hkv*G, 0), - block_shape=(G, BLOCK_M_PER_HQ), - order=(1, 0) - ) - L_block_ptr = tl.make_block_ptr( - base=L + l_offset, - shape=(G, Q_LEN), # (G, M) - strides=(stride_lh, stride_lm), - offsets=(off_hkv*G, 0), - block_shape=(G, BLOCK_M_PER_HQ), - order=(1, 0) - ) - - # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) - m_i = m_i.reshape(G, BLOCK_M_PER_HQ) - l_i = l_i.reshape(G, BLOCK_M_PER_HQ) - if SAFE_M_BOUNDARY: - tl.store(M_block_ptr, m_i) - tl.store(L_block_ptr, l_i) - else: - tl.store(M_block_ptr, m_i, boundary_check=(1,)) - tl.store(L_block_ptr, l_i, boundary_check=(1,)) - - # -- store output - idx_z = off_z - idx_t = off_t - idx_hq = off_hkv*G + off_g[:, None, None] - idx_m = off_m[None, :, None] - idx_d = offs_vd[None, None, :] - - mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) - {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - """ - + compute_forward_inner - + get_bounded_indices_func - + load_checked_block - + load_checked_2d - + compute_next_offset_func - + compute_forward_block_mn, + source=load_template("flex_decode") + + load_template("utilities") + + load_template("common"), ) @@ -410,11 +148,12 @@ def create_flex_decoding_kernel(*args, **kwargs): for k, v in kernel_options.items() } - # TODO: Fix flex decoding non-divisible case! - if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: - kernel_options.setdefault("IS_DIVISIBLE", False) - else: + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) # Calculate GQA head sharing gqa_shared_heads = Hq // Hkv diff --git a/torch/_inductor/kernel/flex/templates/common.py.jinja b/torch/_inductor/kernel/flex/templates/common.py.jinja new file mode 100644 index 000000000000..0e967570127d --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/common.py.jinja @@ -0,0 +1,193 @@ + + +# Common Imports +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( + desc_k, + [kv_start + kv_offset, 0], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} + + if USE_TMA: + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start + kv_offset, 0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + if not USE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + + + return acc, l_i, m_i diff --git a/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja b/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja new file mode 100644 index 000000000000..79410fb50046 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja @@ -0,0 +1,248 @@ +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_zq = tl.program_id(1) + off_hq = tl.program_id(2) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN, QK_HEAD_DIM], + strides=[stride_qm, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + + if not USE_TMA: + Q_block_ptr = tl.make_block_ptr( + base=Q , + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1) + idx_hq = tl.program_id(2) + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) diff --git a/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja b/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja new file mode 100644 index 000000000000..1775833b8e68 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja @@ -0,0 +1,682 @@ +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1) # q batch idx + off_hkv = tl.program_id(2) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds prior to the last loop + m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds prior to the last loop + n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv \ No newline at end of file diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja new file mode 100644 index 000000000000..f4596070c833 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja @@ -0,0 +1,252 @@ + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} \ No newline at end of file diff --git a/torch/_inductor/kernel/flex/templates/utilities.py.jinja b/torch/_inductor/kernel/flex/templates/utilities.py.jinja new file mode 100644 index 000000000000..7e2367e4f269 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/utilities.py.jinja @@ -0,0 +1,59 @@ + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_DIM: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index c5c4269cb403..e68a76174c73 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -15,16 +15,18 @@ mm_operations, ) from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.remote_gemm_autotune_cache import gen_best_config from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.torch_version import TorchVersion -from .. import config as inductor_config, ir +from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.subgraph import SubgraphTemplate from ..ir import FlexibleLayout, is_triton +from ..kernel_inputs import MMKernelInputs from ..lowering import ( add_layout_constraint, constrain_to_fx_strides, @@ -54,13 +56,9 @@ _is_static_problem, addmm_epilogue, mm_args, - mm_config_kwargs, mm_grid, - mm_options, persistent_mm_grid, - persistent_mm_options, scale_mm_epilogue, - scaled_mm_options, ) @@ -587,11 +585,6 @@ def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) -def _is_large_block_for_cpu(m, n, k): - # Thresholds are experimentally determined to reduce Triton CPU compile times - return m * n > 2**13 - - @functools.lru_cache def using_b200() -> bool: """Returns true if the device is a NVIDIA B200, otherwise returns false.""" @@ -661,10 +654,14 @@ def tuned_mm(mat1, mat2, *, layout=None): """ Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) """ + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) - device_type = ir.get_device_type(mat1) + static_shape, is_nonzero = _is_static_problem(layout) name = "mm" + # Create MMKernelInputs for standard MM at the top + kernel_inputs = MMKernelInputs([mat1, mat2]) + # below is for getting an overview logging info of inductor mms counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1 log.info( @@ -685,48 +682,38 @@ def tuned_mm(mat1, mat2, *, layout=None): # options to tune from choices = ( - [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] + [aten_mm.bind(kernel_inputs.nodes(), aten_layout)] + if use_aten_gemm_kernels() + else [] ) static_shape, is_nonzero = _is_static_problem(layout) - mm_configs = V.choices.get_base_mm_configs(device_type) - persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) - extra_mm_configs = V.choices.get_extra_mm_configs(device_type) - - dtype = mat1.get_dtype() if is_nonzero and use_triton_template(layout): - for config in mm_configs( - m, - n, - k, - **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + # Get template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, mm_template.name, "mm" ): mm_template.maybe_append_choice( choices, - input_nodes=(mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, - **mm_options(config, m, n, k, layout), + **kwargs, ) if use_triton_tma_template(mat1, mat2): - for config in persistent_mm_configs( - m, - n, - k, - **mm_config_kwargs( - device_type, _is_large_block_for_cpu, dtype.itemsize - ), + # Get TMA template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, persistent_tma_mm_template.name, "mm" ): persistent_tma_mm_template.maybe_append_choice( choices, - input_nodes=(mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, workspace_arg=get_tma_workspace_arg( num_tma_descriptors=2, device=mat1.get_device(), ), - **mm_options(config, m, n, k, layout), - **persistent_mm_options(mat1, mat2), + **kwargs, ) from torch._inductor.ir import get_free_symbols @@ -776,18 +763,20 @@ def tuned_mm(mat1, mat2, *, layout=None): and use_cutlass_template(layout, m, n, k) and _use_cutlass_for_op("mm") ): - CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, kernel_inputs.nodes() + ) if is_nonzero and use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): - CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) + CKTileGemmTemplate.add_choices(choices, layout, kernel_inputs.nodes()) if use_cpp_gemm_template(layout, mat1, mat2): CppGemmTemplate.add_choices( choices, layout, - [mat1, mat2], + kernel_inputs.nodes(), ) input_nodes = [mat1, mat2] @@ -801,14 +790,20 @@ def tuned_mm(mat1, mat2, *, layout=None): if use_aten_gemm_kernels(): always_included.append("extern_mm") num_choices_before_extra_configs = len(choices) - for config in extra_mm_configs( - m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + for kwargs in V.choices.get_mm_configs( + # TODO(coconutruben): remove once we deprecate ah + # mm-extra is a hack to keep the ah functionality alive + # while we transition to the unified kwargs retrieval + kernel_inputs, + layout, + "mm-ah", + "mm", ): mm_template.maybe_append_choice( choices, - input_nodes=(mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, - **mm_options(config, m, n, k, layout), + **kwargs, ) # using AutoHeuristic for ranking @@ -838,13 +833,28 @@ def tuned_mm(mat1, mat2, *, layout=None): choices = choices[:num_choices_before_extra_configs] for k in inductor_config.external_matmul: - choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout)) + choices.append( + lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout) + ) + + best_config_future = None + # Purposely not awaiting the future here - this kicks off the best config lookup at lowering time + # The future will be awaited at scheduling time in select_algorithm.py + if torch._inductor.config.remote_gemm_autotune_cache: + best_config_future = gen_best_config(mat1, mat2) - return autotune_select_algorithm(name, choices, [mat1, mat2], layout) + return autotune_select_algorithm( + name, + choices, + kernel_inputs.nodes(), + layout, + best_config_future=best_config_future, + ) @register_lowering(aten._int_mm, type_promotion_kind=None) def tuned_int_mm(mat1, mat2, *, layout=None): + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2 = mm_args( mat1, mat2, layout=layout, out_dtype=torch.int32 ) @@ -861,8 +871,6 @@ def tuned_int_mm(mat1, mat2, *, layout=None): layout, ) - device_type = ir.get_device_type(mat1) - static_shape, is_nonzero = _is_static_problem(layout) use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) @@ -870,33 +878,37 @@ def tuned_int_mm(mat1, mat2, *, layout=None): [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] ) + # Create MMKernelInputs for Int MM + kernel_inputs = MMKernelInputs([mat1, mat2]) + if use_cutlass and _use_cutlass_for_op("int_mm"): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( - choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True ) - int8_mm_configs = V.choices.get_int8_mm_configs(device_type) - if is_nonzero and use_triton_template(layout, enable_int32=True): - for config in int8_mm_configs( - m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, mm_template.name, "int_mm" ): mm_template.maybe_append_choice( choices, - input_nodes=(mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, - **mm_options(config, m, n, k, layout), + **kwargs, ) - return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + return autotune_select_algorithm("int_mm", choices, kernel_inputs.nodes(), layout) @register_lowering(aten.addmm, type_promotion_kind=None) def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): - device_type = ir.get_device_type(mat1) + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) static_shape, is_nonzero = _is_static_problem(layout) + # Create MMKernelInputs for AddMM at the top + kernel_inputs = MMKernelInputs([inp_expanded, mat1, mat2]) + # below is for getting an overview logging info of inductor mms counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1 log.info( @@ -923,7 +935,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices = ( [ aten_addmm.bind( - (inp, mat1, mat2), + # TODO(coconutruben): replace with kernel_inputs.nodes() + # once that supports the unexpanded nodes as well + [inp, mat1, mat2], layout, alpha=alpha, beta=beta, @@ -932,12 +946,19 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if use_aten_gemm_kernels() else [] ) - return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + return autotune_select_algorithm( + # TODO(coconutruben): replace with kernel_inputs.nodes() + # once that supports the unexpanded nodes as well + "addmm", + choices, + [inp, mat1, mat2], + layout, + ) choices = ( [ aten_addmm.bind( - (inp_expanded, mat1, mat2), + kernel_inputs.nodes(), layout, alpha=alpha, beta=beta, @@ -957,50 +978,42 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.insert( 0, aten_bias_addmm.bind( - (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + kernel_inputs.nodes(), + layout, + alpha=alpha, + beta=beta, ), ) - mm_configs = V.choices.get_base_mm_configs(device_type) - persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) - - dtype = mat1.get_dtype() if is_nonzero and use_triton_template(layout): - for config in mm_configs( - m, - n, - k, - **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + # Get template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, mm_template.name, "addmm" ): mm_template.maybe_append_choice( choices, - input_nodes=(inp_expanded, mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, - **mm_options(config, m, n, k, layout), + **kwargs, prefix_args=1, epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), ) if use_triton_tma_template(mat1, mat2): - for config in persistent_mm_configs( - m, - n, - k, - **mm_config_kwargs( - device_type, _is_large_block_for_cpu, dtype.itemsize - ), + # Get TMA template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, persistent_tma_mm_template.name, "addmm" ): persistent_tma_mm_template.maybe_append_choice( choices, - input_nodes=(inp_expanded, mat1, mat2), + input_nodes=kernel_inputs.nodes(), layout=layout, workspace_arg=get_tma_workspace_arg( num_tma_descriptors=2, device=mat1.get_device(), ), - **mm_options(config, m, n, k, layout), - **persistent_mm_options(mat1, mat2), + **kwargs, prefix_args=1, epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), ) @@ -1013,17 +1026,20 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, - [mat1, mat2, inp_expanded], + # reorder here because CUTLASS expects (x, w, bias) but torch + # is bias, x, w + kernel_inputs.nodes(reorder=[1, 2, 0]), alpha=alpha, beta=beta, - input_reorder=[2, 0, 1], ) if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices( choices, layout, - [mat1, mat2, inp_expanded], + # reorder here because CK expects (x, w, bias) but torch + # is bias, x, w + kernel_inputs.nodes(reorder=[1, 2, 0]), alpha=alpha, beta=beta, input_reorder=[2, 0, 1], @@ -1033,15 +1049,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): CppGemmTemplate.add_choices( choices, layout, - [inp_expanded, mat1, mat2], + kernel_inputs.nodes(), alpha=alpha, beta=beta, has_bias=True, ) - return autotune_select_algorithm( - "addmm", choices, [inp_expanded, mat1, mat2], layout - ) + return autotune_select_algorithm("addmm", choices, kernel_inputs.nodes(), layout) @register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None) @@ -1089,7 +1103,7 @@ def tuned_sparse_semi_structured_mm( ) return autotune_select_algorithm( - "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout + "sparse_semi_structured_mm", choices, (mat1, mat1_meta, mat2), layout ) @@ -1123,6 +1137,7 @@ def tuned_scaled_mm( Returns: Tensor: The result of the scaled matrix multiplication """ + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat_a, mat_b = mm_args( mat_a, mat_b, layout=layout, out_dtype=out_dtype ) @@ -1138,7 +1153,6 @@ def tuned_scaled_mm( layout, ) - device_type = ir.get_device_type(mat_a) check_supported_striding(mat_a, mat_b) scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) @@ -1165,59 +1179,51 @@ def tuned_scaled_mm( _, is_nonzero = _is_static_problem(layout) - scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) - scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( - device_type - ) + # Prepare triton input nodes and create kernel_inputs at the top + triton_input_nodes: list[Any] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias - if is_nonzero and use_triton_template(layout, enable_float8=True): - triton_input_nodes: tuple[Any, ...] - if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: - # Need to unsqueeze bias from [N] -> [1, N] - triton_bias = L[aten.unsqueeze](bias, 0) - else: - triton_bias = bias + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = [ + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ] + suffix_args = 3 + else: + triton_input_nodes = [mat_a, mat_b, triton_scale_a, triton_scale_b] + suffix_args = 2 - if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: - assert len(scale_a.get_size()) == len(scale_b.get_size()) - # Need to unsqueeze scale from [] -> [1, 1] - triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) - triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) - else: - triton_scale_a = scale_a - triton_scale_b = scale_b - - if bias: - triton_input_nodes = ( - mat_a, - mat_b, - triton_scale_a, - triton_scale_b, - triton_bias, - ) - suffix_args = 3 - else: - triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) - suffix_args = 2 + # Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1) + kernel_inputs = MMKernelInputs(triton_input_nodes, mat1_idx=0, mat2_idx=1) + if is_nonzero and use_triton_template(layout, enable_float8=True): # TODO (paulzhan): There is no template that exists for bias and TMA # Don't run tma template currently if bias exists if use_triton_tma_template(mat_a, mat_b) and not bias: - for config in scaled_persistent_mm_configs(m, n, k): - kwargs = scaled_mm_options( - config, - m, - n, - k, - layout, - scale_a, - scale_b, - use_fast_accum, - device_tma=True, - ) + # Get TMA template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, scaled_mm_device_tma_template.name, "scaled_mm" + ): + kwargs["USE_FAST_ACCUM"] = use_fast_accum scaled_mm_device_tma_template.maybe_append_choice( choices, - input_nodes=triton_input_nodes, + input_nodes=kernel_inputs.nodes(), layout=layout, workspace_arg=get_tma_workspace_arg( num_tma_descriptors=2, @@ -1226,7 +1232,11 @@ def tuned_scaled_mm( **kwargs, ) - for config in scaled_mm_configs(m, n, k): + # Get template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout, mm_template.name, "scaled_mm" + ): + kwargs["USE_FAST_ACCUM"] = use_fast_accum if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): # Triton crashes however uncommon for real workloads continue @@ -1236,13 +1246,10 @@ def tuned_scaled_mm( if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): continue - kwargs = scaled_mm_options( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) # possibly appends a TritonTemplateCaller to choices mm_template.maybe_append_choice( choices, - input_nodes=triton_input_nodes, + input_nodes=kernel_inputs.nodes(), layout=layout, **kwargs, suffix_args=suffix_args, @@ -1258,12 +1265,12 @@ def tuned_scaled_mm( CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, - input_nodes, # type: ignore[arg-type] + kernel_inputs.nodes(), # type: ignore[arg-type] use_fast_accum=use_fast_accum, # type: ignore[arg-type] ) if is_nonzero and use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 2beb91f0bbb2..dee1c1ac9c35 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -3,17 +3,13 @@ from collections.abc import Sequence from typing import Any -import sympy - import torch from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn from torch._inductor.utils import sympy_product from torch._inductor.virtualized import V -from .. import config as inductor_config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE log = logging.getLogger(__name__) @@ -49,96 +45,6 @@ def acc_type(dtype): return f"tl.{dtype}".replace("torch.", "") -def mm_options(config, sym_m, sym_n, sym_k, layout): - """ - Common options to matmul triton templates. - """ - even_k_symbolic = ( - # it isn't worth guarding on this - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( - not inductor_config.force_same_precision - or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) - ) - options_dict = dict( - EVEN_K=even_k_symbolic, - ALLOW_TF32=allow_tf32, - USE_FAST_ACCUM=False, # Option for _scaled_mm - ACC_TYPE=acc_type(layout.dtype), - num_stages=config.num_stages, - num_warps=config.num_warps, - **config.kwargs, - ) - - # If GROUP_M not specified then default to 8 - if "GROUP_M" not in config.kwargs: - group_m = config.kwargs.get("GROUP_M", 8) - options_dict["GROUP_M"] = group_m - - return options_dict - - -def tma_options() -> dict[str, Any]: - from torch.utils._triton import has_triton_stable_tma_api - - return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()} - - -def persistent_mm_options(mat1, mat2): - res = { - "A_ROW_MAJOR": not mat1.layout.is_transposed(), - "B_ROW_MAJOR": not mat2.layout.is_transposed(), - "NUM_SMS": get_num_sms(), - "TMA_SIZE": TMA_DESCRIPTOR_SIZE, - } - res.update(tma_options()) - return res - - -def scaled_mm_options( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a, - scale_b, - use_fast_accum: bool, - device_tma: bool = False, -) -> dict[str, Any]: - def are_compatible_scales(size_a, size_b) -> bool: - # Same sized scales are compatible - if len(size_a) == len(size_b): - return True - - # Both need to be scalars or len(1) tensors - if len(size_a) <= 1 and len(size_b) <= 1: - return True - - return False - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - - mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) - - mm_template_options["ACC_TYPE"] = "tl.float32" - mm_template_options["USE_FAST_ACCUM"] = use_fast_accum - mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 - - if device_tma: - mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE - mm_template_options["NUM_SMS"] = get_num_sms() - - mm_template_options.update(tma_options()) - - return mm_template_options - - def mm_args( mat1, mat2, @@ -181,20 +87,6 @@ def mm_args( return [m, n, k, layout, mat1, mat2, *others] -def mm_config_kwargs(device, exclude_condition, dtype_size=None): - if device == "cpu": - return { - "scale": 0.5, - "exclude": exclude_condition, - } - - if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE": - return { - "dtype_size": dtype_size, - } - return {} - - def addmm_epilogue(dtype, alpha, beta): def epilogue(acc, bias): if alpha != 1: diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 19ca389c2a53..3424585e1214 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -316,28 +316,28 @@ def early_config_prune(g, m, configs, named_args): {%- else %} offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) - offs_k = k_start_offset + tl.arange(0, BLOCK_K) - a_ptrs = ( - a_ptr + for k_offset in range(0, k_size, BLOCK_K): + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + offs_k = group_offs_k + k_start_offset + a_ptrs = ( + a_ptr {%- if not A_IS_2D %} - + g * A_STRIDE_G + + g * A_STRIDE_G {%- endif %} - + (m_start_offset + offs_am[:, None]) * A_STRIDE_M - + offs_k[None, :] * A_STRIDE_K - ) - b_ptrs = ( - b_ptr + + (m_start_offset + offs_am[:, None]) * A_STRIDE_M + + offs_k[None, :] * A_STRIDE_K + ) + b_ptrs = ( + b_ptr {%- if not B_IS_2D %} - + g * B_STRIDE_G + + g * B_STRIDE_G {%- endif %} - + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N - + offs_k[None, :] * B_STRIDE_K - ) - for k_offset in range(0, k_size, BLOCK_K): + + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N + + offs_k[None, :] * B_STRIDE_K + ) a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) if k_offset + BLOCK_K > k_size: - group_offs_k = k_offset + tl.arange(0, BLOCK_K) a = tl.where(group_offs_k < k_size, a, 0) b = tl.where(group_offs_k < k_size, b, 0) {%- if USE_FAST_ACCUM %} @@ -387,7 +387,7 @@ def early_config_prune(g, m, configs, named_args): {%- else %} idx_n = offs_bn[None, :] {%- endif %} - mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size + mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < n_size) {%- if M_IS_VARYING or N_IS_VARYING %} {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}} {%- else %} diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index 64249e6fb57a..df3e8fcf1e65 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -1,8 +1,10 @@ # mypy: allow-untyped-defs +import logging + import torch -from .. import ir +from ..kernel_inputs import MMKernelInputs from ..lowering import lowerings from ..select_algorithm import ( autotune_select_algorithm, @@ -11,8 +13,10 @@ ) from ..utils import use_aten_gemm_kernels, use_triton_template from ..virtualized import V -from .mm_common import mm_args, mm_grid, mm_options +from .mm_common import mm_args, mm_grid + +log = logging.getLogger(__name__) aten = torch.ops.aten @@ -119,9 +123,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): """ Computes mm(mat1, mat2) + mm(mat3, mat4) """ + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) - device_type = ir.get_device_type(mat1) # Optimization is optional, because we can always just not do the fusion if ( @@ -140,27 +144,34 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) ) + # Create MMKernelInputs for MM Plus MM (matrices are at indices 0, 1 for first pair) + # Note: This is a special case with 4 matrices, but we use the first pair for M, N, K extraction + kernel_inputs = MMKernelInputs([mat1, mat2, mat3, mat4], mat1_idx=0, mat2_idx=1) + assert layout1 == layout2 # options to tune from choices = ( - [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + [aten_mm_plus_mm.bind(kernel_inputs.nodes(), layout1)] if use_aten_gemm_kernels() else [] ) - mm_configs = V.choices.get_mm_plus_mm_configs(device_type) if use_triton_template(layout1): - for config in mm_configs(): + # Get template params using the new unified function + for kwargs in V.choices.get_mm_configs( + kernel_inputs, layout1, mm_plus_mm_template.name, "mm_plus_mm" + ): + # Apply BLOCK_K constraint specific to mm_plus_mm # see https://github.com/triton-lang/triton/issues/1298 # BLOCK_K = K causes llvm error - if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1): + if V.graph.sizevars.statically_known_lt(kwargs.get("BLOCK_K", k1), k1): mm_plus_mm_template.maybe_append_choice( choices, - input_nodes=(mat1, mat2, mat3, mat4), + input_nodes=kernel_inputs.nodes(), layout=layout1, - **mm_options(config, m1, n1, k1, layout1), + **kwargs, ) return autotune_select_algorithm( - "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 + "mm_plus_mm", choices, kernel_inputs.nodes(), layout1 ) diff --git a/torch/_inductor/kernel_inputs.py b/torch/_inductor/kernel_inputs.py new file mode 100644 index 000000000000..6c66c1161900 --- /dev/null +++ b/torch/_inductor/kernel_inputs.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch._inductor.config +from torch._inductor import ir +from torch._inductor.virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Sequence + + import sympy + + +class KernelInputs: + """ + Class to store and provide access to input nodes for kernels. + This class takes in a tuple of input nodes and provides methods to access + information about these nodes, such as their device type and device. + """ + + def __init__(self, input_nodes: list[Any]): + """ + Initialize with a tuple of input nodes. + + Args: + input_nodes: A tuple of input nodes to store + """ + self._input_nodes = input_nodes + self._device_name: Optional[str] = None + assert len(input_nodes) > 0, "Expected at least one input node" + + def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]: + """ + Return the stored input nodes, optionally reordered. + + Args: + reorder: Optional sequence of indices to reorder the nodes. + For example, (2, 0, 1) would return nodes in that order. + + Returns: + The tuple of input nodes, optionally reordered + """ + if reorder is None: + return self._input_nodes + assert len(self._input_nodes) == len(reorder), ( + f"Reorder length mismatch: {len(self._input_nodes)} vs {len(reorder)}" + ) + return [self._input_nodes[i] for i in reorder] + + @property + def device_type(self) -> Optional[str]: + """ + Get the device type of the first node. + + Returns: + The device type (e.g., 'cuda', 'cpu') + """ + + return ir.get_device_type(self._input_nodes[0]) + + def device(self) -> torch.device: + """ + Get the device of the first node. + + Returns: + The device of the first node + """ + return self._input_nodes[0].get_device() + + def device_name(self) -> Optional[str]: + """ + Get the device name information. + + Returns: + A tuple of (gpu_name, vendor, model) + """ + if self._device_name is None: + device = self.device() + if self.device_type == "cuda": + device_properties = torch.cuda.get_device_properties(device) + self._device_name = device_properties.gcnArchName + return self._device_name + + def shapes_symbolic(self) -> tuple[tuple[Any, ...], ...]: + """ + Get the symbolic shapes of all input nodes. + + Returns: + A tuple of shape tuples for each input node + """ + return tuple(node.get_size() for node in self._input_nodes) + + def shapes_hinted(self) -> tuple[tuple[int, ...], ...]: + """ + Get the size hints for shapes of all input nodes. + + Returns: + A tuple of shape tuples with integer hints for each input node + """ + return tuple( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + for node in self._input_nodes + ) + + def strides_symbolic(self) -> tuple[tuple[sympy.Integer, ...], ...]: + """ + Get the symbolic strides of all input nodes. + + Returns: + A tuple of stride tuples for each input node + """ + return tuple(node.get_stride() for node in self._input_nodes) + + def strides_hinted(self) -> tuple[tuple[int, ...], ...]: + """ + Get the size hints for strides of all input nodes. + + Returns: + A tuple of stride tuples with integer hints for each input node + """ + return tuple( + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + for node in self._input_nodes + ) + + def dtypes(self) -> tuple[torch.dtype, ...]: + """ + Get the dtypes of all input nodes. + + Returns: + A tuple of dtypes for each input node + """ + return tuple(node.get_dtype() for node in self._input_nodes) + + def dtype(self, idx: int = 0) -> torch.dtype: + """ + Get the dtype of a specific input node. + + Args: + idx: Index of the node to get the dtype from (default: 0) + + Returns: + The dtype of the specified input node + """ + return self._input_nodes[idx].get_dtype() + + +class MMKernelInputs(KernelInputs): + """ + Specialized KernelInputs for matrix multiplication operations. + Provides additional methods to access M, N, K dimensions. + """ + + def __init__(self, input_nodes: list[Any], mat1_idx: int = -2, mat2_idx: int = -1): + """ + Initialize with a tuple of input nodes. + + By default, we assume the last 2 input nodes are mat1 and mat2, but + the caller can adjust when necessary + """ + super().__init__(input_nodes) + # for mm, we need at least 2 nodes, and we need to know which nodes + # are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others + # might be (mat1, mat2, scale), etc. + assert len(self._input_nodes) >= 2, "Expected at least 2 input nodes" + + # Adjust assertions to handle negative indices + m1_idx, m2_idx = mat1_idx, mat2_idx + if mat1_idx < 0: + m1_idx += len(input_nodes) + if mat2_idx < 0: + m2_idx += len(input_nodes) + + assert 0 <= m1_idx < len(input_nodes), f"Invalid mat1_idx: {mat1_idx}" + assert 0 <= m1_idx < len(input_nodes), f"Invalid mat2_idx: {mat2_idx}" + + self._mat1_idx = mat1_idx + self._mat2_idx = mat2_idx + + def mnk_symbolic( + self, + ) -> tuple[sympy.Integer, sympy.Integer, sympy.Integer]: + """ + Get the symbolic M, N, K dimensions for matrix multiplication. + Handles both 2D (MM) and 3D (BMM) tensors. + + M is extracted from the second-to-last dimension of the first operand (mat1). + N is extracted from the last dimension of the second operand (mat2). + K is extracted from the last dimension of the first operand (mat1). + + Returns: + A tuple of (M, N, K) dimensions + """ + mat1 = self.nodes()[self._mat1_idx] + mat2 = self.nodes()[self._mat2_idx] + + m = mat1.get_size()[-2] # M from second-to-last dimension of mat1 + k = mat1.get_size()[-1] # K from last dimension of mat1 + n = mat2.get_size()[-1] # N from last dimension of mat2 + + # Ensure K dimensions match between operands + k0 = mat2.get_size()[-2] # K from second-to-last dimension of mat2 + V.graph.sizevars.check_equals(k, k0) + return (m, n, k) + + def mnk_hinted(self) -> tuple[int, int, int]: + """ + Get the hinted M, N, K dimensions for matrix multiplication. + Handles both 2D (MM) and 3D (BMM) tensors. + + Uses shapes_hinted from the base class to get integer hints for dimensions. + + Returns: + A tuple of (M, N, K) dimensions as integers + """ + hinted_shapes = self.shapes_hinted() + mat1_shape = hinted_shapes[self._mat1_idx] + mat2_shape = hinted_shapes[self._mat2_idx] + + m = mat1_shape[-2] # M from second-to-last dimension of mat1 + k = mat1_shape[-1] # K from last dimension of mat1 + n = mat2_shape[-1] # N from last dimension of mat2 + + # Ensure K dimensions match between operands + k_check = mat2_shape[-2] # K from second-to-last dimension of mat2 + assert k == k_check, f"K dimensions don't match: {k} vs {k_check}" + + return (m, n, k) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 5c8110f6c399..74a562365b69 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -26,6 +26,7 @@ from torch._dynamo.utils import counters from torch._higher_order_ops.associative_scan import associative_scan_op from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._library.utils import get_layout_constraint_tag from torch._prims_common import ( canonicalize_dim, canonicalize_dims, @@ -163,6 +164,10 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A if not isinstance(fn, torch._ops.OpOverload): # Only OpOverloads have layout constraints. return None + + if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False): + return tag_to_layout_constraint(maybe_layout_tag) + if fn in _maybe_layout_constraints: return _maybe_layout_constraints[fn] return None diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index d287208419a9..0967bb553e04 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -124,6 +124,28 @@ def compute_size_for_scheduler_buffer( buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed + When an operation mutates a buffer in-place, the scheduler creates a new buffer name + to track the "before" and "after" states, even though they share the same memory. + + The mutated buffer represents a rename with zero allocation and deallocation cost. + During dependency tracking, we transfer dependencies from the mutated name back to + the original buffer, ensuring the original memory is only freed when all aliases + are done. + + This handles cases where a buffer has multiple non-overlapping aliases - rather than + trying to assign free costs to individual aliases, we forward all alias dependencies + to the original buffer. + + Consider: + buf0 = op0() + buf1 = mutation_op_(buf0) + del buf0 + ... + op(buf1) + del buf1 + + The only memory events are the creation prior to op0, and the deletion following buf1. + Returns: A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). """ @@ -135,18 +157,11 @@ def compute_size_for_scheduler_buffer( def _compute_and_update_buf_size( sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False ) -> int: - if isinstance(sched_buf.node.layout, NoneLayout): - # mutations should inherit the size of the mutated buffer - if sched_buf.get_mutations(): - mutated_buf_name = sched_buf.get_mutations()[0] - if mutated_buf_name in sched_buf_to_size: - (_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name] - else: - (_size_alloc, _size_free) = (0, 0) - sched_buf_to_size[sched_buf.get_name()] = (0, _size_free) - sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0) - else: - sched_buf_to_size[sched_buf.get_name()] = (0, 0) + if sched_buf.get_name() in V.graph.scheduler.mutation_real_name: + sched_buf_to_size[sched_buf.get_name()] = (0, 0) + return 0 + elif isinstance(sched_buf.node.layout, NoneLayout): + sched_buf_to_size[sched_buf.get_name()] = (0, 0) return 0 elif isinstance(sched_buf.node.layout, MultiOutputLayout): size_alloc = 0 @@ -200,6 +215,14 @@ def assign_memory_planning_info_for_scheduler_buffers( for dep in node.unmet_dependencies: dep_name_to_succ_nodes[dep.name].add(node) + # iterate in reverse, so dependencies are picked up transitively. + for mutating_buf_name, real_buf_name in reversed( + V.graph.scheduler.mutation_real_name.items() + ): + dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[ + mutating_buf_name + ] + # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) for buf_name in name_to_buf.keys(): @@ -219,58 +242,72 @@ def assign_memory_planning_info_for_scheduler_nodes( """ Assign to each scheduler node its predecessor and successor nodes. """ - from .scheduler import SchedulerBuffer - for index, node in enumerate(nodes): - size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) - pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]() - for dep in node.read_writes.reads: - if dep.name in name_to_buf and dep in node.unmet_dependencies: - pred_buffers.add(name_to_buf[dep.name]) - elif dep.name in name_to_freeable_input_buf: - pred_buffers.add(name_to_freeable_input_buf[dep.name]) - pred_nodes = OrderedSet( - name_to_fused_node[pred_buffer.defining_op_name()] - for pred_buffer in pred_buffers - if (isinstance(pred_buffer, SchedulerBuffer)) - ) + node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {} + node_to_pred_buffers: dict[ + BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer] + ] = collections.defaultdict(OrderedSet) + + # collect all predecessors using existing successor mappings + for node in nodes: succ_nodes = OrderedSet( succ_node for buffer in node.get_outputs() for succ_node in buffer.mpi_buffer.succ_nodes ) + node_to_succ_nodes[node] = succ_nodes + + # For each successor, add current node as its predecessor + for succ_node in succ_nodes: + node_to_pred_nodes[succ_node].add(node) + + # For each output buffer, add it as predecessor to its successor nodes + # TODO - is pred buffers needed ? + for buffer in node.get_outputs(): + for succ_node in buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(buffer) + + for freeable_buffer in name_to_freeable_input_buf.values(): + for succ_node in freeable_buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(freeable_buffer) + + # Second pass: assign memory planning info using completed predecessor mappings + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + succ_nodes = node_to_succ_nodes[node] + node.mpi_node = MemoryPlanningInfoForNode( index=index, size=size_alloc, - pred_buffers=pred_buffers, - pred_nodes=pred_nodes, + pred_buffers=node_to_pred_buffers[node], + pred_nodes=node_to_pred_nodes[node], succ_nodes=succ_nodes, ) -def estimate_peak_memory( +# map each scheduler buffer to its size, start step, and end step +@dataclasses.dataclass +class BufferInfo: + buffer: Union[SchedulerBuffer, FreeableInputBuffer] + size_alloc: int + size_free: int + start_step: int + end_step: int + + +def compute_memory_timeline( nodes: list[BaseSchedulerNode], name_to_freeable_input_buf: dict[str, FreeableInputBuffer], graph_outputs: OrderedSet[str], -) -> tuple[int, list[int]]: +) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]: """ - Given a list of nodes in their execution order, estimate the peak memory, by - keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. - - Returns: - int: peak memory - List[int]: memory usage at each node (or each step). + Compute buffer allocation and deallocation sizes and map their + lifetime to the node schedule """ - # map each scheduler buffer to its size, start step, and end step - @dataclasses.dataclass - class BufferInfo: - buffer: Union[SchedulerBuffer, FreeableInputBuffer] - size_alloc: int - size_free: int - start_step: int - end_step: int - # get the execution step of each node, this will be used to determine # the end_step of buffers node_to_step: dict[BaseSchedulerNode, int] = { @@ -325,6 +362,27 @@ class BufferInfo: ) ) + return buf_info_list, node_to_step + + +def estimate_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[int, list[int]]: + """ + Given a list of nodes in their execution order, estimate the peak memory, by + keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. + + Returns: + int: peak memory + List[int]: memory usage at each node (or each step). + """ + + buf_info_list, _ = compute_memory_timeline( + nodes, name_to_freeable_input_buf, graph_outputs + ) + # incremental memory changes at each step memory = [0 for _ in range(len(nodes) + 1)] diff --git a/torch/_inductor/remote_gemm_autotune_cache.py b/torch/_inductor/remote_gemm_autotune_cache.py new file mode 100644 index 000000000000..0ef026269b10 --- /dev/null +++ b/torch/_inductor/remote_gemm_autotune_cache.py @@ -0,0 +1,20 @@ +import asyncio +from typing import TypeVar + +import torch._inductor.config as config +from torch._inductor import ir + + +_T = TypeVar("_T") + + +def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]: + """ + Generate the best GEMM autotune config for the given matrices. + """ + if config.is_fbcode(): + from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config + + return gen_best_config(mat1, mat2) + else: + raise NotImplementedError("Function gen_best_config is not yet implemented") diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 67140369faac..850c7660d5d9 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -10,6 +10,8 @@ from types import ModuleType from typing import Any, Callable, TYPE_CHECKING +from torch._utils_internal import log_triton_builds + if TYPE_CHECKING: from torch._inductor.runtime.triton_heuristics import CachingAutotuner @@ -57,11 +59,18 @@ def _worker_compile_triton( from torch._inductor import config with config.patch(extra_config): - start_ns = time.time_ns() - kernel = load_kernel() - kernel.precompile(warm_cache_only=True) - elapsed_ns = time.time_ns() - start_ns - kernel.prepare_for_pickle() - # We can release this memory in the compile subprocesses: - linecache.clearcache() - return kernel, elapsed_ns // 1000 + fail = None + try: + start_ns = time.time_ns() + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + elapsed_ns = time.time_ns() - start_ns + kernel.prepare_for_pickle() + # We can release this memory in the compile subprocesses: + linecache.clearcache() + return kernel, elapsed_ns // 1000 + except Exception as e: + fail = str(e) + raise + finally: + log_triton_builds(fail=fail) diff --git a/torch/_inductor/runtime/debug_utils.py b/torch/_inductor/runtime/debug_utils.py new file mode 100644 index 000000000000..9c15ff890dda --- /dev/null +++ b/torch/_inductor/runtime/debug_utils.py @@ -0,0 +1,138 @@ +import functools +import logging +import threading +import weakref + +import torch +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + +local = threading.local() +local.memory_tracker = None + + +class BufferMemoryTracker: + """ + Tracks inductor runtime allocations and deallocations to compare against + expected behavior. + """ + + def __init__(self) -> None: + self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = ( + weakref.WeakValueDictionary() # type: ignore[assignment] + ) + self.died_since_last_step: OrderedSet[str] = OrderedSet() + self.added_since_last_step: OrderedSet[str] = OrderedSet() + self.error = ( + torch._inductor.config.test_configs.track_memory_lifecycle == "assert" + ) + + def set_tensor(self, name: str, tensor: torch.Tensor) -> None: + storage = tensor.untyped_storage() + + self.added_since_last_step.add(name) + self.tensor_tracker[name] = storage + + def on_tensor_death() -> None: + self.died_since_last_step.add(name) + + weakref.finalize(storage, on_tensor_death) + + def advance_step(self) -> None: + self.died_since_last_step.clear() + self.added_since_last_step.clear() + + def log_or_raise(self, msg: str) -> None: + if self.error: + raise RuntimeError(msg) + else: + log.info(msg) + + def check_step_delta( + self, + expected_allocated: list[str], + expected_freed: list[str], + is_final_step: bool, + ) -> None: + """Check only the delta changes since last step""" + + # Check expected deaths - we dont currently distinguish between nodes which die in last step + # and are returned as outputs, so skip if final_step. + if not is_final_step: + missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step + if missing_deaths: + self.log_or_raise( + f"Expected tensors to die but still alive: {missing_deaths}" + ) + + # Check for unexpected deaths + unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed) + if unexpected_deaths: + self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}") + + # Check newly alive tensors - separate messages like deaths + actual_allocated = self.added_since_last_step + expected_allocated_set = OrderedSet(expected_allocated) + + extra_alive = actual_allocated - expected_allocated_set + if extra_alive: + self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}") + + missing_alive = expected_allocated_set - actual_allocated + if missing_alive: + self.log_or_raise( + f"Expected allocated tensors but missing: {missing_alive}" + ) + + # Reset for next step + self.advance_step() + + if is_final_step: + local.memory_tracker = None + + +def get_mem_tracker() -> BufferMemoryTracker: + if local.memory_tracker is None: + local.memory_tracker = BufferMemoryTracker() + return local.memory_tracker + + +def track_tensor(tensor: torch.Tensor, name: str) -> None: + get_mem_tracker().set_tensor(name, tensor) + + +def tracked_empty_strided( + size: list[int], + stride: list[int], + *, + dtype: torch.dtype, + device: torch.device, + name: str, +) -> torch.Tensor: + o = torch.empty_strided(size, stride, dtype=dtype, device=device) + track_tensor(o, name) + return o + + +def check_memory_step( + allocated: list[str], freed: list[str], is_final_step: bool = False +) -> None: + tracker = get_mem_tracker() + tracker.check_step_delta(allocated, freed, is_final_step) + + +@functools.lru_cache(None) +def register_check_mem_op() -> None: + lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901 + lib.define( + "check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()" + ) + lib.impl("check_memory_step", check_memory_step, "BackendSelect") + from torch._higher_order_ops.effects import _EffectType, _register_effectful_op + + _register_effectful_op( + torch.ops._inductor_debug.check_memory_step.default, + _EffectType.ORDERED, + ) diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index a52df4745f59..3290e25eeae4 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -63,16 +63,21 @@ def __init__(self, kernel: CompiledKernel) -> None: kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared ) + def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: + if hasattr(kernel.metadata, param_name): + if getattr(kernel.metadata, param_name) > 0: + raise NotImplementedError( + f"{scratch_name} scratch not yet supported" + ) + return True + return False + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. # Inductor never uses this field or enables it, but we still have to pass # an extra None into the set of params if its enabled - if hasattr(kernel.metadata, "global_scratch_size"): - if kernel.metadata.global_scratch_size > 0: - raise NotImplementedError("Global scratch not yet supported") - else: - self.has_global_scratch = True - else: - self.has_global_scratch = False + self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size") + # same situation for profile scratch - triton-lang/triton#7258 + self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") self.arg_tys = self.arg_ty_from_signature(kernel.src) self.function: Optional[int] = ( @@ -214,12 +219,12 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher - # Add a None if triton wants an extra parameter to the cubin - if self.has_global_scratch: - arg_tys = self.arg_tys + "O" - args = (*args, None) - else: - arg_tys = self.arg_tys + # Add a None if triton wants extra parameters for scratch spaces + arg_tys = self.arg_tys + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) assert len(args) == len(arg_tys) # TODO: can handle grid functions here or in C++, so diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index a824c94ab65f..b61baa66281f 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -80,7 +80,7 @@ def div_floor_integer(a, b): def remainder_integer(a, b): # NOTE: a % b matches C division, not floor division remainder = a % b - return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder) + return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder) @triton.jit @@ -131,9 +131,9 @@ def minimum_with_index(a_value, a_index, b_value, b_index): if is_floating(a_value): a_isnan = a_value != a_value b_isnan = b_value != b_value - mask |= a_isnan and not b_isnan + mask |= a_isnan & (not b_isnan) # Consider NaNs as equal - equal |= a_isnan and b_isnan + equal |= a_isnan & b_isnan # Prefer lowest index if values are equal mask |= equal & (a_index < b_index) @@ -147,9 +147,9 @@ def maximum_with_index(a_value, a_index, b_value, b_index): if is_floating(a_value): a_isnan = a_value != a_value b_isnan = b_value != b_value - mask |= a_isnan and not b_isnan + mask |= a_isnan & (not b_isnan) # Consider NaNs as equal - equal |= a_isnan and b_isnan + equal |= a_isnan & b_isnan # Prefer lowest index if values are equal mask |= equal & (a_index < b_index) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 196dc329e80b..8425cba55795 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1061,6 +1061,7 @@ def save_gpu_kernel(self, stream, launcher): "def_args": launcher.def_args, "call_args": launcher.call_args, "global_scratch": launcher.global_scratch, + "profile_scratch": launcher.profile_scratch, } from torch._inductor.codecache import CudaKernelParamCache @@ -1086,6 +1087,13 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1; while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. """ + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): + # skip triton template + return launcher + with dynamo_timed( "CachingAutotuner.coordinate_descent_tuning", # These generate too many pt2_compile_event logs: @@ -1100,13 +1108,6 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): return self._coordinate_descent_tuning(launcher, *args, **kwargs) def _coordinate_descent_tuning(self, launcher, *args, **kwargs): - if ( - self.heuristic_type == HeuristicType.TEMPLATE - or self.heuristic_type == HeuristicType.USER_AUTOTUNE - ): - # skip triton template - return launcher - config2launcher = {launcher.config: launcher} # TODO: should we just load the kernels ahead of time if we know we're going to call this? @@ -1754,9 +1755,23 @@ def make_launcher(self) -> LauncherType: launcher.def_args = def_args launcher.call_args = call_args kernel_metadata = getattr(self.kernel, "metadata", None) - launcher.global_scratch = getattr( - kernel_metadata, "global_scratch_size", None + + # for the scratch arguments: None indicates that the kernel doesn't + # take any scratch argument; otherwise a number indicates the number + # of bytes of scratch that need to be provided. + + # in AMD's Triton backend, the global scratch size is never provided + # (but for AMD it's safe to pass an extra null arg, so always include it) + global_scratch: Optional[int] = getattr( + kernel_metadata, + "global_scratch_size", + (0 if torch.version.hip else None), + ) + profile_scratch: Optional[int] = getattr( + kernel_metadata, "profile_scratch_size", None ) + launcher.global_scratch = global_scratch + launcher.profile_scratch = profile_scratch return launcher diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 951f07ab7a5b..d8a96c573b32 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -605,7 +605,7 @@ def codegen_originating_info( out_lines.append(op_info_str) if "stack_trace" in o.meta: stack_trace = f"{o.meta['stack_trace']}" - stack_trace_last_line = stack_trace.split("|")[-1] + stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1] out_lines.append( "#pragma CMT " + stack_trace_last_line.replace("{", "{{") @@ -2179,11 +2179,18 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) self.process_grouped_nodes() - if torch._inductor.config.graph_partition: + if ( + torch._inductor.config.graph_partition + and torch._inductor.config.triton.cudagraphs + ): self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) self.compute_last_usage() + + if torch._inductor.config.test_configs.track_memory_lifecycle: + self.insert_memory_check_nodes() + log_ir_post_fusion(self.nodes) V.debug.graph_diagram(self.nodes) self.debug_draw_graph() @@ -2518,6 +2525,83 @@ def add_user( compute_dependencies_log.debug("BUFFER USER LIST\n") compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + def insert_memory_check_nodes(self) -> None: + from .memory import ( + assign_memory_planning_info_for_scheduler_buffers, + compute_memory_timeline, + FreeableInputBuffer, + get_freeable_input_buf, + ) + + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = ( + get_freeable_input_buf(self.nodes, graph_inputs) + ) + + if not torch._inductor.config.reorder_for_peak_memory: + assign_memory_planning_info_for_scheduler_buffers( + self.nodes, self.name_to_buf + ) + + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + buf_info_list, _ = compute_memory_timeline( + self.nodes, + name_to_freeable_input_buf, + graph_outputs, + ) + + step_allocs_deallocs: list[tuple[list[str], list[str]]] = [ + ([], []) for _ in range(len(self.nodes)) + ] + for buf_info in buf_info_list: + # Skip zero-size buffers + if buf_info.size_alloc == 0 and buf_info.size_free == 0: + continue + + buf_name = buf_info.buffer.get_name() + + step_allocs_deallocs[buf_info.start_step][0].append(buf_name) + step_allocs_deallocs[buf_info.end_step][1].append(buf_name) + + from torch._inductor.runtime.debug_utils import register_check_mem_op + + register_check_mem_op() + + def construct_mem_check_node( + step_idx: int, is_final_step: bool + ) -> ExternKernelSchedulerNode: + expected_newly_alive = step_allocs_deallocs[step_idx][0] + expected_newly_dead = step_allocs_deallocs[step_idx][1] + + nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step] + + node = ir.MemoryCheckKernel( + layout=NoneLayout(device=torch.device("cpu")), + kernel=torch.ops._inductor_debug.check_memory_step.default, + tensor_args=[], + nontensor_args=nontensor_args, + unflatten_args=lambda tensor_args, constant_args: ( + tensor_args, + { + "alive": constant_args[0], + "dead": constant_args[1], + "is_final_step": constant_args[2], + }, + ), + ) + node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}" + return ExternKernelSchedulerNode(self, node) + + new_nodes = [] + + for i, node in enumerate(self.nodes): + new_nodes.append(node) + new_nodes.append( + construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1)) + ) + + self.nodes = new_nodes + def dead_node_elimination(self) -> None: """ Remove any nodes without users @@ -4231,6 +4315,12 @@ def should_partition( ) -> bool: """Return True if we should partition the inductor graph on this node""" + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if not torch._inductor.config.triton.cudagraphs: + return True + # avoid duplicating logs when should_partition is called multiple times # on the same node def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 903d616bb91e..01337fc0d30b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -27,7 +27,14 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided -from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state +from torch._dynamo.utils import ( + counters, + dynamo_timed, + get_chromium_event_logger, + identity, + preserve_rng_state, +) +from torch._inductor.await_utils import await_sync from torch._inductor.utils import clear_on_fresh_cache from torch.utils._filelock import FileLock from torch.utils._ordered_set import OrderedSet @@ -2274,6 +2281,7 @@ def __call__( input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, + best_config_future=None, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2339,7 +2347,12 @@ def autotune(choices, hint_override: Optional[int] = None): dynamo_compile_column_us="compile_time_autotune_time_us", metadata=_autotune_metadata(input_nodes), ): - return benchmark(choices, hint_override=hint_override) + benchmark_results = benchmark(choices, hint_override=hint_override) + if config.max_autotune_report_choices_stats: + _log_autotune_choices_stats( + f"{name}_template_autotuning", benchmark_results + ) + return benchmark_results if config.autotune_in_subproc: # Initialize the suprocess pool so it will warmup early. @@ -2376,6 +2389,35 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) autotune_start_ts = time.time() + + if best_config_future is not None: + best_config = await_sync(best_config_future) + + important_keys = [ + "ACC_TYPE", + "ALLOW_TF32", + "BLOCK_K", + "BLOCK_M", + "BLOCK_N", + "EVEN_K", + "GROUP_M", + "USE_FAST_ACCUM", + "num_stages", + "num_warps", + "num_consumer_groups", + "num_buffers_warp_spec", + ] + choices = [ + choice + for choice in choices + if all( + f"{k}={best_config[k]}" in choice.description + for k in important_keys + ) + for k in important_keys + ] + log.info("Filtered to %d choices based on best_config", len(choices)) + timings = self.lookup( choices, name, @@ -2639,11 +2681,13 @@ def on_complete(future): def wait_on_futures(): log.debug("Waiting on futures") counters["inductor"]["select_algorithm_precompile"] += 1 + exceptions: list[tuple[ChoiceCaller, BaseException]] = [] for future in as_completed( futures, timeout=precompilation_timeout_seconds, ): if e := future.exception(): + exceptions.append((futures[future], e)) from torch._inductor.codegen.cuda.cuda_kernel import ( CUDATemplateCaller, ) @@ -2671,6 +2715,8 @@ def wait_on_futures(): futures.get(future), elapsed_times.get(future), ) + if exceptions: + _log_autotune_exceptions(exceptions) executor.shutdown(wait=True) @@ -3396,5 +3442,106 @@ def _autotune_metadata(input_nodes): } +def _log_autotune_choices_stats( + event_name: str, timings: dict[ChoiceCaller, float] +) -> None: + """Helper function to extract autotune metadata from benchmark results.""" + if not timings: + return None + + metadata: dict[str, Union[int, float, str]] = { + "num_choices": len(timings), + "num_triton_choices": len( + [c for c in timings if isinstance(c, TritonTemplateCaller)] + ), + } + + sorted_choices = sorted(timings, key=timings.__getitem__) + best_choice = sorted_choices[0] + metadata["best_kernel"] = best_choice.name + if best_choice.description: + metadata["best_kernel_desc"] = best_choice.description + metadata["best_time"] = timings[best_choice] + + best_triton_pos = next( + ( + i + for i, choice in enumerate(sorted_choices) + if isinstance(choice, TritonTemplateCaller) + ), + None, + ) + if best_triton_pos is not None: + metadata["best_triton_pos"] = best_triton_pos + best_triton_kernel = sorted_choices[best_triton_pos] + if best_triton_pos != 0: + metadata["best_triton_time"] = timings[best_triton_kernel] + metadata["best_triton_kernel"] = best_triton_kernel.name + if best_triton_kernel.description: + metadata["best_triton_kernel_desc"] = best_triton_kernel.description + + payload = json.dumps(metadata) + get_chromium_event_logger().add_event_data( + event_name, autotune_choices_stats=payload + ) + sys.stderr.write(f"Autotune Choices Stats:\n{payload}\n") + + +def _log_autotune_exceptions( + exceptions: list[tuple[ChoiceCaller, BaseException]], +) -> None: + """Log autotune exceptions to chromium event logger.""" + if not exceptions: + return + + try: + pt2_compile_substack = get_chromium_event_logger().get_pt2_compile_substack() + if not pt2_compile_substack: + return + + current_event = pt2_compile_substack[-1] + if not current_event.endswith("_template_precompiling"): + return + + exception_details = [] + for choice, exc in exceptions: + try: + choice_type = ( + "triton" if isinstance(choice, TritonTemplateCaller) else "other" + ) + data = { + "choice_type": choice_type, + "choice": choice.description, + "exception_message": str(exc), + } + + exc_type_match = re.search(r"(\w+):", str(exc)) + if exc_type_match: + data["exception"] = exc_type_match.group(1) + + if "OutOfMemoryError" in str(exc): + required_match = re.search(r"Required: (\d+)", str(exc)) + if required_match: + data["required_memory"] = required_match.group(1) + + limit_match = re.search(r"Hardware limit:\s*(\d+)", str(exc)) + if limit_match: + data["hardware_limit"] = limit_match.group(1) + + exception_details.append(data) + except Exception: + # Don't let logging errors break the main flow + continue + + if exception_details: + metadata = json.dumps({"exceptions": exception_details}) + get_chromium_event_logger().try_add_event_data( + current_event, metadata=metadata + ) + except Exception: + # Silently ignore logging errors to avoid breaking autotune + pass + + # ensure lowering is imported so that `extern_kernels.*` is populated from . import lowering # noqa: F401 diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index a26a578755f6..88f635426bfd 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -10,6 +10,7 @@ import torch.fx from torch._dynamo.utils import dynamo_timed +from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir from torch._inductor.utils import BoxedBool, InputType @@ -116,6 +117,7 @@ def save( def load( *, path: str, format: Literal["binary", "unpacked"] = "binary" ) -> CompiledArtifact: + path = normalize_path_separator(path) with dynamo_timed("CompiledArtifact.load"): if format == "binary": # can't assert that it is a file since it might not exist yet diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 62f4bf2618da..57eaef9b4dbb 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -7,11 +7,16 @@ from threading import Lock from typing import Any, Callable, Optional, TYPE_CHECKING +import sympy + import torch from torch.utils._ordered_set import OrderedSet +from torch.utils._triton import has_triton_stable_tma_api -from . import config -from .utils import get_backend_num_stages +from . import config, config as inductor_config +from .kernel_inputs import KernelInputs, MMKernelInputs +from .template_registry import register_template_heuristic +from .utils import get_backend_num_stages, get_num_sms, TMA_DESCRIPTOR_SIZE from .virtualized import V @@ -147,6 +152,12 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): """ def __init__(self) -> None: + # Whether the heuristic is used for int8. Use this when the heuristic is int8 exclusive + # but prefer the preprocess_mm_configs argument when it's used for both + self.has_int8_tensor: bool = False + # Whether to scale configs at all + # TODO(coconutruben): remove this once mm_plus_mm and tests support scaling + self.should_scale_configs: bool = True # List of dictionaries to store the kernel configs. Configs that evaluate to true # will be utilised on the target platform. The configs are as follows: # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) @@ -467,16 +478,18 @@ def _scale_mm_configs( configs: list[BaseConfig], scale: float, has_int8_tensor: bool, - exclude: Callable[[int, int, int], bool], + exclude: Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool], hint_override: Optional[int] = None, ) -> list[BaseConfig]: """ Scales and filters matrix multiplication configs based on input size. """ + if not self.should_scale_configs: + return configs from .runtime.runtime_utils import next_power_of_2 min_block_size = 16 - min_block_size_k = 32 if has_int8_tensor else 16 + min_block_size_k = 32 if (has_int8_tensor or self.has_int8_tensor) else 16 scaled_configs = [] for hint_override in [None] + config.multi_kernel_hints: @@ -561,6 +574,13 @@ def _prune_exhaustive_configs( return pruned_configs + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter configs based on specific requirements. + Subclasses can override this to implement custom filtering logic. + """ + return configs + def preprocess_mm_configs( self, m: int, @@ -568,14 +588,17 @@ def preprocess_mm_configs( k: int, configs: list[BaseConfig], has_int8_tensor: bool = False, - scale: int = 1, - exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, dtype_size: int = 0, + op_name: str = "mm", # For preprocessing overrides e.g. on CPU ) -> Generator[TritonConfig, None, None]: + configs = self._filter_configs(configs) scaled_configs = self._scale_mm_configs( m, n, k, configs, scale, has_int8_tensor, exclude ) - if config.max_autotune_gemm_search_space == "EXHAUSTIVE": assert dtype_size > 0, "dtype_size must be provided for exhaustive search" scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size) @@ -594,49 +617,11 @@ def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs) - def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs) - - def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs) - - def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - mm_configs = ( - self.mm_configs + self.mixed_mm_configs - if config.max_autotune_gemm_search_space == "EXHAUSTIVE" - else self.mm_configs - ) - return partial(self.preprocess_mm_configs, configs=mm_configs) - - def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - persistent_mm_configs = ( - self.exhaustive_configs - if config.max_autotune_gemm_search_space == "EXHAUSTIVE" - else self.persistent_mm_configs - ) - - # num_warps=2 not safe for TMA - persistent_mm_configs = [ - config for config in persistent_mm_configs if config.num_warps != 2 - ] - return partial(self.preprocess_mm_configs, configs=persistent_mm_configs) - - def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs) - - def get_scaled_persistent_mm_configs( - self, - ) -> partial[Generator[TritonConfig, None, None]]: + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: return partial( - self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs + self.preprocess_mm_configs, configs=self.conv_configs, op_name="conv" ) - def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs) - - def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: - return partial(self.preprocess_mm_configs, configs=self.conv_configs) - # Flex attn helpers def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: flex_attn_fwd_configs: list[FlexConfig] = [] @@ -696,7 +681,80 @@ def get_flex_decode_configs( class CPUConfigHeuristic(BaseConfigHeuristic): - pass + """ + CPU-specific config heuristic with CPU-specific optimizations. + """ + + def _get_cpu_exclude_function( + self, method: str = "bmm" + ) -> Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool]: + """ + Get CPU-specific exclude function based on method type. + Returns a function that can be used as exclude condition. + Moved from mm_common._is_large_block_for_cpu and refactored to return a function. + """ + if method in ("conv"): + + def exclude_conv( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + return exclude_conv + elif method in ("mm", "addmm", "int_mm"): + + def exclude_mm( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + return m * n > 2**13 + + return exclude_mm + else: # Default to bmm implementation for unknown methods + + def exclude_bmm( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + return exclude_bmm + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", # For preprocessing overrides e.g. on CPU + ) -> Generator[TritonConfig, None, None]: + """ + CPU-specific preprocessing that applies CPU-specific scaling (0.5) and exclusion logic. + """ + # Get CPU-specific exclude function based on operation type + cpu_exclude_fn = self._get_cpu_exclude_function(op_name) + + # Apply CPU-specific scaling (0.5) and exclusion logic + return super().preprocess_mm_configs( + m, + n, + k, + configs=configs, + has_int8_tensor=has_int8_tensor, + scale=0.5, + exclude=cpu_exclude_fn, + dtype_size=dtype_size, + op_name=op_name, + ) class CUDAConfigHeuristic(BaseConfigHeuristic): @@ -1002,14 +1060,13 @@ def __init__(self) -> None: for wpeu in [0, int(8 // num_warps)] ] - def _filter_configs( - self, configs: list[BaseConfig], new_num_stages: int - ) -> list[BaseConfig]: - # TODO: _filter_configs can be removed once backend specific configs are added - # for all methods + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + ROCm specific filtering + """ for c in configs: c.num_stages = self.default_num_stages - return configs + return super()._filter_configs(configs) def _finalize_mm_configs( self, @@ -1076,57 +1133,6 @@ def _finalize_mm_configs( kwargs["GROUP_M"] = group_m yield self.triton_config(**kwargs) - def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.extra_mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - - def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.int8_mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - - def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - mm_configs = ( - self.mm_configs + self.mixed_mm_configs - if config.max_autotune_gemm_search_space == "EXHAUSTIVE" - else self.mm_configs - ) - filtered_configs = self._filter_configs(mm_configs, self.default_num_stages) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - - def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.persistent_mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - - def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.scaled_mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - - def get_scaled_persistent_mm_configs( - self, - ) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.scaled_persistent_mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - - def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1) - return partial(self._finalize_mm_configs, configs=filtered_configs) - - def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.conv_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) - def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: flex_attn_fwd_configs: list[FlexConfig] = [] @@ -1207,3 +1213,669 @@ class MTIAConfigHeuristic(BaseConfigHeuristic): """ Placeholder child class for MTIA specific overrides. """ + + +# Template-specific mixin classes + + +class TemplateConfigHeuristics: + def get_template_configs( + self, + kernel_inputs: KernelInputs, + layout: Any, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get template configs for the given inputs. + This is the main entry point for template-specific logic. + """ + # NOTE: not an abstract class, because that clashed below for the mixin + # functionality. Can be adjusted, but not a high priority + yield from {} + + +class MMTemplateConfigMixin(TemplateConfigHeuristics): + """ + Mixin class that converts config lists to template kwargs. + This handles the logic that was previously in choices.get_mm_configs. + + This mixin expects to be used with BaseConfigHeuristic or its subclasses. + """ + + # Type annotations to ensure the mixin works with BaseConfigHeuristic + get_mm_configs: Callable[[], partial[Generator[TritonConfig, None, None]]] + get_exhaustive_mm_configs: Callable[ + [], partial[Generator[TritonConfig, None, None]] + ] + _filter_configs: Callable[[list[BaseConfig]], list[BaseConfig]] + + def _get_config_generator( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + """ + Get the appropriate config generator based on search space. + Can be overridden by subclasses for template-specific behavior. + """ + # Handle exhaustive search case + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return self.get_exhaustive_mm_configs() + else: + return self.get_mm_configs() + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + layout: Any, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Convert config lists to template kwargs. + This replaces the logic from choices.get_mm_configs and inlines mm_options. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + input_nodes = kernel_inputs.nodes() + if len(input_nodes) < 2: + raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}") + + # Extract M, N, K from kernel_inputs + m, n, k = kernel_inputs.mnk_symbolic() + + # Extract dtype and device_type from kernel_inputs + dtype = kernel_inputs.dtype() + + # Get the appropriate config generator + configs = self._get_config_generator() + + # Generate and process configs + for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name): + template_kwargs = self._convert_config_to_template_kwargs( + c, m, n, k, layout + ) + yield template_kwargs + + def _convert_config_to_template_kwargs( + self, + triton_config: TritonConfig, + m: sympy.Integer, + n: sympy.Integer, + k: sympy.Integer, + layout: Any, + ) -> dict[str, Any]: + """ + Convert triton config to template kwargs. + Moved from mm_common.mm_options. + """ + # Calculate EVEN_K symbolic + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(k, triton_config.kwargs["BLOCK_K"]) + == triton_config.kwargs["BLOCK_K"] + ) + + # Calculate allow_tf32 + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0) + ) + + # Build options dict + options_dict = dict( + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm + ACC_TYPE=self._get_acc_type(layout.dtype), + num_stages=triton_config.num_stages, + num_warps=triton_config.num_warps, + **triton_config.kwargs, + ) + + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in triton_config.kwargs: + group_m = triton_config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + + def _get_acc_type(self, dtype: torch.dtype) -> str: + """ + Get accumulator type for the given dtype. + Moved from mm_common.acc_type. + """ + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +# INT8 specific mixin to filter correctly +class INT8MMTemplateConfigMixin(MMTemplateConfigMixin): + """ + Ensure that we feed in has_int8_tensor=True + """ + + def __init__(self) -> None: + super().__init__() + self.has_int8_tensor = True + + +# MMPlusMM specific mixin to avoid running _scale_mm_configs +class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin): + """ + Ensure that _should_scale_configs is False + """ + + # TODO(coconutruben): remove this once all tests work + # with proper scaling on mm_plus_mm + def __init__(self) -> None: + super().__init__() + self.should_scale_configs = False + + +# TMA-specific mixin for TMA templates +class TMAConfigMixin(MMTemplateConfigMixin): + """ + TMA-specific mixin that uses persistent configs and adds TMA options. + This inherits from MMTemplateConfigMixin and overrides config generation. + """ + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + TMA specific filtering, as num_warps=2 not safe for TMA + """ + configs = [c for c in configs if c.num_warps != 2] + return super()._filter_configs(configs) + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + layout: Any, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate TMA template configs by calling super and adding TMA-specific options. + """ + # Get base template configs from superclass + for template_kwargs in super().get_template_configs( + kernel_inputs, layout, op_name + ): + # Add TMA-specific options (moved from mm_common.persistent_mm_options) + input_nodes = kernel_inputs.nodes() + self._add_tma_options(template_kwargs, input_nodes) + yield template_kwargs + + def _add_tma_options( + self, template_kwargs: dict[str, Any], input_nodes: list[Any] + ) -> None: + """ + Add TMA-specific options to template kwargs. + Moved from mm_common.persistent_mm_options and mm_common.tma_options. + """ + # For TMA templates, we need the actual matrix tensors + mat1 = input_nodes[-2] + mat2 = input_nodes[-1] + + tma_opts = { + "A_ROW_MAJOR": not mat1.layout.is_transposed(), + "B_ROW_MAJOR": not mat2.layout.is_transposed(), + "NUM_SMS": get_num_sms(), + "TMA_SIZE": TMA_DESCRIPTOR_SIZE, + "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), + } + template_kwargs.update(tma_opts) + + +# Scaled MM-specific mixin for scaled MM templates (non-TMA) +class ScaledMMConfigMixin(MMTemplateConfigMixin): + """ + Scaled MM-specific mixin that uses scaled configs and adds scaled MM options. + This is for non-TMA scaled MM templates only. + This inherits from MMTemplateConfigMixin and overrides config generation. + """ + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + layout: Any, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate scaled MM template configs with scaled MM-specific options. + Handles the remaining logic from mm_common including assertions and SCALING_ROWWISE. + """ + input_nodes = kernel_inputs.nodes() + + # Initial assertion from mm_common.scaled_mm_options + assert len(input_nodes) >= 4, ( + f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}" + ) + + # Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3]) + scale_a = input_nodes[2] + scale_b = input_nodes[3] + + # Scale compatibility assertion from mm_common.scaled_mm_options + def are_compatible_scales(size_a: Any, size_b: Any) -> bool: + # Same sized scales are compatible + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + # Get base template configs from superclass + for template_kwargs in super().get_template_configs( + kernel_inputs, layout, op_name + ): + # Add scaled MM-specific options (moved from mm_common.scaled_mm_options) + # Override accumulator type for scaled MM + template_kwargs["ACC_TYPE"] = "tl.float32" + # Add SCALING_ROWWISE attribute based on scale_a tensor shape + template_kwargs["SCALING_ROWWISE"] = len(size_a) == 2 + + yield template_kwargs + + +# Scaled TMA-specific mixin for scaled MM templates with TMA +class ScaledTMAConfigMixin(ScaledMMConfigMixin): + """ + Scaled TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality. + This is for scaled MM templates that use device TMA. + This inherits from ScaledMMConfigMixin and adds TMA-specific options. + """ + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + TMA specific filtering, as num_warps=2 not safe for TMA + """ + configs = [c for c in configs if c.num_warps != 2] + return super()._filter_configs(configs) + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + layout: Any, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate scaled TMA template configs with both scaled MM and TMA-specific options. + """ + # Get base scaled MM template configs from superclass + for template_kwargs in super().get_template_configs( + kernel_inputs, layout, op_name + ): + # Add TMA-specific options for device TMA scaled MM + template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + template_kwargs["NUM_SMS"] = get_num_sms() + template_kwargs["TMA_EXPERIMENTAL_API"] = not has_triton_stable_tma_api() + + yield template_kwargs + + +# Template-specific heuristic classes using multiple inheritance + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("mm", "cuda", register=torch.version.hip is None) +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("bmm", "cuda", register=torch.version.hip is None) +class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + """Standard MM template heuristic for CUDA""" + + +# TODO(coconutruben): deprecate once autoheuristic is deprecated +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is None) +class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + """Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.extra_mm_configs + self.exhaustive_configs = self.extra_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "mm_persistent_tma", "cuda", register=torch.version.hip is None +) +class CUDAPersistentTMATemplateConfigHeuristic(TMAConfigMixin, CUDAConfigHeuristic): + """Persistent TMA template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use persistent_mm_configs + self.mm_configs = self.persistent_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm" +) +class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic): + """Scaled MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "scaled_mm_device_tma", "cuda", register=torch.version.hip is None +) +class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic): + """Scaled TMA template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + self.mm_configs = self.scaled_persistent_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_persistent_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("mm_plus_mm", "cuda", register=torch.version.hip is None) +class CUDAMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, CUDAConfigHeuristic +): + """MM Plus MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "mm", "cuda", register=torch.version.hip is None, op_name="int_mm" +) +class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic): + """Int8 MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +# ROCm template-specific classes + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("mm", "cuda", register=torch.version.hip is not None) +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("bmm", "cuda", register=torch.version.hip is not None) +class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): + """Standard MM template heuristic for ROCm""" + + +# TODO(coconutruben): deprecate once autoheuristic is deprecated +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None) +class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): + """Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.extra_mm_configs + self.exhaustive_configs = self.extra_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "mm", "cuda", register=torch.version.hip is not None, op_name="scaled_mm" +) +class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic): + """Scaled MM template heuristic for ROCm (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "mm", "cuda", register=torch.version.hip is not None, op_name="int_mm" +) +class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic): + """Int8 MM template heuristic for ROCm""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + "mm_plus_mm", "cuda", register=torch.version.hip is not None +) +class ROCmMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, ROCmConfigHeuristic +): + """MM Plus MM template heuristic for ROCm""" + + def __init__(self) -> None: + super().__init__() + # self.default_num_stages is used to make sure all configs have that in ROCm land + # for mm_plus_mm, we actually just want stages = 1, as pipelining brings no benefits + self.default_num_stages = 1 + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# CPU template-specific classes + + +@register_template_heuristic("mm", "cpu") +@register_template_heuristic("bmm", "cpu") +class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic): + """Standard MM template heuristic for CPU""" + + +@register_template_heuristic("mm", "cpu", op_name="scaled_mm") +class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic): + """Scaled MM template heuristic for CPU (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic("mm", "cpu", op_name="int_mm") +class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic): + """Int8 MM template heuristic for CPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic("mm_plus_mm", "cpu") +class CPUMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, CPUConfigHeuristic +): + """MM Plus MM template heuristic for CPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# XPU template-specific classes + + +@register_template_heuristic("mm", "xpu") +@register_template_heuristic("bmm", "xpu") +class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic): + """Standard MM template heuristic for XPU""" + + +@register_template_heuristic("mm", "xpu", op_name="scaled_mm") +class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic): + """Scaled MM template heuristic for XPU (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic("mm", "xpu", op_name="int_mm") +class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic): + """Int8 MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic("mm_plus_mm", "xpu") +class XPUMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, XPUConfigHeuristic +): + """MM Plus MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# MTIA template-specific classes + + +@register_template_heuristic("mm", "mtia") +@register_template_heuristic("bmm", "mtia") +class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic): + """Standard MM template heuristic for MTIA""" + + +@register_template_heuristic("mm", "mtia", op_name="scaled_mm") +class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic): + """Scaled MM template heuristic for MTIA (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic("mm", "mtia", op_name="int_mm") +class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic): + """Int8 MM template heuristic for MTIA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic("mm_plus_mm", "mtia") +class MTIAMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, MTIAConfigHeuristic +): + """MM Plus MM template heuristic for MTIA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs diff --git a/torch/_inductor/template_registry.py b/torch/_inductor/template_registry.py new file mode 100644 index 000000000000..d11343e63f0f --- /dev/null +++ b/torch/_inductor/template_registry.py @@ -0,0 +1,98 @@ +""" +Template heuristic registry system for PyTorch Inductor. + +This module provides a centralized registration system for template heuristics, +allowing automatic registration based on device type and conditional registration +for CUDA vs ROCm based on torch.version.hip. +""" + +from __future__ import annotations + +import logging +from functools import cache +from typing import Any, Optional, TYPE_CHECKING + + +if TYPE_CHECKING: + from .template_heuristics import TemplateConfigHeuristics + +# Module-wide registry for template heuristics +_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {} + +log = logging.getLogger(__name__) + + +def register_template_heuristic( + template_name: str, + device_type: str, + register: bool = True, + op_name: Optional[str] = None, +) -> Any: + """ + Decorator to register template heuristic classes. + + Args: + template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") + device_type: Device type ("cuda", "cpu", "xpu") + register: Whether to register this heuristic. Caller should pass the condition directly. + op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional + and is only used when a template uses different heuristics for different ops + + Returns: + Decorator function that registers the class if conditions are met. + + Example: + @register_template_heuristic("mm", "cuda", register=torch.version.hip is None) + class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + pass + """ + + def decorator( + cls: type[TemplateConfigHeuristics], + ) -> type[TemplateConfigHeuristics]: + if register: + key: tuple[str, ...] = (device_type, template_name) + if op_name is not None: + key = (device_type, template_name, op_name) + _TEMPLATE_HEURISTIC_REGISTRY[key] = cls + log.info( + f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004 + ) + return cls + + return decorator + + +@cache +def get_template_heuristic( + template_name: str, device_type: str, op_name: str +) -> TemplateConfigHeuristics: + """ + Retrieve a template heuristic instance for the given template and device type. + + Args: + template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") + device_type: Device type ("cuda", "cpu", "xpu") + + Returns: + Template heuristic instance. + + Raises: + ValueError: If no heuristic is found for the given combination. + """ + # First check the more specific key + keys = [(device_type, template_name, op_name), (device_type, template_name)] + + # Look up in registry + heuristic_class = None + for key in keys: + if key in _TEMPLATE_HEURISTIC_REGISTRY: + heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key] + break + if heuristic_class is None: + raise ValueError( + f"No template heuristic found for '{template_name=}', " + f"'{device_type=}', '{op_name=}'. " + f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}" + ) + return heuristic_class() diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index cd69e3950626..0418edb2a115 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -856,7 +856,9 @@ def stringfy_layout(layout: ir.Layout | None) -> str: all_writes.append("%" + output_name) for node in inductor_nodes: - detailed_metadata.append(f"{wrapper.comment} {node.format_node()}") + detailed_metadata.append( + f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}" + ) detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}") @@ -893,7 +895,15 @@ def is_unrealized_node(n: IRNode) -> bool: return is_unrealized_node(n.data) if isinstance(n, ir.StorageBox): return is_unrealized_node(n.data) - return isinstance(n, ir.IRNode) and not ir.IRNode.is_realized_node(n) + return isinstance(n, ir.IRNode) and not isinstance( + n, + ( + ir.ComputedBuffer, + ir.InputsKernel, + ir.InputBuffer, + ir.TemplateBuffer, + ), + ) # kwargs and args may include a container of node, for example torch.cat([t1, t2]) # flatten them before search the unrealized nodes @@ -2965,6 +2975,26 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool: # (e.g., via ValueRanges) that it is still in bounds if V.graph.sizevars.statically_known_true(e <= int_max): return True + + # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option, + # in case the hinted value is < 2**32, but the allowed range is larger. + # However, to prevent possible perf regressions on pre-existing AOTI models + # which don't set an upper bound on the valid range, we'll skip the check. + # To recap: + # - If using AOTI: + # - If allowed range has no upper bound, then check the hint to determine + # whether this fits in int32 + # - If allowed range does have an upper bound, then obey the upper bound + # (check whether upper bound < int32_max) without checking the hint. + + if V.aot_compilation: + # check whether value has an upper bound (1e20 is > INT64_MAX, assume + # there is no upper bound if it can be larger than 1e20) + if V.graph.sizevars.statically_known_true(e < 1e20): + # if so, then assume int_max < upper bound < inf + # so this could potentially have int64 values + return False + # Otherwise, the hint MUST exist and be in range return has_hint(e) and size_hint(e) <= int_max @@ -3299,6 +3329,13 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: ) +def is_using_cudagraph_partition() -> bool: + return ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + + def dtype_from_size(size: int) -> torch.dtype: from .virtualized import V @@ -3344,13 +3381,12 @@ def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str: for i, e in enumerate(row): widths[i] = max(widths[i], len(str(e))) lines = [] - # Need nested {} for string formatting; ignore SET_LINTER here - lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # noqa: set_linter + lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # widths whitespace horizontal separators total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1) lines.append("-" * total_width) for row in elements: - lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) # noqa: set_linter + lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) return "\n".join(lines) @@ -3405,20 +3441,36 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An Returns: dict[str, Any]: The possibly-updated `config_patches` dictionary. """ + + def patch_config( + config_patches: dict[str, Any], config_name: str, config_value: Any + ) -> None: + value = config_patches.get(config_name, getattr(config, config_name)) + if value is None: + config_patches[config_name] = config_value + elif not value and value != config_value: + raise RuntimeError( + f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True." + ) + compile_standalone = config_patches.get( "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone ) + # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing + config_patches = config_patches.copy() if compile_standalone: - package_cpp_only = config_patches.get( - "aot_inductor.package_cpp_only", config.aot_inductor.package_cpp_only + # Standlaone AOTInductor means only generate cpp project for building a standalone binary + patch_config(config_patches, "aot_inductor.package_cpp_only", True) + # Standlaone AOTInductor needs to embed the kernel code in the binary + patch_config(config_patches, "aot_inductor.embed_kernel_binary", True) + # Default to use multi-arch kernel codegen for non-rocm GPU + patch_config( + config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip ) - if package_cpp_only is None: - config_patches = {**config_patches, "aot_inductor.package_cpp_only": True} - elif not package_cpp_only: - raise RuntimeError( - "compile_standalone=True requires package_cpp_only=True. " - "Please set aot_inductor.package_cpp_only=True in your inductor config." - ) + patch_config( + config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model" + ) + return config_patches @@ -3449,14 +3501,6 @@ def is_valid_aoti_model_name() -> bool: return True -def aoti_model_name_from_config() -> str: - from torch._inductor import config - - model_name = config.aot_inductor.model_name_for_generated_files - model_name = "aoti_model" if model_name is None else model_name - return model_name - - def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: if unbacked_only: return free_unbacked_symbols(x) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index bd8acb2789e1..251cdefe0f05 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -210,6 +210,7 @@ def __init__( self._lib = get_library_allowing_overwrite(self._namespace, self._name) self._register_to_dispatcher(self._tags) self._disabled_kernel: set = set() + self._used_triton_kernels: list[Any] = list() OPDEFS[self._qualname] = self @property diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 72805c765d86..741b341f7e21 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -1,4 +1,6 @@ +import ast import contextlib +import inspect import threading from collections.abc import Generator, Iterable from typing import Any, Callable, Optional, Union @@ -9,6 +11,79 @@ from .infer_schema import infer_schema +triton_ops_to_kernels: dict[str, list[object]] = {} + + +def get_triton_kernels_for_op(name: str) -> list[object]: + return triton_ops_to_kernels.get(name, []) + + +def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """ + Inspect the source of an arbitrary callable passed to torch._library.triton_op, + and grab all of the triton kernels that are wrapped inside of it. + + TODO: This check is best effort. It does *not* handle the case where the triton + kernel is hidden behind recursive function calls. + """ + + def find_triton_kernels(fn: Callable[..., Any]) -> list[object]: + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return [] # Source code not available + + from torch._inductor.utils import IndentedBuffer + + buffer = IndentedBuffer() + buffer.splice(source, strip=True) + tree = ast.parse(buffer.getrawvalue()) + + # Visitor to collect function calls and triton kernels + class Visitor(ast.NodeVisitor): + def __init__(self) -> None: + self.triton_kernels: list[Any] = [] + + def visit_Call(self, node: ast.Call) -> None: + triton_func_names = ("capture_triton", "wrap_triton") + if isinstance(node.func, ast.Attribute): + attr = node.func + if ( + isinstance(attr.value, ast.Attribute) + and isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr == "_library" + and attr.attr in triton_func_names + ): + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + # Catch capture_triton, wrap_triton that's been + # imported directly + elif isinstance(node.func, ast.Name): + if node.func.id in triton_func_names: + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + self.generic_visit(node) + + collector = Visitor() + collector.visit(tree) + closure_vars = inspect.getclosurevars(fn) + resolved = [] + # First, resolve triton kernel names + for name in collector.triton_kernels: + if name in closure_vars.nonlocals: + resolved.append(closure_vars.nonlocals[name]) + elif name in closure_vars.globals: + resolved.append(closure_vars.globals[name]) + elif name in closure_vars.builtins: + resolved.append(closure_vars.builtins[name]) + return resolved + + return find_triton_kernels(fn) + + @exposed_in("torch.library") def triton_op( name: str, @@ -155,9 +230,28 @@ def functional_decomp( # type: ignore[no-untyped-def] if custom_triton_ops_decomposition_disabled(): return mode.__torch_dispatch__(op, types, args, kwargs) else: + # TODO: https://github.com/pytorch/pytorch/issues/160333 + # We should deduplicate the unrecognized_types logic. + import torch._subclasses + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t + not in [ + torch.Tensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ] + ] + + if unrecognized_types: + return NotImplemented with mode: return fn(*args, **kwargs) + triton_kernels = get_inner_triton_kernels(fn) + triton_ops_to_kernels[name] = triton_kernels result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 940318520452..59a316acc69a 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -3,7 +3,7 @@ import inspect import sys from collections.abc import Iterable, Iterator -from typing import Any, Callable, Union +from typing import Any, Callable, Literal, Optional, overload, Union import torch import torch.utils._pytree as pytree @@ -501,6 +501,20 @@ def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str] ] +# Case 1: with_default=True (or omitted). Return type is guaranteed to be a Tag. +@overload +def get_layout_constraint_tag( + fn: Any, *, with_default: Literal[True] = True +) -> _C.Tag: ... + + +# Case 2: with_default=False. Return type can be a Tag or None. +@overload +def get_layout_constraint_tag( + fn: Any, *, with_default: Literal[False] +) -> Optional[_C.Tag]: ... + + def get_layout_constraint_tag(fn, *, with_default=True): for tag in tags_by_priority: if tag in fn.tags: diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index ffd3160b47ee..c4bdeceeb494 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -726,8 +726,49 @@ def _invalid_settings_err_msg(settings, verbose=False): return msg +def process_env_var_string_for_windows(env_var_str: str) -> str: + """ + When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html + Such as: + TORCH_LOGS="+schedule,+inductor,+output_code" + + On Linux, it shows as: + declare -x SSH_TTY="/dev/pts/0" + declare -x TERM="xterm" + declare -x TORCH_LOGS="+schedule,+inductor,+output_code" + declare -x USER="xu" + + On Windows, it shows as: + TORCHINDUCTOR_WINDOWS_TESTS=1 + TORCH_LOGS="+schedule,+inductor,+output_code" + UCRTVersion=10.0.22000.0 + + For Linux, it shows quotes by default, And Windows is not shows quotes. + Besides that, Windows would auto assemble quotes when env var processing. + On Linux, we will get variable: "+schedule,+inductor,+output_code" + On Windows, we will get variable: '"+schedule,+inductor,+output_code"' + + So, we need remove the outer quotes for Windows. + """ + _IS_WINDOWS = sys.platform == "win32" + + def remove_outer_quotes(s: str) -> str: + if len(s) >= 2 and ( + (s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'") + ): + return s[1:-1] + return s + + if _IS_WINDOWS: + env_var_str = remove_outer_quotes(env_var_str) + + return env_var_str + + @functools.lru_cache def _parse_log_settings(settings): + settings = process_env_var_string_for_windows(settings) + if settings == "": return {} diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index a61fc8559357..fc16cf58c640 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5861,6 +5861,61 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( return grad_q, grad_k, grad_v +@register_meta([aten._scaled_dot_product_attention_math_for_mps]) +def meta__scaled_dot_product_attention_math_for_mps( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + dropout_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +) -> tuple[Tensor, Tensor]: + def ensure_4d(x): + if x.dim() == 3: + return x.unsqueeze(0), True + elif x.dim() > 4: + batch_size = 1 + for i in range(x.dim() - 3): + batch_size *= x.shape[i] + return x.view(batch_size, x.size(-3), x.size(-2), x.size(-1)), True + else: + return x, False + + q_, unsqueezed = ensure_4d(query) + k_, _ = ensure_4d(key) + v_, _ = ensure_4d(value) + + batch_size, num_head, q_size, head_size = q_.shape + _, k_size, max_seq_length, _ = k_.shape + + def sdpa_vector_fast_mps(): + out = q_.new_empty(q_.shape) + if unsqueezed: + out = out.view_as(query) + + attn = q_.new_empty((batch_size, num_head, q_size, max_seq_length)) + if unsqueezed: + if query.dim() == 3: + attn = attn.squeeze(0) + else: + shape = list(query.shape[:-3]) + attn.shape[1:4] + attn = attn.view(shape) + return out, attn + + def sdpa_vector_2pass_mps(): + blocks = 32 + out = q_.new_empty(q_.shape) + intermediate = q_.new_empty((batch_size, num_head, q_size, blocks, head_size)) + return out, intermediate + + if (max_seq_length >= 1024) or (k_size < q_size and max_seq_length >= 4096): + return sdpa_vector_2pass_mps() + else: + return sdpa_vector_fast_mps() + + @register_meta([aten._scaled_dot_product_efficient_attention]) def meta__scaled_dot_product_efficient_attention( query: Tensor, @@ -7025,8 +7080,7 @@ def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): @register_meta([aten.nan_to_num.default, aten.nan_to_num.out]) @out_wrapper() def nan_to_num(self, nan=None, posinf=None, neginf=None): - result_size = list(self.size()) - return self.new_empty(result_size) + return torch.empty_like(self) @register_meta(torch.ops.aten.transpose_) @@ -7369,6 +7423,12 @@ def _meta_grouped_mm_common( mat_a_is_2d = mat_a.dim() == 2 mat_b_is_2d = mat_b.dim() == 2 + if not mat_a_is_2d or not mat_b_is_2d: + torch._check( + mat_a.size(-1) == mat_b.size(-2), + "contraction dimension of mat_a and mat_b must match", + ) + if scaled: def is_row_major(mat): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 6739b334c116..bb26bbb508bd 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -302,7 +302,7 @@ def _backend_select_impl(*args, **kwargs): else: return _prim_impl(*args, **kwargs) - name = schema.split("(")[0] + name = schema.split("(", maxsplit=1)[0] schema = schema[len(name) :] # register non-functional ops with old custom ops API diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 4d33280f7ac8..7ebd2ec92d12 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -514,6 +514,8 @@ def maybe_guard_or_true(x): def _view_has_unbacked_input(a, shape): from torch.fx.experimental.symbolic_shapes import has_hint + shape = utils.extract_shape_from_varargs(shape, validate=False) + return ( any(not has_hint(s) for s in a.size()) or any(not has_hint(s) for s in a.stride()) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e7d9e1fc23b4..52b776946b36 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -940,6 +940,21 @@ def merge_devices(t: object) -> None: if any(map(check_cpu_device, (common_device, t.device))): return + # if prefer_device_type is set, prefer that device type over others + prefer_device_type = torch._functorch.config.fake_tensor_prefer_device_type + if prefer_device_type is not None: + common_has_preferred = prefer_device_type in common_device.type + t_has_preferred = prefer_device_type in t.device.type + + if not common_has_preferred and t_has_preferred: + # Switch to the preferred device type + common_device = t.device + is_cpu_zero_dim = t_is_cpu_zero_dim + return + elif common_has_preferred and not t_has_preferred: + # Keep the existing preferred device type + return + # mismatching devices of non-zero dim tensors, throw # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as raise RuntimeError( diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 921e97be233a..c9262e1b2ee0 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -178,14 +178,18 @@ def __init__(self, tensor): self.int_mode = False break + self.sci_mode = ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + or nonzero_finite_min < 1.0e-4 + if PRINT_OPTS.sci_mode is None + else PRINT_OPTS.sci_mode + ) + if self.int_mode: # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites # to indicate that the tensor is of floating type. add 1 to the len to account for this. - if ( - nonzero_finite_max / nonzero_finite_min > 1000.0 - or nonzero_finite_max > 1.0e8 - ): - self.sci_mode = True + if self.sci_mode: for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) self.max_width = max(self.max_width, len(value_str)) @@ -195,12 +199,7 @@ def __init__(self, tensor): self.max_width = max(self.max_width, len(value_str) + 1) else: # Check if scientific representation should be used. - if ( - nonzero_finite_max / nonzero_finite_min > 1000.0 - or nonzero_finite_max > 1.0e8 - or nonzero_finite_min < 1.0e-4 - ): - self.sci_mode = True + if self.sci_mode: for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) self.max_width = max(self.max_width, len(value_str)) @@ -209,9 +208,6 @@ def __init__(self, tensor): value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) self.max_width = max(self.max_width, len(value_str)) - if PRINT_OPTS.sci_mode is not None: - self.sci_mode = PRINT_OPTS.sci_mode - def width(self): return self.max_width diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 1980a6801f62..1713c39e39b1 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5653,11 +5653,11 @@ def merge_dicts(*dicts): >>> torch.is_nonzero(torch.tensor([1, 3, 5])) Traceback (most recent call last): ... - RuntimeError: bool value of Tensor with more than one value is ambiguous + RuntimeError: Boolean value of Tensor with more than one value is ambiguous >>> torch.is_nonzero(torch.tensor([])) Traceback (most recent call last): ... - RuntimeError: bool value of Tensor with no values is ambiguous + RuntimeError: Boolean value of Tensor with no values is ambiguous """.format(**common_args), ) @@ -12556,8 +12556,8 @@ def merge_dicts(*dicts): add_docstr( torch.hamming_window, """ -hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, \ -layout=torch.strided, device=None, requires_grad=False) -> Tensor +hamming_window(window_length, *, dtype=None, layout=None, device=None, pin_memory=False, \ +requires_grad=False) -> Tensor """ + r""" Hamming window function. @@ -12585,16 +12585,82 @@ def merge_dicts(*dicts): + r""" Arguments: window_length (int): the size of returned window - periodic (bool, optional): If True, returns a window to be used as periodic + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {pin_memory} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window. + +.. function:: hamming_window(window_length, periodic, *, dtype=None, layout=None, device=None, \ + pin_memory=False, requires_grad=False) -> Tensor + :noindex: + +Hamming window function with periodic specified. + +Arguments: + window_length (int): the size of returned window + periodic (bool): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {pin_memory} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window. + +.. function:: hamming_window(window_length, periodic, float alpha, *, dtype=None, layout=None, device=None, \ + pin_memory=False, requires_grad=False) -> Tensor + :noindex: + +Hamming window function with periodic and alpha specified. + +Arguments: + window_length (int): the size of returned window + periodic (bool): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float): The coefficient :math:`\alpha` in the equation above + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {pin_memory} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window. + +.. function:: hamming_window(window_length, periodic, float alpha, float beta, *, dtype=None, layout=None, \ + device=None, pin_memory=False, requires_grad=False) -> Tensor + :noindex: + +Hamming window function with periodic, alpha and beta specified. + +Arguments: + window_length (int): the size of returned window + periodic (bool): If True, returns a window to be used as periodic function. If False, return a symmetric window. - alpha (float, optional): The coefficient :math:`\alpha` in the equation above - beta (float, optional): The coefficient :math:`\beta` in the equation above + alpha (float): The coefficient :math:`\alpha` in the equation above + beta (float): The coefficient :math:`\beta` in the equation above Keyword args: {dtype} Only floating point types are supported. layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only ``torch.strided`` (dense layout) is supported. {device} + {pin_memory} {requires_grad} Returns: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 8c448adb0c6a..f2613e734bbf 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -117,6 +117,10 @@ def signpost_event(category: str, name: str, parameters: dict[str, Any]): log.info("%s %s: %r", category, name, parameters) +def add_mlhub_insight(category: str, insight: str, insight_description: str): + pass + + def log_compilation_event(metrics): log.info("%s", metrics) @@ -350,3 +354,7 @@ def get_default_numa_options(): Must return None or NumaOptions, but not specifying to avoid circular import. """ return None + + +def log_triton_builds(fail: Optional[str]): + pass diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 2352bb836a9d..745cdd315a63 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -403,9 +403,12 @@ def load(self): func not in _get_allowed_globals().values() and func not in _get_user_allowed_globals().values() ): - raise UnpicklingError( + error_msg = ( f"Trying to call reduce for unrecognized function {func}" ) + if hasattr(func, "__self__"): + error_msg += f" which belongs to {func.__self__}" + raise UnpicklingError(error_msg) result = func(*args) if func in torch._tensor_classes and "sparse" in func.__module__: _sparse_tensors_to_validate.append(result) diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e9e48f1cf306..4d1a78df1f74 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,6 +8,16 @@ import torch from ._utils import _device_t, _get_device_index +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) __all__ = [ @@ -15,9 +25,17 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", + "empty_cache", "device_count", "device_index", "is_available", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py new file mode 100644 index 000000000000..d34a11a3a02e --- /dev/null +++ b/torch/accelerator/memory.py @@ -0,0 +1,201 @@ +from collections import OrderedDict +from typing import Any + +import torch + +from ._utils import _device_t, _get_device_index + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other application. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return + torch._C._accelerator_emptyCache() + + +def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: + r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from device memory allocation. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of June 2025, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of June 2025, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed device memory allocation calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of device memory allocation calls. + - ``"num_device_free"``: number of device memory free calls. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return OrderedDict() + device_index = _get_device_index(device_index, optional=True) + stats = torch._C._accelerator_getDeviceStats(device_index) + flat_stats = [] + + def flatten(prefix: str, value: Any) -> None: + if isinstance(value, dict): + for k, v in value.items(): + nested_prefix = f"{prefix}.{k}" if prefix else k + flatten(nested_prefix, v) + else: + flat_stats.append((prefix, value)) + + flatten("", stats) + flat_stats.sort() + return OrderedDict(flat_stats) + + +def memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory occupied by tensors + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors + in bytes for a given device index. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory managed by the caching allocator + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator + in bytes for a given device index. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.peak", 0) + + +def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetAccumulatedStats(device_index) + + +def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "peak" stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 3e81f46f5c23..f93c050f4508 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -397,7 +397,10 @@ def __enter__(self): self._enabled, self._cache_enabled, ) - return mode.__torch_function__(torch.amp._enter_autocast, (), args) + mode.__torch_function__(torch.amp._enter_autocast, (), args) + return self + + return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): @@ -420,7 +423,10 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[ov mode, torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, ): - return mode.__torch_function__(torch.amp._exit_autocast, (), ()) + mode.__torch_function__(torch.amp._exit_autocast, (), ()) + # This is very important because the above line actually doesn't + # run exit code so it end up swallowing exceptions. + return False return False def __call__(self, func): diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 302f7e0b0b7c..47185aeea527 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -98,7 +98,7 @@ def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, # string manip to split tensor_fqn into module_fqn and tensor_name # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' - tensor_name = tensor_fqn.split(".")[-1] + tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1] module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] module = fqn_to_module(model, module_fqn) diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index bf643a97f60f..4b2707b65d0f 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -194,6 +194,9 @@ class GradientEdge(NamedTuple): node: Node output_nr: int + # This token can be used to ensure the graph stays alive when it cannot be + # done via the node field + ownership_token: Optional[Node] = None def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: @@ -209,9 +212,18 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: ) grad_fn = _get_grad_fn_or_grad_acc(tensor) + # Python-based Node are owned by the C++ side meaning the python grad_fn + # object we hold here does NOT keep the C++ graph alive. + # Create an ownership token by creating a new C++ node that own the graph + # we care about here. + token = None + if isinstance(grad_fn, torch._C._FunctionBase): + with torch.enable_grad(): + token = tensor.view_as(tensor).grad_fn + # Note that output_nr default to 0 which is the right value # for the AccumulateGrad node. - return GradientEdge(grad_fn, tensor.output_nr) + return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token) def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None: diff --git a/torch/compiler/config.py b/torch/compiler/config.py index dc9c82a5200e..ceb8f41db844 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -66,6 +66,18 @@ A common use case for such a tag is to break caches. """ +force_disable_caches: bool = Config( + justknob="pytorch/remote_cache:force_disable_caches", + env_name_force=[ + "TORCHINDUCTOR_FORCE_DISABLE_CACHES", + "TORCH_COMPILE_FORCE_DISABLE_CACHES", + ], + default=False, +) +""" +Force disables all caching -- This will take precedence over and override any other caching flag +""" + dynamic_sources: str = Config( env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default="" ) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 3a97c0794684..59cb8047467c 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -77,6 +77,70 @@ void initModule(PyObject* module) { m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); + + m.def("_accelerator_isAllocatorInitialized", []() { + const auto device_type = at::accelerator::getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->initialized(); + }); + + m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); + + m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { + using c10::CachingAllocator::Stat; + using c10::CachingAllocator::StatArray; + using c10::CachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + + const auto stats = at::accelerator::getDeviceStats(device_index); + const auto stat_to_dict = [](const Stat& stat) -> py::dict { + py::dict dict; + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { + const std::array(StatType::NUM_TYPES)> + kStatTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(kStatTypeNames.size())) { + dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); + } + return dict; + }; + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); + result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); + result["active_bytes"] = stat_array_to_dict(stats.active_bytes); + result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); + result["allocation"] = stat_array_to_dict(stats.allocation); + result["segment"] = stat_array_to_dict(stats.segment); + result["active"] = stat_array_to_dict(stats.active); + result["inactive_split"] = stat_array_to_dict(stats.inactive_split); + result["inactive_split_bytes"] = + stat_array_to_dict(stats.inactive_split_bytes); + result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); + result["oversize_segments"] = stat_to_dict(stats.oversize_segments); + return result; + }); + + m.def( + "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetAccumulatedStats(device_index); + }); + + m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetPeakStats(device_index); + }); } } // namespace torch::accelerator diff --git a/torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h b/torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h new file mode 100644 index 000000000000..866e09b13407 --- /dev/null +++ b/torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +// We don't want to forward declare in general but including ModelRunner will +// pollute the public API namespace too much. Therefore, we just use pimpl an +// incomplete ModelRunner here. +class ModelRunner; + +class TORCH_API ModelRunnerHandle { + public: + ModelRunnerHandle( + const std::string& packagePath, + const std::string& modelName); + + ModelRunnerHandle(ModelRunnerHandle&&) = default; + ModelRunnerHandle& operator=(ModelRunnerHandle&&) = default; + ModelRunnerHandle(const ModelRunnerHandle&) = delete; + ModelRunnerHandle& operator=(const ModelRunnerHandle&) = delete; + ~ModelRunnerHandle(); + + c10::IValue run( + const std::vector& args, + const std::unordered_map& kwargs); + + /** + * A low level API which expects user to always pass in flattened inputs. + * The ownership of the entire input list must be transferred to the + * executor via std::move or in-place construction. + */ + std::vector runWithFlatInputsAndOutputs( + std::vector flatInputs); + + private: + std::unique_ptr impl_; +}; + +} // namespace torch::nativert diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 4e8cb2efca0e..f0024f8f0b07 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -979,13 +979,13 @@ static void validate_outputs_impl( } if (grad.device() != metadata.device()) { - // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but - // should be eventually removed - if (!(metadata.is_tensor_subclass() || - grad.unsafeGetTensorImpl()->is_python_dispatch())) { - if (grad.dim() == 0) { - grad = grad.to(metadata.device()); - } else { + if (grad.dim() == 0) { + grad = grad.to(metadata.device()); + } else { + // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but + // should be eventually removed + if (!(metadata.is_tensor_subclass() || + grad.unsafeGetTensorImpl()->is_python_dispatch())) { std::stringstream ss; ss << "invalid gradient at index " << i << " - expected device "; ss << metadata.device() << " but got " << grad.device(); diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index fd672a48502a..7c6792f5e698 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -674,6 +674,9 @@ struct ThreadLocalResults { CallTypeHelper::tuple_type trace_keys_; AppendOnlyList exit_times_; AppendOnlyList c_exit_times_; + + int active_frames_{0}; + int remaining_start_frames_{0}; }; // ============================================================================ @@ -999,7 +1002,8 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) PyThreadState_Swap(thread_state); thread_local_results_.emplace_back(thread_state, &value_cache_, this); - auto* ctx = thread_local_results_.back().ctx_; + auto& tls = thread_local_results_.back(); + auto* ctx = tls.ctx_; // When we begin profiling there are already frames on the Python // interpreter stack. To ensure a complete trace, we must push calls @@ -1021,7 +1025,7 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) } for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { - recordPyCall(thread_local_results_.back(), it->get(), true); + recordPyCall(tls, it->get(), true); auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -1029,6 +1033,8 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount); } + tls.remaining_start_frames_ = tls.active_frames_; + // Note: // This profile will not compose with other CPython profilers, and // cannot be round tripped via `sys.settrace(sys.gettrace())` @@ -1141,6 +1147,7 @@ void PythonTracer::recordPyCall( const auto time = c10::getApproximateTime(); is_startup_frame ? start_frames_.push_back({key, time}) : queue_->getSubqueue()->emplace_py_call(key, time); + ++tls.active_frames_; } void PythonTracer::recordCCall( @@ -1160,6 +1167,7 @@ void PythonTracer::recordCCall( auto key = tls.intern( arg, (void*)(fn->m_ml), frame); queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); + ++tls.active_frames_; } // ============================================================================ @@ -1457,11 +1465,20 @@ int PythonTracer::pyProfileFn( case PyTrace_RETURN: local_results.exit_times_.emplace_back(c10::getApproximateTime()); + local_results.active_frames_--; + if (local_results.active_frames_ < + local_results.remaining_start_frames_) { + local_results.remaining_start_frames_ = local_results.active_frames_; + } break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: - local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); + if (local_results.active_frames_ > + local_results.remaining_start_frames_) { + local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); + local_results.active_frames_--; + } break; } return 0; diff --git a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp index e8cdbfbbe8c8..dc3c4889057c 100644 --- a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp @@ -20,7 +20,25 @@ class FakeWork : public Work { class FakeProcessGroup : public Backend { public: - FakeProcessGroup(int rank, int size) : Backend(rank, size) {} + struct Options : Backend::Options { + explicit Options() : Backend::Options("fake") {} + + int fake_option = 0; + }; + + FakeProcessGroup( + int rank, + int size, + c10::intrusive_ptr options = c10::make_intrusive()) + : Backend(rank, size), options_(std::move(options)) {} + + const std::string getBackendName() const override { + return "fake"; + } + + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } c10::intrusive_ptr broadcast( std::vector& /* tensors */, @@ -194,6 +212,9 @@ class FakeProcessGroup : public Backend { const BarrierOptions& /* opts */ = BarrierOptions()) override { return c10::make_intrusive(); } + + private: + c10::intrusive_ptr options_; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index f16a3eadcb25..655193e8f318 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1091,18 +1091,15 @@ ErrorType ProcessGroupNCCL::getError() { void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { const auto key = std::to_string(pool->device()); - auto device = at::Device(at::DeviceType::CUDA, pool->device()); LOG(INFO) << logPrefix() << "Performing NCCL user buffer registration for all buffers in " << "MemPool: " << pool->id() << ", device index: " << key << ", i am " << this; auto ncclComm = getNCCLComm(key); if (ncclComm == nullptr) { - // HACK: currently we are using this function for NVLS - // reductions, and that's why using OpType::ALLREDUCE. - // If we end up using this API for zero-copy P2P, we might - // need to refactor and account for different OpType. - ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE); + C10_THROW_ERROR( + DistBackendError, + "NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue"); } TORCH_INTERNAL_ASSERT(ncclComm != nullptr); { @@ -2284,6 +2281,10 @@ void ProcessGroupNCCL::Watchdog::runLoop() { // Work status logging for desync debug desyncDebugger_.logWorkStart(work); + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start @@ -2295,10 +2296,6 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; } - // allow watchdog to do an event query on a side thread - at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); - at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; - // Clean up completed work if (work.isCompleted()) { // In case user didn't call `work.wait()` with async collectives, @@ -5430,6 +5427,7 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto inputTensor = inputTensors.back(); + check_gpu_single_tensor(inputTensor); std::vector outputs; diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 0b4a2f956840..973197ded14f 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -4,7 +4,10 @@ #include #include #include +#include +#include #include +#include namespace c10d::control_plane { diff --git a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu index 3b7effb3a7d6..3049464d96ee 100644 --- a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu +++ b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu @@ -8,6 +8,7 @@ // Two warnings in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") #if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \ CUDA_VERSION >= 12000 @@ -37,6 +38,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") #include +C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 824f26414c9f..c39957c2e838 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3776,14 +3776,27 @@ such as `dist.all_reduce(tensor, async_op=True)`. auto fakeProcessGroup = intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>( - module, "FakeProcessGroup", backend) - .def( - py::init([](int rank, int size) { - return c10::make_intrusive<::c10d::FakeProcessGroup>( - rank, size); - }), - py::arg("rank"), - py::arg("world_size")); + module, "FakeProcessGroup", backend); + intrusive_ptr_class_<::c10d::FakeProcessGroup::Options>( + fakeProcessGroup, "Options", backendOptions) + .def(py::init()) + .def_readwrite( + "fake_option", &::c10d::FakeProcessGroup::Options::fake_option); + fakeProcessGroup + .def( + py::init([](int rank, + int size, + c10::intrusive_ptr<::c10d::FakeProcessGroup::Options> + options) { + return c10::make_intrusive<::c10d::FakeProcessGroup>( + rank, size, std::move(options)); + }), + py::arg("rank"), + py::arg("world_size"), + py::arg("options") = + c10::make_intrusive<::c10d::FakeProcessGroup::Options>()) + .def_property_readonly( + "options", &::c10d::FakeProcessGroup::getBackendOptions); auto fakeWork = intrusive_ptr_no_gil_destructor_class_<::c10d::FakeWork>( module, "FakeWork", work) diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index 52e4b9fe56ef..b23722ec384a 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -193,38 +193,50 @@ class SocketImpl { }; std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { - char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT + // It can be be very slow to repeatedly hit DNS resolution failure, but its + // very helpful to have DNS names in logs by default. So we try to use DNS but + // if we hit a transient failure we just disable it for the remainder of the + // job, logging IP addresses instead. See + // https://github.com/pytorch/pytorch/issues/159007 + static bool disable_getnameinfo = false; - if (int err = ::getnameinfo( - addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV)) { - C10D_WARNING( - "The hostname of the client socket cannot be retrieved. err={}", err); + char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT - // if we can't resolve the hostname, display the IP address + if (!disable_getnameinfo) { + int err = ::getnameinfo( + addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV); + if (err != 0) { + C10D_WARNING( + "The hostname of the client socket cannot be retrieved. err={}", err); + disable_getnameinfo = true; + } + } + // if getnameinfo failed, disable would be set + if (!disable_getnameinfo) { if (addr->sa_family == AF_INET) { - struct sockaddr_in* psai = (struct sockaddr_in*)&addr; - // NOLINTNEXTLINE(*array*) - char ip[INET_ADDRSTRLEN]; - if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != - nullptr) { - return fmt::format("{}:{}", ip, psai->sin_port); - } - } else if (addr->sa_family == AF_INET6) { - struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; - // NOLINTNEXTLINE(*array*) - char ip[INET6_ADDRSTRLEN]; - if (inet_ntop( - addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != - nullptr) { - return fmt::format("[{}]:{}", ip, psai->sin6_port); - } + return fmt::format("{}:{}", host, port); } - return "?UNKNOWN?"; + return fmt::format("[{}]:{}", host, port); } + // if we can't resolve the hostname, display the IP address if (addr->sa_family == AF_INET) { - return fmt::format("{}:{}", host, port); + struct sockaddr_in* psai = (struct sockaddr_in*)&addr; + // NOLINTNEXTLINE(*array*) + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != + nullptr) { + return fmt::format("{}:{}", ip, psai->sin_port); + } + } else if (addr->sa_family == AF_INET6) { + struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; + // NOLINTNEXTLINE(*array*) + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != + nullptr) { + return fmt::format("[{}]:{}", ip, psai->sin6_port); + } } - return fmt::format("[{}]:{}", host, port); + return "?UNKNOWN?"; } } // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index 44b1575cd800..b2f216335bb1 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -22,6 +22,18 @@ #define CUDART_SUPPORTS_MULTICAST #endif +// add these definitions so that we can compile with CUDA < 12.3 +// borrowed from +// https://github.com/NVIDIA/nccl/blob/3ea7eedf3b9b94f1d9f99f4e55536dfcbd23c1ca/src/include/p2p.h#L20 +#if CUDA_VERSION < 12030 +#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) +#define CU_IPC_HANDLE_SIZE 64 +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; +#endif + namespace c10d { namespace symmetric_memory { @@ -34,11 +46,13 @@ AllocationRef::AllocationRef( void* ptr, HandleType handle, size_t block_size, - int device_idx) + int device_idx, + bool is_multicast) : ptr(ptr), handle(handle), block_size(block_size), - device_idx(device_idx) {} + device_idx(device_idx), + is_multicast(is_multicast) {} AllocationRef::~AllocationRef() { if (is_finalizing()) { @@ -51,6 +65,10 @@ AllocationRef::~AllocationRef() { auto driver_api = c10::cuda::DriverAPI::get(); C10_CUDA_DRIVER_CHECK( driver_api->cuMemUnmap_(reinterpret_cast(ptr), block_size)); + if (is_multicast) { + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size)); + } C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle)); #elif defined(USE_ROCM) C10_HIP_CHECK(hipMemUnmap(reinterpret_cast(ptr), block_size)); @@ -400,6 +418,23 @@ void* CUDASymmetricMemoryAllocator::alloc( prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; // NOLINTNEXTLINE(bugprone-signed-char-misuse) prop.location.id = device_idx; + const auto driver_api = c10::cuda::DriverAPI::get(); + + if (handle_type_ == Expandable_Segments_Handle_Type::UNSPECIFIED) { + // Initialize NVML + if (driver_api->nvmlInit_v2_() == NVML_SUCCESS) { + // Get the driver version + int version = -1; + const auto res = driver_api->nvmlSystemGetCudaDriverVersion_v2_(&version); + if (res == NVML_SUCCESS) { + // Check if driver is sufficiently new + if (version < 12040) { + handle_type_ = Expandable_Segments_Handle_Type::POSIX_FD; + } + } + } + } + if (handle_type_ == Expandable_Segments_Handle_Type::POSIX_FD) { prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; } else { @@ -407,7 +442,6 @@ void* CUDASymmetricMemoryAllocator::alloc( } size_t granularity; - auto driver_api = c10::cuda::DriverAPI::get(); C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); block_size = at::round_up(block_size, granularity); @@ -769,6 +803,10 @@ c10::intrusive_ptr make_symm_mem( for (int r = 0; r < world_size; ++r) { if (r == rank) { alloc_refs.emplace_back(block->alloc_ref); + if (mc_addr != nullptr) { + alloc_refs.push_back(c10::make_intrusive( + mc_addr, mc_handle, block->block_size, block->device_idx, true)); + } continue; } alloc_refs.push_back(c10::make_intrusive( diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index a5340ffc9806..f61d8f9622a7 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -15,12 +15,14 @@ struct AllocationRef : public c10::intrusive_ptr_target { HandleType handle; size_t block_size; int device_idx; + bool is_multicast; AllocationRef( void* ptr, HandleType handle, size_t block_size, - int device_idx); + int device_idx, + bool is_multicast = false); ~AllocationRef(); }; diff --git a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu index 5a6b4ce8e81d..1c513c66fae6 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu @@ -155,7 +155,7 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { int rank, c10::IntArrayRef sizes, c10::ScalarType dtype, - int64_t storage_offset) { + int64_t storage_offset) override { // TODO: deduplicate const size_t numel = std::accumulate( sizes.begin(), diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index 7f97f3d38a33..55ebebb28e24 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -28,12 +28,20 @@ namespace c10d::nvshmem_extension { constexpr int MiB = 1024 * 1024; +extern "C" void nvshmem_init() __attribute__((weak)); + // Check if NVSHMEM is available bool is_nvshmem_available() { // Runtime check static std::mutex mutex; static int is_available = -2; std::lock_guard lock(mutex); + + // Checked if the symbol is statically linked + if(is_available == -2 && nvshmem_init) { + is_available = 1; + } + if (is_available == -2) { void* handle{}; // Open the shared library, RTLD_LAZY defers symbol resolution until needed diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 190752070250..c25e83c07c6d 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -8,10 +8,8 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") -C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() -C10_DIAGNOSTIC_POP() #include #include diff --git a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp index 03b43184d143..4c326b6a0e27 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp @@ -7,12 +7,10 @@ #include #include -C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") #include #include C10_DIAGNOSTIC_POP() -C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index f28aefc06dee..86308ae6cdf3 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -6,10 +6,8 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") -C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() -C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index da14865e584a..c8e0ae9c2736 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { static std::unordered_map dict_version_map; static int dict_version_watcher_id; +static int dict_recursive_tag_watcher_id; static uint64_t global_dict_version_id = 1; static int dict_version_watch_callback( PyDict_WatchEvent event, @@ -1042,7 +1043,8 @@ static void _parse_empty_strided_args( static PyObject* _empty_strided_device( PyObject* dummy, PyObject* args, - c10::DeviceType device_type) { + c10::DeviceType device_type, + bool is_pinned = false) { HANDLE_TH_ERRORS; at::SmallVector sizes; at::SmallVector strides; @@ -1050,7 +1052,7 @@ static PyObject* _empty_strided_device( _parse_empty_strided_args(args, sizes, strides, dtype); if (device_type == c10::DeviceType::CPU) { return THPVariable_Wrap( - at::detail::empty_strided_cpu(sizes, strides, dtype)); + at::detail::empty_strided_cpu(sizes, strides, dtype, is_pinned)); } #ifdef USE_CUDA else if (device_type == c10::DeviceType::CUDA) { @@ -1084,6 +1086,13 @@ static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) { return _empty_strided_device(dummy, args, c10::DeviceType::CPU); } +static PyObject* _empty_strided_cpu_pinned(PyObject* dummy, PyObject* args) { + // at::empty_strided is surprising slow. This is a lower-overhead + // version that saves ~2us on every allocation. + return _empty_strided_device( + dummy, args, c10::DeviceType::CPU, /*is_pinned=*/true); +} + static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) { // at::empty_strided is surprising slow. This is lower-overhead. return _empty_strided_device(dummy, args, c10::DeviceType::CUDA); @@ -1127,6 +1136,10 @@ static PyMethodDef _methods[] = { {"assert_alignment", assert_alignment, METH_VARARGS, nullptr}, {"dict_version", dict_version, METH_VARARGS, nullptr}, {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr}, + {"_empty_strided_cpu_pinned", + _empty_strided_cpu_pinned, + METH_VARARGS, + nullptr}, {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr}, {"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr}, {"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr}, @@ -1154,22 +1167,6 @@ std::string get_exception_message() { return exc_message; } -bool is_nn_module(py::handle example_value) { - py::object torch_module_cls = py::module_::import("torch.nn").attr("Module"); - return py::isinstance(example_value, torch_module_cls); -} - -std::string get_type_str(py::handle example_value) { - std::string type_name; - try { - type_name = py::str(py::type::of(example_value)).cast(); - } catch (const py::error_already_set& e) { - // Fallback that never throws in release builds - type_name = ""; - } - return type_name; -} - bool is_immutable_object(py::handle example_value) { py::object config_module = py::module_::import("torch._dynamo.config"); @@ -1561,6 +1558,37 @@ class GuardManager; class RootGuardManager; class DictGuardManager; +// Global registry used by the *recursive-dict-tag* optimisation. +// +// Key : `PyObject*` pointing to a watched `dict` +// Value : list of `GuardManager*` instances that have recorded that dict +// +// Why is this global? +// ------------------- +// * CPython allows only a small, fixed number of dict-watcher IDs (ā‰ˆ64). +// All `GuardManager`s therefore share a single watcher callback. +// * Different guard managers (possibly across different frames) can end up +// watching the same dictionary pointer. Therefore, we have a list of guard +// managers for each dict pointer. +// +// When is watch registered? +// * During the recording phase of recursive dict tag matching in GuardManager. +// +// When are they watched? +// * In the dict_recursive_tag_watch_callback function. +// +// When are the dict pointers unwatched? +// * If a dict is mutated or the guard manager deallocates. +// * Read `unwatch_all_saved_dict_pointers` docstring for more details. +// +// Expected size +// ------------- +// Every compilation frame contributes its tag-safe dicts to this registry, so +// the container can grow large over the lifetime of the process. That’s +// acceptable: lookup is by pointer (hash/equals = identity) and each entry +// stores only lightweight pointers. +std::unordered_map> dict_to_guard_managers; + /** * Base class for the leaf guard in the GuardManager hierarchy. */ @@ -2611,15 +2639,13 @@ class GuardManager { : _root(root), _source(std::move(source)), _is_dict(py::isinstance(example_value)), - _is_immutable(is_immutable_object(example_value)), - _is_nn_module(is_nn_module(example_value)), - _is_tensor(THPVariable_Check(example_value.ptr())), - _type_str(get_type_str(example_value)) { + _is_immutable(is_immutable_object(example_value)) { if (_is_dict) { _dict_tag = get_dict_version_unchecked(example_value.ptr()); - _is_empty_dict = PyDict_Size(example_value.ptr()) == 0; } - + py::object typ = py::type::of(example_value); + py::object weakref_mod = py::module_::import("weakref"); + _weak_type = weakref_mod.attr("ref")(typ); py::object config_module = py::module_::import("torch._dynamo.config"); _max_saved_pointers_for_recursive_dict_tags_check = config_module.attr("max_saved_pointers_for_recursive_dict_tags_check") @@ -2631,6 +2657,7 @@ class GuardManager { virtual ~GuardManager() { cleanup_tag_safe_entries(); + disable_recursive_dict_tag_optimization(); } void cleanup_tag_safe_entries() { @@ -2681,28 +2708,19 @@ class GuardManager { return _is_immutable; } - bool is_guarded_value_nn_module() { - return _is_nn_module; - } - - bool is_guarded_value_dict() { - return _is_dict; - } - - bool is_guarded_value_empty_dict() { - return _is_empty_dict; - } - - bool is_guarded_value_tensor() { - return _is_tensor; + bool is_recursive_dict_tag_matching_disabled() { + return _disable_dict_tag_matching; } - std::string type_of_guarded_value() { - return _type_str; - } + py::object get_type_of_guarded_value() { + if (!_weak_type || _weak_type.is_none()) { + return py::type::of(py::none()); + } - bool is_recursive_dict_tag_matching_disabled() { - return _disable_dict_tag_matching; + if (!PyCallable_Check(_weak_type.ptr())) { + throw std::runtime_error("_weak_type is not callable"); + } + return _weak_type(); } public: @@ -2742,25 +2760,24 @@ class GuardManager { _tensor_pointers[value] = tensor_pointers; } + void disable_recursive_dict_tag_optimization() { + unwatch_all_saved_dict_pointers(); + _disable_dict_tag_matching = true; + } + public: // For cloning GuardManager( RootGuardManager* root, std::string source, bool is_dict, - bool is_empty_dict, bool is_immutable, - bool is_nn_module, - bool is_tensor, - std::string type_str) + py::object weak_type) : _root(root), _source(std::move(source)), _is_dict(is_dict), - _is_empty_dict(is_empty_dict), _is_immutable(is_immutable), - _is_nn_module(is_nn_module), - _is_tensor(is_tensor), - _type_str(std::move(type_str)) {} + _weak_type(weak_type) {} void clone_common( RootGuardManager* cloned_root, @@ -2792,14 +2809,7 @@ class GuardManager { return nullptr; } GuardManager* cloned_mgr = new GuardManager( - cloned_root, - _source, - _is_dict, - _is_empty_dict, - _is_immutable, - _is_nn_module, - _is_tensor, - _type_str); + cloned_root, _source, _is_dict, _is_immutable, _weak_type); if (is_tag_safe()) { cloned_mgr->mark_tag_safe(); if (is_tag_safe_root()) { @@ -2861,6 +2871,10 @@ class GuardManager { } bool check_dict_pointer_tags(PyObject* value) { + if (_dict_callback_installed) { + // This means that for 3.12+, there are callbacks watching dict pointers. + return true; + } for (auto& kv : _dict_pointers[value]) { PyObject* dict_pointer = kv.first; uint64_t old_tag = kv.second; @@ -2975,7 +2989,7 @@ class GuardManager { // This is a tag safe node, record the dict pointer if (_is_dict) { record_dict_pointer(_root, value); - } else if (_is_tensor && _has_no_tensor_aliasing_guard) { + } else if (_has_no_tensor_aliasing_guard) { record_tensor_pointer(_root, value); } } @@ -2991,6 +3005,11 @@ class GuardManager { throw std::runtime_error( "Could not register a callback for recursive dict tag optimization"); } +#if IS_PYTHON_3_12_PLUS + // Ideally we don't need to even register a weakref callback for value. + // But it does not hurt to be more cautious + _dict_callback_installed = watch_dict_pointers(value); +#endif } } if (!result) { @@ -3007,8 +3026,9 @@ class GuardManager { } GuardManager* guard_manager = static_cast( PyCapsule_GetPointer(self_capsule, "GuardManager*")); - if (guard_manager) - guard_manager->_disable_dict_tag_matching = true; + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(); + } Py_RETURN_NONE; } @@ -3059,6 +3079,81 @@ class GuardManager { return true; } + bool watch_dict_pointers(PyObject* value) { +#if IS_PYTHON_3_12_PLUS + // ----------------------------------------------------------------------------- + // CPython 3.12 dict-watcher integration + // ----------------------------------------------------------------------------- + // + // We register a single watcher on all every dictionary pointer recorded by + // a tag-safe root. The watcher callback fires *once* for any structural + // change to those dictionaries + // + // Fast-path benefit + // ----------------- + // In steady state we no longer need to iterate over the recorded + // dictionaries and compare their `ma_version_tag`s (the + // ā€œare-tags-unchangedā€ loop that used to dominate the fast-path guard + // evaluation). The presence of an *active watcher* is itself a guarantee + // that none of the dicts has mutated; if one **does** mutate, the callback + // simply flips `_disable_dict_tag_matching = true`, causing the next guard + // evaluation to skip the recursive-dict-tag optimisation entirely. + for (auto& kv : _dict_pointers[value]) { + PyObject* dict_pointer = kv.first; + int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer); + if (rc != 0) { + PyErr_Clear(); + return false; + } + dict_to_guard_managers[dict_pointer].push_back(this); + } +#endif + return true; + } + + void unwatch_all_saved_dict_pointers() { + /* + We may have recorded hundreds/thousands of dict pointers for the recursive + dict-tag optimisation. If any of those dicts mutates, we want to disable the + optimisation and then unwatch as many dict pointers as we can. + + Be careful: the same dict pointer can be recorded by multiple GuardManagers. + So the flow is: + + 1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer]. + 2) If the list for that dict becomes empty, then: + - PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer) + - erase the dict_pointer entry from dict_to_guard_managers. + */ +#if IS_PYTHON_3_12_PLUS + if (!_disable_dict_tag_matching) { + for (auto& value_stashed_pointers : _dict_pointers) { + auto stashed_pointers = value_stashed_pointers.second; + + for (auto& stashed_pointer : stashed_pointers) { + PyObject* dict_pointer = stashed_pointer.first; + + // Delete the guard manager from the dict_to_guard_managers + auto it = std::find( + dict_to_guard_managers[dict_pointer].begin(), + dict_to_guard_managers[dict_pointer].end(), + this); + if (it != dict_to_guard_managers[dict_pointer].end()) { + dict_to_guard_managers[dict_pointer].erase(it); + } + + // Unwatch the dict pointer if this was the last guard manager + // watching it. + if (dict_to_guard_managers[dict_pointer].empty()) { + PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer); + dict_to_guard_managers.erase(dict_pointer); + } + } + } + } +#endif + } + virtual bool check_nopybind(FrameLocalsMapping* value) { return check_nopybind_template(value); } @@ -3285,11 +3380,7 @@ class GuardManager { bool _has_no_tensor_aliasing_guard = false; bool _is_dict = false; - bool _is_empty_dict = false; bool _is_immutable = false; - bool _is_nn_module = false; - bool _is_tensor = false; - std::string _type_str; uint64_t _dict_tag{0}; uint64_t _max_saved_pointers_for_recursive_dict_tags_check = 0; @@ -3301,6 +3392,14 @@ class GuardManager { _dict_pointers; std::unordered_map> _tensor_pointers; std::vector _tag_safe_entries; + + // 3.12+ related helper + bool _dict_callback_installed = false; + + protected: + // weakref to the type of guarded value + // protected because it is used for cloning by DictGuardManager + py::object _weak_type; }; GuardAccessor::GuardAccessor( @@ -3873,17 +3972,13 @@ class DictGuardManager : public GuardManager { PyTypeObject* expected_type, bool is_exact_dict_type, std::vector indices, - std::string type_of, - bool is_empty_dict) + py::object weak_type) : GuardManager( cloned_root, std::move(source), true, // _is_dict - is_empty_dict, false, // _is_immutable - false, // _is_nn_module - false, // _is_tensor - std::move(type_of)), + weak_type), _size(size), _expected_type(expected_type), _is_exact_dict_type(is_exact_dict_type), @@ -3903,8 +3998,7 @@ class DictGuardManager : public GuardManager { _expected_type, _is_exact_dict_type, _indices, - type_of_guarded_value(), - is_guarded_value_empty_dict()); + _weak_type); if (is_tag_safe()) { cloned_mgr->mark_tag_safe(); if (is_tag_safe_root()) { @@ -3989,6 +4083,27 @@ void add_relational_guard_resetter_to_cloned_root( root->add_relational_guard_resetter(std::move(guard)); } +#if IS_PYTHON_3_12_PLUS +static int dict_recursive_tag_watch_callback( + PyDict_WatchEvent event, + PyObject* dict, + PyObject* key, + PyObject* new_value) noexcept { + if (event != PyDict_EVENT_CLONED) { + auto it = dict_to_guard_managers.find(dict); + if (it != dict_to_guard_managers.end()) { + auto guard_managers = it->second; + for (auto& guard_manager : guard_managers) { + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(); + } + } + } + } + return 0; // keep watching +} +#endif + std::unique_ptr make_guard_manager( RootGuardManager* root, std::string source, @@ -4288,6 +4403,10 @@ class GetAttrGuardAccessor : public GuardAccessor { ")"; } + std::string get_attr_name() { + return py::str(_attr_name).cast(); + } + public: // cloning functions GetAttrGuardAccessor(GuardManager* guard_manager, GetAttrGuardAccessor* from) : GuardAccessor(guard_manager, from) { @@ -5511,6 +5630,118 @@ class TypeGuardAccessor : public GuardAccessor { void clone_visitor(TypeGuardAccessor* to) {} }; +/** + * Represent x.__dict__ accessor, where x is type object. + */ +class TypeDictGuardAccessor : public GuardAccessor { + public: + // name = __type_dict_accessor__, a unique string used as attribute name. + TypeDictGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref + if (x == nullptr) { + return false; + } + return _guard_manager->check_nopybind(x); + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref + if (x == nullptr) { + return GuardDebugInfo(false, "null type dict on " + repr(), 0); + } + return _guard_manager->check_verbose_nopybind(x); + } + + std::string repr() const override { + return "TypeDictGuardAccessor"; + } + + public: // cloning functions + TypeDictGuardAccessor( + GuardManager* guard_manager, + TypeDictGuardAccessor* from) + : GuardAccessor(guard_manager, from) { + from->clone_visitor(this); + } + + GuardAccessor* clone( + RootGuardManager* cloned_root, + const py::function& clone_filter_fn) override { + return clone_common(cloned_root, clone_filter_fn); + } + + void clone_visitor(TypeDictGuardAccessor* to) {} +}; + +/** + * Represent x.__mro__ accessor, where x is type object. + */ +class TypeMROGuardAccessor : public GuardAccessor { + public: + // name = __type_mro_accessor__, a unique string used as attribute name. + TypeMROGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + PyObject* x = ((PyTypeObject*)obj)->tp_mro; // borrowed ref + return _guard_manager->check_nopybind(x); + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + PyObject* x = ((PyTypeObject*)obj)->tp_mro; // borrowed ref + return _guard_manager->check_verbose_nopybind(x); + } + + std::string repr() const override { + return "TypeMROGuardAccessor"; + } + + public: // cloning functions + TypeMROGuardAccessor(GuardManager* guard_manager, TypeMROGuardAccessor* from) + : GuardAccessor(guard_manager, from) { + from->clone_visitor(this); + } + + GuardAccessor* clone( + RootGuardManager* cloned_root, + const py::function& clone_filter_fn) override { + return clone_common(cloned_root, clone_filter_fn); + } + + void clone_visitor(TypeMROGuardAccessor* to) {} +}; + /** * Getitem tuple_iterator accessor. */ @@ -5786,6 +6017,158 @@ class WeakRefCallGuardAccessor : public GuardAccessor { void clone_visitor(WeakRefCallGuardAccessor* to) {} }; +/** + * Represent x.__code__ + */ +class CodeGuardAccessor : public GuardAccessor { + public: + // name = __type_mro_accessor__, a unique string used as attribute name. + CodeGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + PyObject* func = obj; + if (PyMethod_Check(obj)) { + func = PyMethod_GET_FUNCTION(obj); // borrowed ref + } else if (PyInstanceMethod_Check(obj)) { + func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref + } + PyObject* x = PyFunction_GetCode(func); // borrowed ref + if (x == nullptr) { + PyErr_Clear(); + return false; + } + return _guard_manager->check_nopybind(x); + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + PyObject* func = obj; + if (PyMethod_Check(obj)) { + func = PyMethod_GET_FUNCTION(obj); // borrowed ref + } else if (PyInstanceMethod_Check(obj)) { + func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref + } + PyObject* x = PyFunction_GetCode(func); + if (x == nullptr) { + PyErr_Clear(); + return GuardDebugInfo( + false, + std::string(repr() + ": Not a function on ") + get_source(), + 0); + } + + return _guard_manager->check_verbose_nopybind(x); + } + + std::string repr() const override { + return "CodeGuardAccessor"; + } + + public: // cloning functions + CodeGuardAccessor(GuardManager* guard_manager, CodeGuardAccessor* from) + : GuardAccessor(guard_manager, from) { + from->clone_visitor(this); + } + + GuardAccessor* clone( + RootGuardManager* cloned_root, + const py::function& clone_filter_fn) override { + return clone_common(cloned_root, clone_filter_fn); + } + + void clone_visitor(CodeGuardAccessor* to) {} +}; + +/** + * Represent x.__closure__ + */ +class ClosureGuardAccessor : public GuardAccessor { + public: + // name = __type_mro_accessor__, a unique string used as attribute name. + ClosureGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + PyObject* func = obj; + if (PyMethod_Check(obj)) { + func = PyMethod_GET_FUNCTION(obj); // borrowed ref + } else if (PyInstanceMethod_Check(obj)) { + func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref + } + PyObject* x = PyFunction_GetClosure(func); // borrowed ref + if (x == nullptr) { + PyErr_Clear(); + return false; + } + return _guard_manager->check_nopybind(x); + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + PyObject* func = obj; + if (PyMethod_Check(obj)) { + func = PyMethod_GET_FUNCTION(obj); // borrowed ref + } else if (PyInstanceMethod_Check(obj)) { + func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref + } + PyObject* x = PyFunction_GetClosure(func); + if (x == nullptr) { + PyErr_Clear(); + return GuardDebugInfo( + false, + std::string(repr() + ": Not a function on ") + get_source(), + 0); + } + + return _guard_manager->check_verbose_nopybind(x); + } + + std::string repr() const override { + return "ClosureGuardAccessor"; + } + + public: // cloning functions + ClosureGuardAccessor(GuardManager* guard_manager, ClosureGuardAccessor* from) + : GuardAccessor(guard_manager, from) { + from->clone_visitor(this); + } + + GuardAccessor* clone( + RootGuardManager* cloned_root, + const py::function& clone_filter_fn) override { + return clone_common(cloned_root, clone_filter_fn); + } + + void clone_visitor(ClosureGuardAccessor* to) {} +}; + /** * Implements function call no args - e.g, torch.cuda.current_device() */ @@ -6364,11 +6747,11 @@ PyObject* torch_c_dynamo_guards_init() { py::class_>( py_m, "GuardAccessor") .def("repr", &GuardAccessor::repr); - // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< GetAttrGuardAccessor, GuardAccessor, - std::unique_ptr>(py_m, "GetAttrGuardAccessor"); + std::unique_ptr>(py_m, "GetAttrGuardAccessor") + .def("get_attr_name", &GetAttrGuardAccessor::get_attr_name); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< GenericGetAttrGuardAccessor, @@ -6433,6 +6816,16 @@ PyObject* torch_c_dynamo_guards_init() { GuardAccessor, std::unique_ptr>(py_m, "TypeGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + TypeDictGuardAccessor, + GuardAccessor, + std::unique_ptr>(py_m, "TypeDictGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + TypeMROGuardAccessor, + GuardAccessor, + std::unique_ptr>(py_m, "TypeMROGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< WeakRefCallGuardAccessor, GuardAccessor, @@ -6451,6 +6844,16 @@ PyObject* torch_c_dynamo_guards_init() { std::unique_ptr>( py_m, "TupleIteratorGetItemAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + CodeGuardAccessor, + GuardAccessor, + std::unique_ptr>(py_m, "CodeGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + ClosureGuardAccessor, + GuardAccessor, + std::unique_ptr>(py_m, "ClosureGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< GlobalWeakRefGuardAccessor, GuardAccessor, @@ -6468,23 +6871,16 @@ PyObject* torch_c_dynamo_guards_init() { .def( "is_guarded_value_immutable", &GuardManager::is_guarded_value_immutable) - .def( - "is_guarded_value_nn_module", - &GuardManager::is_guarded_value_nn_module) - .def("is_guarded_value_dict", &GuardManager::is_guarded_value_dict) - .def( - "is_guarded_value_empty_dict", - &GuardManager::is_guarded_value_empty_dict) - .def("is_guarded_value_tensor", &GuardManager::is_guarded_value_tensor) .def("has_no_accessors", &GuardManager::has_no_accessors) .def("mark_tag_safe", &GuardManager::mark_tag_safe) .def("mark_tag_safe_root", &GuardManager::mark_tag_safe_root) .def("is_tag_safe", &GuardManager::is_tag_safe) .def("is_tag_safe_root", &GuardManager::is_tag_safe_root) - .def("type_of_guarded_value", &GuardManager::type_of_guarded_value) .def( "is_recursive_dict_tag_matching_disabled", &GuardManager::is_recursive_dict_tag_matching_disabled) + .def( + "get_type_of_guarded_value", &GuardManager::get_type_of_guarded_value) .def( "get_accessors", &GuardManager::get_accessors, @@ -6913,6 +7309,46 @@ PyObject* torch_c_dynamo_guards_init() { py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers + .def( + "type_dict_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__type_dict_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers + .def( + "type_mro_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__type_mro_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers .def( "weakref_call_manager", [](GuardManager& self, @@ -6971,6 +7407,46 @@ PyObject* torch_c_dynamo_guards_init() { py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers + .def( + "code_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__code_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers + .def( + "closure_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__closure_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers .def( "global_weakref_manager", &GuardManager::get_child_manager, @@ -7229,6 +7705,13 @@ PyObject* torch_c_dynamo_guards_init() { throw std::runtime_error("Failed to install dict_version_watch_callback"); } + dict_recursive_tag_watcher_id = + PyDict_AddWatcher(dict_recursive_tag_watch_callback); + if (dict_recursive_tag_watcher_id == -1) { + throw std::runtime_error( + "Failed to install dict_recursive_tag_watch_callback"); + } + #endif return m; diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index b6c009805c71..8d1dd116afe5 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -12,6 +12,7 @@ // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +#include #if defined(__GNUC__) || defined(__clang__) #define AOTI_NOINLINE __attribute__((noinline)) @@ -21,27 +22,18 @@ #define AOTI_NOINLINE #endif -AOTI_NOINLINE static void throw_exception( - const char* call, - const char* file, - int64_t line) { - std::stringstream ss; - ss << call << " API call failed at " << file << ", line " << line; - throw std::runtime_error(ss.str()); -} - -#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ - if ((call) != AOTI_TORCH_SUCCESS) { \ - throw_exception(#call, __FILE__, __LINE__); \ +#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_TORCH_SUCCESS) { \ + torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ } using AOTIRuntimeError = int32_t; #define AOTI_RUNTIME_SUCCESS 0 #define AOTI_RUNTIME_FAILURE 1 -#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ - if ((call) != AOTI_RUNTIME_SUCCESS) { \ - throw_exception(#call, __FILE__, __LINE__); \ +#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_RUNTIME_SUCCESS) { \ + torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ } namespace torch::aot_inductor { diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 9d512ce1f481..b1446318dd34 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle); @@ -267,6 +270,16 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( AtenTensorHandle* ret_new_tensor // returns new reference ); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided_pinned( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor // returns new reference +); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_as_strided( AtenTensorHandle self, const int64_t* sizes_ptr, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h index cc2dcdf4c75e..d5bc50750fc7 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h @@ -15,6 +15,8 @@ extern "C" { #endif AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); #ifdef __cplusplus } // extern "C" diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 92d30ded855f..470919cf389c 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -51,6 +51,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTenso AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_abs(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index c486888877a6..56bd07115858 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -39,6 +39,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, Ate AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); @@ -69,6 +70,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices(AtenTens AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_median(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 6fc51bd0c8f8..09ebbb76d0b2 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -14,6 +14,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index a33198fd1ba0..868da9831e76 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous( }); } +AOTITorchError aoti_torch_is_defined( + AtenTensorHandle tensor, + bool* ret_is_defined) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); + *ret_is_defined = t->defined(); + }); +} + AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle) { @@ -452,6 +461,28 @@ AOTITorchError aoti_torch_empty_strided( }); } +AOTITorchError aoti_torch_empty_strided_pinned( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::IntArrayRef sizes(sizes_ptr, ndim); + c10::IntArrayRef strides(strides_ptr, ndim); + TORCH_CHECK( + c10::DeviceType(device_type) == c10::DeviceType::CPU, + "only CPU tensors can be pinned"); + *ret_new_tensor = new_tensor_handle(at::detail::empty_strided_cpu( + sizes, + strides, + static_cast(dtype), + /*is_pinned=*/true)); + }); +} + AOTITorchError aoti_torch_create_tensor_from_blob( void* data, int64_t ndim, @@ -1182,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) { if (msg) { std::cout << " " << msg; } - std::cout << " " - << "]:" << '\n'; + std::cout << " " << "]:" << '\n'; // Print exact tensor values for small size tensors const int64_t numel = t->numel(); diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index 5a3ef9865b7c..f98da60a1049 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -74,6 +74,65 @@ template struct IsVecMaskType> : std::true_type {}; #endif +template +struct CascadeSumHelper { + // A data struct to help cascade summation: + std::vector sum_stk{}; + uint64_t depth{0}; // depth of sum_stk. + uint64_t num_chunks{0}; // number of chunks stored in sum_stk. + uint64_t index{0}; // index of the current data. + CascadeSumHelper() = default; + CascadeSumHelper(uint64_t N) { + uint64_t m = (N + kChunkSize - 1) / kChunkSize; // div up + depth = m > 0 + ? static_cast(ceil(log2(static_cast(m)))) + : 0; + if constexpr (IsVecType::value) { + sum_stk.assign( + std::max(depth, static_cast(1)), + T(typename T::value_type(0))); + } else { + sum_stk.assign(std::max(depth, static_cast(1)), T(0)); + } + } +}; + +template +inline T cascade_sum_combine(T& data, CascadeSumHelper* c) { + // Note: In order to be consistent with other reductions in inductor, + // the returned value may be wrong and cascade_sum_final must be executed to + // get the final correct result. Inductor uses the reduction suffix to ensure + // that cascade_sum_final is called in the end. + c->sum_stk[0] = c->sum_stk[0] + data; + // Use cascade summation to improve numerical stability. + // https://en.wikipedia.org/wiki/Pairwise_summation + if (c->depth > 0) { + c->index++; + if (c->index == kChunkSize) { + c->num_chunks += 1; + c->index = 0; + uint64_t mask = c->num_chunks; + uint64_t j = 1; + for (; j < c->depth && (mask & 1) == 0; ++j) { + c->sum_stk[j] = c->sum_stk[j] + c->sum_stk[j - 1]; + c->sum_stk[j - 1] = T(0); + mask >>= 1; + } + return c->sum_stk[j - 1]; + } + } + return c->sum_stk[0]; +} + +template +inline T cascade_sum_final(CascadeSumHelper* c) { + T result = c->sum_stk[0]; + for (const auto i : c10::irange(1, c->depth)) { + result = result + c->sum_stk[i]; + } + return result; +} + template struct WelfordHelper { // A data struct to help welford reduction: @@ -211,6 +270,31 @@ Welford welford_combine( out.index}; } +template +inline T cascade_sum_combine( + T& data, + int64_t tail_size, + CascadeSumHelper* c) { + auto out = c->sum_stk[0] + data; + c->sum_stk[0] = T::set(c->sum_stk[0], out, tail_size); + if (c->depth > 0) { + c->index++; + if (c->index == kChunkSize) { + c->num_chunks += 1; + c->index = 0; + uint64_t mask = c->num_chunks; + uint64_t j = 1; + for (; j < c->depth && (mask & 1) == 0; ++j) { + c->sum_stk[j] = c->sum_stk[j] + c->sum_stk[j - 1]; + c->sum_stk[j - 1] = T(0); + mask >>= 1; + } + return c->sum_stk[j - 1]; + } + } + return c->sum_stk[0]; +} + template T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = at::vec::maximum(a, b); diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index ff2ef1f2377c..9728d27d4d79 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -12,7 +12,7 @@ cases*/ static constexpr auto bfloat16_type_string = "__nv_bfloat16"; -#if defined(USE_ROCM) +#if defined(USE_ROCM) && ROCM_VERSION < 70000 static auto type_declarations_template = at::jit::CodeTemplate(R"( ${HalfHeader} ${BFloat16Header} diff --git a/torch/csrc/jit/frontend/error_report.cpp b/torch/csrc/jit/frontend/error_report.cpp index d642746abaaa..d5a8408e971c 100644 --- a/torch/csrc/jit/frontend/error_report.cpp +++ b/torch/csrc/jit/frontend/error_report.cpp @@ -6,7 +6,34 @@ namespace torch::jit { // Avoid storing objects with destructor in thread_local for mobile build. #ifndef C10_MOBILE -static thread_local std::vector calls; +// [NOTE: Thread-safe CallStack] +// `calls` maintains a stack of Python calls that resulted in the +// currently compiled TorchScript code. RAII ErrorReport::CallStack +// push and pop from the `calls` object during compilation to track +// these stacks so that they can be used to report compilation errors +// +// Q: Why can't this just be a thread_local vector (as it was previously)? +// +// A: Sometimes a CallStack RAII guard is created in Python in a given +// thread (say, thread A). Then later, someone can call +// sys._current_frames() from another thread (thread B), which causes +// thread B to hold references to the CallStack guard. e.g. +// 1. CallStack RAII guard created by thread A +// 2. CallStack guard now has a reference from thread B +// 3. thread A releases guard, but thread B still holds a reference +// 4. thread B releases guard, refcount goes to 0, and we +// call the destructor +// under this situation, **we pop an element off the wrong `call` +// object (from the wrong thread!) +// +// To fix this: +// * in CallStack, store a reference to which thread's `calls` +// the CallStack corresponds to, so you can pop from the correct +// `calls` object. +// * make it a shared_ptr and add a mutex to make this thread safe +// (since now multiple threads access a given thread_local calls object) +static thread_local std::shared_ptr calls = + std::make_shared(); #endif // C10_MOBILE ErrorReport::ErrorReport(const ErrorReport& e) @@ -17,20 +44,23 @@ ErrorReport::ErrorReport(const ErrorReport& e) #ifndef C10_MOBILE ErrorReport::ErrorReport(const SourceRange& r) - : context(r), error_stack(calls.begin(), calls.end()) {} + : context(r), error_stack(calls->get_stack()) {} void ErrorReport::CallStack::update_pending_range(const SourceRange& range) { - calls.back().caller_range = range; + calls->update_pending_range(range); } ErrorReport::CallStack::CallStack( const std::string& name, const SourceRange& range) { - calls.push_back({name, range}); + source_callstack_ = calls; + source_callstack_->push_back({name, range}); } ErrorReport::CallStack::~CallStack() { - calls.pop_back(); + if (source_callstack_) { + source_callstack_->pop_back(); + } } #else // defined C10_MOBILE ErrorReport::ErrorReport(const SourceRange& r) : context(r) {} @@ -61,7 +91,7 @@ static std::string get_stacked_errors(const std::vector& error_stack) { std::string ErrorReport::current_call_stack() { #ifndef C10_MOBILE - return get_stacked_errors(calls); + return get_stacked_errors(calls->get_stack()); #else TORCH_CHECK(false, "Call stack not supported on mobile"); #endif // C10_MOBILE diff --git a/torch/csrc/jit/frontend/error_report.h b/torch/csrc/jit/frontend/error_report.h index 635dd35468e3..9f5ad9bf3bb6 100644 --- a/torch/csrc/jit/frontend/error_report.h +++ b/torch/csrc/jit/frontend/error_report.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace torch::jit { @@ -18,6 +19,38 @@ struct TORCH_API ErrorReport : public std::exception { const char* what() const noexcept override; + class TORCH_API Calls { + private: + std::vector calls_; + mutable std::mutex mutex_; + + public: + void push_back(Call call) { + std::lock_guard lock(mutex_); + calls_.push_back(std::move(call)); + } + + void pop_back() { + std::lock_guard lock(mutex_); + calls_.pop_back(); + } + + bool empty() const { + std::lock_guard lock(mutex_); + return calls_.empty(); + } + + void update_pending_range(const SourceRange& range) { + std::lock_guard lock(mutex_); + calls_.back().caller_range = range; + } + + std::vector get_stack() const { + std::lock_guard lock(mutex_); + return calls_; + } + }; + struct TORCH_API CallStack { // These functions are used to report why a function was being compiled // (i.e. what was the call stack of user functions at compilation time that @@ -28,6 +61,9 @@ struct TORCH_API ErrorReport : public std::exception { // Change the range that is relevant for the current function (i.e. after // each successful expression compilation, change it to the next expression) static void update_pending_range(const SourceRange& range); + + private: + std::shared_ptr source_callstack_; }; static std::string current_call_stack(); diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 5f1a3e798bf9..0e9f0c9c2178 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -359,8 +359,8 @@ void SimpleValue::setAttr( throw( ErrorReport(loc) << "Assignment to attribute '" << field - << "' cannot be of a type that contains class " - << "'" << classType->repr_str() << "'.\n" + << "' cannot be of a type that contains class " << "'" + << classType->repr_str() << "'.\n" << "Classes that recursively contain instances of themselves" << " are not yet supported"); } @@ -826,4 +826,82 @@ SugaredValuePtr SugaredEnumClass::iter( return enum_values_list_constant; } +std::shared_ptr TorchCheckValue::call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) { + if (args.size() + kwargs.size() < 1 || args.size() + kwargs.size() > 2) { + throw( + ErrorReport(loc) << "torch._check() expects 1 or 2 arguments, got " + << (args.size() + kwargs.size())); + } + + NamedValue* cond_arg = nullptr; + NamedValue* message_arg = nullptr; + bool found_cond_kwarg = false; + bool found_message_kwarg = false; + + for (const auto& kwarg : kwargs) { + if (kwarg.name() == "cond") { + if (found_cond_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'cond'"); + } + cond_arg = const_cast(&kwarg); + found_cond_kwarg = true; + } else if (kwarg.name() == "message") { + if (found_message_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'message'"); + } + message_arg = const_cast(&kwarg); + found_message_kwarg = true; + } else { + throw( + ErrorReport(loc) << "torch._check() got unexpected keyword argument '" + << kwarg.name() << "'"); + } + } + + if (args.size() >= 1) { + if (found_cond_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'cond'"); + } + cond_arg = const_cast(&args[0]); + } + + if (args.size() >= 2) { + if (found_message_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'message'"); + } + message_arg = const_cast(&args[1]); + } + + if (!cond_arg) { + throw( + ErrorReport(loc) << "torch._check() missing required argument 'cond'"); + } + + std::vector assert_args; + assert_args.push_back(*cond_arg); + + if (message_arg) { + assert_args.push_back(*message_arg); + } else { + Value* default_msg = insertConstant(*m.graph(), std::string(""), loc); + assert_args.emplace_back(loc, "message", default_msg); + } + + emitBuiltinCall(loc, *m.graph(), Symbol::aten("_assert"), assert_args, {}); + return std::make_shared(); +} + } // namespace torch::jit diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index d88e77b16cd1..59ddea774d5d 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -136,8 +136,7 @@ struct TORCH_API SugaredValue // Value * virtual Value* len(const SourceRange& loc, GraphFunction& m) { throw( - ErrorReport(loc) << "'" << kind() << "'" - << " object is not iterable"); + ErrorReport(loc) << "'" << kind() << "'" << " object is not iterable"); } // expression for ith element for iterable value @@ -858,4 +857,19 @@ struct TORCH_API SliceValue : public SugaredValue { Value* step_; }; +struct TORCH_API TorchCheckValue : public SugaredValue { + explicit TorchCheckValue() = default; + + std::string kind() const override { + return "torch._check sugared value"; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; +}; + } // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 9cf12ffde38a..0ac07adf0d45 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -76,8 +76,8 @@ static std::optional runTorchSlice_opset9( if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { return std::nullopt; } - auto startsAttr = node->is(attr::starts); - auto endsAttr = node->is(attr::ends); + auto const& startsAttr = node->is(attr::starts); + auto const& endsAttr = node->is(attr::ends); if (startsAttr.size() != endsAttr.size()) { return std::nullopt; } diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index ece03b19e961..32c0e1b77c2c 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -216,7 +216,7 @@ void FunctionExtractor::FunctionContext::SetAttrName( TORCH_INTERNAL_ASSERT( v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end()); auto* n_in_def = v_it->second->node(); - auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; + node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; } std::optional FunctionExtractor::FunctionContext::FindAttrName( @@ -405,7 +405,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { auto common_ancestor = FindCommonAncestor(scopes); if (common_ancestor.has_value() && IsValidScope(common_ancestor.value())) { - return common_ancestor.value(); + return common_ancestor; } } } diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 73106ba0ef3c..71595b769ac1 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -35,8 +35,8 @@ static bool isRNN(const Node* node) { } static bool isNopTranspose(const std::vector& perm) { - for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) { - if (perm[i] != i) { + for (size_t i = 0, perm_size = perm.size(); i < perm_size; i++) { + if (perm[i] != static_cast(i)) { return false; } } diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 7a28f1e41c1b..966388278a32 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -10,8 +10,6 @@ #include -#include - namespace torch::jit { namespace { @@ -344,7 +342,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { auto it = std::find(node->inputs().begin(), node->inputs().end(), input); if (it != node->inputs().end()) { - int index = std::distance(node->inputs().begin(), it); + auto index = std::distance(node->inputs().begin(), it); TORCH_WARN( "ONNX Preprocess - Removing mutation from node ", node->kind().toQualString(), diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 086e50ae6a7a..452b18f3efc3 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -282,7 +282,7 @@ Value* CloneValueFromListConstruct( auto input = n_graph->addInput(); if (scalar_type) { auto v_type = TensorType::create( - scalar_type.value(), + scalar_type, at::kCPU, c10::SymbolicShape(), c10::VaryingShape{}, @@ -411,7 +411,9 @@ void ConvertGraphToONNXProto( } } -std::optional ComputeConstantFolding(Node* n, int opset_version) { +std::optional ComputeConstantFolding( + const Node* n, + int opset_version) { if (n->inputs().empty()) { return std::nullopt; } @@ -463,7 +465,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( auto it_0 = std::find_if(shape_vector.begin(), shape_vector.end(), is_zero); bool shape_has_zero = it_0 != shape_vector.end(); - int minus_one_pos = -1; + int64_t minus_one_pos = -1; for (auto i : c10::irange(shape_vector.size())) { if (shape_vector[i].value() == -1) { minus_one_pos = i; @@ -773,7 +775,7 @@ void ProcessBroadcastNode(Node* n) { } void ProcessShapeForConcatNode(Node* n) { - int axis = n->i(attr::axis); + auto axis = n->i(attr::axis); if (ConstantValueMap::HasRank(n->input(0)->debugName())) { auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value(); size_t axis_adjust = 0; @@ -1244,7 +1246,7 @@ void ProcessUnsqueezeNode(Node* n) { void ComputeConstant(Node* n, int opset_version) { if (n->kind() == ::c10::onnx::Constant) { if (n->kindOf(attr::value) == AttributeKind::t) { - at::Tensor const_val = n->t(attr::value); + const at::Tensor& const_val = n->t(attr::value); at::Tensor const_val_copy = at::empty(const_val.sizes(), const_val.options()); const_val_copy.copy_(const_val); @@ -1381,7 +1383,7 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { - auto input0_shape_value = input0_shape_size.value(); + const auto& input0_shape_value = input0_shape_size.value(); if (ConstantValueMap::HasValue(n->input(1)->debugName())) { // When value of `shape` is statically known, // output shape can be computed. @@ -1474,7 +1476,7 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { - auto input0_shape_value = input0_shape_size.value(); + const auto& input0_shape_value = input0_shape_size.value(); int64_t total_size = 1; auto is_full_static = true; for (const auto i : c10::irange(input0_shape_value.size())) { @@ -1510,7 +1512,7 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { - auto input0_shape_value = input0_shape_size.value(); + const auto& input0_shape_value = input0_shape_size.value(); if (ConstantValueMap::HasValue(n->input(1)->debugName())) { auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector( n->input(1)->debugName()); @@ -1659,10 +1661,10 @@ void SpecialPostProcess(Node* n) { }; auto find_sequence_empty = [](Value* input, - TensorTypePtr t_type) -> Node* { + const TensorTypePtr& t_type) -> Node* { auto find_sequence_empty_impl = [](Value* input, - TensorTypePtr t_type, + const TensorTypePtr& t_type, auto& find_sequence_empty_ref) -> Node* { auto input_node = input->node(); TORCH_INTERNAL_ASSERT(input_node); @@ -1708,7 +1710,7 @@ void SpecialPostProcess(Node* n) { return nullptr; }; return find_sequence_empty_impl( - input, std::move(t_type), find_sequence_empty_impl); + input, t_type, find_sequence_empty_impl); }; if (seq_node && t_type && t_type->scalarType()) { @@ -2255,7 +2257,7 @@ void ONNXSetDynamicInputShape( } } -static bool HasSequenceTypeOutput(Node* node) { +static bool HasSequenceTypeOutput(const Node* node) { if (node->kind() == ::c10::onnx::SplitToSequence || node->kind() == ::c10::onnx::SequenceInsert || node->kind() == ::c10::onnx::SequenceEmpty || diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 3116c0721a6c..63e6804c97eb 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -21,83 +21,6 @@ using namespace ::c10::onnx; } -// Get the scale of the input to quantized op. There are two cases here -// 1. For ops with output_scale specified in op signature, we get the output -// scale -// 2. For ops with no output scale in op signature (like quantized::relu) -// we traverse up the graph to get the scale from its input until we hit a node -// where scale is explicitly specified. -double getScaleFromInput(Node* input_node) { - std::optional scale; - std::string input_name = input_node->kind().toQualString(); - std::unordered_set noscale_ops = { - "quantized::max_pool2d", - "aten::max_pool2d", - "aten::relu", - "prim::ListUnpack", - "aten::split_with_sizes", - "quantized::nchw2nhwc", - "quantized::nhwc2nchw", - "aten::slice", - "aten::avg_pool2d", - "quantized::cat", - "prim::ListConstruct", - "aten::upsample_nearest2d", - "aten::sigmoid", - "aten::reshape"}; - if (input_name == "aten::quantize_per_tensor") { - TORCH_CHECK( - input_node->inputs().size() > 1, - "aten::quantize_per_tensor expected scale to be 2nd input"); - scale = toIValue(input_node->inputs()[1]); - return scale.value().toDouble(); - } else if (input_name == "quantized::linear") { - // %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::linear expected scale to be 3rd input"); - scale = toIValue(input_node->inputs()[2]); - return scale.value().toDouble(); - } else if (input_name == "quantized::conv2d") { - // %r = quantized::conv2d(%input, %packed_weight, %w_scale, %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::conv2d expected scale to be 3rd input"); - auto num_inputs = input_node->inputs().size(); - scale = toIValue(input_node->inputs()[num_inputs - 2]); - return scale.value().toDouble(); - } else if (input_name == "quantized::conv2d_relu") { - // %r = quantized::conv2d_relu(%input, %packed_weight, %w_scale, - // %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::conv2d_relu expected scale to be 3rd input"); - auto num_inputs = input_node->inputs().size(); - scale = toIValue(input_node->inputs()[num_inputs - 2]); - return scale.value().toDouble(); - } else if (input_name == "quantized::add") { - // %r = quantized::add(%input_a, %input_b, %w_scale, %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::add expected scale to be 3rd input"); - scale = toIValue(input_node->inputs()[2]); - return scale.value().toDouble(); - } else if (input_name == "aten::sigmoid") { - // For the _caffe2::Int8Sigmoid op output scale is 1.0/256 - // And output zero_point is set to 0 (quint8 type). - return 1.0L / 256; - } - // For the ops below the scale is not part of the op signature, so we traverse - // up the graph to get the scale from its input when defined in the graph. - else if (noscale_ops.find(input_name) != noscale_ops.end()) { - return getScaleFromInput(input_node->inputs()[0]->node()); - } - TORCH_INTERNAL_ASSERT( - false, - "Unrecognized quantized operator while trying to compute q_scale for operator ", - input_name); -} - static std::vector CreateQuantizedWeights( std::shared_ptr& graph, const at::Tensor& weight, @@ -315,7 +238,7 @@ static void unpackQuantizedWeightsHelper( auto config_vals = elements[1].to>(); auto tensors = elements[2].to>>(); - std::optional weight = tensors[1]; + const std::optional& weight = tensors[1]; TORCH_INTERNAL_ASSERT( weight, "Weight should always be present in serialized qconv."); unpacked_weight = *weight; @@ -373,7 +296,7 @@ static void unpackQuantizedWeightsHelper( TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version"); std::vector non_optional = elements[1].toTensorVector(); - at::Tensor conv_params_packed = non_optional[0]; + const at::Tensor& conv_params_packed = non_optional[0]; unpacked_weight = non_optional[1]; const int64_t kSpatialDim = conv_params_packed[0].item(); diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 3f2708619be8..e30648399c5a 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -90,7 +90,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { if (PyBool_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackBool(obj.ptr())); } else if (THPUtils_checkLong(obj.ptr())) { - scalar = at::Scalar(THPUtils_unpackLong(obj.ptr())); + scalar = THPUtils_unpackInteger(obj.ptr()); } else if (PyComplex_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr())); } else if (THPUtils_checkDouble(obj.ptr())) { @@ -512,7 +512,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { if (py::isinstance(obj)) { return py::cast(obj); } else if (py::isinstance(obj)) { - return py::cast(obj); + return THPUtils_unpackInteger(obj.ptr()); } else if (py::isinstance(obj)) { return py::cast(obj); } else if (PyComplex_CheckExact(obj.ptr())) { @@ -598,6 +598,8 @@ py::object toPyObject(IValue ivalue) { return py::cast(*tensor.const_data_ptr()); case at::ScalarType::Long: return py::cast(*tensor.const_data_ptr()); + case at::ScalarType::UInt64: + return py::cast(*tensor.const_data_ptr()); case at::ScalarType::Double: return py::cast(*tensor.const_data_ptr()); case at::ScalarType::ComplexDouble: @@ -763,6 +765,8 @@ py::object toPyObject(IValue ivalue) { return py::cast(std::move(ivalue).toSymFloat()); } else if (ivalue.isSymBool()) { return py::cast(std::move(ivalue).toSymBool()); + } else if (ivalue.isUnsigned()) { + return py::cast(std::move(ivalue).toUInt()); } else { TORCH_CHECK( false, diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index b9db0be814e4..8b16e089aa50 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1222,6 +1222,8 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { return SpecialFormValue::create(prim::isinstance); + } else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) { + return std::make_shared(); #ifdef USE_RPC // RPC module is only available when build flag "USE_DISTRIBUTED" is on. } else if ( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index d5586a5b9cd7..9e408682ca6c 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1910,7 +1910,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { } auto& out_t = p_node->Output(0).toTensor(); - if (in0_t.sizes() == in1_t.sizes() && + if (te && te->checkInput(in0_t) && in0_t.sizes() == in1_t.sizes() && in0_t.scalar_type() == in1_t.scalar_type() && in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() && in0_t.scalar_type() == at::kFloat) { diff --git a/torch/csrc/profiler/README.md b/torch/csrc/profiler/README.md index 339c84c0a08e..dc27337349dd 100644 --- a/torch/csrc/profiler/README.md +++ b/torch/csrc/profiler/README.md @@ -13,14 +13,49 @@ The profiler instruments PyTorch to collect information about the model's execut - [Codebase Structure](#codebase-structure) - [`RecordFunction`](#recordfunction) - [Autograd Integration](#autograd-integration) -- [Collection and Post-Processing](#collection-and-post-processing) +- [Torch Operation Collection](#torch-operation-collection) +- [Allocation Event Collection](#allocation-event-collection) - [Kineto Integration](#kineto-integration) - [Python Tracing](#python-tracing) +- [Clock Alignment](#clock-alignment) ## Codebase Structure ## -TODO - +This section highlights directories an files that are significant to the profiler. Lesser relevant files, directories, and modules are omitted. +``` +torch/ +│ +ā”œā”€ā”€ profiler/ # Main package containing the core frontend logic +│ ā”œā”€ā”€ __init__.py # Initialization file for profiler package +│ ā”œā”€ā”€ profiler.py # Main profiler frontend class +│ └── _utils.py # FunctionEvent utils +│ +ā”œā”€ā”€ autograd/ # Autograd package +│ ā”œā”€ā”€ __init__.py # Initialization file for autograd package +│ ā”œā”€ā”€ profiler.py # Main profiler backend class +│ └── profiler_utils.py # FunctionEvent utils +│ +ā”œā”€ā”€ csrc/ # C and C++ source code +│ └── profiler/ # Profiler C++ source code +│ ā”œā”€ā”€ collection.cpp # Main collection logic +│ ā”œā”€ā”€ collection.h # Collection definitions +│ ā”œā”€ā”€ kineto_client_interface.cpp # Interface to call Profiler from kineto (on-demand only) +│ ā”œā”€ā”€ kineto_client_interface.h # Client interface definitions +│ ā”œā”€ā”€ kineto_shim.cpp # Shim to call kineto from profiler +│ ā”œā”€ā”€ kineto_shim.h # Shim definitions +│ ā”œā”€ā”€ util.cpp # utils for handling args in profiler events +│ ā”œā”€ā”€ util.h # util definitions +│ └── README.md # This file +│ └── autograd/ # Autograd C++ source code +│ ā”œā”€ā”€ profiler_python.cpp # Main python stack collection logic +│ ā”œā”€ā”€ profiler_python.h # Python stack collection definitions +│ ā”œā”€ā”€ profiler_kineto.cpp # Profiler backend logic for starting collection/kineto +│ └── profiler_kineto.h # Profiler backend definitions for starting collection/kineto +│ └── ATen/ # ATen C++ source code +│ ā”œā”€ā”€ record_function.cpp # RecordFunction collection logic +│ └── record_function.h # RecordFunction definitions +└── LICENSE # License information +``` ## `RecordFunction` ## [aten/src/ATen/record_function.h](../../../aten/src/ATen/record_function.h) @@ -43,14 +78,39 @@ The profiler records two pieces of information from the autograd engine: (\*) Note that only op invocations whose inputs require gradients are assigned a sequence number -## Collection and Post-Processing ## +## Torch Operation Collection ## +This section describes the general flow for collecting torch operations during auto-trace (in-process, synchronous tracing). For details on on-demand tracing (out-of-process, asynchronous), please refer to the Libkineto README. + +When a trace begins, the autograd/profiler backend calls into `profiler_kineto.cpp` to prepare, start, or stop collection. At the start of tracing, the `onFunctionEnter` and `onFunctionExit` callbacks defined in `profiler_kineto.cpp` are registered. + +Callback registration can be either global or local, depending on the `ExperimentalConfig` used: +- **Global:** The callback is registered to all threads throughout execution. +- **Local:** The callback is registered only to threads present *at the start* of tracing. +Within `onFunctionEnter`, the profiler creates a `ThreadLocalSubqueue` instance for each thread, ensuring that each CPU operation is associated with the thread on which it was executed. When a torch operation is entered, the profiler calls `begin_op` (defined in `collection.cpp`) to record the necessary information. The `begin_op` routine is intentionally lightweight, as it is on the "hot path" during profiling. Excessive overhead here would distort the profile and reduce its usefulness. Therefore, only minimal information is collected during the callback; most logic occurs during post-processing. -TODO +## Allocation Event Collection ## + +Unlike torch operations, which have a start and stop, allocation events are represented as `cpu_instant_event` (zero duration). As a result, `RecordFunction` is bypassed for these events. Instead, `emplace_allocation_event` is called directly to enqueue the event into the appropriate `ThreadLocalSubqueue`. ## Kineto Integration ## -TODO +Kineto serves as an abstraction layer for collecting events across multiple architectures. It interacts with libraries such as CUPTI to receive GPU and accelerator events, which are then forwarded to the frontend profiler. Kineto requires time to "prepare" (also referred to as "warmup") these third-party modules to avoid distorting the profile with initialization routines. While this could theoretically be done at job startup, keeping a heavy library like CUPTI running unnecessarily introduces significant overhead. +As previously mentioned, `profiler_kineto.cpp` is used in the backend to invoke the appropriate profiler stage. It also calls into `kineto_shim.cpp`, which triggers the corresponding routines in Kineto. Once a trace is complete, all events collected by Kineto are forwarded to the profiler for two main reasons: +1. To coalesce all data and complete any post-processing between profiler and Kineto events. +2. To forward these events to the Python frontend as `FunctionEvents`. +The final step in integration is file export. After all events have been collected and post-processed, they can be exported to a JSON file for visualization in Perfetto or Chrome Tracer. This is done by calling Kineto's `ActivityTraceInterface::save`, which writes all event information to disk. ## Python Tracing ## -TODO +When `with_stack=True` is set in the profiler, the Python stack tracer is generated using the `make` function defined in `PythonTracerBase`. The implementation resides in `profiler_python.cpp`. +To profile the stack, `PyEval_SetProfile` is used to trace and handle various execution events within a Python program. This enables comprehensive profiling by monitoring and responding to specific cases: +- **Python Function Calls (`PyTrace_CALL`):** The `recordPyCall` method logs each Python function call, capturing essential details for later analysis. +- **C Function Calls (`PyTrace_C_CALL`):** The `recordCCall` method documents calls to C functions, including relevant arguments, providing a complete view of the program's execution flow. +- **Python Function Returns (`PyTrace_RETURN`):** Exit times of Python functions are recorded, enabling precise measurement of function execution durations. +- **C Function Returns and Exceptions (`PyTrace_C_RETURN` and `PyTrace_C_EXCEPTION`):** Exit times for C functions are tracked, whether they conclude normally or due to an exception, ensuring all execution paths are accounted for. +This setup allows for detailed and accurate data collection on both Python and C function executions, facilitating thorough post-processing and analysis. After profiling, the accumulated event stacks are processed to match entrances and exits, constructing complete events for further analysis by the profiler. +**Note:** For Python 3.12.0–3.12.4, a bug in CPython requires the use of `sys.monitoring` as a workaround. + +## Clock Alignment ## + +Depending on the system environment, the profiler will use the most efficient clock when creating a timestamp. The default for most Linux systems is TSC, which records time in the form of CPU cycles. To convert from this time to the unix time in nanoseconds, we create a clock converter. If Kineto is included in the profiler, this converter will also be passed into Kineto as well to ensure alignment. diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index ef7f1d0784ea..ec779fd67fb0 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -44,7 +44,19 @@ struct FromImpl { // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a // static_assert above that T is trivially copyable, which should be // enough. +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ std::memcpy(&result, reinterpret_cast(&val), sizeof(val)); +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + // if value has size less than sizeof(StableIValue), then only lowest bytes + // have to be updated + std::memcpy( + reinterpret_cast(&result) + sizeof(StableIValue) - + sizeof(val), + reinterpret_cast(&val), + sizeof(val)); +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif return result; } }; @@ -127,7 +139,22 @@ struct ToImpl { }; Result result; // See NOTE[ -Wclass-memaccess ] above. +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ std::memcpy(reinterpret_cast(&result.t), &val, sizeof(result)); +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + static_assert( + sizeof(T) <= sizeof(StableIValue), + "StableLibrary stack does not support parameter types larger than 64 bits."); + // if value has size less than sizeof(StableIValue), then only lowest bytes + // have to be updated + std::memcpy( + reinterpret_cast(&result.t), + reinterpret_cast(&val) + sizeof(StableIValue) - + sizeof(result), + sizeof(result)); +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif return result.t; } }; diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index a8f68f4a5e3a..7ce25af14d3f 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -4,11 +4,15 @@ #include #include #include +#include +#include #include using torch::stable::Tensor; +namespace torch::stable { + // We expect this to be the stable version of the empty_like op that takes in // no kwargs (device, dtype, layout, memory_format). We will add kwargs // support in the future. @@ -21,7 +25,7 @@ inline Tensor empty_like(const Tensor& self) { from(std::nullopt), from(std::nullopt), from(std::nullopt)}; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::empty_like", "", stack.data())); return to(stack[0]); } @@ -32,16 +36,44 @@ inline Tensor empty_like(const Tensor& self) { // actually a Scalar. This is because Scalar.h is currently not // header-only. inline Tensor fill_(const Tensor& self, double value) { - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value)); + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value)); return self; } +// We expect this to be the stable version of the narrow.default op. +// narrow takes in a SymInt for start and length, but these are typed as +// int64_t as SymInt is not yet header-only. +inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) { + AtenTensorHandle ret0 = nullptr; + + TORCH_ERROR_CODE_CHECK( + aoti_torch_aten_narrow(self.get(), dim, start, length, &ret0)); + return Tensor(ret0); +} + +// We expect this to be the stable version of the pad.default op. +// pad.default takes in a SymInt[] as the pad argument however pad is typed as +// use std::vector because +// (1) IntArrayRef is not yet header-only +// (2) SymInt is not yet header-only +inline Tensor pad( + const Tensor& self, + std::vector pad, + const std::string& mode = "constant", + double value = 0.0) { + AtenTensorHandle ret0 = nullptr; + + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_pad( + self.get(), pad.data(), pad.size(), mode.c_str(), &value, &ret0)); + return Tensor(ret0); +} + // We expect this to be the stable version of the transpose op with identical // semantics to the existing transpose.int op. inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { const auto num_args = 3; std::array stack{from(self), from(dim0), from(dim1)}; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::transpose", "int", stack.data())); return to(stack[0]); } @@ -52,7 +84,9 @@ inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { inline Tensor zero_(Tensor& self) { const auto num_args = 1; std::array stack{from(self)}; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::zero_", "", stack.data())); return to(stack[0]); } + +} // namespace torch::stable diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index 1b9b3fecb417..d02763923a5f 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -1,10 +1,8 @@ #pragma once -// TODO ASAP: THIS FILE SHOULD BE HEADER ONLY BUT ISN'T ENFORCED: -// I only need it for AOTI_TORCH_ERROR_CODE_CHECK, see #154908 -#include - #include +#include +#include namespace torch::stable { @@ -31,13 +29,21 @@ class Tensor { std::shared_ptr ath_; public: - Tensor() = delete; + // Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH) + // Steals ownership from the ATH + Tensor() { + AtenTensorHandle ret; + TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret)); + ath_ = std::shared_ptr(ret, [](AtenTensorHandle ath) { + TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + }); + } // Construct a stable::Tensor from an AtenTensorHandle (ATH) // Steals ownership from the ATH explicit Tensor(AtenTensorHandle ath) : ath_(ath, [](AtenTensorHandle ath) { - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); }) {} // Copy and move constructors can be default cuz the underlying handle is a @@ -65,19 +71,19 @@ class Tensor { void* data_ptr() const { void* data_ptr; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); return data_ptr; } int64_t dim() const { int64_t dim; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); return dim; } int64_t numel() const { int64_t numel; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); return numel; } @@ -86,38 +92,43 @@ class Tensor { // Here, we assume the default contiguous memory format. bool is_contiguous() const { bool is_contiguous; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_is_contiguous(ath_.get(), &is_contiguous)); return is_contiguous; } int64_t stride(int64_t dim) const { int64_t stride; - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_get_stride(ath_.get(), dim, &stride)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride)); return stride; } DeviceIndex get_device() const { int32_t device_index; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(ath_.get(), &device_index)); return static_cast(device_index); } bool is_cuda() const { int32_t device_type; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_type(ath_.get(), &device_type)); return device_type == aoti_torch_device_type_cuda(); } int64_t size(int64_t dim) const { int64_t size; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); return size; } + bool defined() const { + bool defined; + TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined)); + return defined; + } + // ============================================================================= // END of C-shimified TensorBase APIs // ============================================================================= diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 98803390e510..62c8390f7c9b 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> +// checksum<> // clang-format off #pragma once @@ -61,6 +61,7 @@ class ForwardRef { ptr_ = std::make_unique(*other.ptr_); return *this; } + ~ForwardRef(); const T& operator*() const { return *ptr_; } @@ -157,7 +158,6 @@ class Node; class OptionalTensorArgument; class OutputSpec; class OutputTokenSpec; -class Program; class RangeConstraint; class SchemaVersion; class SymBool; @@ -3013,6 +3013,8 @@ class ExportedProgram { SchemaVersion schema_version; std::vector verifiers = {}; std::string torch_version = "<=2.4"; + std::unordered_map tensor_paths = {}; + std::unordered_map constant_paths = {}; public: @@ -3064,36 +3066,31 @@ class ExportedProgram { torch_version = std::move(def); } - friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); - friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); -}; - -class Program { - private: - std::unordered_map methods; + const std::unordered_map& get_tensor_paths() const { + return tensor_paths; + } - public: + void set_tensor_paths(std::unordered_map def) { + tensor_paths = std::move(def); + } - const std::unordered_map& get_methods() const { - return methods; + const std::unordered_map& get_constant_paths() const { + return constant_paths; } - void set_methods(std::unordered_map def) { - methods = std::move(def); + void set_constant_paths(std::unordered_map def) { + constant_paths = std::move(def); } - friend void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t); - friend void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t); + friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); }; class Model { private: std::string name; - std::unordered_map tensorPaths; - Program program; - std::unordered_map delegates; - std::unordered_map deviceAllocationMap; - std::unordered_map constantPaths; + ExportedProgram program; + std::unordered_map variants; public: @@ -3105,44 +3102,20 @@ class Model { name = std::move(def); } - const std::unordered_map& get_tensorPaths() const { - return tensorPaths; - } - - void set_tensorPaths(std::unordered_map def) { - tensorPaths = std::move(def); - } - - const Program& get_program() const { + const ExportedProgram& get_program() const { return program; } - void set_program(Program def) { + void set_program(ExportedProgram def) { program = std::move(def); } - const std::unordered_map& get_delegates() const { - return delegates; - } - - void set_delegates(std::unordered_map def) { - delegates = std::move(def); - } - - const std::unordered_map& get_deviceAllocationMap() const { - return deviceAllocationMap; - } - - void set_deviceAllocationMap(std::unordered_map def) { - deviceAllocationMap = std::move(def); + const std::unordered_map& get_variants() const { + return variants; } - const std::unordered_map& get_constantPaths() const { - return constantPaths; - } - - void set_constantPaths(std::unordered_map def) { - constantPaths = std::move(def); + void set_variants(std::unordered_map def) { + variants = std::move(def); } friend void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t); @@ -3316,6 +3289,8 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nloh nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version; nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers; nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version; + nlohmann_json_j["tensor_paths"] = nlohmann_json_t.tensor_paths; + nlohmann_json_j["constant_paths"] = nlohmann_json_t.constant_paths; } inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) { @@ -3326,6 +3301,8 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nl nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version); nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers); nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version); + nlohmann_json_t.tensor_paths = nlohmann_json_j.value("tensor_paths", nlohmann_json_default_obj.tensor_paths); + nlohmann_json_t.constant_paths = nlohmann_json_j.value("constant_paths", nlohmann_json_default_obj.constant_paths); } inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t) { @@ -3511,21 +3488,15 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlo inline void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t) { nlohmann_json_j["name"] = nlohmann_json_t.name; - nlohmann_json_j["tensorPaths"] = nlohmann_json_t.tensorPaths; nlohmann_json_j["program"] = nlohmann_json_t.program; - nlohmann_json_j["delegates"] = nlohmann_json_t.delegates; - nlohmann_json_j["deviceAllocationMap"] = nlohmann_json_t.deviceAllocationMap; - nlohmann_json_j["constantPaths"] = nlohmann_json_t.constantPaths; + nlohmann_json_j["variants"] = nlohmann_json_t.variants; } inline void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_json_t) { Model nlohmann_json_default_obj; nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); - nlohmann_json_t.tensorPaths = nlohmann_json_j.value("tensorPaths", nlohmann_json_default_obj.tensorPaths); nlohmann_json_t.program = nlohmann_json_j.value("program", nlohmann_json_default_obj.program); - nlohmann_json_t.delegates = nlohmann_json_j.value("delegates", nlohmann_json_default_obj.delegates); - nlohmann_json_t.deviceAllocationMap = nlohmann_json_j.value("deviceAllocationMap", nlohmann_json_default_obj.deviceAllocationMap); - nlohmann_json_t.constantPaths = nlohmann_json_j.value("constantPaths", nlohmann_json_default_obj.constantPaths); + nlohmann_json_t.variants = nlohmann_json_j.value("variants", nlohmann_json_default_obj.variants); } inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t) { @@ -3604,15 +3575,6 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nl nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); } -inline void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t) { - nlohmann_json_j["methods"] = nlohmann_json_t.methods; -} - -inline void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t) { - Program nlohmann_json_default_obj; - nlohmann_json_t.methods = nlohmann_json_j.value("methods", nlohmann_json_default_obj.methods); -} - inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) { nlohmann_json_j["min_val"] = nlohmann_json_t.min_val; nlohmann_json_j["max_val"] = nlohmann_json_t.max_val; @@ -3717,6 +3679,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo template ForwardRef::ForwardRef(ForwardRef&&) = default; template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +template ForwardRef::~ForwardRef() = default; } // namespace _export } // namespace torch diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 8a16b0211dce..1ae03f91f218 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -938,6 +938,27 @@ auto FunctionParameter::check( std::vector& overloaded_args, int argnum, int64_t* failed_idx) -> bool { + if (_check(obj, overloaded_args, argnum, failed_idx)) { + return true; + } + // NB: This will not detect torch function inside elements of a list. So + // you still have to handle that manually + // NB: torch function on Tensor subclasses NOT eligible here, you handled + // that internally + if (check_has_torch_function(obj, /*ignore_mode*/ true) && + !THPVariable_Check(obj)) { + // unrelated objects with __torch_function__ + append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false); + return true; + } + return false; +} + +auto FunctionParameter::_check( + PyObject* obj, + std::vector& overloaded_args, + int argnum, + int64_t* failed_idx) -> bool { switch (type_) { case ParameterType::TENSOR: { if (is_tensor_and_append_overloaded(obj, &overloaded_args)) { @@ -1013,15 +1034,7 @@ auto FunctionParameter::check( case ParameterType::PYOBJECT: return true; case ParameterType::SCALARTYPE: - if (THPDtype_Check(obj) || THPPythonScalarType_Check(obj)) { - return true; - } - if (check_has_torch_function(obj, /*ignore_mode*/ true)) { - // tensor subclasses and unrelated objects with __torch_function__ - append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false); - return true; - } - return false; + return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::MEMORY_FORMAT: @@ -1788,21 +1801,7 @@ at::Tensor PythonArgs::tensor_slow(int i) { if (PyBool_Check(obj)) { scalar = at::Scalar(THPUtils_unpackBool(obj)); } else if (THPUtils_checkLong(obj)) { - int overflow = -1; - long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); - if (value == -1 && PyErr_Occurred()) { - throw python_error(); - } - if (overflow != 0) { - // try unsigned - unsigned long long value = PyLong_AsUnsignedLongLong(obj); - if (value == static_cast(-1) && PyErr_Occurred()) { - throw python_error(); - } - scalar = at::Scalar(static_cast(value)); - } else { - scalar = at::Scalar(static_cast(value)); - } + scalar = THPUtils_unpackInteger(obj); } else if (PyComplex_Check(obj)) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj)); } else if (THPUtils_checkDouble(obj)) { diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index bc281f2512a5..2c1373921e57 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -322,6 +322,12 @@ struct FunctionParameter { int argnum, int64_t* failed_idx = nullptr); + bool _check( + PyObject* obj, + std::vector& overloaded_args, + int argnum, + int64_t* failed_idx = nullptr); + void set_default_str(const std::string& str); TORCH_PYTHON_API std::string type_name() const; diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index 25ca2692b329..a8b9b8632a00 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -208,3 +208,22 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) { } return (c10::DeviceIndex)value; } + +template +inline T THPUtils_unpackInteger(PyObject* obj) { + int overflow = -1; + const auto value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (!overflow) { + return static_cast(value); + } + // try unsigned + const auto uvalue = PyLong_AsUnsignedLongLong(obj); + if (uvalue == static_cast>(-1) && + PyErr_Occurred()) { + throw python_error(); + } + return static_cast(uvalue); +} diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 63e59096162f..1bd6f9edc031 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of October 2019, for size >= 1MB allocations). + (as of June 2025, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of October 2019, for size < 1MB allocations). + (as of June 2025, for size < 1MB allocations). Metric type: diff --git a/torch/distributed/_composable/replicate_with_fsdp.py b/torch/distributed/_composable/replicate_with_fsdp.py index b49d240e4d75..219501a0a708 100644 --- a/torch/distributed/_composable/replicate_with_fsdp.py +++ b/torch/distributed/_composable/replicate_with_fsdp.py @@ -43,7 +43,7 @@ from torch.distributed.tensor import Shard -cls_to_fsdp_cls: dict[type, type] = {} +cls_to_replicate_cls: dict[type, type] = {} _ROOT_MODULE_PREFIX = "" @@ -51,10 +51,10 @@ class _ReplicateStateContext: - """This has state shared across FSDP states.""" + """This has state shared across Replicate states.""" def __init__(self) -> None: - # All FSDP states in the root state's module tree + # All Replicate states in the root state's module tree self.all_states: list[_ReplicateState] = [] # Iteration's forward root runs the once-per-forward logic; this root # may not be the overall root set by lazy initialization in cases where @@ -173,7 +173,7 @@ def replicate_impl( offload_policy: OffloadPolicy = OffloadPolicy(), ignored_params: Optional[set[nn.Parameter]] = None, ): - torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") + torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp") if isinstance(module, (nn.ModuleList, nn.ModuleDict)): raise ValueError( f"replicate does not support containers that do not implement forward: {module}" @@ -224,11 +224,11 @@ def replicate_impl( # Place Replicate leftmost for highest priority in the method resolution order for module in modules: cls = module.__class__ - new_cls = cls_to_fsdp_cls.get(cls, None) + new_cls = cls_to_replicate_cls.get(cls, None) if not new_cls: dct = {"__deepcopy__": _unimplemented_deepcopy} - new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) - cls_to_fsdp_cls[cls] = new_cls + new_cls = type(f"Replicate{cls.__name__}", (FSDPModule, cls), dct) + cls_to_replicate_cls[cls] = new_cls module.__class__ = new_cls return arg_module @@ -262,27 +262,7 @@ def replicate( ) device_mesh = kwargs.pop("device_mesh", None) - if device_mesh is not None: - from torch.distributed.device_mesh import _mesh_resources - - root_mesh = _mesh_resources.get_root_mesh(device_mesh) - # if a root mesh is not the same as device_mesh, - # meaning the device_mesh is sliced out from the root mesh. - if root_mesh != device_mesh: - # TODO: This is a temporary work around to enable DDP + TP. - # We should do the logic in DDP so that the 2D implementation is - # sound and the state_dict works out of the box. - # - # This won't conflict with what is done in DDP class as the module - # replicate is going to pass is NOT the original module. - from torch.distributed.tensor.parallel.ddp import ( - _localize_dtensor, - _reconstruct_dtensor, - ) - - module.register_forward_pre_hook(_reconstruct_dtensor) - module.register_forward_hook(_localize_dtensor) - else: + if device_mesh is None: device_mesh = replicate_mesh() module = replicate_impl(module, mesh=device_mesh, **kwargs) diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 8472f0d9dd04..0b53da3988bd 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -815,10 +815,6 @@ def _are_we_tracing() -> bool: # If fake mode is turned on, we are almost definitely compiling/tracing. if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: return True - - if torch._dynamo.compiled_autograd.in_compiled_autograd_initial_trace: - return True - return get_proxy_mode() is not None diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index d050c8b40c6c..4b0e9acc19bd 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1270,6 +1270,11 @@ def _fused_scaled_matmul_reduce_scatter_impl( .flatten(0, -2) ) A_scale_shards = list(A_scale.chunk(group.size())) + # cuBLAS's row-wise kernel requires scales to be aligned to 16 bytes. + # When we slice them we might break this and need to reallocate them. + A_scale_shards = [ + t if t.data_ptr() % 16 == 0 else t.clone() for t in A_scale_shards + ] else: raise ValueError("A_scale cannot be none for scaled_mm") diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index dda1885a8e16..c543fdffc1c7 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -1,15 +1,69 @@ import os +import subprocess import sysconfig -from typing import Optional +from typing import Any, Optional from torch.utils._triton import has_triton +def _find_nvshmem_device_library() -> str: + paths = [os.path.join(sysconfig.get_path("purelib"), "nvidia", "nvshmem", "lib")] + + # Add common system installation paths + common_paths = [ + "/usr/local/lib", + "/usr/lib", + "/opt/nvidia/nvshmem/lib", + ] + paths.extend(common_paths) + + try: + import torch + + torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib") + so_path = os.path.join(torch_lib, "libtorch_nvshmem.so") + + if os.path.exists(so_path): + try: + result = subprocess.run( + ["readelf", "-d", so_path], + capture_output=True, + text=True, + check=True, + ) + + for line in result.stdout.splitlines(): + if ("RPATH" in line or "RUNPATH" in line) and "[" in line: + rpath = line.split("[", 1)[1].split("]", 1)[0] + for p in rpath.split(":"): + p = p.strip().replace("$ORIGIN", torch_lib) + if p and p not in paths: + paths.append(p) + except subprocess.CalledProcessError: + pass + + except ImportError: + pass + + for path in paths: + device_lib = os.path.join(path, "libnvshmem_device.bc") + if os.path.exists(device_lib): + return device_lib + + raise RuntimeError(f"NVSHMEM device library not found. Searched: {paths}") + + def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: """ Enable NVSHMEM device functions for Triton. It performs a NVSHMEM device-side initialization on the kernel module created by Triton. + This function sets a global hook that initializes NVSHMEM for Triton + kernels. To avoid unnecessary initializations, the hook only acts on + kernels that have "nvshmem" in their function name. Therefore, it is + required that all Triton kernels using NVSHMEM primitives follow this + naming convention. + Args: lib_dir (Optional[str]): The directory where the NVSHMEM device library is located. If not provided, it will use the default path where NVSHMEM @@ -19,92 +73,210 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: dict[str, str]: A dictionary containing the NVSHMEM device library name and path. """ - from triton.runtime.jit import JITFunction + import triton from torch._C._distributed_c10d import _nvshmemx_cumodule_init - # Detect NVSHMEM device library path from python library path - if lib_dir is None: - py_lib_path = sysconfig.get_path("purelib") - lib_dir = py_lib_path + "/nvidia/nvshmem/lib" - - lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") - if not os.path.exists(lib_path): - raise RuntimeError("NVSHMEM device library not found") + if lib_dir is not None: + lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError( + f"NVSHMEM device library not found at specified path: {lib_path}" + ) + else: + # Otherwise, search for the library automatically. + lib_path = _find_nvshmem_device_library() extern_libs = {"libnvshmem_device": lib_path} # A hook function to initialize NVSHMEM in Triton def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] - key = kwargs["key"] - device = kwargs["compile"]["device"] jit_function = kwargs["fn"].jit_function - kernel_cache, _, _, _ = jit_function.device_caches[device] - kernel = kernel_cache.get(key, None) - kernel.run - _nvshmemx_cumodule_init(kernel.module) + # Only initialize NVSHMEM module for kernels containing "nvshmem" in their name + if "nvshmem" in jit_function.fn.__name__: + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache, _, _, _ = jit_function.device_caches[device] + kernel = kernel_cache.get(key, None) + if kernel is not None: + kernel.run + _nvshmemx_cumodule_init(kernel.module) # Register the function as a post-compile hook - JITFunction.compiled_hook = nvshmem_init_hook + triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook # Return to user so that they can use it in Triton kernel invocation return extern_libs if has_triton(): + import triton + import triton.language as tl from triton.language import core + @triton.jit # type: ignore[misc] + def put(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Put tensor data from local PE to a remote PE. + + This high-level function provides a tensor-aware interface for NVSHMEM put + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + dest: Destination tensor on the remote PE. Type must match source. + source: Source tensor on the local PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a blocking operation that returns after data has been copied out + of the source array on the local PE. + - The operation does not guarantee delivery to the destination PE. + Use nvshmem_fence() for ordering or nvshmem_quiet() for completion. + + Example: + ``` + # Transfer 100 elements to PE 1 + nvshmem.put(dest_tensor, src_tensor, 100, 1) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return putmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes, pe + ) + @core.extern - def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + def putmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM put""" return core.extern_elementwise( "", "", - [dst, src, nelems, pe], + [dest, source, size_bytes, pe], { ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int64"), # pe number ): ("nvshmemx_putmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def get(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Get tensor data from a remote PE to local PE. + + This high-level function provides a tensor-aware interface for NVSHMEM get + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + dest: Destination tensor on the local PE. Type must match source. + source: Source tensor on the remote PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a blocking operation that returns after data has been delivered + to the destination array on the local PE. + - The destination data is guaranteed to be available for use after the call returns. + + Example: + ``` + # Get 100 elements from PE 0 + nvshmem.get(dest_tensor, src_tensor, 100, 0) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return getmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes, pe ) @core.extern - def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + def getmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM get""" return core.extern_elementwise( "", "", - [dst, src, nelems, pe], + [dest, source, size_bytes, pe], { ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int64"), # pe number ): ("nvshmemx_getmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern def putmem_signal_block( # type: ignore[no-untyped-def] dst, src, - nelems, + size_bytes, sig_addr, signal, sig_op, pe, - _builder=None, + _semantic=None, ): # type: ignore[no-untyped-def] + """ + Put data to remote PE with atomic signal operation using block-scoped operation. + + This function copies data from the local PE to the remote PE and then + atomically updates a signal variable on the remote PE to indicate completion. + This enables efficient point-to-point synchronization between PEs. + + Args: + dst (int64): Symmetric address of the destination data object on the remote PE. + src (int64): Local address of the source data object containing data to be copied. + size_bytes (int64): Number of bytes to transfer. Must be positive. + sig_addr (int64): Symmetric address of the signal variable (uint64_t) on the remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int64): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomic set operation + - NVSHMEM_SIGNAL_ADD (5): Atomic add operation + pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that returns after data has been copied out + of the source array and the signal has been updated on the remote PE. + - The signal update is performed atomically with respect to other signal + operations and synchronization routines. + - The signal variable must be of type uint64_t in symmetric memory. + - Use with nvshmem_signal_wait_until() for synchronization. + + Example: + ``` + # Transfer data and set completion flag to 1 + NVSHMEM_SIGNAL_SET = 0 + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, 1024, sig_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe + ) + ``` + """ return core.extern_elementwise( "", "", - [dst, src, nelems, sig_addr, signal, sig_op, pe], + [dst, src, size_bytes, sig_addr, signal, sig_op, pe], { ( core.dtype("int64"), @@ -117,11 +289,51 @@ def putmem_signal_block( # type: ignore[no-untyped-def] ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) + # Wait and Signal Operations + + @triton.jit # type: ignore[misc] + def wait_until(ivar, cmp_op, cmp_val): # type: ignore[no-untyped-def] + """ + Wait until a tensor variable meets a specified condition. + + This high-level function provides a tensor-aware interface for NVSHMEM wait_until + operations. It automatically handles tensor address extraction, making + the API more ergonomic and type-safe. + + Args: + ivar_tensor: Tensor to monitor (typically int64/uint64) in symmetric memory. + cmp: Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until ivar == cmp_val + - NVSHMEM_CMP_NE (1): Wait until ivar != cmp_val + - NVSHMEM_CMP_GT (2): Wait until ivar > cmp_val + - NVSHMEM_CMP_GE (3): Wait until ivar >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until ivar < cmp_val + - NVSHMEM_CMP_LE (5): Wait until ivar <= cmp_val + cmp_val: Value to compare against. + + Notes: + - This is a blocking operation that will wait indefinitely until the + condition is satisfied. + - The tensor must be in symmetric memory and accessible from other PEs. + + Example: + ``` + # Wait until flag tensor becomes 1 (set by another PE) + NVSHMEM_CMP_EQ = 0 + nvshmem.wait_until_tensor(flag_tensor, NVSHMEM_CMP_EQ, 1) + ``` + """ + tl.static_assert( + ivar.type.element_ty.itemsize == 8, + "wait_until expects a 64-bit type for the synchronization variable", + ) + return wait_until_extern_wrapper(ivar.to(tl.int64), cmp_op, cmp_val) + @core.extern - def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + def wait_until_extern_wrapper(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] return core.extern_elementwise( "", "", @@ -134,11 +346,49 @@ def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-de ): ("nvshmem_longlong_wait_until", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + def signal_wait_until(sig_addr, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] + """ + Wait until a signal variable meets a specified condition. + + This function blocks the calling thread until the value at the specified + signal variable satisfies the given comparison condition. Signal variables + are special uint64_t symmetric objects used for efficient synchronization + with signal operations. + + Args: + sig_addr (int64): Symmetric address of the signal variable (uint64_t). + Must be 8-byte aligned symmetric memory. + cmp (int64): Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until signal == cmp_val + - NVSHMEM_CMP_NE (1): Wait until signal != cmp_val + - NVSHMEM_CMP_GT (2): Wait until signal > cmp_val + - NVSHMEM_CMP_GE (3): Wait until signal >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until signal < cmp_val + - NVSHMEM_CMP_LE (5): Wait until signal <= cmp_val + cmp_val (int64): Value to compare against. + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation designed specifically for signal variables. + - Signal variables are updated atomically by putmem_signal operations. + - More efficient than wait_until for signal-based synchronization patterns. + - Ensures the signal update is fully complete before returning. + - Commonly used with putmem_signal_block for producer-consumer patterns. + + Example: + ``` + # Wait for signal to be set to completion value + NVSHMEM_CMP_EQ = 0 + nvshmem.signal_wait_until(signal_ptr, NVSHMEM_CMP_EQ, 42) + ``` + """ return core.extern_elementwise( "", "", @@ -151,11 +401,45 @@ def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no ): ("nvshmem_signal_wait_until", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def] + def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): # type: ignore[no-untyped-def] + """ + Perform an atomic signal operation on a remote PE. + + This function atomically updates a signal variable on the specified remote PE + using the given operation and value. This enables efficient point-to-point + synchronization and notification between PEs. + + Args: + sig_addr (int64): Symmetric address of the signal variable (uint64_t) on the remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int64): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomically set sig_addr = signal + - NVSHMEM_SIGNAL_ADD (5): Atomically set sig_addr += signal + pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a one-sided operation - the remote PE does not need to participate. + - The signal operation is performed atomically on the remote PE. + - Can be used with signal_wait_until() on the remote PE for synchronization. + - Provides low-overhead notification mechanism between PEs. + - The signal variable must be of type uint64_t in symmetric memory. + + Example: + ```python + # Atomically set remote signal to 1 to notify completion + NVSHMEM_SIGNAL_SET = 0 + nvshmem.signal_op(remote_signal_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe) + ``` + """ return core.extern_elementwise( "", "", @@ -169,11 +453,47 @@ def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-u ): ("nvshmemx_signal_op", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) + # Memory Ordering Operations @core.extern - def fence(_builder=None): # type: ignore[no-untyped-def] + def fence(_semantic=None): # type: ignore[no-untyped-def] + """ + Ensure ordering of put operations to each remote PE. + + This function provides a memory fence that ensures point-to-point ordering + of remote memory operations. Put operations issued before the fence are + guaranteed to be ordered before put operations issued after the fence, + when targeting the same remote PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This provides weaker ordering guarantees than quiet(). + - Operations to each PE are ordered, but operations to different PEs + may still be reordered relative to each other. + - Does not guarantee completion of operations, only ordering. + - Non-blocking operations are not ordered by fence - use quiet() instead. + - Essential for ensuring correct ordering in communication patterns. + + Memory Ordering Guarantees: + - Put operations before fence() → ordered before → Put operations after fence() + - Ordering is maintained per-destination-PE basis + - Remote PEs can observe the enforced ordering + + Example: + ``` + # Ensure first put completes before second put to same PE + nvshmem.put(dst, src, nelems, target_pe) + nvshmem.fence() # Enforce ordering + nvshmem.put(dst2, src2, nelems, target_pe) + ``` + """ return core.extern_elementwise( "", "", @@ -182,11 +502,46 @@ def fence(_builder=None): # type: ignore[no-untyped-def] (): ("nvshmem_fence", core.dtype("int32")), }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def quiet(_builder=None): # type: ignore[no-untyped-def] + def quiet(_semantic=None): # type: ignore[no-untyped-def] + """ + Wait for completion of all outstanding put operations. + + This function blocks until all outstanding remote memory operations issued + by the calling PE have completed. It provides stronger guarantees than + fence() by ensuring both ordering and completion of all operations. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that waits for completion. + - Ensures all previous put operations have been delivered to their destinations. + - Provides global ordering - operations to ALL PEs are ordered. + - Required to complete non-blocking operations. + - More expensive than fence() but provides stronger guarantees. + + Memory Ordering Guarantees: + - All put operations before quiet() are completed before any operations after quiet() + - Operations are visible to all PEs as having occurred before subsequent operations + - Both blocking and non-blocking operations are completed + + Example: + ``` + # Ensure all data transfers complete before setting completion flag + nvshmem.putmem_block(data_ptr, src_ptr, data_size, target_pe) + nvshmem.quiet() # Wait for data transfer completion + nvshmem.putmem_block( + flag_ptr, flag_src_ptr, 8, target_pe + ) # Signal completion + ``` + """ return core.extern_elementwise( "", "", @@ -195,88 +550,433 @@ def quiet(_builder=None): # type: ignore[no-untyped-def] (): ("nvshmem_quiet", core.dtype("int32")), }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) + # PE Information Operations @core.extern - def my_pe(_builder=None): # type: ignore[no-untyped-def] + def my_pe(_semantic=None): # type: ignore[no-untyped-def] + """ + Get the PE number of the calling PE. + + This function returns the unique identifier (PE number) of the current + processing element within the NVSHMEM job. PE numbers range from 0 to + nvshmem_n_pes() - 1. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: PE number of the calling PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - This is a pure function that returns the same value throughout execution. + - PE numbering starts from 0 and is contiguous. + - Each PE has a unique identifier within the NVSHMEM job. + - Can be called from both host and device code. + - Essential for implementing PE-specific logic and communication patterns. + + Example: + ``` + # Get current PE number for conditional logic + pe = nvshmem.my_pe() + if pe == 0: + # Root PE logic + pass + else: + # Non-root PE logic + pass + ``` + """ return core.extern_elementwise( "", "", [], {(): ("nvshmem_my_pe", core.dtype("int32"))}, is_pure=True, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def n_pes(_builder=None): # type: ignore[no-untyped-def] + def n_pes(_semantic=None): # type: ignore[no-untyped-def] + """ + Get the total number of PEs in the NVSHMEM job. + + This function returns the total count of processing elements (PEs) + participating in the current NVSHMEM job. This value remains constant + throughout the execution of the program. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Total number of PEs in the job (always ≄ 1). + + Notes: + - This is a pure function that returns the same value throughout execution. + - The value is determined at NVSHMEM initialization and never changes. + - Valid PE numbers range from 0 to n_pes() - 1. + - Can be called from both host and device code. + - Essential for implementing collective operations and communication patterns. + + Example: + ``` + # Broadcast from root to all other PEs + total_pes = nvshmem.n_pes() + my_rank = nvshmem.my_pe() + + if my_rank == 0: + # Send to all other PEs + for peer in range(1, total_pes): + nvshmem.putmem_block(dst_ptr, src_ptr, size, peer) + ``` + """ return core.extern_elementwise( "", "", [], {(): ("nvshmem_n_pes", core.dtype("int32"))}, is_pure=True, - _builder=_builder, + _semantic=_semantic, ) + # Synchronization Operations @core.extern - def barrier_all(_builder=None): # type: ignore[no-untyped-def] + def barrier_all(_semantic=None): # type: ignore[no-untyped-def] + """ + Synchronize all PEs with completion guarantee. + + This function creates a barrier across all PEs in the NVSHMEM job. It ensures + that all local and remote memory updates issued before the barrier by any PE + are completed before any PE exits the barrier. This provides both + synchronization and memory consistency. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Stronger guarantee than sync_all() - ensures completion of remote operations. + - Blocks until all PEs reach the barrier AND all memory operations complete. + - Must be called from kernels launched with cooperative launch. + - Provides full memory consistency across all PEs. + - More expensive than sync_all() due to completion guarantees. + + Memory Consistency Guarantees: + - All memory updates before barrier_all() are visible to all PEs + - All remote memory operations are completed before any PE continues + - Provides a global synchronization point with memory ordering + + Example: + ``` + # Ensure all PEs complete their work before proceeding + # All PEs execute this - it's a collective operation + nvshmem.barrier_all() + # At this point, all previous operations are complete on all PEs + ``` + """ return core.extern_elementwise( "", "", [], {(): ("nvshmem_barrier_all", core.dtype("int32"))}, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def sync_all(_builder=None): # type: ignore[no-untyped-def] + def sync_all(_semantic=None): # type: ignore[no-untyped-def] + """ + Synchronize all PEs with local completion guarantee. + + This function creates a lightweight synchronization barrier across all PEs. + It ensures that all local store operations issued before the sync are + visible to other PEs, but does not guarantee completion of remote memory + operations initiated by the calling PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Lighter weight than barrier_all() - only ensures local store visibility. + - Does not guarantee completion of remote memory updates initiated locally. + - Must be called from kernels launched with cooperative launch. + - Suitable when only synchronization (not completion) is needed. + - More efficient than barrier_all() for synchronization-only patterns. + + Memory Consistency Guarantees: + - Local store operations are visible to other PEs + - Does NOT ensure completion of outgoing remote operations + - Provides synchronization point without full completion overhead + + Example: + ``` + # Lightweight synchronization between PEs + # All PEs execute this - it's a collective operation + nvshmem.sync_all() + # Local stores are visible, but remote ops may still be in flight + ``` + """ return core.extern_elementwise( "", "", [], {(): ("nvshmem_sync_all", core.dtype("int32"))}, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) - @core.extern - def alltoall(team, dest, source, nelems, _builder=None): # type: ignore[no-untyped-def] - """Perform alltoall operation on NVSHMEM symmetric memory""" + # Collective Operations (mem-based APIs - sizes in bytes) + @triton.jit # type: ignore[misc] + def alltoall(team, dest, source, nelems_per_pe): # type: ignore[no-untyped-def] + """ + All-to-all tensor exchange between PEs in a team. + + This high-level function provides a tensor-aware interface for NVSHMEM alltoall + operations. Each PE sends nelems_per_pe elements to every other PE and receives + the same amount from every other PE. + + Args: + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor. Must be large enough for nelems_per_pe * n_pes elements. + source: Source tensor containing data for all PEs. Must contain nelems_per_pe * n_pes elements. + nelems_per_pe: Number of elements to exchange with each PE. + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a collective operation - all PEs in the team must participate. + - Data layout: source=[data_for_pe0, data_for_pe1, ...], dest=[data_from_pe0, data_from_pe1, ...] + + Example: + ``` + # Each PE exchanges 10 elements with every other PE + nvshmem.alltoall(0, dest_tensor, src_tensor, 10) + ``` + """ + tl.static_assert(dest.type == source.type) + size_bytes_per_pe = nelems_per_pe * dest.type.element_ty.itemsize + return alltoallmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), size_bytes_per_pe + ) + + @core.extern # type: ignore[misc] + def alltoallmem_block_extern_wrapper( + team: Any, dest: Any, source: Any, size_bytes: Any, _semantic: Any = None + ) -> None: + """Low-level extern wrapper for NVSHMEM alltoall""" return core.extern_elementwise( "", "", - [team, dest, source, nelems], + [team, dest, source, size_bytes], { ( core.dtype("int64"), # team handle core.dtype("int64"), # dest ptr core.dtype("int64"), # source ptr - core.dtype("int64"), # nelems - ): ("nvshmem_longlong_alltoall", core.dtype("int32")) + core.dtype("int64"), # size in bytes + ): ("nvshmemx_alltoallmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) - @core.extern - def broadcast(team, dest, source, nelems, pe_root, _builder=None): # type: ignore[no-untyped-def] - """Broadcasts data from a root PE to all other PEs in a team""" + @triton.jit # type: ignore[misc] + def broadcast(team, dest, source, nelems, pe_root): # type: ignore[no-untyped-def] + """ + Broadcast tensor data from a root PE to all other PEs in a team. + + This high-level function provides a tensor-aware interface for NVSHMEM broadcast + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor with type information. All PEs receive data here. + source: Source tensor on the root PE. Type must match dest. + nelems: Number of elements to broadcast. + pe_root: PE number of the root PE that provides the source data. + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a collective operation - all PEs in the team must participate. + - Must be called from kernels launched with cooperative launch. + + Example: + ``` + # Broadcast 100 elements from PE 0 to all PEs + nvshmem.broadcast(0, dest_tensor, src_tensor, 100, 0) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return broadcastmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), nbytes, pe_root + ) + + @core.extern # type: ignore[misc] + def broadcastmem_block_extern_wrapper( + team: Any, + dest: Any, + source: Any, + size_bytes: Any, + pe_root: Any, + _semantic: Any = None, + ) -> None: + """Low-level extern wrapper for NVSHMEM broadcast""" return core.extern_elementwise( "", "", - [team, dest, source, nelems, pe_root], + [team, dest, source, size_bytes, pe_root], { ( core.dtype("int64"), # team handle core.dtype("int64"), # dest ptr core.dtype("int64"), # source ptr - core.dtype("int64"), # nelems + core.dtype("int64"), # size in bytes core.dtype("int64"), # pe_root - ): ("nvshmem_longlong_broadcast", core.dtype("int32")) + ): ("nvshmemx_broadcastmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, + ) + + # Reduction Operation + @triton.jit # type: ignore[misc] + def reduce(team, dest, source, nreduce, operation: tl.constexpr): # type: ignore[no-untyped-def] + """ + Performs a collective reduction on tensors across a team of PEs. + + This high-level function provides a tensor-aware interface for NVSHMEM + reduction operations. It automatically infers the data type from the + input tensors and calls the appropriate underlying NVSHMEM function. + + Args: + team: The team handle for the collective (0 for NVSHMEM_TEAM_WORLD). + dest: Destination tensor for the reduction results. + source: Source tensor containing data to be reduced. Must be the same type as dest. + nreduce: The number of elements in the source tensor to reduce. + operation: The reduction operation to perform ("sum", "max", "min", "prod"). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - This is a collective operation that must be called by all PEs in the team. + - Requires a cooperative grid launch. + + Example: + ``` + # Perform a sum reduction on two tensors + nvshmem.reduce(0, dest_tensor, src_tensor, 100, "sum") + ``` + """ + tl.static_assert(dest.type == source.type) + dtype = dest.type.element_ty + return reduce_extern_wrapper( + team, + dest.to(tl.int64), + source.to(tl.int64), + nreduce, + operation, + dtype, + ) + + @core.extern # type: ignore[misc] + def reduce_extern_wrapper( + team: Any, + dest: Any, + source: Any, + nreduce: Any, + operation: str, + dtype: Any, + _semantic: Any = None, + ) -> None: + """ + Low-level extern wrapper for NVSHMEM reduction operations. + + This function provides a generic interface to NVSHMEM reduction operations, + automatically selecting the appropriate NVSHMEM function based on the data type + and operation specified. + Args: + team (int64): The team handle (0 for NVSHMEM_TEAM_WORLD). + dest (pointer): Destination pointer where reduction results are stored. + source (pointer): Source pointer containing data to be reduced. + nreduce (int64): Number of elements to reduce. + operation (str): Reduction operation ("sum", "max", "min", "prod"). + dtype: Data type specification - accepts torch.dtype, tl.dtype, str, or constexpr. + _semantic: Optional semantic information for Triton compilation. + + Raises: + ValueError: If the operation is not supported. + TypeError: If the data type is not supported. + + Example: + nvshmem.reduce(0, dest_ptr, src_ptr, 100, "sum", torch.float32) + """ + # Mapping from Triton dtype names to NVSHMEM typenames + DTYPE_TO_NVSHMEM_MAP = { + "int8": "int8", + "int16": "int16", + "int32": "int32", + "int64": "int64", + "uint8": "uint8", + "uint16": "uint16", + "uint32": "uint32", + "uint64": "uint64", + "fp16": "half", + "bf16": "bfloat16", + "fp32": "float", + "fp64": "double", + } + + # Triton dtype names are standardized as fp16, bf16, fp32, etc. + dtype_name = str(dtype).replace("tl.", "") + + if dtype_name not in DTYPE_TO_NVSHMEM_MAP: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes: {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Extract operation name from constexpr if needed + op_name = operation.value if hasattr(operation, "value") else operation + + # Validate operation is supported + supported_ops = {"sum", "max", "min", "prod"} + if op_name not in supported_ops: + raise ValueError( + f"Unsupported reduction operation: '{op_name}'. Supported ops are {supported_ops}" + ) + + # Map to NVSHMEM typename and validate dtype is supported + nvshmem_typename = DTYPE_TO_NVSHMEM_MAP.get(dtype_name) + if nvshmem_typename is None: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes are {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Generate NVSHMEM function name + nvshmem_func = f"nvshmem_{nvshmem_typename}_{op_name}_reduce" + + # Define function signature - all parameters are int64 in Triton (they are just ptrs) + signature = ( + core.dtype("int64"), # team handle + core.dtype("int64"), # destination pointer + core.dtype("int64"), # source pointer + core.dtype("int64"), # number of elements + ) + + return core.extern_elementwise( + "", + "", + [team, dest, source, nreduce], + {signature: (nvshmem_func, core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, ) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 2a08212dfa9c..6153d8e03fdf 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -232,13 +232,11 @@ def hook_with_zero_step( ) ddp_ref = weakref.ref(ddp) - # NOTE: Gloo may hang with this overlapping approach, so we require - # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] - if (pg != dist.Backend.NCCL) and (pg != "hccl"): + if pg == dist.Backend.GLOO: raise RuntimeError( - "Overlapping DDP with ZeRO using this approach currently requires " - "NCCL/HCCL backend to avoid hangs" + "Gloo backend using Overlapping DDP with ZeRO may meet hangs" ) if shard_buckets: @@ -394,13 +392,11 @@ def hook_with_zero_step_interleaved( ) ddp_ref = weakref.ref(ddp) - # NOTE: Gloo may hang with this overlapping approach, so we require - # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] - if (pg != dist.Backend.NCCL) and (pg != "hccl"): + if pg == dist.Backend.GLOO: raise RuntimeError( - "Overlapping DDP with ZeRO using this approach currently requires " - "NCCL/HCCL backend to avoid hangs" + "Gloo backend using Overlapping DDP with ZeRO may meet hangs" ) if shard_buckets: diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index dc988e999c4e..a0d205f80821 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -1,33 +1,26 @@ # pyre-strict import concurrent.futures +import glob import json import logging import math import mmap import os -import shutil import struct -import tempfile import time from dataclasses import dataclass, field from typing import Any, Optional -import fsspec # type: ignore[import-untyped] -from fsspec.core import url_to_fs # type: ignore[import-untyped] -from fsspec.implementations.local import LocalFileSystem # type: ignore[import-untyped] - import torch from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, _get_dcp_custom_metadata, - _get_dtype, _get_safetensors_file_metadata, _metadata_fn, DATA_OFFSETS_KEY, DEFAULT_EXTRA_METADATA_KEY, DTYPE_KEY, - FILE_NAME, SAVED_OFFSETS_KEY, SHAPE_KEY, SUFFIX, @@ -100,6 +93,9 @@ def _parse_input_metadata( Raises: ValueError: If no DCP custom metadata is found in a safetensors file """ + + from safetensors.torch import _getdtype # type: ignore[import] + # Dictionary to track the full size of each tensor across all shards fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} @@ -138,14 +134,13 @@ def _parse_input_metadata( if fqn in output_data.fqn_data or len(output_files_data) == 1: output_data.fqn_data[fqn] = _FqnData( shape_in_file=tensor_size, - dtype_size=torch.finfo(_get_dtype(dtype_str)).bits + dtype_size=torch.finfo(_getdtype(dtype_str)).bits // 8, # Convert bits to bytes dtype_str=dtype_str, ) def _write_metadata( - fs: fsspec.AbstractFileSystem, output_files_data: dict[str, _OutputFileData], ) -> None: """ @@ -156,12 +151,11 @@ def _write_metadata( field for each tensor in the output_files_data. Args: - fs: Filesystem interface for file operations output_files_data: Dictionary mapping output file paths to their metadata """ # Process each output file for file_path, output_data in output_files_data.items(): - with fs.open(file_path, "wb") as f: + with open(file_path, "wb") as f: metadata = {} curr_offset = 0 @@ -205,7 +199,6 @@ def _write_metadata( def _read_tensor_data_mmap( - input_fs: fsspec.AbstractFileSystem, file_path: str, start_offset: int, end_offset: int, @@ -215,7 +208,6 @@ def _read_tensor_data_mmap( Read tensor data from a safetensors file using memory mapping for efficiency. Args: - input_fs: Filesystem interface for input file operations file_path: Path to the safetensors file start_offset: Start offset of tensor data within the data section end_offset: End offset of tensor data within the data section @@ -224,24 +216,15 @@ def _read_tensor_data_mmap( Returns: Raw tensor data as bytes """ - # For local files, use mmap for efficient access - if isinstance(input_fs, LocalFileSystem): - # Local file - use mmap - with open(file_path, "rb") as f: - with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: - absolute_start = metadata_size + start_offset - absolute_end = metadata_size + end_offset - return bytes(mm[absolute_start:absolute_end]) - else: - # Remote file - fall back to regular read - with input_fs.open(file_path, "rb") as f: - f.seek(metadata_size + start_offset) - return f.read(end_offset - start_offset) + # Use mmap for efficient access + with open(file_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + absolute_start = metadata_size + start_offset + absolute_end = metadata_size + end_offset + return bytes(mm[absolute_start:absolute_end]) def _process_output_file( - input_fs: fsspec.AbstractFileSystem, - output_fs: fsspec.AbstractFileSystem, output_file: str, output_data: _OutputFileData, input_files_data: dict[str, _InputFileData], @@ -252,8 +235,6 @@ def _process_output_file( This function is designed to be run in parallel for different output files. Args: - input_fs: Filesystem interface for input file operations - output_fs: Filesystem interface for output file operations output_file: Path to the output file output_data: Metadata for the output file input_files_data: Dictionary mapping input file paths to their metadata @@ -275,7 +256,6 @@ def _process_output_file( # Use memory mapping to read tensor data efficiently data_to_write = _read_tensor_data_mmap( - input_fs, safetensors_file, data_offsets[0], data_offsets[1], @@ -291,7 +271,6 @@ def _process_output_file( # Write this tensor shard to the appropriate position in the output file _write_sub_tensor_to_file_optimized( - output_fs, data_to_write, fqn_data.dtype_size, # Size of each element in bytes fqn_data.shape_in_file, # Full tensor shape @@ -304,8 +283,6 @@ def _process_output_file( def _write_data( - input_fs: fsspec.AbstractFileSystem, - output_fs: fsspec.AbstractFileSystem, input_files_data: dict[str, _InputFileData], output_files_data: dict[str, _OutputFileData], num_threads: int = 1, @@ -318,8 +295,6 @@ def _write_data( the work is split across threads with each thread handling a different output file. Args: - input_fs: Filesystem interface for input file operations - output_fs: Filesystem interface for output file operations input_files_data: Dictionary mapping input file paths to their metadata output_files_data: Dictionary mapping output file paths to their metadata num_threads: Number of threads to use for parallel processing @@ -327,9 +302,7 @@ def _write_data( if num_threads <= 1 or len(output_files_data) <= 1: # Sequential processing for output_file, output_data in output_files_data.items(): - _process_output_file( - input_fs, output_fs, output_file, output_data, input_files_data - ) + _process_output_file(output_file, output_data, input_files_data) else: # Parallel processing with ThreadPoolExecutor with concurrent.futures.ThreadPoolExecutor( @@ -340,8 +313,6 @@ def _write_data( futures.append( executor.submit( _process_output_file, - input_fs, - output_fs, output_file, output_data, input_files_data, @@ -358,191 +329,7 @@ def _write_data( raise -def _write_row_wise_tensor( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - full_tensor_strides: list[int], - sub_tensor_strides: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Writes a row-wise sharded tensor to the output file. - - This is an optimized path for tensors that are sharded along the first dimension, - with all other dimensions being complete. This allows writing entire rows at once. - - Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - full_tensor_strides: Strides of the full tensor - sub_tensor_strides: Strides of the sub-tensor - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor - sub_tensor_shape: The shape of the sub-tensor - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file - """ - # Open the output file in read+binary mode to allow seeking and writing - with fs.open(output_file_path, "r+b") as out_f: - # Calculate the number of elements in each row - elements_per_row = full_tensor_strides[ - 0 - ] # This is the stride of the first dimension - - # For each row in the sub-tensor - for row_idx in range(sub_tensor_shape[0]): - # Calculate the row index in the full tensor - full_row_idx = sub_tensor_offsets[0] + row_idx - - # Calculate the position in the full tensor - full_pos = full_row_idx * full_tensor_strides[0] - full_byte_offset = output_start_byte + full_pos * element_size - - # Calculate the position in the sub-tensor - sub_pos = row_idx * sub_tensor_strides[0] - sub_byte_offset = sub_pos * element_size - - # Extract the row data from the sub-tensor - row_size = elements_per_row * element_size - row_data = sub_tensor_bytes[sub_byte_offset : sub_byte_offset + row_size] - - # Seek to the correct position in the output file and write the data - out_f.seek(full_byte_offset) - out_f.write(row_data) - - -def _write_column_wise_tensor( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - tensor_shape: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Writes a column-wise sharded 2D tensor to the output file. - - This is an optimized path for 2D tensors that are sharded along the second dimension, - with the first dimension being complete. This requires writing column by column. - - Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - tensor_shape: The shape of the overall tensor - sub_tensor_strides: Strides of the sub-tensor - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor - sub_tensor_shape: The shape of the sub-tensor - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file - """ - # Open the output file in read+binary mode to allow seeking and writing - with fs.open(output_file_path, "r+b") as out_f: - # For each column in the sub-tensor - for col_idx in range(sub_tensor_shape[1]): - # Calculate the column index in the full tensor - full_col_idx = sub_tensor_offsets[1] + col_idx - - # For each row in the column - for row_idx in range(sub_tensor_shape[0]): - # Calculate the position in the full tensor - full_pos = row_idx * tensor_shape[1] + full_col_idx - full_byte_offset = output_start_byte + full_pos * element_size - - # Calculate the position in the sub-tensor - sub_pos = row_idx * sub_tensor_shape[1] + col_idx - sub_byte_offset = sub_pos * element_size - - # Extract the element data from the sub-tensor - element_data = sub_tensor_bytes[ - sub_byte_offset : sub_byte_offset + element_size - ] - - # Seek to the correct position in the output file and write the data - out_f.seek(full_byte_offset) - out_f.write(element_data) - - -def _write_element_by_element( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - tensor_shape: list[int], - full_tensor_strides: list[int], - sub_tensor_strides: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Writes a sub-tensor to the output file using a general element-by-element approach. - - This is a general approach that works for any sharding pattern, but is less efficient - than the specialized approaches for row-wise or column-wise sharding. - - Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - tensor_shape: The shape of the overall tensor - full_tensor_strides: Strides of the full tensor - sub_tensor_strides: Strides of the sub-tensor - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor - sub_tensor_shape: The shape of the sub-tensor - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file - """ - # Open the output file in read+binary mode to allow seeking and writing - with fs.open(output_file_path, "r+b") as out_f: - # Create a list to hold the current indices for each dimension - indices = [0] * len(tensor_shape) - - # Calculate the total number of elements in the sub-tensor - total_elements = 1 - for dim_size in sub_tensor_shape: - total_elements *= dim_size - - # Process each element in the sub-tensor - for element_idx in range(total_elements): - # Calculate the indices for this element in the sub-tensor - sub_idx = element_idx - for dim in range(len(sub_tensor_shape) - 1, -1, -1): - indices[dim] = sub_idx % sub_tensor_shape[dim] - sub_idx //= sub_tensor_shape[dim] - - # Calculate the position of this element in the sub-tensor's byte array - sub_pos = 0 - for dim in range(len(sub_tensor_shape)): - sub_pos += indices[dim] * sub_tensor_strides[dim] - sub_byte_offset = sub_pos * element_size - - # Calculate the position of this element in the full tensor - full_pos = 0 - for dim in range(len(tensor_shape)): - # The global index is the local index plus the offset for this dimension - global_idx = indices[dim] + sub_tensor_offsets[dim] - full_pos += global_idx * full_tensor_strides[dim] - full_byte_offset = output_start_byte + full_pos * element_size - - # Extract the element data from the sub-tensor - element_data = sub_tensor_bytes[ - sub_byte_offset : sub_byte_offset + element_size - ] - - # Seek to the correct position in the output file and write the data - out_f.seek(full_byte_offset) - out_f.write(element_data) - - def _write_sub_tensor_to_file_optimized( - fs: fsspec.AbstractFileSystem, sub_tensor_bytes: bytes, element_size: int, tensor_shape: list[int], @@ -552,15 +339,16 @@ def _write_sub_tensor_to_file_optimized( output_start_byte: int, ) -> None: """ - Optimized version of _write_sub_tensor_to_file with enhanced sharding pattern detection. + Optimized version that writes the maximum number of contiguous bytes possible. - Uses advanced pattern detection to optimize common sharding patterns: - - Row-wise sharding with memory-efficient bulk copying - - Contiguous chunk detection for direct memory operations - - General fallback for arbitrary patterns + Uses a unified algorithm that calculates the maximum contiguous bytes that can be + written in each iteration and continues until the entire subtensor is written. + Handles all sharding patterns efficiently: + - Full sub-tensor at once for row-wise sharding + - Row-by-row for column-wise sharding + - Optimized chunks for other patterns Args: - fs: Filesystem interface for file operations sub_tensor_bytes: Raw tensor data as bytes element_size: Size of each element in bytes tensor_shape: Shape of the full tensor @@ -573,191 +361,151 @@ def _write_sub_tensor_to_file_optimized( if not tensor_shape or not sub_tensor_shape: return - # Enhanced row-wise sharding detection - if len(tensor_shape) >= 2 and len(sub_tensor_shape) >= 2: - # Check if this is a row-wise chunk (all dims except first are complete) - is_row_wise = all( - sub_tensor_shape[i] == tensor_shape[i] and sub_tensor_offsets[i] == 0 - for i in range(1, len(tensor_shape)) - ) + # Calculate tensor strides for efficient indexing + tensor_strides = [1] + for i in range(len(tensor_shape) - 1, 0, -1): + tensor_strides.insert(0, tensor_strides[0] * tensor_shape[i]) - if is_row_wise: - # Optimized row-wise copy using bulk memory operations - _write_row_wise_tensor_optimized( - fs, - sub_tensor_bytes, - element_size, - tensor_shape, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) - return - - # Fall back to the original implementation for complex patterns - _write_sub_tensor_to_file( - fs, - bytearray(sub_tensor_bytes), - element_size, - tensor_shape, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) + sub_tensor_strides = [1] + for i in range(len(sub_tensor_shape) - 1, 0, -1): + sub_tensor_strides.insert(0, sub_tensor_strides[0] * sub_tensor_shape[i]) + total_elements = math.prod(sub_tensor_shape) -def _write_row_wise_tensor_optimized( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytes, - element_size: int, - tensor_shape: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Optimized row-wise tensor writing using bulk memory operations. + with open(output_file_path, "r+b") as out_f: + elements_written = 0 - This function an optimization strategy: - - Direct memory copy for contiguous rows - - Minimal file seeking operations - - Bulk data transfer instead of element-by-element - """ - with fs.open(output_file_path, "r+b") as out_f: - # Optimized row-wise copy - elements_per_row = math.prod(tensor_shape[1:]) - bytes_per_row = elements_per_row * element_size + while elements_written < total_elements: + # Convert linear index to multi-dimensional indices + temp_idx = elements_written + indices = [] + for dim_size in reversed(sub_tensor_shape): + indices.append(temp_idx % dim_size) + temp_idx //= dim_size + indices.reverse() - start_row = sub_tensor_offsets[0] - num_rows = sub_tensor_shape[0] + # Calculate maximum contiguous elements we can write from this position + max_contiguous = _calculate_max_contiguous_elements( + indices, sub_tensor_shape, tensor_shape + ) - # Calculate byte positions - tensor_start_byte = output_start_byte + start_row * bytes_per_row - chunk_size_bytes = num_rows * bytes_per_row + # Calculate source position in bytes + src_pos = sum( + idx * stride for idx, stride in zip(indices, sub_tensor_strides) + ) + src_byte_offset = src_pos * element_size - # Direct memory copy for contiguous rows - out_f.seek(tensor_start_byte) - out_f.write(sub_tensor_bytes[:chunk_size_bytes]) + # Calculate destination position in bytes + dest_indices = [ + idx + offset for idx, offset in zip(indices, sub_tensor_offsets) + ] + dest_pos = sum( + idx * stride for idx, stride in zip(dest_indices, tensor_strides) + ) + dest_byte_offset = output_start_byte + dest_pos * element_size + # Write the contiguous chunk + bytes_to_write = max_contiguous * element_size + out_f.seek(dest_byte_offset) + chunk_data = sub_tensor_bytes[ + src_byte_offset : src_byte_offset + bytes_to_write + ] + out_f.write(chunk_data) -def _write_sub_tensor_to_file( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - tensor_shape: list[int], - sub_tensor_offsets: list[int], + elements_written += max_contiguous + + +def _calculate_max_contiguous_elements( + indices: list[int], sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: + tensor_shape: list[int], +) -> int: """ - Original implementation - writes a sub-tensor from a byte array into a file representing the full tensor at specified offsets. + Calculate the maximum number of contiguous elements that can be written from current position. - This function handles the complex task of placing a tensor shard (sub-tensor) at the correct - position within the consolidated tensor file. It works by calculating the exact byte offsets - for each slice of data and writing them to the appropriate positions. This implementation - supports tensors of any dimensionality with optimized paths for common sharding patterns: - - Row-wise sharding (optimized path) - - Column-wise sharding for 2D tensors (optimized path) - - Any other arbitrary sharding pattern (general element-by-element approach) + This determines the largest chunk by checking how elements are laid out in memory + and finding natural boundaries where contiguity breaks. Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - tensor_shape: The shape of the overall tensor (list) - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor (list) - sub_tensor_shape: The shape of the sub-tensor (list) - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file + indices: Current position indices in the sub-tensor + sub_tensor_shape: Shape of the sub-tensor being written + tensor_shape: Shape of the full tensor + + Raises: + ValueError: If input lists are empty, have mismatched lengths, or contain invalid values """ - # Handle the case of empty tensors - if not tensor_shape or not sub_tensor_shape: - return + # Validate input lists are not empty + if not indices or not sub_tensor_shape or not tensor_shape: + raise ValueError("Input lists cannot be empty") - # Calculate strides for the full tensor (row-major order, C-style) - # Stride is the number of elements to skip to move to the next element in that dimension - full_tensor_strides = [1] * len(tensor_shape) - for i in range(len(tensor_shape) - 2, -1, -1): - full_tensor_strides[i] = full_tensor_strides[i + 1] * tensor_shape[i + 1] - - # Calculate strides for the sub-tensor (row-major order, C-style) - sub_tensor_strides = [1] * len(sub_tensor_shape) - for i in range(len(sub_tensor_shape) - 2, -1, -1): - sub_tensor_strides[i] = sub_tensor_strides[i + 1] * sub_tensor_shape[i + 1] - - # Check if this is a row-wise sharded tensor - # Row-wise sharding is detected when the last dimension is complete - # and only the first dimension is partial - is_row_wise = False - if len(tensor_shape) >= 2: - # Check if all dimensions except the first are complete - all_other_dims_complete = True - for i in range(1, len(tensor_shape)): - if sub_tensor_shape[i] != tensor_shape[i]: - all_other_dims_complete = False - break - - # Row-wise sharding: first dimension is partial, all others are complete - is_row_wise = all_other_dims_complete and sub_tensor_shape[0] < tensor_shape[0] - - # Check if this is a column-wise sharded 2D tensor - # Column-wise sharding is detected when the first dimension is complete - # and the second dimension is partial (only for 2D tensors) - is_column_wise = False - if len(tensor_shape) == 2: - is_column_wise = ( - sub_tensor_shape[0] == tensor_shape[0] - and sub_tensor_shape[1] < tensor_shape[1] + # Validate all lists have the same length (same number of dimensions) + if not (len(indices) == len(sub_tensor_shape) == len(tensor_shape)): + raise ValueError( + f"All input lists must have the same length. Got indices: {len(indices)}, " + f"sub_tensor_shape: {len(sub_tensor_shape)}, tensor_shape: {len(tensor_shape)}" ) - # Call the appropriate function based on the sharding pattern - if is_row_wise: - _write_row_wise_tensor( - fs, - sub_tensor_bytes, - element_size, - full_tensor_strides, - sub_tensor_strides, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) - elif is_column_wise: - _write_column_wise_tensor( - fs, - sub_tensor_bytes, - element_size, - tensor_shape, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) - else: - _write_element_by_element( - fs, - sub_tensor_bytes, - element_size, - tensor_shape, - full_tensor_strides, - sub_tensor_strides, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) + # Validate indices are within bounds of sub_tensor_shape + for i, (idx, sub_dim) in enumerate(zip(indices, sub_tensor_shape)): + if idx >= sub_dim: + raise ValueError( + f"Index {idx} at dimension {i} is out of bounds for sub-tensor shape {sub_tensor_shape}" + ) + + # Validate sub_tensor dimensions don't exceed tensor dimensions + for i, (sub_dim, tensor_dim) in enumerate(zip(sub_tensor_shape, tensor_shape)): + if sub_dim > tensor_dim: + raise ValueError( + f"Sub-tensor dimension {sub_dim} at position {i} exceeds tensor dimension {tensor_dim}" + ) + + # Start with elements remaining in the last dimension + max_contiguous = sub_tensor_shape[-1] - indices[-1] + + # Check if we can extend across multiple dimensions + # We can write across dimension boundaries if we're writing complete "rows" + # and the layout in destination tensor maintains contiguity + + # For 2D case: check if we can write multiple complete rows + if len(sub_tensor_shape) >= 2: + # If we're at the start of a row and can write complete rows + if indices[-1] == 0: # At start of last dimension (column) + rows_remaining = sub_tensor_shape[-2] - indices[-2] # Rows left to write + + # Check if writing complete rows maintains contiguity in destination + # This is true for row-wise sharding or when sub-tensor spans full width + if sub_tensor_shape[-1] == tensor_shape[-1]: # Full width + max_contiguous = rows_remaining * sub_tensor_shape[-1] + + # For higher dimensions, check if we can extend further + if len(sub_tensor_shape) >= 3 and indices[-2] == 0: + # Check if we can write complete 2D slices + remaining_in_dim = sub_tensor_shape[-3] - indices[-3] + if ( + sub_tensor_shape[-1] == tensor_shape[-1] + and sub_tensor_shape[-2] == tensor_shape[-2] + ): + max_contiguous = ( + remaining_in_dim * sub_tensor_shape[-2] * sub_tensor_shape[-1] + ) + + return max_contiguous def _write_overall_metadata_file( - fs: fsspec.AbstractFileSystem, output_dir: str, output_files_data: dict[str, _OutputFileData], ) -> None: + """ + Write the overall metadata file that maps tensor names to their file locations. + + This creates a model.safetensors.index.json file that HuggingFace models use + to locate tensors across multiple files. + + Args: + output_dir: Directory where the metadata file will be written + output_files_data: Dictionary mapping output file paths to their metadata + """ total_size = 0 weight_map = {} for output_path, value in output_files_data.items(): @@ -770,32 +518,10 @@ def _write_overall_metadata_file( metadata_to_write["weight_map"] = weight_map metadata_path = os.path.join(output_dir, f"{_metadata_fn}") - with fs.open(metadata_path, "w") as metadata_file: + with open(metadata_path, "w") as metadata_file: json.dump(metadata_to_write, metadata_file, indent=2) -def _upload_files_to_remote_fs( - local_fs: fsspec.AbstractFileSystem, - local_dir: str, - output_fs: fsspec.AbstractFileSystem, - output_dir: str, -) -> None: - """ - Uploads the consolidated files to the remote filesystem. - """ - for path in local_fs.ls(local_dir, detail=False): - file = os.path.basename(path) - model_str = FILE_NAME.split("-")[0] - # Upload only the consolidated files with full tensors or the metadata file. - # The check for file.startwith(model_str) is to ensure that we only upload - # the consolidated files in the format "model-0000n-of-0000m.safetensors" - # and not the files with sharded tensors. - if file.endswith(SUFFIX) and file.startswith(model_str) or file == _metadata_fn: - local_path = os.path.join(local_dir, file) - remote_path = os.path.join(output_dir, file) - output_fs.put_file(local_path, remote_path) - - def consolidate_safetensors_files( input_dir: str, output_dir: str, @@ -827,17 +553,6 @@ def consolidate_safetensors_files( output_dir, start_time, ) - # Create filesystem using fsspec for file operations - input_fs, _ = url_to_fs(input_dir) - output_fs, _ = url_to_fs(output_dir) - - if not isinstance(output_fs, LocalFileSystem): - local_output_dir = tempfile.mkdtemp() - logger.info("Created temporary directory %s", local_output_dir) - local_output_fs, _ = url_to_fs(local_output_dir) - else: - local_output_fs = output_fs - local_output_dir = output_dir # Initialize the output file structure output_files_data: dict[str, _OutputFileData] = {} @@ -846,7 +561,7 @@ def consolidate_safetensors_files( for fqn, index in fqn_to_index_mapping.items(): # Generate names like "model-00001-of-00005.safetensors" file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) - output_path = f"{local_output_dir}/{file_name}" + output_path = os.path.join(output_dir, file_name) if output_path not in output_files_data: output_files_data[output_path] = _OutputFileData( @@ -857,19 +572,16 @@ def consolidate_safetensors_files( else: # If no mapping is provided, create a single output file file_name = _gen_file_name(1, 1) - output_path = f"{local_output_dir}/{file_name}" + output_path = os.path.join(output_dir, file_name) output_files_data[output_path] = _OutputFileData() # Find all safetensors files in the input directory - safetensors_files = [] - for file in input_fs.ls(input_dir, detail=False): - if file.endswith(SUFFIX): - safetensors_files.append(file) + safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) # Read metadata from all input files input_files_data: dict[str, _InputFileData] = {} for safetensor_file in safetensors_files: - with input_fs.open(safetensor_file, "rb") as f: + with open(safetensor_file, "rb") as f: metadata, size = _get_safetensors_file_metadata(f) input_files_data[safetensor_file] = _InputFileData( metadata_size=size, metadata=metadata @@ -879,22 +591,12 @@ def consolidate_safetensors_files( _parse_input_metadata(input_files_data, output_files_data) # Step 2: Write metadata headers to output files - _write_metadata(local_output_fs, output_files_data) + _write_metadata(output_files_data) # Step 3: Write actual tensor data from input files to output files - _write_data( - input_fs, local_output_fs, input_files_data, output_files_data, num_threads - ) + _write_data(input_files_data, output_files_data, num_threads) # Step 4: Write overall model.index.safetensors.json file with weight map - _write_overall_metadata_file(local_output_fs, local_output_dir, output_files_data) + _write_overall_metadata_file(output_dir, output_files_data) logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) - - if local_output_dir != output_dir: - logger.info("Copying consolidated files to remote storage %s", output_dir) - _upload_files_to_remote_fs( - local_output_fs, local_output_dir, output_fs, output_dir - ) - shutil.rmtree(local_output_dir) - logger.info("Deleting temporary directory %s", local_output_dir) diff --git a/torch/distributed/checkpoint/_hf_utils.py b/torch/distributed/checkpoint/_hf_utils.py index 1a3f627fd69b..0d14229b7f8c 100644 --- a/torch/distributed/checkpoint/_hf_utils.py +++ b/torch/distributed/checkpoint/_hf_utils.py @@ -51,8 +51,6 @@ class _HFStorageInfo: """This is the per entry storage info.""" relative_path: str - offset: int - length: int shape: torch.Size dtype: torch.dtype diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 13fd61910dd2..542203ed82cf 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -6,23 +6,16 @@ from typing import Any, Optional import torch -from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files, ) -from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, - _get_dtype, - _get_safetensors_file_metadata, _HFStorageInfo, _metadata_fn, CUSTOM_METADATA_KEY, - DATA_OFFSETS_KEY, - DEFAULT_EXTRA_METADATA_KEY, - DTYPE_KEY, SAVED_OFFSETS_KEY, - SHAPE_KEY, SHARDED_DIR_NAME, SUFFIX, ) @@ -52,11 +45,9 @@ __all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] -class HuggingFaceStorageWriter(FsspecWriter): +class HuggingFaceStorageWriter(FileSystemWriter): """ - A writer that writes to a huggingface repository in the huggingface format. - Uses Fsspec back-end to communicate with back-end storage. - Fsspec registration of the storage solution is required. + A writer that writes to storage in the huggingface safetensors format. """ def __init__( @@ -64,26 +55,20 @@ def __init__( path: str, fqn_to_index_mapping: Optional[dict[str, int]] = None, thread_count: int = 1, - token: Optional[str] = None, save_distributed: bool = False, enable_consolidation: bool = False, - consolidated_output_path: Optional[str] = None, thread_count_consolidation: int = 1, ) -> None: """ Initialize the huggingface writer pointing to path. Args: - path: hf directory where the checkpoint will be read from. - Needs to have .safetensors files, but can be from any fsspec supported storage, - including localFS and hf://. - This needs to be a remote path if you want to enable consolidation after saving. + path: directory where the checkpoint will be read from. fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. Indices are from 1 to N, where N is the number of files. If not provided, the tensors will be written to a single file. If none, then all the tensors on the same rank will be written to the same file. thread_count: Number of threads to use to write distributed checkpoint. Default to 1. - token: The token to use to authenticate with huggingface hub. save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. Default is False which assumes rank-0 checkpointing of the full state_dict. enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be @@ -92,19 +77,11 @@ def __init__( to consolidated output files. Default to 1. """ - if token is not None: - super().__init__( - path=path, - token=token, - serialization_format=SerializationFormat.SAFETENSORS, - thread_count=thread_count, - ) - else: - super().__init__( - path=path, - serialization_format=SerializationFormat.SAFETENSORS, - thread_count=thread_count, - ) + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + thread_count=thread_count, + ) self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping self.save_distributed: bool = save_distributed self.enable_consolidation: bool = enable_consolidation @@ -215,30 +192,24 @@ def metadata_path(self) -> str: return _metadata_fn -class HuggingFaceStorageReader(FsspecReader): +class HuggingFaceStorageReader(FileSystemReader): """ - A reader that reads from a huggingface repository in the huggingface format. - Uses in Fsspec back-end to communicate with storage. - Fsspec registration of the storage solution is required. + A reader that reads a checkpoint in the huggingface safetensors format. """ - def __init__(self, path: str, token: Optional[str] = None) -> None: + def __init__(self, path: str) -> None: """ Initialize the huggingface reader pointing to path. Args: - path: hf directory where the checkpoint will be read from. - Needs to have .safetensors file, but can be from any fsspec supported storage, - including localFS and hf://. - token: The token to use to authenticate with huggingface hub. + path: directory where the checkpoint will be read from. """ - if token is not None: - super().__init__(path=path, token=token) - else: - super().__init__(path=path) + super().__init__(path=path) def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import safe_open # type: ignore[import] + per_file: dict[str, list[ReadItem]] = {} for read_item in plan.items: @@ -247,21 +218,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: per_file.setdefault(file_name, []).append(read_item) for file_name, reqs in per_file.items(): - with self.fs.create_stream(file_name, "rb") as stream: + with safe_open(filename=file_name, framework="pt") as f: for req in reqs: item_md = self.storage_data[req.storage_index] - stream.seek(item_md.offset) - tensor_bytes = stream.read(item_md.length) - - tensor = torch.frombuffer( - tensor_bytes, - dtype=item_md.dtype, - ) - tensor = tensor.reshape(item_md.shape) - tensor = narrow_tensor_by_index( - tensor, req.storage_offsets, req.lengths + # Create slices for each dimension based on offsets and lengths + slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) ) + tensor = f.get_slice(req.storage_index.fqn)[slices] target_tensor = planner.resolve_tensor(req).detach() assert target_tensor.size() == tensor.size(), ( @@ -276,6 +242,9 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: return fut def read_metadata(self) -> Metadata: + from safetensors import safe_open # type: ignore[import] + from safetensors.torch import _getdtype # type: ignore[import] + state_dict_metadata: dict[str, TensorStorageMetadata] = {} storage_data: dict[MetadataIndex, _HFStorageInfo] = {} @@ -285,53 +254,47 @@ def read_metadata(self) -> Metadata: safetensors_files.append(file) for safetensor_file in safetensors_files: - with self.fs.create_stream(safetensor_file, "rb") as f: - safetensors_metadata, metadata_size = _get_safetensors_file_metadata(f) - custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) + with safe_open(safetensor_file, framework="pt") as f: + keys = f.keys() + extra_metadata = f.metadata() dcp_sharding_info = None - if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY): + if extra_metadata and extra_metadata.get(CUSTOM_METADATA_KEY): dcp_sharding_info = json.loads( - custom_metadata.get(CUSTOM_METADATA_KEY) + extra_metadata.get(CUSTOM_METADATA_KEY) ) - for key, val in safetensors_metadata.items(): - if key == DEFAULT_EXTRA_METADATA_KEY: - continue - + for key in keys: + shape = f.get_slice(key).get_shape() + dtype = f.get_slice(key).get_dtype() # construct state_dict_metadata if dcp_sharding_info is not None: offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] else: - offset = [0] * len(val[SHAPE_KEY]) + offset = [0] * len(shape) if key not in state_dict_metadata: state_dict_metadata[key] = TensorStorageMetadata( - properties=TensorProperties( - dtype=_get_dtype(val[DTYPE_KEY]) - ), + properties=TensorProperties(dtype=_getdtype(dtype)), size=torch.Size( - [ - saved + offset - for saved, offset in zip(val[SHAPE_KEY], offset) - ] + [saved + offset for saved, offset in zip(shape, offset)] ), chunks=[ ChunkStorageMetadata( offsets=torch.Size(offset), - sizes=torch.Size(val[SHAPE_KEY]), + sizes=torch.Size(shape), ) ], ) else: state_dict_metadata[key].chunks.append( ChunkStorageMetadata( - torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]) + torch.Size(offset), sizes=torch.Size(shape) ) ) size = list(state_dict_metadata[key].size) for i in range(len(size)): - size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i]) + size[i] = max(size[i], shape[i] + offset[i]) state_dict_metadata[key].size = torch.Size(size) # construct storage data @@ -340,15 +303,11 @@ def read_metadata(self) -> Metadata: fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] ) else: - metadata_index = MetadataIndex( - fqn=key, offset=[0] * len(val[SHAPE_KEY]) - ) + metadata_index = MetadataIndex(fqn=key, offset=[0] * len(shape)) storage_data[metadata_index] = _HFStorageInfo( relative_path=safetensor_file, - offset=val[DATA_OFFSETS_KEY][0] + metadata_size, - length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], - shape=torch.Size(val[SHAPE_KEY]), - dtype=_get_dtype(val[DTYPE_KEY]), + shape=torch.Size(shape), + dtype=_getdtype(dtype), ) metadata = Metadata( diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index 9e1031c7fdda..e7acf4975173 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -1,11 +1,17 @@ +import os +import tempfile from concurrent.futures import Future, ThreadPoolExecutor from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Optional, Union +from datetime import timedelta +from typing import Any, cast, Optional, Union from typing_extensions import deprecated, Protocol, runtime_checkable import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._pg_transport import PGTransport from torch.distributed.checkpoint._state_dict_stager import StateDictStager from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE @@ -315,3 +321,146 @@ def synchronize_staging(self) -> None: def close(self) -> None: pass + + +class _ReplicationStager(AsyncStager): + """ + An AsyncStager implementation that replicates state_dict across training ranks + using PGTransport. + + Args: + pg: ProcessGroup for distributed communication + timeout: Timeout for communication operations + device: Device to use for tensor operations + storage_dir: Directory to store persisted state_dicts + + Warning: This is experimental and subject to change. + """ + + _synchronize_after_execute: bool = False + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta = timedelta(minutes=30), + device: torch.device = torch.device("cpu"), + storage_dir: Optional[str] = None, + ): + self._pg = pg + self._timeout = timeout + self._device = device + self._transport = PGTransport(pg, timeout, device, None) + + # Set up storage directory for persisting exchanged state_dicts + if storage_dir is None: + self._storage_dir = tempfile.mkdtemp(prefix="replication_stager_") + else: + self._storage_dir = storage_dir + os.makedirs(self._storage_dir, exist_ok=True) + + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + """ + Stage the state_dict by replicating it across ranks. Returns a state_dict representing + the received replica. + + Perform the actual replication logic. Creates bidirectional pairs where each rank exchanges + state_dict with its partner at (rank + world_size//2) % world_size. + Uses simple rank-based ordering to prevent deadlocks. + + Assumes world_size is always even. + """ + if not dist.is_initialized(): + return state_dict + + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Calculate partner rank using half-world offset + # creates bidirectional pairs for replication. + offset = world_size // 2 + partner_rank = (current_rank + offset) % world_size + + # Use simple rank-based ordering to prevent deadlocks. + # Lower-numbered rank sends first, higher-numbered rank receives first. + if current_rank < partner_rank: + # Send first, then receive + self._transport.send_checkpoint([partner_rank], state_dict) + received_state_dict = self._transport.recv_checkpoint(partner_rank) + else: + # Receive first, then send + received_state_dict = self._transport.recv_checkpoint(partner_rank) + self._transport.send_checkpoint([partner_rank], state_dict) + + # Persist the received state_dict for future discoverability + received_state_dict = cast(STATE_DICT_TYPE, received_state_dict) + self._persist_state_dict(received_state_dict, current_rank, partner_rank) + + return received_state_dict + + def _persist_state_dict( + self, state_dict: STATE_DICT_TYPE, current_rank: int, partner_rank: int + ) -> None: + """ + Persist the received state_dict to disk for future discoverability. + Only keeps one replica per rank, overwriting any previous replica. + Uses atomic write pattern (temp file + rename). + + Args: + state_dict: The state_dict received from partner rank + current_rank: Current rank that received the state_dict + partner_rank: Rank that sent the state_dict + """ + final_path = self._get_persisted_path(current_rank, partner_rank) + temp_path = final_path + ".tmp" + + try: + # Ensure parent directory exists and is writable + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Write to temporary file with explicit flushing + with open(temp_path, "wb") as f: + torch.save(state_dict, f) + # Flush application buffers to OS buffers + f.flush() + # Force OS buffers to disk for durability + os.fsync(f.fileno()) + + # Atomic rename to final location + os.rename(temp_path, final_path) + except Exception as e: + # Clean up temp file if it exists + try: + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception: + pass # Ignore cleanup errors + # Re-raise the original exception with more context + raise RuntimeError( + f"Failed to persist state_dict from rank {partner_rank} to rank {current_rank}: {e}" + ) from e + + def _get_persisted_path(self, current_rank: int, partner_rank: int) -> str: + """ + Get the file path where a state_dict would be persisted. + + Args: + current_rank: Current rank + + Returns: + File path for the persisted state_dict + """ + filename = f"rank_{current_rank}_replica_partner_{partner_rank}.pt" + return os.path.join(self._storage_dir, filename) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + + def close(self) -> None: + """ + Clean up resources. Persisted files are intentionally left for future discovery. + """ diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 85f2fff4f831..e7d1e053fbfd 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -5,8 +5,9 @@ import os import threading import warnings +from collections.abc import Iterator from functools import reduce -from itertools import chain +from itertools import chain, zip_longest from typing import Optional, TYPE_CHECKING, Union import torch @@ -69,7 +70,7 @@ def __init__(self) -> None: self.mesh_stack: list[DeviceMesh] = [] self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: dict[ - int, tuple[str, Optional[C10dBackend.Options]] + int, tuple[Optional[str], Optional[C10dBackend.Options]] ] = {} self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -166,7 +167,13 @@ def create_sub_mesh( return res_submesh def create_flatten_mesh( - self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None + self, + device_mesh: "DeviceMesh", + mesh_dim_name: Optional[str] = None, + backend_override: tuple[Optional[str], Optional[C10dBackend.Options]] = ( + None, + None, + ), ) -> "DeviceMesh": root_mesh = _mesh_resources.get_root_mesh(device_mesh) @@ -217,6 +224,7 @@ def create_flatten_mesh( root_mesh.device_type, mesh_nd, mesh_dim_names=(mesh_dim_name,), + backend_override=(backend_override,), ) if cur_rank in mesh_nd: res_flattened_mesh = flattened_mesh @@ -283,7 +291,7 @@ def get_mesh_dim_by_name( def _set_mesh_dim_group_options( self, dim: int, - backend: str, + backend: Optional[str], pg_options: Optional[C10dBackend.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) @@ -439,6 +447,9 @@ def __init__( mesh: Union[torch.Tensor, "ArrayLike"], *, mesh_dim_names: Optional[tuple[str, ...]] = None, + backend_override: Optional[ + tuple[tuple[Optional[str], Optional[C10dBackend.Options]], ...] + ] = None, _init_backend: bool = True, ) -> None: self.device_type = device_type @@ -450,6 +461,8 @@ def __init__( else torch.tensor(mesh, device="cpu", dtype=torch.int) ) self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None + if backend_override is None: + backend_override = ((None, None),) * self.mesh.ndim # private field to pre-generate DeviceMesh's hash self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) @@ -463,7 +476,7 @@ def __init__( # process (we need to know if the current global rank is in the mesh or not). if _init_backend: self._setup_world_group_and_device() - self._init_process_groups() + self._init_process_groups(backend_override) if is_initialized() and get_backend() == "threaded": self._thread_id = threading.get_ident() @@ -525,7 +538,12 @@ def _setup_world_group_and_device(self): return _get_default_group() - def _init_process_groups(self): + def _init_process_groups( + self, + backend_override: tuple[ + tuple[Optional[str], Optional[C10dBackend.Options]], ... + ], + ): # group_name associated with each mesh dimension, each # mesh dimension should have one sub-group per rank # @@ -535,7 +553,9 @@ def _init_process_groups(self): if ( self.mesh.ndim == 1 and self.mesh.numel() == get_world_size() - and 0 not in _mesh_resources.mesh_dim_group_options + and _mesh_resources.mesh_dim_group_options.get(0, (None, None)) + == (None, None) + and backend_override[0] == (None, None) ): # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. # Otherwise, create new pg. @@ -563,12 +583,17 @@ def _init_process_groups(self): # Respect dim group options specified via _MeshEnv.set_dim_group_options(). # Inherit from the parent group if no options are specified for the group. if dim in _mesh_resources.mesh_dim_group_options: + if backend_override[dim] != (None, None): + raise RuntimeError( + f"Dimension {dim} present both in the backend_override argument " + "and via _mesh_resources._set_mesh_dim_group_options" + ) ( backend, pg_options, ) = _mesh_resources.mesh_dim_group_options[dim] else: - backend, pg_options = None, None + backend, pg_options = backend_override[dim] # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. @@ -591,10 +616,19 @@ def _init_process_groups(self): dim_group = None has_split_group = False if ( - bound_device_id := getattr( - default_group, "bound_device_id", None + ( + bound_device_id := getattr( + default_group, "bound_device_id", None + ) + ) + is not None + and torch.cuda.is_available() + and ( + backend is None + or default_group._get_backend(torch.device("cuda")).name() + == backend ) - ) is not None and torch.cuda.is_available(): + ): dim_group = split_group( parent_pg=default_group, pg_options=pg_options, @@ -968,7 +1002,13 @@ def get_coordinate(self) -> Optional[list[int]]: """ return self._coordinate_on_dim if self._coordinate_on_dim else None - def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": + def _flatten( + self, + mesh_dim_name: Optional[str] = None, + backend_override: Union[ + None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] + ] = None, + ) -> "DeviceMesh": """ Returns a 1D DeviceMesh by flattening the current DeviceMesh. @@ -986,13 +1026,65 @@ def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": "Cannot flatten a DeviceMesh without mesh_dim_names!" ) - return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) + if backend_override is not None: + (backend_override_tuple,) = _normalize_backend_override( + {0: backend_override}, 1 + ) + else: + backend_override_tuple = (None, None) + + return _mesh_resources.create_flatten_mesh( + self, mesh_dim_name, backend_override_tuple + ) + + def _normalize_backend_override( + backend_override: dict[ + Union[int, str], + Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], + ], + ndim: int, + mesh_dim_names: Optional[tuple[str, ...]] = None, + ) -> Iterator[tuple[Optional[str], Optional[C10dBackend.Options]]]: + if mesh_dim_names is None: + mesh_dim_names = () + for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names): + if dim_name is not None and dim_name in backend_override: + if dim_idx in backend_override: + raise RuntimeError( + f"Found redundant dim index {dim_idx} and " + f"name {dim_name} in backend_override" + ) + val = backend_override.pop(dim_name) + elif dim_idx in backend_override: + val = backend_override.pop(dim_idx) + else: + yield (None, None) + continue + + if isinstance(val, str): + yield (val, None) + elif isinstance(val, C10dBackend.Options): + yield (None, val) + else: + yield val + + if backend_override: + raise RuntimeError( + f"Found invalid keys in backend_override: got {list(backend_override.keys())}, " + f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" + ) def init_device_mesh( device_type: str, mesh_shape: tuple[int, ...], *, mesh_dim_names: Optional[tuple[str, ...]] = None, + backend_override: Optional[ + dict[ + Union[int, str], + Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], + ] + ] = None, ) -> DeviceMesh: """ Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. @@ -1017,6 +1109,11 @@ def init_device_mesh( mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` must be unique. + backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of + the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a + dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name + of the backend and its options, or just one of these two components (in which case the other will be + set to its default value). Returns: DeviceMesh: A :class:`DeviceMesh` object representing the device layout. @@ -1043,6 +1140,15 @@ def init_device_mesh( f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", ) + if backend_override is not None: + backend_override_tuple = tuple( + _normalize_backend_override( + backend_override, len(mesh_shape), mesh_dim_names + ) + ) + else: + backend_override_tuple = None + # assume valid device types are all letters if device_type and not device_type.isalpha(): raise RuntimeError( @@ -1058,6 +1164,7 @@ def init_device_mesh( device_type=device_type, mesh=mesh, mesh_dim_names=mesh_dim_names, + backend_override=backend_override_tuple, ) return device_mesh diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 800020d68686..a7ca2453b251 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -4780,6 +4780,11 @@ def barrier( None, if not async_op or if not part of the group .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. + .. note:: `ProcessGroupNCCL` implements barrier as an all_reduce of a 1-element tensor. A device must be chosen + for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to + `device_ids` arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device + that was first used with this process group, if another collective with tensor inputs has been performed, (4) + the device index indicated by the global rank mod local device count. """ group = group or _get_default_group() diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 2759f20bd277..1175da3b91b7 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -27,7 +27,7 @@ from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.numa.binding import NumaOptions +from torch.numa.binding import NumaOptions __all__ = [ @@ -104,13 +104,6 @@ def __post_init__(self): self.entrypoint = self.fn assert self.entrypoint - if ( - self.numa_options is not None - and not self.numa_options.should_fall_back_if_binding_fails - and not isinstance(self.entrypoint, str) - ): - raise ValueError("numa_options is only supported for str entrypoints.") - def get_entrypoint_name(self): """Get the entry point name. diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index d283e0129f0a..7e293ce47cb7 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -80,7 +80,7 @@ def trainer(a, b, c): to_map, ) from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.numa.binding import NumaOptions +from torch.numa.binding import NumaOptions __all__ = [ @@ -227,6 +227,7 @@ def start_processes( log_line_prefixes=log_line_prefixes, start_method=start_method, logs_specs=logs_specs, + numa_options=numa_options, ) try: diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 6cd8d2a12f35..ed3ea86b0f2a 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -37,7 +37,7 @@ SubprocessHandler, ) from torch.distributed.elastic.multiprocessing.tail_log import TailLog -from torch.distributed.numa.binding import maybe_wrap_with_numa_bindings, NumaOptions +from torch.numa.binding import NumaOptions IS_WINDOWS = sys.platform == "win32" @@ -631,6 +631,7 @@ def __init__( start_method: str, logs_specs: LogsSpecs, log_line_prefixes: Optional[dict[int, str]] = None, + numa_options: Optional[NumaOptions] = None, ): super().__init__( name, @@ -655,6 +656,8 @@ def __init__( # successfully. If any process died on event.wait() calling set() method will deadlock. self._worker_finished_event = mp.get_context(self.start_method).Event() + self._numa_options: Optional[NumaOptions] = numa_options + def _start(self): if self._pc: raise ValueError( @@ -676,6 +679,7 @@ def _start(self): join=False, daemon=False, start_method=self.start_method, + numa_options=self._numa_options, ) def _is_done(self) -> bool: @@ -814,10 +818,6 @@ def __init__( log_line_prefixes: Optional[dict[int, str]] = None, numa_options: Optional[NumaOptions] = None, ): - entrypoint, args = maybe_wrap_with_numa_bindings( - entrypoint=entrypoint, local_rank_to_args=args, numa_options=numa_options - ) - super().__init__( name, entrypoint, @@ -831,6 +831,7 @@ def __init__( self._running_local_ranks: set[int] = set(range(self.nprocs)) self._failures: dict[int, ProcessFailure] = {} self.subprocess_handlers: dict[int, SubprocessHandler] = {} + self._numa_options: Optional[NumaOptions] = numa_options def _start(self): if self.subprocess_handlers: @@ -845,6 +846,7 @@ def _start(self): stdout=self.stdouts[local_rank], stderr=self.stderrs[local_rank], local_rank_id=local_rank, + numa_options=self._numa_options, ) for local_rank in range(self.nprocs) } diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index fea707a3c3ab..947ce7b001ef 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -3,10 +3,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( SubprocessHandler, ) +from torch.numa.binding import NumaOptions __all__ = ["get_subprocess_handler"] @@ -19,6 +21,7 @@ def get_subprocess_handler( stdout: str, stderr: str, local_rank_id: int, + numa_options: Optional[NumaOptions] = None, ) -> SubprocessHandler: return SubprocessHandler( entrypoint=entrypoint, @@ -27,4 +30,5 @@ def get_subprocess_handler( stdout=stdout, stderr=stderr, local_rank_id=local_rank_id, + numa_options=numa_options, ) diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index 6b927fcd6a67..c2327e1cd3cf 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -11,6 +11,8 @@ from subprocess import Popen from typing import Any, Optional +from torch.numa.binding import maybe_wrap_command_with_numa_bindings, NumaOptions + __all__ = ["SubprocessHandler"] @@ -39,6 +41,7 @@ def __init__( stdout: Optional[str], stderr: Optional[str], local_rank_id: int, + numa_options: Optional[NumaOptions], ): self._stdout = open(stdout, "w") if stdout else None self._stderr = open(stderr, "w") if stderr else None @@ -47,6 +50,15 @@ def __init__( env_vars.update(env) args_str = (entrypoint, *[str(e) for e in args]) + args_str = ( + maybe_wrap_command_with_numa_bindings( + command_args=args_str, + gpu_index=local_rank_id, + numa_options=numa_options, + ) + or args_str + ) + self.local_rank_id = local_rank_id self.proc: Popen = self._popen(args_str, env_vars) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 121f3d4c1388..554367e8705c 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -32,7 +32,7 @@ HSDPMeshInfo, TrainingState, ) -from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState +from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState logger = logging.getLogger("torch.distributed.fsdp.fully_shard") @@ -166,6 +166,7 @@ def __init__( self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None self._all_gather_comm: AllGather = DefaultAllGather() + self._all_gather_output = torch.empty(0, device=self.device) self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter() # Optional stream to run the user-defined all-reduce hook in # Saved here and not in the comm. context because we allow the user to @@ -310,6 +311,22 @@ def unshard(self, async_op: bool = False): # used in the all-gather streams self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) self._reshard_after_forward_event = None + + world_size = self._all_gather_process_group.size() + if world_size == 1: + # can't skip due to early return in wait_for_unshard if + # no self._all_gather_result + self._all_gather_result = AllGatherResult( + all_gather_output=self._all_gather_output, + all_gather_event=self.device_handle.Event().record(), + all_gather_work=None, + param_all_gather_input_dtypes=[], + param_all_gather_input_numels=[], + all_gather_input_split_sizes=[], + ) + + return + with record_function(self._with_fqn("FSDP::all_gather")): self._all_gather_result = foreach_all_gather( self.fsdp_params, @@ -336,18 +353,52 @@ def wait_for_unshard(self): if prev_all_gather_state := self.comm_ctx.all_gather_state: self._wait_all_gather_streams_on_event(prev_all_gather_state.event) self.comm_ctx.all_gather_state = None # free the all-gather result - with record_function(self._with_fqn("FSDP::all_gather_copy_out")): - foreach_all_gather_copy_out( - self._all_gather_result, - self.fsdp_params, - self._all_gather_process_group, - ) + world_size = self._all_gather_process_group.size() + if world_size == 1: + # directly initialize unsharded parameters from sharded parameters + + for fsdp_param in self.fsdp_params: + # Use all_gather_inputs which already handles conversion to param_dtype + # This is consistent with the world_size > 1 path + all_gather_input = fsdp_param.all_gather_inputs[0] + + # Make sure the all_gather_outputs has proper storage size before using it + # First ensure we have at least one tensor in all_gather_outputs + fsdp_param.init_all_gather_outputs( + [all_gather_input.numel()], + [all_gather_input.dtype], + world_size, + self.device, + force_recreate=False, + ) + + tensor = fsdp_param.all_gather_outputs[0] + alloc_storage(tensor) + + # find alternative way to check if tensor.is_inference + with torch.autograd._unsafe_preserve_version_counter(tensor): + tensor.copy_(all_gather_input) + + else: + with record_function(self._with_fqn("FSDP::all_gather_copy_out")): + foreach_all_gather_copy_out( + self._all_gather_result, + self.fsdp_params, + self._all_gather_process_group, + ) + for fsdp_param in self.fsdp_params: fsdp_param.init_unsharded_param() + self._to_unsharded() all_gather_copy_out_event = self.device_handle.Event() all_gather_copy_out_event.record() - if not async_op and self._training_state == TrainingState.FORWARD: + + if ( + not async_op + and self._training_state == TrainingState.FORWARD + and world_size > 1 + ): # Defer free to allow for overlap of this copy-out with next # all-gather collective self.comm_ctx.all_gather_state = AllGatherState( @@ -355,6 +406,7 @@ def wait_for_unshard(self): ) else: self._wait_all_gather_streams_on_event(all_gather_copy_out_event) + self._all_gather_result = None # free unless saved in `all_gather_state` def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index d788ad568bd5..ef6e75c8dde3 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -26,7 +26,7 @@ from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.numa.binding import NumaOptions +from torch.numa.binding import NumaOptions __all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] @@ -107,7 +107,13 @@ def __post_init__(self): if self.logs_specs is None: self.logs_specs = DefaultLogsSpecs() - if self.numa_options is None and torch.cuda.is_available(): + if ( + self.numa_options is None + # NOTE: This filter isn't relevant for str entrypoints, + # but it's the default anyway. + and self.start_method == "spawn" + and torch.cuda.is_available() + ): self.numa_options = get_default_numa_options() logger.info("Using default numa options = %r", self.numa_options) diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index b39a806fa776..38ba1241c4e5 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -24,7 +24,7 @@ def get_schedule_ops( - schedule: Union[str, _PipelineSchedule], + schedule: Union[str, type[_PipelineSchedule]], pp_degree: int, num_microbatches: int, num_stages_per_rank: Optional[int] = None, @@ -38,7 +38,7 @@ def get_schedule_ops( if isinstance(schedule, str): schedule_class = get_schedule_class(schedule) - elif type(schedule) == _PipelineSchedule: + elif issubclass(schedule, _PipelineSchedule): schedule_class = schedule else: raise ValueError(f"Invalid schedule: {schedule}") @@ -98,6 +98,7 @@ def __init__( _ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"), _ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"), _ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2), + _ComputationType.OVERLAP_F_B: _ComputationTypeColor("purple", "Overlap F+B", 3), } @@ -136,6 +137,15 @@ def visualize_schedule( used_computation.add(action.computation_type) color = comp_type_color.color width = comp_type_color.width + + # Check if action has sub_actions to determine styling + if action.sub_actions is not None: + linewidth = 2 # Thicker border for compound actions + text_weight = "normal" # Bold text for compound actions + else: + linewidth = 1 # Default linewidth for regular actions + text_weight = "normal" # Default text weight + # Draw the rectangle to represent the action duration rect = Rectangle( (draw_position, num_ranks - rank_idx - 1), @@ -143,8 +153,10 @@ def visualize_schedule( 1, facecolor=color, edgecolor="black", + linewidth=linewidth, ) ax.add_patch(rect) + # Draw the text centered within the rectangle ax.text( draw_position + width / 2, @@ -154,8 +166,9 @@ def visualize_schedule( va="center", fontsize=font_size, color="white", + weight=text_weight, ) - # Increment the drawing position by the width of the current action + draw_position += width else: draw_position += 1 # Move to the next diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index 0a4da5c098b3..2f0472211b8c 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates + import logging from dataclasses import dataclass from typing import Union @@ -122,6 +123,32 @@ def generate_stage_to_rank_mapping( return mapping +def generate_rank_to_stage_mapping( + pp_size: int, num_stages: int, style: str = "loop" +) -> dict[int, list[int]]: + """ + Compute the rank to stage id mapping for either a looped or V-style schedule. + + This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank. + + Returns a dictionary mapping rank -> list of stage indices assigned to that rank. + """ + stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style) + + # Invert the mapping: rank -> list of stages + rank_to_stages: dict[int, list[int]] = {} + for stage_id, rank in stage_to_rank.items(): + if rank not in rank_to_stages: + rank_to_stages[rank] = [] + rank_to_stages[rank].append(stage_id) + + # Sort the stage lists for each rank to ensure consistent ordering + for stages in rank_to_stages.values(): + stages.sort() + + return rank_to_stages + + @dataclass class PipeInfo: """ diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 2c0700afab52..1c0f4d27a638 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -50,6 +50,7 @@ class _ComputationType(Enum): SEND_B = 8 RECV_B = 9 FULL_BACKWARD = 10 + OVERLAP_F_B = 11 def __str__(self): str_map = { @@ -63,6 +64,7 @@ def __str__(self): _ComputationType.SEND_B: "SEND_B", _ComputationType.RECV_B: "RECV_B", _ComputationType.FULL_BACKWARD: "B", + _ComputationType.OVERLAP_F_B: "OVERLAP_F_B", } return str_map[self] @@ -88,6 +90,8 @@ def from_str(action): return _ComputationType.RECV_B elif action == "B": return _ComputationType.FULL_BACKWARD + elif action == "OVERLAP_F_B": + return _ComputationType.OVERLAP_F_B else: raise RuntimeError(f"Invalid computation type {action}") @@ -102,6 +106,7 @@ def from_str(action): SEND_B = _ComputationType.SEND_B RECV_B = _ComputationType.RECV_B FULL_BACKWARD = _ComputationType.FULL_BACKWARD +OVERLAP_F_B = _ComputationType.OVERLAP_F_B # Convenience shorthand for compute actions only since they are used in 'simple schedule format' F = FORWARD @@ -119,13 +124,22 @@ class _Action(NamedTuple): stage_index: int computation_type: _ComputationType microbatch_index: Optional[int] = None + sub_actions: Optional[tuple["_Action", ...]] = None + + def __str__(self): + return self.__repr__() def __repr__(self): - repr = str(self.stage_index) - repr += str(self.computation_type) - if self.microbatch_index is not None: - repr += str(self.microbatch_index) - return repr + if self.sub_actions is not None: + # Use recursive repr for sub_actions + sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions] + return f"({';'.join(sub_action_reprs)}){self.computation_type}" + else: + repr_str = str(self.stage_index) + repr_str += str(self.computation_type) + if self.microbatch_index is not None: + repr_str += str(self.microbatch_index) + return repr_str @staticmethod def from_str(action_string: str): @@ -136,6 +150,38 @@ def from_str(action_string: str): e.g. `2F0`, `1UNSHARD`, `3SEND_F1` """ action_string = action_string.strip() + if action_string == "": + return None + + # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType + if action_string.startswith("(") and ")" in action_string: + # Find the closing bracket to separate sub_actions from computation type + bracket_end = action_string.find(")") + sub_part = action_string[ + 1:bracket_end + ] # Remove '[' and get content before ']' + computation_type_part = action_string[ + bracket_end + 1 : + ] # Get part after ']' + + # Parse sub_actions + sub_actions = [] + if sub_part.strip(): + for sub_str in sub_part.split(";"): + sub_action = _Action.from_str(sub_str.strip()) + if sub_action is not None: + sub_actions.append(sub_action) + + # For sub_actions format, we create an action with just the computation type + # The stage_index and microbatch_index are not meaningful for the container action + return _Action( + stage_index=-1, # Placeholder, not meaningful for sub_actions container + computation_type=_ComputationType.from_str(computation_type_part), + microbatch_index=None, + sub_actions=tuple(sub_actions) if sub_actions else None, + ) + + # Handle regular single action format if match := _action_regex.match(action_string): stage_index, computation_type, microbatch_index = match.groups() return _Action( @@ -508,6 +554,13 @@ def __init__( ) def _initialize_stage(self, args, kwargs): + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + all_ops.extend(self._stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) if self._has_backward: self._stage._prepare_backward_infra(self._n_microbatches) @@ -963,7 +1016,7 @@ def _add_unshard_reshard( compute_actions: list[Optional[_Action]], max_active_stages: int = 3, ) -> list[_Action]: - """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. + """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP. UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. RESHARD does the opposite, releasing memory (but doing no communication) @@ -983,11 +1036,24 @@ def next_stage_indices( ret: list[int] = [] for a in next_actions: - if a is not None and a.stage_index not in seen: - seen.add(a.stage_index) - ret.append(a.stage_index) - if len(ret) == count: - break + if a is not None: + # Handle OVERLAP_F_B actions by checking their sub_actions + if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: + for sub_action in a.sub_actions: + if sub_action.stage_index not in seen: + seen.add(sub_action.stage_index) + ret.append(sub_action.stage_index) + if len(ret) == count: + break + if len(ret) == count: + break + else: + # Regular action + if a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break return ret active_stages: set[int] = set() @@ -1044,10 +1110,13 @@ def _merge_bw( if action is None: continue - while len(compute_actions) and (next_action := compute_actions[0]) is None: - # remove any None actions between 'action' and 'next_action' + # Remove any None actions and find the next non-None action + while len(compute_actions) and compute_actions[0] is None: compute_actions.pop(0) + # Get the next action if it exists + next_action = compute_actions[0] if len(compute_actions) > 0 else None + if ( action.computation_type == BACKWARD_INPUT and next_action is not None @@ -1069,6 +1138,9 @@ def _add_send_recv( stage_to_rank: Callable[[int], int], num_stages: int, ) -> dict[int, list[_Action]]: + """ + Transforms a compute-only schedule into a complete schedule with communication actions. + """ comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions} prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions} @@ -1137,6 +1209,19 @@ def _ready_to_schedule( else: return True + # TODO: For now we are splitting OVERLAP_F_B into replacing it to + # its forward and backward components + # We need to figure out how to do the communication + for rank in compute_actions: + new_actions: list[_Action] = [] + for action in compute_actions[rank]: + if action is not None and action.sub_actions is not None: + # Replace OVERLAP_F_B action with its sub_actions + new_actions.extend(action.sub_actions) + else: + new_actions.append(action) + compute_actions[rank] = new_actions + while compute_actions: progress = False # go in order of ranks even if dict keys aren't ordered @@ -1193,40 +1278,82 @@ def _validate_schedule( for stage_id in range(num_stages) } stage_index_to_rank_mapping = {} + + def _process_action(action: _Action, rank: int, step: int): + """Process a single action and update stage_actions and stage_index_to_rank_mapping""" + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + if mb_id not in stage_actions[s_id][F]: + error_msg = ( + f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, " + f"microbatch {mb_id} without first running Forward" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][B].add(mb_id) + elif ctype == I: + if mb_id not in stage_actions[s_id][F]: + error_msg = ( + f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, " + f"microbatch {mb_id} without first running Forward" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][I].add(mb_id) + elif ctype == W: + if mb_id not in stage_actions[s_id][I]: + error_msg = ( + f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, " + f"microbatch {mb_id} without first running Backward Input" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][W].add(mb_id) + + if s_id not in stage_index_to_rank_mapping: + stage_index_to_rank_mapping[s_id] = rank + else: + existing_rank = stage_index_to_rank_mapping[s_id] + assert rank == existing_rank, ( + f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + ) + for rank in actions: - for action in actions[rank]: + for step, action in enumerate(actions[rank]): if action is None: continue assert isinstance(action, _Action), ( - f"Got an invalid action: {action}, expected instance of _Action" + f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action" ) - s_id = action.stage_index - ctype = action.computation_type - mb_id = action.microbatch_index - if ctype == F: - stage_actions[s_id][F].add(mb_id) - elif ctype == B: - assert mb_id in stage_actions[s_id][F], ( - f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" - ) - stage_actions[s_id][B].add(mb_id) - elif ctype == I: - assert mb_id in stage_actions[s_id][F], ( - f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" - ) - stage_actions[s_id][I].add(mb_id) - elif ctype == W: - assert mb_id in stage_actions[s_id][I], ( - f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input" - ) - stage_actions[s_id][W].add(mb_id) - if s_id not in stage_index_to_rank_mapping: - stage_index_to_rank_mapping[s_id] = rank + + # Check if action has sub_actions + if action.sub_actions is not None: + # Process each sub_action instead of the main action + for sub_action in action.sub_actions: + _process_action(sub_action, rank, step) else: - existing_rank = stage_index_to_rank_mapping[s_id] - assert rank == existing_rank, ( - f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" - ) + # Process the main action normally + _process_action(action, rank, step) for s_id in stage_actions: f_mb = len(stage_actions[s_id][F]) @@ -1238,6 +1365,11 @@ def _validate_schedule( f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" ) + assert i_mb == w_mb, ( + f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \ + but got I={i_mb}, W={w_mb}" + ) + assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, ( f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ but got B={b_mb}, I={i_mb}, W={w_mb}" @@ -1303,6 +1435,14 @@ def __init__( ) def _initialize_stages(self, args: tuple[Any, ...], kwargs): + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + for stage in self._stages: + all_ops.extend(stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) # or real value (if this stage and next stage are on the same device) next_stage_args: tuple[Any, ...] = tuple() @@ -2403,7 +2543,7 @@ def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): if actions[rank][timestamp] is not None: temp_action = actions[rank][timestamp] assert temp_action is not None - stage_index, op, microbatch = temp_action + stage_index, op, microbatch, _ = temp_action if not need_bubble( stage_index, op, microbatch, num_stages_global, seen_ops ): diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index e4de0ddd03ab..c1abebde5b85 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -935,6 +935,60 @@ def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs ) + def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: + """ + Get the operations to initialize the p2p communicators between previous and next stages. + This is done so by creating a dummy tensor and sending it to the next stage and receiving + from the previous stage. + """ + ops: list[dist.P2POp] = [] + next_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index + 1) + prev_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index - 1) + + recv_tensor = torch.zeros(1, device=self.device) + send_tensor = torch.tensor(self.stage_index, device=self.device) + # forward + if not self.is_first: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + # backward + if not self.is_first: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + return ops + class _PipelineStage(_PipelineStageBase): def __init__( diff --git a/torch/distributed/run.py b/torch/distributed/run.py index c37ecd8f72d8..2738191f0e37 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -382,7 +382,7 @@ def main(): from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.launcher.api import elastic_launch, LaunchConfig -from torch.distributed.numa.binding import ( +from torch.numa.binding import ( AffinityMode as _AffinityMode, # Signify as private with _ NumaOptions as _NumaOptions, ) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 346e2966b15b..b562153ad507 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -23,6 +23,7 @@ ) from torch.distributed.tensor._utils import try_find_mesh_from_args from torch.distributed.tensor.placement_types import Partial, Placement, Replicate +from torch.utils._python_dispatch import return_and_correct_aliasing try: @@ -138,7 +139,6 @@ def dispatch( (2) registered sharding strategy, then rule (3) composite implicit autograd decomposition """ - if op_call in self._custom_op_handlers: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] @@ -165,7 +165,8 @@ def dispatch( assert output_sharding is not None, "output sharding should not be None" mesh = op_info.compute_mesh - if mesh.get_coordinate() is not None: + participating = mesh.get_coordinate() is not None + if participating: # computation that happens in the current rank of the mesh, normal case if output_sharding.needs_redistribute: # If sharding propagation decision needs redistribute, perform redistribute @@ -197,8 +198,19 @@ def dispatch( cast(dtensor.DTensor, args[0]), cast(torch.Tensor, local_tensor_args[0]), ) + + # If the user provided a generator, we hook it up to our RNG manager, but we also pop it from kwargs + # so the op_call does not directly use it (we want op_call to fall back to the 'default' which is + # our RNG manager) + maybe_user_generator = op_info.local_kwargs.pop("generator", None) + assert maybe_user_generator is None or isinstance( + maybe_user_generator, torch.Generator + ) + # maybe_user_generator = None rng_context = ( - random._rng_tracker._distribute_region(first_arg._spec) + random._rng_tracker._distribute_region( + first_arg._spec, generator=maybe_user_generator + ) if random._rng_tracker and not first_local_arg.is_meta else contextlib.nullcontext() ) @@ -289,7 +301,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: assert len(out_dts) >= 1, "out variant should have at least one out arg" return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] else: - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + if participating and op_info.schema.is_view_op(): + return return_and_correct_aliasing(op_call, args, kwargs, ret) + else: + return ret @staticmethod def redistribute_local_args( diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index eb528ee4f9af..bffb399b2bca 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -40,6 +40,16 @@ def __setattr__(self, attr: str, value: Any) -> None: # change (though we do not expect `mesh` or `placements` to change) if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): self._hash = None + # This assert was triggered by buggy handling for dict outputs in some + # FX passes, where you accidentally iterate over a dict and try to put + # keys into TensorMeta. See https://github.com/pytorch/pytorch/issues/157919 + if attr == "tensor_meta" and value is not None: + from torch.fx.passes.shape_prop import TensorMetadata + + # TODO: the TensorMetadata arises from + # test/distributed/tensor/experimental/test_tp_transform.py::TensorParallelTest::test_tp_transform_e2e + # but I actually can't reproduce it, maybe it is also a bug! + assert isinstance(value, (TensorMeta, TensorMetadata)), value def _hash_impl(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index d19049269251..b60373ea6f83 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -359,6 +359,15 @@ def __str__(self) -> str: args_schema.append(str(arg)) return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})" + def __post_init__(self) -> None: + has_symints = False + for a in self.args_schema: + if isinstance(a, DTensorSpec) and a.tensor_meta is not None: + if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape): + has_symints = True + break + self.has_symints = has_symints + def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: arg = self.args_schema[arg_idx] is_tensor = isinstance(arg, DTensorSpec) @@ -441,6 +450,12 @@ def is_out_variant_op(self) -> bool: # be entirely correct, but it's good enough for now. return "out" in self.op._schema.overload_name + def is_view_op(self) -> bool: + return any( + a.alias_info is not None and not a.alias_info.is_write + for a in self.op._schema.arguments + ) + def __hash__(self) -> int: # Only hash args and kwargs that op indicates to hash if not self.schema_info: diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 03057e8c4c5b..1e6eb40939e4 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -818,27 +818,38 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: return grad_in_strategy -@register_op_strategy( - [aten.native_layer_norm.default], - schema_info=RuntimeSchemaInfo(1), -) -def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: +def _common_norm_forward_strategy( + op_schema: OpSchema, + rms_norm: bool = False, +) -> OpStrategy: + """Common forward strategy logic for layer_norm and rms_norm.""" mesh = op_schema.get_mesh_from_args() - # args must be: input, normalized_shape, weight, bias, eps - # for None weight and bias, their corresponding objects will - # be None as well. layer_norm_strategy returns one OpStrategy - # for the triple return values (out, mean, rstd). - assert len(op_schema.args_schema) == 5 - ( - input_strategy, - normalized_shape, - weight_strategy, - bias_strategy, - _, - ) = op_schema.args_schema + if not rms_norm: + # layer_norm args: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + else: + # rms_norm args: input, normalized_shape, weight, eps + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + normalized_shape, + weight_strategy, + _, + ) = op_schema.args_schema + bias_strategy = None - # the current layer norm implementation requires that all + # the current norm implementation requires that all # input DTensor's sharding must be in form of OpStrategy assert isinstance(input_strategy, OpStrategy) assert isinstance(normalized_shape, (int, Sequence, torch.Size)) @@ -847,7 +858,7 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: input_ndim = input_strategy.ndim axis = input_ndim - len(normalized_size) - # we use OpStrategy because the output (out, mean, rstd) + # we use OpStrategy because the output values (out, mean, rstd) # should have the same placements output_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): @@ -915,6 +926,22 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: return output_strategy +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_forward_strategy(op_schema) + + +@register_op_strategy( + [aten._fused_rms_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_forward_strategy(op_schema, rms_norm=True) + + def _common_norm_backward_strategy( op_schema: OpSchema, rms_norm: bool = False, @@ -1114,34 +1141,63 @@ def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: return _common_norm_backward_strategy(op_schema, rms_norm=True) +def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + sort_dim = normalize_dim(sort_dim, input_strategy.ndim) + single_mesh_dim_strategies = [] + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + for dim in range(input_strategy.ndim): + if dim != sort_dim: + dim_shardings: PlacementList = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) + + @register_op_strategy( [aten.topk.default], schema_info=RuntimeSchemaInfo(2), ) def topk_strategy(op_schema: OpSchema) -> OpStrategy: - input_strategy = cast(OpStrategy, op_schema.args_schema[0]) topk_dim = ( cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 ) - topk_dim = normalize_dim(topk_dim, input_strategy.ndim) + return sort_strategy(op_schema, topk_dim) - single_mesh_dim_strategies = [] - # two outputs (values, indices), 1 input - # replicate always works - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) +@register_op_strategy( + aten.sort.default, + schema_info=RuntimeSchemaInfo( + 1, + ), +) +def sort_default_strategy(op_schema: OpSchema) -> OpStrategy: + # mostly copy paste from topk_strategy + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + sort_dim = -1 + if len(op_schema.args_schema) > 1: + sort_dim = cast(int, op_schema.args_schema[1]) + return sort_strategy(op_schema, sort_dim) - # every dim except topk dim should work - for dim in range(input_strategy.ndim): - if dim != topk_dim: - dim_shardings: PlacementList = [Shard(dim)] * 3 - single_mesh_dim_strategies.append(dim_shardings) - # TODO: topk on sharded dim requires non-trival reduction, address it later - return expand_to_full_mesh_op_strategy( - input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 - ) +@register_op_strategy( + aten.sort.stable, + schema_info=RuntimeSchemaInfo( + 1, + static_kwargkey=["dim", "descending", "stable"], + ), +) +def sort_stable_strategy(op_schema: OpSchema) -> OpStrategy: + # mostly copy paste from topk_strategy + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + sort_dim = -1 + if "dim" in op_schema.kwargs_schema: + sort_dim = cast(int, op_schema.kwargs_schema["dim"]) + return sort_strategy(op_schema, sort_dim) @register_op_strategy( diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index c2242a2cb93b..a5a037a3c73e 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -27,7 +27,6 @@ normalize_dim, register_op_strategy, register_prop_rule, - replicate_op_strategy, ) from torch.distributed.tensor.placement_types import ( Partial, @@ -565,20 +564,12 @@ def replica_only_strategy(op_schema: OpSchema) -> StrategyType: return OpStrategy([OpSpec(replicate_spec)]) -@register_op_strategy( - [aten.sort.stable, aten.sort.default], schema_info=RuntimeSchemaInfo(1) -) -def sort_strategy(op_schema: OpSchema): - return cast(TupleStrategy, replicate_op_strategy(op_schema)) - - @register_op_strategy( [ aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src, - aten.scatter_add.default, ], schema_info=RuntimeSchemaInfo(1), ) @@ -605,11 +596,44 @@ def scatter_strategy(op_schema: OpSchema) -> StrategyType: return op_strategy -@register_op_strategy(aten.gather.default) +@register_op_strategy(aten.scatter_add.default, schema_info=RuntimeSchemaInfo(1)) +def scatter_add_strategy(op_schema: OpSchema) -> StrategyType: + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(index_strategy, OpStrategy) + assert isinstance(dim, int) + dim = normalize_dim(dim, input_strategy.ndim) + mesh = input_strategy.mesh + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim and input_shape[d] == index_shape[d]: + sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +@register_op_strategy(aten.gather.default, schema_info=RuntimeSchemaInfo(1)) def gather_strategy(op_schema: OpSchema) -> StrategyType: mesh = op_schema.get_mesh_from_args() input_strategy = cast(OpStrategy, op_schema.args_schema[0]) dim = cast(int, op_schema.args_schema[1]) + dim = normalize_dim(dim, input_strategy.ndim) index_strategy = cast(OpStrategy, op_schema.args_schema[2]) input_shape = input_strategy.shape @@ -625,7 +649,7 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType: # input sharding, input sharded, index accepts mask partial, output follows index # this only works when the input is sharded on the gather dimension, and # index has size 1 on the gather dimension - if index_shape[dim] == 1: + if dim < len(index_shape) and index_shape[dim] == 1: index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) input_sharding: PlacementList = [ index_partial_placement, @@ -639,6 +663,12 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType: index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] single_mesh_dim_strategies.append(index_sharding) + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim: + sharding: PlacementList = [Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=1 ) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index c942da67cd8a..1f0906b0beff 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -22,7 +22,12 @@ prod, register_op_strategy, ) -from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Placement, + Replicate, + Shard, +) aten = torch.ops.aten @@ -605,8 +610,30 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: ) for mesh_dim, p in enumerate(input_src_placements) ] + + def _rewrite_shard_dim(p: Shard): + """ + Rewrite the shard dim to the corresponding tensor dim in output. + For ``_StridedShard``, we can safely keep the placement type and + ``split_factor`` unchanged and only rewrite the ``dim`` because: + 1. ``_StridedShard`` has no impact on sharding (i.e. how + tensor is partitioned) compared to ``Shard``. It only changes + how shards permute across the devices. + 2. ``view()`` op on DTensor strictly forbids shard redistribution + which means if ``view()`` may cause shard permutation across + devices, it should be rejected. This is enforced in today's + sharding prop for ``view()``. + 3. Since DTensor ``view()`` won't introduce any redistribution, + it's certain that ``placements`` won't change except the + inner ``dim`` attribute of ``Shard`` or ``_StridedShard``. + """ + if isinstance(p, _StridedShard): + return _StridedShard(shard_dim_map[p.dim], split_factor=p.split_factor) + else: + return Shard(shard_dim_map[p.dim]) + output_placements = [ - Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + _rewrite_shard_dim(p) if isinstance(p, Shard) else p for p in input_tgt_placements ] diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 082805db7fde..70ea7e9ce97a 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -146,7 +146,9 @@ def set_seed(self, name: str, seed: int) -> None: ) self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) - def _distribute_region(self, spec: DTensorSpec): + def _distribute_region( + self, spec: DTensorSpec, generator: Optional[torch.Generator] = None + ): pass def _manual_seed(self, parallel_seed: int) -> None: @@ -191,7 +193,17 @@ def _manual_seed(self, parallel_seed: int) -> None: self.set_seed("parallel-rng", parallel_seed) @contextlib.contextmanager - def _distribute_region(self, spec: DTensorSpec): + def _distribute_region( + self, spec: DTensorSpec, generator: Optional[torch.Generator] = None + ): + g_name = "parallel-rng" + if generator is not None: + # This is a little hacky, but for any user-passed generator, we store its state under a unique key, + # not because we need to keep a copy of it but because its the easiest way to make it work with the + # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. + g_name = "user-passed-generator" + assert g_name not in self.rng_states + self.rng_states[g_name] = generator.get_state() # check if the parallel rng state has been synchronized or not if not self.rng_state_is_sync("parallel-rng"): raise RuntimeError( @@ -202,23 +214,29 @@ def _distribute_region(self, spec: DTensorSpec): if self.distribute_region_enabled: if self._device.type == "hpu": self._device_handle.set_rng_ctx("philox") - old_offset = self.get_offset("parallel-rng") - self._set_pre_op_offset(spec) + old_offset = self.get_offset(g_name) + self._set_pre_op_offset(g_name, spec) with torch.random.fork_rng( devices=[self._device], device_type=self._device.type ): assert self._device_handle is not None - self._device_handle.set_rng_state(self.rng_states["parallel-rng"]) + self._device_handle.set_rng_state(self.rng_states[g_name]) try: yield # execute the region code finally: # update offset to synchronize among ranks - self._set_post_op_offset(spec, old_offset) + self._set_post_op_offset(g_name, spec, old_offset) if self._device.type == "hpu": self._device_handle.unset_rng_ctx("philox") else: yield + if generator is not None: + # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future + # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates + # the seed value in their rng and uses it with DTensor again, we always use the latest value + generator.set_state(self.rng_states.pop(g_name)) + def get_offset(self, name: str) -> int: if name not in self.rng_states: raise RuntimeError( @@ -240,7 +258,7 @@ def set_offset(self, name: str, offset: int) -> None: ) self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) - def _set_pre_op_offset(self, spec: DTensorSpec) -> None: + def _set_pre_op_offset(self, name: str, spec: DTensorSpec) -> None: """Set the starting RNG offset for current device's local shard before actual op execution. The pre_op_offset value should start from the current RNG offset and increment by the size of local shard until it reaches the size of the whole @@ -248,6 +266,7 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None: will be the same. Args: + name (str): The name of the generator to use (should be a key in self.rng_states) spec (:class:`DTensorSpec`): the spec of the DTensor object on which we prepare the offset for running random ops. @@ -350,20 +369,23 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None: local_size = prod(local_size_on_rank_0) # get current RNG offset - current_offset = self.get_offset("parallel-rng") + current_offset = self.get_offset(name) # pytorch: offset must be multiple of 4 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 - self.set_offset("parallel-rng", current_offset + offset_incr) + self.set_offset(name, current_offset + offset_incr) - def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: + def _set_post_op_offset( + self, name: str, spec: DTensorSpec, old_offset: int + ) -> None: """Sets the RNG to a synchronized state after running the local random op. Every rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor random ops. Args: + name (str): The name of the generator to use (should be a key in self.rng_states) spec (:class:`DTensorSpec`): the spec of the DTensor object on which we post-process the offset for running random ops. @@ -378,7 +400,7 @@ def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: # pytorch: offset must be multiple of 4 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp numel = (numel + 3) // 4 * 4 - self.set_offset("parallel-rng", old_offset + numel) + self.set_offset(name, old_offset + numel) def _calc_shard_linear_idx( self, shard_coord: list[int], shard_size: list[int] diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 3355d730dfa7..1ccb42c47bfe 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -8,7 +8,6 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode -from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpInfo, @@ -302,9 +301,8 @@ def propagate(self, op_info: OpInfo) -> None: # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, - # and compile autograd initial tracing, which do not need to be as fast as - # eager mode DTensor usages. - if _are_we_tracing(): + # and tracing does not need to be as fast as eagermode DTensor usages. + if op_info.schema.has_symints: output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: output_sharding = cast( diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index afc008372c3a..f33a52c495a4 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -34,11 +34,6 @@ __all__ = ["context_parallel", "set_rotate_method"] -compiled_create_block_mask = torch.compile( - create_block_mask, dynamic=False, fullgraph=True -) - - class _CausalBehavior(Enum): SKIP = None NOT_IS_CAUSAL = False @@ -1179,6 +1174,10 @@ def create_cp_block_mask( """ from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE + compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True + ) + def _rewrite_mask_mod( mask_mod: _mask_mod_signature, rank: int, diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 3ed8a6c37883..51f0865f4304 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -1,3 +1,4 @@ +import logging import os import warnings import zipfile @@ -52,6 +53,8 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +log: logging.Logger = logging.getLogger(__name__) + @deprecated( "`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. " @@ -440,7 +443,8 @@ def load( f, expected_opset_version=expected_opset_version, ) - except RuntimeError: + except RuntimeError as e: + log.warning("Ran into the following error when deserializing: %s", e) pt2_contents = PT2ArchiveContents({}, {}, {}) if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0: @@ -450,10 +454,18 @@ def load( return pt2_contents.exported_programs["model"] # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?) - warnings.warn( - "This version of file is deprecated. Please generate a new pt2 saved file." - ) with zipfile.ZipFile(f, "r") as zipf: + if "version" not in zipf.namelist(): + raise RuntimeError( + "We ran into an error when deserializing the saved file. " + "Please check the warnings above for possible errors. " + ) + + log.warning( + "Trying to deserialize for the older format. This version of file is " + "deprecated. Please generate a new pt2 saved file." + ) + # Check the version version = zipf.read("version").decode().split(".") from torch._export.serde.schema import ( diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index e009c03d4f09..1c87bb29bfe9 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -10,10 +10,7 @@ import torch from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file -from torch.export.exported_program import ( - _copy_graph_module_and_signature, - _decompose_exported_program, -) +from torch.export.exported_program import _decompose_exported_program _InputT = typing_extensions.ParamSpec("_InputT") @@ -23,6 +20,28 @@ __all__ = [] # type: ignore[var-annotated] +def _copy_graph_module_and_signature( + ep: torch.export.ExportedProgram, +) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: + # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), + # and this can break placeholder names in some particular cases. + # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. + # So we manually overwrite placeholder names by reading the old graph. + gm = copy.deepcopy(ep.graph_module) + new_graph_signature = copy.deepcopy(ep.graph_signature) + + # iterate over old/new graph modules + for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr] + old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] + new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] + # iterate over placeholders + assert len(old_phs) == len(new_phs) + for old_node, new_node in zip(old_phs, new_phs): + new_node.name = old_node.name + + return gm, new_graph_signature + + def _remove_detach_pass( gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature ) -> None: @@ -341,7 +360,8 @@ def _compiled_and_package( "aot_inductor.package": True, "aot_inductor.package_cpp_only": True, "always_keep_tensor_constants": True, - "aot_inductor.package_constants_in_so": False, + # we'll change this back to False once we enable weight deduping for standalone mode + "aot_inductor.package_constants_in_so": standalone, "aot_inductor.compile_standalone": standalone, } aoti_files_map = {} diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py index b91dfbb0db80..67bda0c34ce4 100644 --- a/torch/export/experimental/_utils.py +++ b/torch/export/experimental/_utils.py @@ -1,9 +1,11 @@ +import logging import typing from torch._inductor.utils import IndentedBuffer __all__ = [] # type: ignore[var-annotated] +logger = logging.getLogger(__name__) def _get_main_cpp_file( @@ -125,8 +127,10 @@ def _get_main_cpp_file( [ f"auto constants_map{i + 1} = std::make_shared();", f"auto constants_array{i + 1} = std::make_shared>();", - f"auto model{i + 1} = AOTInductorModel{model_name}::Create(", - f" constants_map{i + 1}, constants_array{i + 1}, device_str,", + f"auto model{i + 1} = std::make_unique(", + f" std::move(constants_map{i + 1}),", + f" std::move(constants_array{i + 1}),", + " device_str,", f' "{package_name}/data/aotinductor/{model_name}/");', f"model{i + 1}->load_constants();", ] @@ -154,7 +158,10 @@ def _get_main_cpp_file( ib.writeline("\n// Validate outputs") for i in range(len(model_names)): ib.writeline( - f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;""" + f"""std::cout << "output_tensor{i + 1}\\n" << output_tensor{i + 1} << std::endl;""" + ) + ib.writeline( + f"""torch::save(output_tensor{i + 1}, "output_tensor{i + 1}.pt");""" ) ib.writeline("return 0;") @@ -184,9 +191,14 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str "", "set(CMAKE_CXX_STANDARD 17)", "", - "find_package(Torch REQUIRED)", ] ) + + from torch._inductor.config import test_configs + + if test_configs.use_libtorch: + ib.writeline("find_package(Torch REQUIRED)") + if cuda: ib.writeline("find_package(CUDA REQUIRED)") @@ -200,6 +212,7 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str model_libs = " ".join(model_names) ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") + if cuda: ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 4552cc5c59c7..125d2dd9c9bd 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -653,7 +653,10 @@ def update_arg(old_arg, new_ph): shape_env = _get_shape_env(gm) if shape_env is not None: with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + gm, + functools.partial( + _node_metadata_hook, metadata={"stack_trace": stack_trace} + ), ): insert_deferred_runtime_asserts( gm, @@ -1615,17 +1618,6 @@ def _update( verifiers=verifiers if verifiers is not None else self.verifiers, ) - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - setattr(result, k, copy.deepcopy(v, memo)) - - graph_module, graph_signature = _copy_graph_module_and_signature(self) - result = result._update(graph_module, graph_signature) - return result - def _get_shape_env(gm): vals = [ @@ -1643,30 +1635,6 @@ def _get_shape_env(gm): return v.node.shape_env -def _copy_graph_module_and_signature( - ep: "ExportedProgram", -) -> tuple[torch.fx.GraphModule, "ExportGraphSignature"]: - # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), - # and this can break placeholder names in some particular cases. - # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. - # So we manually overwrite placeholder names by reading the old graph. - gm = copy.deepcopy(ep.graph_module) - new_graph_signature = copy.deepcopy(ep.graph_signature) - - # iterate over old/new graph modules - for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr] - old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] - new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] - # iterate over placeholders - assert len(old_phs) == len(new_phs) - for old_node, new_node in zip(old_phs, new_phs): - if new_node.name != old_node.name: - new_node.name = old_node.name - new_node.target = old_node.target - - return gm, new_graph_signature - - def _get_updated_range_constraints( gm: torch.fx.GraphModule, old_range_constraints: "Optional[dict[sympy.Symbol, Any]]" = None, diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index 4e1d21de660d..4238bac5899e 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -52,19 +52,21 @@ def _get_new_device( if isinstance(v, torch.Tensor): ep._constants[k] = v.to(_get_new_device(v.device, location)) - for node in ep.graph.nodes: - # move all the nodes kwargs with burnt-in device - if "device" in node.kwargs: - kwargs = node.kwargs.copy() - kwargs["device"] = _get_new_device(kwargs["device"], location) - node.kwargs = kwargs - # move all the tensor metadata - node.meta["val"] = pytree.tree_map( - lambda v: v.to(_get_new_device(v.device, location)) - if isinstance(v, torch.Tensor) - else v, - node.meta.get("val"), - ) + for m in ep.graph_module.modules(): + if isinstance(m, torch.fx.GraphModule): + for node in m.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) ep.validate() return ep diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 501060e85c6e..323253a1501b 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -12,6 +12,7 @@ import torch import torch.utils._pytree as pytree from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch._inductor.cpp_builder import normalize_path_separator from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs from torch.export.pt2_archive._package_weights import ( @@ -75,6 +76,8 @@ class PT2ArchiveWriter: """ def __init__(self, archive_path_or_buffer: FileLike): + if isinstance(archive_path_or_buffer, str): + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] # NOTICE: version here is different from the archive_version # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version @@ -169,6 +172,8 @@ class PT2ArchiveReader: """ def __init__(self, archive_path_or_buffer: FileLike): + if isinstance(archive_path_or_buffer, str): + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( "Invalid archive format" diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index dfb9b9f8074b..ea575727d918 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -145,7 +145,7 @@ def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: co.co_name, co.co_qualname, # type: ignore[attr-defined] co.co_firstlineno, - co.co_lnotab, + co.co_linetable, co.co_exceptiontable, # type: ignore[attr-defined] co.co_freevars, co.co_cellvars, diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index a578723ea1cb..9f2c40904634 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1959,7 +1959,7 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: # nn_module_stack if node.op not in ["placeholder", "output"]: if "nn_module_stack" not in node.meta: - node.meta["nn_module_stack"] = self.module_stack + node.meta["nn_module_stack"] = self.module_stack.copy() # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): if isinstance(mod_cls, type): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 375019f7dc83..420537ccfd3f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4712,7 +4712,7 @@ def create_symfloatnode( self, sym: sympy.Expr, *, - hint: Optional[int], + hint: Optional[int | float | bool], source: Optional[Source] = None, ) -> FloatLikeType: """Create a SymFloat value from a symbolic expression""" @@ -4810,7 +4810,6 @@ def create_unbacked_symfloat(self) -> SymFloat: ) self.counter["create_unbacked_symbol"] += 1 if not self._ignore_fresh_unbacked_symbols_tls(): - print(f"adding {symbol}") self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() diff --git a/torch/fx/graph.py b/torch/fx/graph.py index bebfe099b0b9..514490513cbf 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1621,7 +1621,7 @@ def output(self, result: "Argument", type_expr: Optional[Any] = None): op="output", target="output", args=(result,), type_expr=type_expr ) - def _target_to_str(self, target: Target) -> str: + def _target_to_str(self, target: Optional[Target]) -> str: if callable(target): op = target.__name__ else: diff --git a/torch/fx/node.py b/torch/fx/node.py index f80638641658..0d9c67757a76 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -4,7 +4,7 @@ import logging import operator import types -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec @@ -15,6 +15,7 @@ normalize_function, normalize_module, ) +from torch.utils._dtype_abbrs import dtype_abbrs from .._ops import ops as _ops from ._compatibility import compatibility @@ -597,6 +598,8 @@ def format_node( self, placeholder_names: Optional[list[str]] = None, maybe_return_typename: Optional[list[str]] = None, + *, + include_tensor_metadata: bool = False, ) -> Optional[str]: """ Return a descriptive string representation of ``self``. @@ -618,6 +621,7 @@ def format_node( maybe_return_typename: A single-element list that will store a formatted string representing the output of the generated ``forward`` function. Internal use only. + include_tensor_metadata: Whether to include tensor metadata Returns: str: If 1) we're using ``format_node`` as an internal helper @@ -649,11 +653,36 @@ def format_node( maybe_return_typename[0] = f" -> {_type_repr(self.type)}" return f"return {self.args[0]}" else: - maybe_typename = ( - f"{_type_repr(self.type)} " if self.type is not None else "" + + def stringify_shape(shape: Iterable) -> str: + return f"[{', '.join([str(x) for x in shape])}]" + + meta_val = self.meta.get( + "val", + self.meta.get("tensor_meta", self.meta.get("example_value", None)), ) + type_annotation = "" + if ( + include_tensor_metadata + and isinstance(meta_val, torch.Tensor) + and meta_val.layout + not in ( + torch.sparse_csc, + torch.sparse_csr, + ) + ): + stride_annotation = f"{stringify_shape(meta_val.stride())}" + device_annotation = f"{meta_val.device}" + type_annotation = ( + f'Tensor "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}' + f'{stride_annotation}{device_annotation}"' + ) + else: + type_annotation = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) return ( - f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"%{self.name} : {type_annotation}[num_users={len(self.users)}] = " f"{self.op}[target={self._pretty_print_target(self.target)}](" f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})" ) diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index bc7537c23847..dd8edb50e161 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -203,7 +203,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: and node.target is torch.ops.aten._local_scalar_dense.default ): dtype = node.args[0].meta["val"].dtype - if dtype != torch.float64: + if not dtype.is_floating_point: continue assert isinstance(node.args[0], fx.Node), node.args[0] @@ -212,6 +212,10 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: expr_to_tensor_proxy[s] = MetaProxy( node.args[0], tracer=tracer, fake_mode=fake_mode ) + # Upcast the float tensor to torch.float64 to avoid precision problem + expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default( + expr_to_tensor_proxy[s], torch.float64 + ) expr_to_sym_proxy[s] = MetaProxy( node, tracer=tracer, fake_mode=fake_mode ) diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 438661090942..6fc17b959424 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -18,16 +18,29 @@ class Partition: def __init__( - self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + self, + id: Optional[int] = None, + nodes: Optional[Iterable[Node]] = None, + node_orders: Optional[Iterable[int]] = None, ): self.id = id - self.nodes = dict.fromkeys(nodes) if nodes is not None else {} + self.nodes: dict[Node, Optional[int]] = {} + if nodes is not None: + if node_orders is None: + self.nodes = dict.fromkeys(nodes, None) + else: + nodes_list = list(nodes) + node_orders_list = list(node_orders) + assert len(nodes_list) == len(node_orders_list), ( + "nodes and node_orders must have the same length" + ) + self.nodes = dict(zip(nodes_list, node_orders_list)) def __repr__(self) -> str: return str(self.nodes) - def add_node(self, node: Node): - self.nodes.update({node: None}) + def add_node(self, node: Node, node_order: Optional[int] = None): + self.nodes.update({node: node_order}) def remove_node(self, node: Node): del self.nodes[node] @@ -172,7 +185,7 @@ def dfs_iter_find_cycle(all_user_nodes: set[Node]): return merge_id, True - def merge_single_node(node: Node, id: Optional[int]): + def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]): def _update_partition_map(node: Node, id: int): # Iterate through all the users of this node and update the partition map to indicate # that there is a path from the partition id of this node to the target partition id. @@ -189,16 +202,19 @@ def _update_partition_map(node: Node, id: int): assignment.pop(node) elif id not in partitions_by_id: assignment[node] = id - partitions_by_id[id] = Partition(id=id, nodes=[node]) + assert node_order is not None + partitions_by_id[id] = Partition( + id=id, nodes=[node], node_orders=[node_order] + ) partition_users[id] = set(node.users) _update_partition_map(node, id) else: assignment[node] = id - partitions_by_id[id].add_node(node) + partitions_by_id[id].add_node(node, node_order) logger.debug("Proposing partitions...") - for node in reversed(self.graph_module.graph.nodes): + for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)): # use Dict as an ordered set to ensure deterministic partitioning result, don't care value merge_candidates: dict[int, None] = {} @@ -211,7 +227,7 @@ def _update_partition_map(node: Node, id: int): partition_id = next(new_partition_id) nodes_order[node] = partition_id partitions_order[partition_id] = partition_id - merge_single_node(node, partition_id) + merge_single_node(node, node_order, partition_id) merge_candidates[partition_id] = None # merge all possible partitions @@ -228,6 +244,14 @@ def _update_partition_map(node: Node, id: int): # in the graph, otherwise, this is a no-op self_id, _ = maybe_merge_partition(self_id, other_id) + # sort partition nodes based on descending node order + for partition in partitions_by_id.values(): + partition.nodes = dict( + sorted( + partition.nodes.items(), key=operator.itemgetter(1), reverse=True + ) + ) + # post processing to re-assign "getitem" nodes into upstream partition logger.debug("Reassigning getitem nodes to its producer node's partition...") nodes_reassignment: dict[Node, int] = {} @@ -248,7 +272,7 @@ def _update_partition_map(node: Node, id: int): if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): - merge_single_node(node, id) + merge_single_node(node, None, id) # filter out single node partitions if not self.allows_single_node_partition: diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index bb71a25971da..19e101a5c120 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -337,12 +337,13 @@ def match_symbol(symint, cb): torch._check, torch.ops.aten._assert_scalar.default, ): + cond = node.args[0] if node.args else node.kwargs.get("cond") if ( - node.args[0] == True # noqa: E712 - or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy + cond == True # noqa: E712 + or (assert_expr := _get_sym_val(cond)) in expr_to_proxy and assert_expr in added_asserts ): - arg = node.args[0] + arg = cond gm.graph.erase_node(node) if isinstance(arg, fx.Node) and not arg.users: gm.graph.erase_node(arg) diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index d3ef35bdb107..e0b2ff63ba07 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -719,7 +719,7 @@ def extend_acc_subgraph(self, tag: str): """ # Dict that maps node to its users and ignore users that # are in the subgraph that has greater tag - deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1])) self.update_reverse_deps_for_fusions(deps) # Parent nodes of the subgraph diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 1b22490405de..33db9fd03d79 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -96,7 +96,7 @@ def fuse_as_graphmodule( gm: GraphModule, nodes: NodeList, module_name: str, - partition_lookup_table: _Optional[dict[Node, None]] = None, + partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None, *, always_return_tuple: bool = False, ) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: @@ -249,7 +249,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: @compatibility(is_backward_compatible=False) def fuse_by_partitions( gm: GraphModule, - partitions: list[dict[Node, None]], + partitions: list[dict[Node, _Optional[int]]], prefix: str = "fused_", always_return_tuple: bool = False, ) -> GraphModule: diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 9efae5918d6e..4cfeeb6238ad 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -3,34 +3,38 @@ # to guarantee that compiling these symbols do not require linking libtorch # to ensure header-only-ness. +# torch/headeronly/util/shim_utils.h +TORCH_ERROR_CODE_CHECK + # c10/util/TypeCast.h convert -# c10/util/bit_cast.h, torch/headeronly/util/bit_cast.h +# torch/headeronly/util/bit_cast.h bit_cast -# c10/util/BFloat16-math.h, c10/util/BFloat16.h +# torch/headeronly/util/BFloat16.h BFloat16 # torch/headeronly/util/Float4_e2m1fn_x2.h Float4_e2m1fn_x2 -# c10/util/Float8_e4m3fn.h +# torch/headeronly/util/Float8_e4m3fn.h Float8_e4m3fn -# c10/util/Float8_e4m3fnuz.h +# torch/headeronly/util/Float8_e4m3fnuz.h Float8_e4m3fnuz -# c10/util/Float8_e5m2.h +# torch/headeronly/util/Float8_e5m2.h Float8_e5m2 -# c10/util/Float8_e5m2fnuz.h +# torch/headeronly/util/Float8_e5m2fnuz.h Float8_e5m2fnuz -# c10/util/Half.h -Half +# torch/headeronly/util/Float8_e8m0fnu.h +Float8_e8m0fnu # torch/headeronly/util/Half.h +Half fp16_ieee_from_fp32_value fp16_ieee_to_fp32_value @@ -38,7 +42,7 @@ fp16_ieee_to_fp32_value # fp32_from_bits called from fp16_ieee_to_fp32_value # fp32_to_bits called from fp16_ieee_from_fp32_value -# c10/util/complex.h +# c10/util/complex.h, torch/headeronly/util/complex.h complex # ATen/NumericUtils.h, c10/util/generic_math.h @@ -90,3 +94,8 @@ bits2x4 bits4x2 bits8 bits16 + +# torch/headeronly/core/ScalarType.h +NumScalarTypes +ScalarType +# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType diff --git a/torch/headeronly/CMakeLists.txt b/torch/headeronly/CMakeLists.txt index 3b8f0d5466de..93d2d7802b52 100644 --- a/torch/headeronly/CMakeLists.txt +++ b/torch/headeronly/CMakeLists.txt @@ -20,6 +20,7 @@ configure_file( file(GLOB HEADERONLY_HEADERS *.h + core/**/*.h cpu/**/*.h macros/*.h util/*.h diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h new file mode 100644 index 000000000000..0e426427997b --- /dev/null +++ b/torch/headeronly/core/ScalarType.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + +// See [dtype Macros note] in c10/core/ScalarType.h regarding macros + +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ \ + _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ + _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ + _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ + _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ + _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ + _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ + _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ + _(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ + _(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ + _(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ + _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ + _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ + _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ + _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ + _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ + +enum class ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ + Undefined, + NumOptions +}; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); + +} // namespace c10 + +namespace torch::headeronly { +using c10::dummy_int1_7_t; +using c10::dummy_uint1_7_t; +using c10::NumScalarTypes; +using c10::ScalarType; +} // namespace torch::headeronly diff --git a/torch/headeronly/ovrsource_defs.bzl b/torch/headeronly/ovrsource_defs.bzl index c590f388ffb0..3c3030c048b1 100644 --- a/torch/headeronly/ovrsource_defs.bzl +++ b/torch/headeronly/ovrsource_defs.bzl @@ -29,6 +29,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile): public_include_directories = ["../.."], public_preprocessor_flags = pp_flags, public_raw_headers = native.glob([ + "core/**/*.h", "cpu/**/*.h", "macros/*.h", "util/*.h", diff --git a/torch/headeronly/util/BFloat16.h b/torch/headeronly/util/BFloat16.h new file mode 100644 index 000000000000..2c1f805ac7b7 --- /dev/null +++ b/torch/headeronly/util/BFloat16.h @@ -0,0 +1,478 @@ +#pragma once + +// Defines the bloat16 type (brain floating-point). This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. + +#include +#include + +#include +#include +#include +#include +#include + +#if defined(__CUDACC__) && !defined(USE_ROCM) +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace c10 { + +struct alignas(2) BFloat16 { + uint16_t x; + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) && defined(__HIPCC__) + C10_HOST_DEVICE BFloat16() = default; +#else + BFloat16() = default; +#endif + + struct from_bits_t {}; + static constexpr C10_HOST_DEVICE from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) + : x(bits) {} + /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); + inline C10_HOST_DEVICE operator float() const; + +#if defined(__CUDACC__) && !defined(USE_ROCM) + inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); + explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; +#endif +}; + +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { + out << (float)value; + return out; +} + +namespace detail { +inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + +#if defined(USE_ROCM) && defined(__HIPCC__) + float* tempRes; + + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&tmp); + res = *tempRes; +#else + std::memcpy(&res, &tmp, sizeof(tmp)); +#endif + + return res; +} + +inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { + uint32_t res = 0; + +#if defined(USE_ROCM) && defined(__HIPCC__) + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + uint32_t* tempRes = reinterpret_cast(&src); + res = *tempRes; +#else + std::memcpy(&res, &src, sizeof(res)); +#endif + + return res >> 16; +} + +inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { +#if defined(USE_ROCM) && defined(__HIPCC__) + if (src != src) { +#elif defined(_MSC_VER) + if (isnan(src)) { +#else + if (std::isnan(src)) { +#endif + return UINT16_C(0x7FC0); + } else { + const uint32_t U32 = c10::bit_cast(src); + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); + } +} + +} // namespace detail + +//-------- the following is copied from c10/util/BFloat16-inl.h ---------// +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +/// Constructors +inline C10_HOST_DEVICE BFloat16::BFloat16(float value) + : +#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 800 + x(__bfloat16_as_ushort(__float2bfloat16(value))) +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + x(c10::bit_cast(sycl::ext::oneapi::bfloat16(value))) +#else + // RNE by default + x(detail::round_to_nearest_even(value)) +#endif +{ +} + +/// Implicit conversions +inline C10_HOST_DEVICE BFloat16::operator float() const { +#if defined(__CUDACC__) && !defined(USE_ROCM) + return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + return float(*reinterpret_cast(&x)); +#else + return detail::f32_from_bits(x); +#endif +} + +#if defined(__CUDACC__) && !defined(USE_ROCM) +inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +inline C10_HOST_DEVICE BFloat16::BFloat16( + const sycl::ext::oneapi::bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE BFloat16 +operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 +operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 +operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { + a = a / b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + +C10_CLANG_DIAGNOSTIC_POP() +} // namespace c10 + +namespace torch::headeronly { + +namespace detail { +using c10::detail::bits_from_f32; +using c10::detail::f32_from_bits; +using c10::detail::round_to_nearest_even; +} // namespace detail + +using c10::BFloat16; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; +using c10::operator<; +using c10::operator>; +using c10::operator<<; +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::BFloat16 min() { + return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 lowest() { + return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 max() { + return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 epsilon() { + return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 round_error() { + return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 infinity() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 quiet_NaN() { + return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 signaling_NaN() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 denorm_min() { + return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); + } +}; + +} // namespace std diff --git a/torch/headeronly/util/Float8_e4m3fn.h b/torch/headeronly/util/Float8_e4m3fn.h new file mode 100644 index 000000000000..d54a8f40a6c1 --- /dev/null +++ b/torch/headeronly/util/Float8_e4m3fn.h @@ -0,0 +1,531 @@ +#pragma once + +/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// bias = 7 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/c10/util/Half.h + +#include +#include + +#if defined(__cplusplus) +#include +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include + +namespace c10 { + +struct alignas(1) Float8_e4m3fn { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fn() = default; + + constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e4m3fn(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) { + out << (float)value; + return out; +} + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E4M3FN number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)input << 24; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(nonsign); +#elif defined(__SYCL_DEVICE_ONLY__) + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#elif defined(_MSC_VER) && !defined(__clang__) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#endif + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + /* + * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1, + * the addition overflows it into bit 31, and the subsequent shift turns the + * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number + * is Nan, 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 4 so the exponent (4 bits originally) + * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0x07 + * for fp8e4m3fn number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x78, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { + /* + * Binary representation of 480.0f, which is the first value + * not representable in fp8e4m3fn range: + * 0 1111 111 - fp8e4m3fn + * 0 10000111 11100000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(1087) << 20; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fn normal range + * into denorm representation + * magic number: ((127 - 7) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(141) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = 0x7f; + } else { + if (f_bits < (UINT32_C(121) << 23)) { + // Input number is smaller than 2^(-6), which is the smallest + // fp8e4m3fn normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +// -------- below is copied from c10/util/Float8_e4m3fn-inl.h --------// +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +/// Constructors + +inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) + : x(detail::fp8e4m3fn_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const { + return detail::fp8e4m3fn_to_fp32_value(x); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const { + return (x & 0b01111111) == 0b01111111; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fn +operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn +operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn +operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator/( + const Float8_e4m3fn& a, + const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator+=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator-=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator*=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator/=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fn to float. + +C10_CLANG_DIAGNOSTIC_POP() + +} // namespace c10 + +namespace torch::headeronly { +using c10::Float8_e4m3fn; +using c10::operator<<; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e4m3fn min() { + return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn lowest() { + return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn max() { + return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn epsilon() { + return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn round_error() { + return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn quiet_NaN() { + return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn denorm_min() { + return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits()); + } +}; + +} // namespace std diff --git a/torch/headeronly/util/Float8_e4m3fnuz.h b/torch/headeronly/util/Float8_e4m3fnuz.h new file mode 100644 index 000000000000..772ffd9e96c6 --- /dev/null +++ b/torch/headeronly/util/Float8_e4m3fnuz.h @@ -0,0 +1,442 @@ +#pragma once + +/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as Float8_e4m3fn: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// The key differences versus Float8_e4m3fn are: +/// bias = 8 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fnuz() = default; + + constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e4m3fnuz& value) { + out << (float)value; + return out; +} + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { + /* + * Binary representation of 256.0f, which is the first value not representable + * (i.e. the first value which would overflow in to the sign bit, resulting in + * a NaN) in fp8e4m3fnuz range: + * 1 0000 000 - fp8e4m3fnuz + * 0 10000111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range + * into denorm representation + * magic number: ((127 - 8) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s. + return 0x80; + } + + if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { + // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +//------ below is copied from c10/util/Float8_e4m3fnuz-inl.h ------// +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +/// Constructors + +inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) + : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const { + return torch::headeronly::detail::fp8_fnuz_to_fp32_value<4, 3>(x); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { + return x == 0b10000000; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/( + const Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fnuz to float. + +C10_CLANG_DIAGNOSTIC_POP() + +} // namespace c10 + +namespace torch::headeronly { +using c10::Float8_e4m3fnuz; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; +using c10::operator<<; + +namespace detail { +using c10::detail::fp8e4m3fnuz_from_fp32_value; +} // namespace detail + +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e4m3fnuz min() { + return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz lowest() { + return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz max() { + return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz epsilon() { + return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz round_error() { + return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz infinity() { + // NaN (no infinities) + return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz quiet_NaN() { + return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz denorm_min() { + return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits()); + } +}; + +} // namespace std diff --git a/torch/headeronly/util/Float8_e5m2.h b/torch/headeronly/util/Float8_e5m2.h new file mode 100644 index 000000000000..aeee40d8e5b8 --- /dev/null +++ b/torch/headeronly/util/Float8_e5m2.h @@ -0,0 +1,456 @@ +#pragma once + +/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// bias = 15 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/c10/util/Half.h + +#include + +#include + +namespace c10 { + +struct alignas(1) Float8_e5m2 { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2() = default; + + constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {} + inline C10_HOST_DEVICE Float8_e5m2(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; +}; + +inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) { + out << (float)value; + return out; +} + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E5M2 number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEEE|MM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 26-30 24-25 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + uint16_t half_representation = input; + half_representation <<= 8; + return fp16_ieee_to_fp32_value(half_representation); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) { + /* + * Binary representation of fp32 infinity + * 0 11111111 00000000000000000000000 + */ + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + + /* + * Binary representation of 65536.0f, which is the first value + * not representable in fp8e5m2 range: + * 0 11111 00 - fp8e5m2 + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2 normal range + * into denorm representation + * magic number: ((127 - 15) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); + } else { + if (f_bits < (UINT32_C(113) << 23)) { + // Input number is smaller than 2^(-14), which is the smallest + // fp8e5m2 normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint32_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +// -------- below is copied from c10/util/Float8_e5m2-inl.h --------// +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#define EXP_WIDTH_FP8 5 +#define MAN_WIDTH_FP8 2 +#define EXP_BIAS_FP8 15 + +/// Constructors + +inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) + : x(detail::fp8e5m2_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e5m2::operator float() const { + return detail::fp8e5m2_to_fp32_value(x); +} + +/// Special values helpers + +inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const { + return (x & 0b01111111) > 0b01111100; +} + +inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const { + return (x & 0b01111111) == 0b01111100; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2 +operator+(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 +operator-(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 +operator*(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator/( + const Float8_e5m2& a, + const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2& operator+=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator-=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator*=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator/=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2 to float. +C10_CLANG_DIAGNOSTIC_POP() +} // namespace c10 + +namespace torch::headeronly { +using c10::Float8_e5m2; +using c10::operator<<; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; + +namespace detail { +using c10::detail::fp8e5m2_from_fp32_value; +using c10::detail::fp8e5m2_to_fp32_value; +} // namespace detail +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::Float8_e5m2 min() { + return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 max() { + return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 lowest() { + return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 epsilon() { + return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 round_error() { + return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 infinity() { + return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 quiet_NaN() { + return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 denorm_min() { + return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); + } +}; + +} // namespace std diff --git a/torch/headeronly/util/Float8_e5m2fnuz.h b/torch/headeronly/util/Float8_e5m2fnuz.h new file mode 100644 index 000000000000..8bcb2ac07f76 --- /dev/null +++ b/torch/headeronly/util/Float8_e5m2fnuz.h @@ -0,0 +1,446 @@ +#pragma once + +/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as e5m2: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// The key differences that e5m2fnuz brings are: +/// bias = 16 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2fnuz() = default; + + constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e5m2fnuz& value) { + out << (float)value; + return out; +} + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { + /* + * Binary representation of 65536.0f, which is the first value not + * representable (i.e. the first value which would overflow in to the sign + * bit, resulting in a NaN) in fp8e4m3fnuz range: + * 1 00000 00 - fp8e5m2fnuz + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range + * into denormalized representation. + * magic number: ((127 - 16) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s + return 0x80; + } + + if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { + // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +//------ below is copied from c10/util/Float8_e5m2fnuz-inl.h ------// +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +/// Constructors + +inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) + : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const { + return torch::headeronly::detail::fp8_fnuz_to_fp32_value<5, 2>(x); +} + +/// Special values helpers + +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { + return x == 0b10000000; +} + +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { + return false; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/( + const Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2fnuz to float. + +C10_CLANG_DIAGNOSTIC_POP() + +} // namespace c10 + +namespace torch::headeronly { +using c10::Float8_e5m2fnuz; +using c10::operator<<; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; + +namespace detail { +using c10::detail::fp8e5m2fnuz_from_fp32_value; +} +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::Float8_e5m2fnuz min() { + return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz max() { + return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz lowest() { + return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz epsilon() { + return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz round_error() { + return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz infinity() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } + // TODO(future): we are mapping neg_zero to both inf and NaN, this is + // surprising and we should figure out what to do about it. + static constexpr c10::Float8_e5m2fnuz quiet_NaN() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz denorm_min() { + return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); + } +}; + +} // namespace std diff --git a/torch/headeronly/util/Float8_e8m0fnu.h b/torch/headeronly/util/Float8_e8m0fnu.h new file mode 100644 index 000000000000..c5a70525f2f2 --- /dev/null +++ b/torch/headeronly/util/Float8_e8m0fnu.h @@ -0,0 +1,226 @@ +#pragma once + +/// Defines the Float8_e8m0fnu type (8-bit floating-point) including +/// conversions to standard C types +/// Binary configuration : +/// eeeeeeee +/// no sign bits +/// 8 exponent bits +/// no mantissa bits +/// +/// This is the E8M0 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.4.1) + +#include +#include + +// TODO(#146647): do we need to special case OPENCL? +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include +#include + +namespace c10 { + +struct alignas(1) Float8_e8m0fnu { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e8m0fnu() = default; + + constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e8m0fnu(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e8m0fnu& value) { + out << (float)value; + return out; +} + +namespace detail { +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 e8m0fnu format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) { + // TODO(#146647): maybe rewrite without control flow + + uint32_t f_bits = c10::detail::fp32_to_bits(f); + + // extract the exponent + uint32_t exponent = (f_bits >> 23) & 0b11111111; + + // special case float32 NaN and +-inf to map to e8m0 nan + if (exponent == 0b11111111) { + return exponent; + } + + // next, we use guard, round, sticky bits and the LSB to implement round to + // nearest, with ties to even + + // guard bit - bit 23, or 22 zero-indexed + uint8_t g = (f_bits & 0x400000) > 0; + // round bit - bit 22, or 21 zero-indexed + uint8_t r = (f_bits & 0x200000) > 0; + // sticky bit - bits 21 to 1, or 20 to 0 zero-indexed + uint8_t s = (f_bits & 0x1FFFFF) > 0; + // in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the + // original float32 is denormal, and to 1 if the original float32 is normal. + uint8_t lsb = exponent > 0; + + // implement the RNE logic + bool round_up = false; + + // if g == 0, round down (no-op) + if (g == 1) { + if ((r == 1) || (s == 1)) { + // round up + round_up = true; + } else { + if (lsb == 1) { + // round up + round_up = true; + } + // if lsb == 0, round down (no-op) + } + } + + if (round_up) { + // adjust exponent + // note that if exponent was 255 we would have already returned earlier, so + // we know we can add one safely without running out of bounds + exponent++; + } + + return exponent; +} + +} // namespace detail + +//------- the below is from c10/util/Float8_e8m0fnu-inl.h ------// +// TODO(#146647): Can we remove the below warning? +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +/// Constructors +inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value) + : x(detail::fp8e8m0fnu_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const { + // TODO(#146647): maybe rewrite without control flow + + // if exponent is zero, need to special case to return 2^-127 instead of zero + if (x == 0) { + return c10::detail::fp32_from_bits(0x00400000); + } + + // if exponent is NaN, need to special case to return properly encoded NaN + if (isnan()) { + return c10::detail::fp32_from_bits(0x7f800001); + } + + // leave sign at 0, set the exponent bits, leave stored mantissa at 0 + uint32_t res = x << 23; + + return c10::detail::fp32_from_bits(res); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const { + return x == 0b11111111; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e8m0fnu to float. +C10_CLANG_DIAGNOSTIC_POP() + +} // namespace c10 + +namespace torch::headeronly { +using c10::Float8_e8m0fnu; +using c10::operator<<; + +namespace detail { +using c10::detail::fp8e8m0fnu_from_fp32_value; +} // namespace detail +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = false; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = false; + static constexpr auto has_denorm_loss = false; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 1; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 1; // just a 2! + static constexpr int radix = 2; + static constexpr int min_exponent = -126; + static constexpr int min_exponent10 = -38; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e8m0fnu min() { + // 2^-127 + return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu lowest() { + // 2^-127 + return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu max() { + // 254 biased, which is 127 unbiased, so 2^127 + return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu epsilon() { + // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this + // is "the difference between 1.0 and the next representable value of the + // given floating-point type". The next representable value is 2.0, so the + // difference is 1.0 which is 2^0. 0 unbiased is 127 biased. + return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu round_error() { + // 0.5 in float, which is 2^-1, and -1 + 127 = 126 + return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu quiet_NaN() { + return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits()); + } +}; + +} // namespace std diff --git a/c10/util/Float8_fnuz_cvt.h b/torch/headeronly/util/Float8_fnuz_cvt.h similarity index 89% rename from c10/util/Float8_fnuz_cvt.h rename to torch/headeronly/util/Float8_fnuz_cvt.h index 327f90d11a71..e2e21a8ce0f9 100644 --- a/c10/util/Float8_fnuz_cvt.h +++ b/torch/headeronly/util/Float8_fnuz_cvt.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -8,7 +8,7 @@ #include #endif -namespace c10::detail { +namespace torch::headeronly::detail { /* * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ @@ -61,4 +61,8 @@ inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { return fp32_from_bits(retval); } -} // namespace c10::detail +} // namespace torch::headeronly::detail + +namespace c10::detail { +using torch::headeronly::detail::fp8_fnuz_to_fp32_value; +} diff --git a/torch/headeronly/util/TypeSafeSignMath.h b/torch/headeronly/util/TypeSafeSignMath.h new file mode 100644 index 000000000000..561ea0467a08 --- /dev/null +++ b/torch/headeronly/util/TypeSafeSignMath.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wstring-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Returns false since we cannot have x < 0 if x is unsigned. +template +inline constexpr bool is_negative( + const T& /*x*/, + std::true_type /*is_unsigned*/) { + return false; +} + +/// Returns true if a signed variable x < 0 +template +inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { + return x < T(0); +} + +/// Returns true if x < 0 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, c10::Half does not :-( +template +inline constexpr bool is_negative(const T& x) { + return is_negative(x, std::is_unsigned()); +} + +/// Returns the sign of an unsigned variable x as 0, 1 +template +inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { + return T(0) < x; +} + +/// Returns the sign of a signed variable x as -1, 0, 1 +template +inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { + return (T(0) < x) - (x < T(0)); +} + +/// Returns the sign of x as -1, 0, 1 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, c10::Half does not :-( +template +inline constexpr int signum(const T& x) { + return signum(x, std::is_unsigned()); +} + +/// Returns true if a and b are not both negative +template +inline constexpr bool signs_differ(const T& a, const U& b) { + return is_negative(a) != is_negative(b); +} + +// Suppress sign compare warning when compiling with GCC +// as later does not account for short-circuit rule before +// raising the warning, see https://godbolt.org/z/Tr3Msnz99 +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif + +/// Returns true if x is greater than the greatest value of the type Limit +template +inline constexpr bool greater_than_max(const T& x) { + constexpr bool can_overflow = + std::numeric_limits::digits > std::numeric_limits::digits; + return can_overflow && x > (std::numeric_limits::max)(); +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +/// Returns true if x < lowest(Limit). Standard comparison +template +inline constexpr bool less_than_lowest( + const T& x, + std::false_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < std::numeric_limits::lowest(); +} + +/// Returns false since all the limit is signed and therefore includes +/// negative values but x cannot be negative because it is unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::false_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x < 0, where 0 is constructed from T. +/// Limit is not signed, so its lower value is zero +template +inline constexpr bool less_than_lowest( + const T& x, + std::true_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < T(0); +} + +/// Returns false sign both types are unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::true_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x is less than the lowest value of type T +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, c10::Half does not : +template +inline constexpr bool less_than_lowest(const T& x) { + return less_than_lowest( + x, std::is_unsigned(), std::is_unsigned()); +} + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +namespace torch::headeronly { +using c10::greater_than_max; +using c10::is_negative; +using c10::less_than_lowest; +using c10::signs_differ; +using c10::signum; +} // namespace torch::headeronly diff --git a/torch/headeronly/util/complex.h b/torch/headeronly/util/complex.h new file mode 100644 index 000000000000..e0a356436acb --- /dev/null +++ b/torch/headeronly/util/complex.h @@ -0,0 +1,616 @@ +#pragma once + +#include + +#include +#include + +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#endif + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wfloat-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") +#endif + +namespace c10 { + +// c10::complex is an implementation of complex numbers that aims +// to work on all devices supported by PyTorch +// +// Most of the APIs duplicates std::complex +// Reference: https://en.cppreference.com/w/cpp/numeric/complex +// +// [NOTE: Complex Operator Unification] +// Operators currently use a mix of std::complex, thrust::complex, and +// c10::complex internally. The end state is that all operators will use +// c10::complex internally. Until then, there may be some hacks to support all +// variants. +// +// +// [Note on Constructors] +// +// The APIs of constructors are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/complex +// +// Since C++14, all constructors are constexpr in std::complex +// +// There are three types of constructors: +// - initializing from real and imag: +// `constexpr complex( const T& re = T(), const T& im = T() );` +// - implicitly-declared copy constructor +// - converting constructors +// +// Converting constructors: +// - std::complex defines converting constructor between float/double/long +// double, +// while we define converting constructor between float/double. +// - For these converting constructors, upcasting is implicit, downcasting is +// explicit. +// - We also define explicit casting from std::complex/thrust::complex +// - Note that the conversion from thrust is not constexpr, because +// thrust does not define them as constexpr ???? +// +// +// [Operator =] +// +// The APIs of operator = are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D +// +// Since C++20, all operator= are constexpr. Although we are not building with +// C++20, we also obey this behavior. +// +// There are three types of assign operator: +// - Assign a real value from the same scalar type +// - In std, this is templated as complex& operator=(const T& x) +// with specialization `complex& operator=(T x)` for float/double/long +// double Since we only support float and double, on will use `complex& +// operator=(T x)` +// - Copy assignment operator and converting assignment operator +// - There is no specialization of converting assignment operators, which type +// is +// convertible is solely dependent on whether the scalar type is convertible +// +// In addition to the standard assignment, we also provide assignment operators +// with std and thrust +// +// +// [Casting operators] +// +// std::complex does not have casting operators. We define casting operators +// casting to std::complex and thrust::complex +// +// +// [Operator ""] +// +// std::complex has custom literals `i`, `if` and `il` defined in namespace +// `std::literals::complex_literals`. We define our own custom literals in the +// namespace `c10::complex_literals`. Our custom literals does not follow the +// same behavior as in std::complex, instead, we define _if, _id to construct +// float/double complex literals. +// +// +// [real() and imag()] +// +// In C++20, there are two overload of these functions, one it to return the +// real/imag, another is to set real/imag, they are both constexpr. We follow +// this design. +// +// +// [Operator +=,-=,*=,/=] +// +// Since C++20, these operators become constexpr. In our implementation, they +// are also constexpr. +// +// There are two types of such operators: operating with a real number, or +// operating with another complex number. For the operating with a real number, +// the generic template form has argument type `const T &`, while the overload +// for float/double/long double has `T`. We will follow the same type as +// float/double/long double in std. +// +// [Unary operator +-] +// +// Since C++20, they are constexpr. We also make them expr +// +// [Binary operators +-*/] +// +// Each operator has three versions (taking + as example): +// - complex + complex +// - complex + real +// - real + complex +// +// [Operator ==, !=] +// +// Each operator has three versions (taking == as example): +// - complex == complex +// - complex == real +// - real == complex +// +// Some of them are removed on C++20, but we decide to keep them +// +// [Operator <<, >>] +// +// These are implemented by casting to std::complex +// +// +// +// TODO(@zasdfgbnm): c10::complex is not currently supported, +// because: +// - lots of members and functions of c10::Half are not constexpr +// - thrust::complex only support float and double + +template +struct alignas(sizeof(T) * 2) complex { + using value_type = T; + + T real_ = T(0); + T imag_ = T(0); + + constexpr complex() = default; + C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) + : real_(re), imag_(im) {} + template + explicit constexpr complex(const std::complex& other) + : complex(other.real(), other.imag()) {} +#if defined(__CUDACC__) || defined(__HIPCC__) + template + explicit C10_HOST_DEVICE complex(const thrust::complex& other) + : real_(other.real()), imag_(other.imag()) {} +// NOTE can not be implemented as follow due to ROCm bug: +// explicit C10_HOST_DEVICE complex(const thrust::complex &other): +// complex(other.real(), other.imag()) {} +#endif + + // Use SFINAE to specialize casting constructor for c10::complex and + // c10::complex + template + C10_HOST_DEVICE explicit constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + template + C10_HOST_DEVICE constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + + constexpr complex& operator=(T re) { + real_ = re; + imag_ = 0; + return *this; + } + + constexpr complex& operator+=(T re) { + real_ += re; + return *this; + } + + constexpr complex& operator-=(T re) { + real_ -= re; + return *this; + } + + constexpr complex& operator*=(T re) { + real_ *= re; + imag_ *= re; + return *this; + } + + constexpr complex& operator/=(T re) { + real_ /= re; + imag_ /= re; + return *this; + } + + template + constexpr complex& operator=(const complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + + template + constexpr complex& operator+=(const complex& rhs) { + real_ += rhs.real(); + imag_ += rhs.imag(); + return *this; + } + + template + constexpr complex& operator-=(const complex& rhs) { + real_ -= rhs.real(); + imag_ -= rhs.imag(); + return *this; + } + + template + constexpr complex& operator*=(const complex& rhs) { + // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } + +#ifdef __APPLE__ +#define FORCE_INLINE_APPLE __attribute__((always_inline)) +#else +#define FORCE_INLINE_APPLE +#endif + template + constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) + __ubsan_ignore_float_divide_by_zero__ { + // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i + // the calculation below follows numpy's complex division + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + +#if defined(__GNUC__) && !defined(__clang__) + // std::abs is already constexpr by gcc + auto abs_c = std::abs(c); + auto abs_d = std::abs(d); +#else + auto abs_c = c < 0 ? -c : c; + auto abs_d = d < 0 ? -d : d; +#endif + + if (abs_c >= abs_d) { + if (abs_c == U(0) && abs_d == U(0)) { + /* divide by zeros should yield a complex inf or nan */ + real_ = a / abs_c; + imag_ = b / abs_d; + } else { + auto rat = d / c; + auto scl = U(1.0) / (c + d * rat); + real_ = (a + b * rat) * scl; + imag_ = (b - a * rat) * scl; + } + } else { + auto rat = c / d; + auto scl = U(1.0) / (d + c * rat); + real_ = (a * rat + b) * scl; + imag_ = (b * rat - a) * scl; + } + return *this; + } +#undef FORCE_INLINE_APPLE + + template + constexpr complex& operator=(const std::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + C10_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } +#endif + + template + explicit constexpr operator std::complex() const { + return std::complex(std::complex(real(), imag())); + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + C10_HOST_DEVICE explicit operator thrust::complex() const { + return static_cast>(thrust::complex(real(), imag())); + } +#endif + + // consistent with NumPy behavior + explicit constexpr operator bool() const { + return real() || imag(); + } + + C10_HOST_DEVICE constexpr T real() const { + return real_; + } + constexpr void real(T value) { + real_ = value; + } + C10_HOST_DEVICE constexpr T imag() const { + return imag_; + } + constexpr void imag(T value) { + imag_ = value; + } +}; + +namespace complex_literals { + +constexpr complex operator""_if(long double imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(long double imag) { + return complex(0.0, static_cast(imag)); +} + +constexpr complex operator""_if(unsigned long long imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(unsigned long long imag) { + return complex(0.0, static_cast(imag)); +} + +} // namespace complex_literals + +template +constexpr complex operator+(const complex& val) { + return val; +} + +template +constexpr complex operator-(const complex& val) { + return complex(-val.real(), -val.imag()); +} + +template +constexpr complex operator+(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const complex& lhs, const T& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const T& lhs, const complex& rhs) { + return complex(lhs + rhs.real(), rhs.imag()); +} + +template +constexpr complex operator-(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const complex& lhs, const T& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const T& lhs, const complex& rhs) { + complex result = -rhs; + return result += lhs; +} + +template +constexpr complex operator*(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const complex& lhs, const T& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const T& lhs, const complex& rhs) { + complex result = rhs; + return result *= lhs; +} + +template +constexpr complex operator/(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const complex& lhs, const T& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const T& lhs, const complex& rhs) { + complex result(lhs, T()); + return result /= rhs; +} + +// Define operators between integral scalars and c10::complex. std::complex does +// not support this when T is a floating-point number. This is useful because it +// saves a lot of "static_cast" when operate a complex and an integer. This +// makes the code both less verbose and potentially more efficient. +#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ + typename std::enable_if_t< \ + std::is_floating_point_v && std::is_integral_v, \ + int> = 0 + +template +constexpr c10::complex operator+(const c10::complex& a, const iT& b) { + return a + static_cast(b); +} + +template +constexpr c10::complex operator+(const iT& a, const c10::complex& b) { + return static_cast(a) + b; +} + +template +constexpr c10::complex operator-(const c10::complex& a, const iT& b) { + return a - static_cast(b); +} + +template +constexpr c10::complex operator-(const iT& a, const c10::complex& b) { + return static_cast(a) - b; +} + +template +constexpr c10::complex operator*(const c10::complex& a, const iT& b) { + return a * static_cast(b); +} + +template +constexpr c10::complex operator*(const iT& a, const c10::complex& b) { + return static_cast(a) * b; +} + +template +constexpr c10::complex operator/(const c10::complex& a, const iT& b) { + return a / static_cast(b); +} + +template +constexpr c10::complex operator/(const iT& a, const c10::complex& b) { + return static_cast(a) / b; +} + +#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION + +template +constexpr bool operator==(const complex& lhs, const complex& rhs) { + return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); +} + +template +constexpr bool operator==(const complex& lhs, const T& rhs) { + return (lhs.real() == rhs) && (lhs.imag() == T()); +} + +template +constexpr bool operator==(const T& lhs, const complex& rhs) { + return (lhs == rhs.real()) && (T() == rhs.imag()); +} + +template +constexpr bool operator!=(const complex& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const complex& lhs, const T& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const T& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const complex& x) { + return (os << static_cast>(x)); +} + +template +std::basic_istream& operator>>( + std::basic_istream& is, + complex& x) { + std::complex tmp; + is >> tmp; + x = tmp; + return is; +} + +template +C10_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::polar(r, theta)); +#else + // std::polar() requires r >= 0, so spell out the explicit implementation to + // avoid a branch. + return complex(r * std::cos(theta), r * std::sin(theta)); +#endif +} + +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) + : real_(real), imag_(imag) {} + C10_HOST_DEVICE inline complex(const c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline C10_HOST_DEVICE operator c10::complex() const { + return {real_, imag_}; + } + + constexpr C10_HOST_DEVICE Half real() const { + return real_; + } + constexpr C10_HOST_DEVICE Half imag() const { + return imag_; + } + + C10_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + +} // namespace c10 + +namespace torch::headeronly { +using c10::complex; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; +using c10::operator==; +using c10::operator!=; +using c10::operator<<; +using c10::operator>>; +using c10::polar; + +namespace complex_literals { +using c10::complex_literals::operator""_if; +using c10::complex_literals::operator""_id; +} // namespace complex_literals + +} // namespace torch::headeronly + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/torch/headeronly/util/shim_utils.h b/torch/headeronly/util/shim_utils.h new file mode 100644 index 000000000000..5acb3e2e347c --- /dev/null +++ b/torch/headeronly/util/shim_utils.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include +#include + +#define TORCH_SUCCESS 0 +#define TORCH_FAILURE 1 + +namespace torch::headeronly::detail { +[[maybe_unused]] C10_NOINLINE static void throw_exception( + const char* call, + const char* file, + int64_t line) { + std::stringstream ss; + ss << call << " API call failed at " << file << ", line " << line; + throw std::runtime_error(ss.str()); +} +} // namespace torch::headeronly::detail + +// This API is 100% inspired by AOTI_TORCH_ERROR_CODE_CHECK defined in +// pytorch/torch/csrc/inductor/aoti_runtime/utils.h to handle the returns +// of the APIs in the shim. We are genericizing this for more global use +// of the shim beyond AOTI, for examples, see torch/csrc/stable/ops.h. +#define TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != TORCH_SUCCESS) { \ + torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ + } diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 4cef60948ad9..eb5f885acc19 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -2,6 +2,7 @@ import logging import multiprocessing import multiprocessing.connection +import multiprocessing.spawn as mp_spawn import os import pickle import signal @@ -12,6 +13,11 @@ from concurrent.futures import as_completed, ThreadPoolExecutor from typing import Optional +from torch.numa.binding import ( + maybe_get_temporary_python_executable_with_numa_bindings, + NumaOptions, +) + from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] @@ -236,6 +242,7 @@ def start_processes( join=True, daemon=False, start_method="spawn", + numa_options: Optional[NumaOptions] = None, ): # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), # this func will start processes in parallel if start_method is 'forkserver'. @@ -251,11 +258,43 @@ def start_processes( # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start start_parallel = False + if numa_options is not None and start_method != "spawn": + raise ValueError("NUMA binding is only compatible with spawn") + + if numa_options is not None and start_parallel: + raise ValueError("NUMA binding is not compatible with parallel start") + mp = multiprocessing.get_context(start_method) error_files = [None] * nprocs processes = [None] * nprocs + original_executable = mp_spawn.get_executable() def start_process(i): + # HACK: We want to force Process.start() to kick off the subprocess + # using a custom numactl command per rank. However, the API exposed + # by multiprocessing only allows us to override the executable for + # the entire context, and only with a single str rather than a tuple. + # Furthermore, there is no API for passing additional options, e.g. + # to make LOCAL_RANK available to the executable. + # + # In order to get around these limitations, we pre-compute + # the appropriate command containing NUMA bindings and store it in a + # temporary executable which passes Python args on to the original + # executable. Then, we call set_executable before and after each + # Process.start() call. + # + # This assumes that, under the hood, Process.start() for rank n + # will not call get_executable after start_process for rank n+1 + # calls set_executable again. We guarantee this by + # raising an exception if `start_parallel`, above. (Not clear + # if there would be a race condition otherwise, but we want to be safe.) + temporary_executable_path = ( + maybe_get_temporary_python_executable_with_numa_bindings( + python_executable_path=original_executable, + gpu_index=i, + numa_options=numa_options, + ) + ) # Each process is assigned a file to write tracebacks to. We # use the file being non-empty to indicate an exception # occurred (vs an expected shutdown). Note: this previously @@ -267,12 +306,19 @@ def start_process(i): ) tf.close() os.unlink(tf.name) - process = mp.Process( - target=_wrap, - args=(fn, i, args, tf.name), - daemon=daemon, - ) - process.start() + + try: + if temporary_executable_path is not None: + mp.set_executable(temporary_executable_path) + process = mp.Process( + target=_wrap, + args=(fn, i, args, tf.name), + daemon=daemon, + ) + process.start() + finally: + if temporary_executable_path is not None: + mp.set_executable(original_executable) return i, process, tf.name if not start_parallel: diff --git a/torch/nativert/ModelRunner.cpp b/torch/nativert/ModelRunner.cpp index f1c2a35db14c..83cb0e00bd72 100644 --- a/torch/nativert/ModelRunner.cpp +++ b/torch/nativert/ModelRunner.cpp @@ -136,4 +136,21 @@ std::vector ModelRunner::runWithFlatInputsAndOutputs( return executor_->execute(std::move(flatInputs)); } +ModelRunnerHandle::ModelRunnerHandle( + const std::string& packagePath, + const std::string& modelName) + : impl_(std::make_unique(packagePath, modelName)) {} +ModelRunnerHandle::~ModelRunnerHandle() = default; + +c10::IValue ModelRunnerHandle::run( + const std::vector& args, + const std::unordered_map& kwargs) { + return impl_->run(args, kwargs); +} + +std::vector ModelRunnerHandle::runWithFlatInputsAndOutputs( + std::vector flatInputs) { + return impl_->runWithFlatInputsAndOutputs(std::move(flatInputs)); +} + } // namespace torch::nativert diff --git a/torch/nativert/ModelRunner.h b/torch/nativert/ModelRunner.h index 4c8875731885..e037e3b26ca8 100644 --- a/torch/nativert/ModelRunner.h +++ b/torch/nativert/ModelRunner.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include diff --git a/torch/nativert/detail/MPMCQueue.h b/torch/nativert/detail/MPMCQueue.h index 3b90503887bb..8301ce3fdb4c 100644 --- a/torch/nativert/detail/MPMCQueue.h +++ b/torch/nativert/detail/MPMCQueue.h @@ -55,6 +55,15 @@ class MPMCQueue { return true; } + /** + * Get the current size of the queue. + * @return The number of elements in the queue. + */ + size_t size() { + std::lock_guard lock(mutex_); + return storage_.size(); + } + private: std::mutex mutex_; std::deque storage_; diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index 932972ae2b5b..906a6ec32728 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -10,10 +10,6 @@ #include #include -// Maximum number of retries when trying to get a frame from -// clearedExecutionFrames_ -constexpr uint32_t kClearExecutionFrameRetries = 10; - namespace torch::nativert { Executor::Executor( @@ -29,7 +25,7 @@ Executor::Executor( ? std::optional(*graph_) : std::nullopt), executionFrames_(executorConfig_.maxNumConcurrentThreads), - clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads), + inactiveExecutionFrames_(executorConfig_.maxNumConcurrentThreads), numExecutionFrames_(0), lastClearedTimestamp_(getCurrentTimestampSeconds()) { if (weights) { @@ -193,34 +189,12 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() { std::shared_ptr weights; weights_.withLock([&](auto& w) { weights = w; }); - // First try to get a frame from clearedExecutionFrames_ if clearing is in - // progress - if (C10_UNLIKELY(clearingInProgress_)) { - ExecutionFrameEntry frameEntry; - uint32_t retry = 0; - while ( - retry < - kClearExecutionFrameRetries) { // Limit retries to avoid infinite loop - if (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) { - if (retry > 0) { - VLOG(1) << "Took " << retry - << " retries to pop from clearedExecutionFrames_"; - } - ExecutorFramePtr ptr{std::move(frameEntry.frame), *this}; - if (ptr->weightVersion() != weights->version()) { - ptr->setWeights(*weights); - } - return ptr; - } - retry++; - } - // If we couldn't get a frame from cleared pool after retries, move onto - // main pool - } - // Try to get a frame from the main pool or create a new one std::unique_ptr frame; - while (!executionFrames_.readIfNotEmpty(frame)) { + + // Try to get a frame from executionFrames_ or inactiveExecutionFrames_ + while (!executionFrames_.readIfNotEmpty(frame) && + !inactiveExecutionFrames_.readIfNotEmpty(frame)) { int64_t numFrames = numExecutionFrames_.load(); if (numFrames < executorConfig_.maxNumConcurrentThreads) { if (numExecutionFrames_.compare_exchange_strong( @@ -243,6 +217,7 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() { } void Executor::clearStaleExecutionFrames() { + LOG(INFO) << "Clearing stale execution frames"; if (!cleanupLock_.try_lock()) { // Another thread is already doing cleanup return; @@ -250,41 +225,48 @@ void Executor::clearStaleExecutionFrames() { // Update timestamp first to minimize contention lastClearedTimestamp_ = getCurrentTimestampSeconds(); - int numPopped = 0; + // Get the size of active execution frames queue directly + size_t activeFramesSize = executionFrames_.size(); + size_t inactiveFramesSize = inactiveExecutionFrames_.size(); + size_t total = activeFramesSize + inactiveFramesSize; + size_t numCleared = 0; std::unique_ptr frame; - // Move frames from executionFrames_ to clearedExecutionFrames_ - while (executionFrames_.readIfNotEmpty(frame)) { - ++numPopped; - // Keep the first popped entries up to minimum size - if (numPopped > executorConfig_.minNumExecutionFrames) { - // Discard stale frames - frame.reset(); - numExecutionFrames_ -= 1; - continue; - } + // If number of active frames is less than the configured min, then transfer + // the difference from inactive frames + size_t minFramesToKeep = std::min( + static_cast(executorConfig_.minNumExecutionFrames), total); + size_t framesToTransfer = + (minFramesToKeep - activeFramesSize) > minFramesToKeep + ? static_cast(0) + : minFramesToKeep - activeFramesSize; + ; + for (size_t i = 0; + i < framesToTransfer && inactiveExecutionFrames_.readIfNotEmpty(frame); + ++i) { + executionFrames_.writeIfNotFull(std::move(frame)); + } - ExecutionFrameEntry entry; - entry.used = false; - entry.frame = std::move(frame); - clearedExecutionFrames_.writeIfNotFull(std::move(entry)); - // Enable clients to pop from clearedExecutionFrames_ while clearing is in - // progress - clearingInProgress_ = true; + size_t newActiveFramesSize = executionFrames_.size(); + + // Clear remaining inactive frames (i.e. those that were not used in the last + // time interval) + while (inactiveExecutionFrames_.readIfNotEmpty(frame)) { + ++numCleared; + frame.reset(); + numExecutionFrames_ -= 1; } - uint32_t numPushed = 0; - ExecutionFrameEntry frameEntry; - // Move frames back from clearedExecutionFrames_ to executionFrames_ - while (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) { - ++numPushed; - executionFrames_.writeIfNotFull(std::move(frameEntry.frame)); - clearingInProgress_ = false; + // Move active frames to inactive so they are cleared next time if not used + // Check newActiveFramesSize > 0 to guuard against other threads adding + // frames to active queue during while loop + while (executionFrames_.readIfNotEmpty(frame) && newActiveFramesSize > 0) { + --newActiveFramesSize; + inactiveExecutionFrames_.writeIfNotFull(std::move(frame)); } - clearingInProgress_ = false; - VLOG(1) << "Cleared " << (numPopped - numPushed) << " out of " << numPopped - << " ExecutionFrame instances in the pool"; + LOG(INFO) << "Cleared " << numCleared << " out of " << total + << " ExecutionFrame instances in the pool"; cleanupLock_.unlock(); } @@ -292,6 +274,8 @@ void Executor::clearStaleExecutionFrames() { void Executor::returnExecutorFrameToPool( std::unique_ptr frame) { // Check if it's time to clean up stale frames + // TODO: consider moving cleanup to a dedicated thread so it does not impact + // p99 latency if (executorConfig_.doExecutionFrameCleanup && lastClearedTimestamp_ + executorConfig_.executionFramePoolCleanupIntervalSec < @@ -301,21 +285,11 @@ void Executor::returnExecutorFrameToPool( try { frame->destroyBorrowedIValues(); - - // Create an entry with used=true - if (C10_UNLIKELY(!clearingInProgress_)) { - TORCH_CHECK( - executionFrames_.writeIfNotFull(std::move(frame)), - "ExecutionFrame pool full"); - } else { - ExecutionFrameEntry frameEntry; - frameEntry.used = true; - frameEntry.frame = std::move(frame); - - TORCH_CHECK( - clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)), - "Cleared ExecutionFrame pool full"); - } + // Always return to active execution frame pool, indicating that frame was + // used in the previous time interval + TORCH_CHECK( + executionFrames_.writeIfNotFull(std::move(frame)), + "ExecutionFrame pool full"); } catch (...) { sem_.release(); throw; diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index 4f40946b4b42..64f2372b9e85 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -122,7 +122,7 @@ class Executor { std::vector getDelegates(); // Get the number of execution frames in the pool - int getNumExecutionFrames() const { + auto getNumExecutionFrames() const { return numExecutionFrames_.load(); } @@ -149,25 +149,6 @@ class Executor { void clearStaleExecutionFrames(); private: - // Structure to track execution frame usage - struct ExecutionFrameEntry { - bool used{false}; - std::unique_ptr frame; - - // Add move constructor and assignment operator - ExecutionFrameEntry() = default; - ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept - : used(other.used), frame(std::move(other.frame)) {} - ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept { - used = other.used; - frame = std::move(other.frame); - return *this; - } - // Delete copy constructor and assignment operator - ExecutionFrameEntry(const ExecutionFrameEntry&) = delete; - ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete; - }; - void maybeRunConstantFolding(const std::shared_ptr& weights); void validateInputs(const std::vector& inputs) const; @@ -188,8 +169,8 @@ class Executor { c10::Semaphore sem_; torch::nativert::detail::MPMCQueue> executionFrames_; - torch::nativert::detail::MPMCQueue - clearedExecutionFrames_; + torch::nativert::detail::MPMCQueue> + inactiveExecutionFrames_; std::atomic_int64_t numExecutionFrames_; std::unique_ptr layoutPlanner_; diff --git a/torch/nativert/executor/ExecutorConfig.h b/torch/nativert/executor/ExecutorConfig.h index cbe596788714..fb57f2b6f2ef 100644 --- a/torch/nativert/executor/ExecutorConfig.h +++ b/torch/nativert/executor/ExecutorConfig.h @@ -9,9 +9,9 @@ namespace torch::nativert { struct ExecutorConfig { bool validateInputs = false; bool debugNan = false; - bool enableStaticCPUKernels = false; + bool enableStaticCPUKernels = true; bool runConstFolding = false; - bool doExecutionFrameCleanup = false; + bool doExecutionFrameCleanup = true; bool tryFreeUnmanagedValuesAfterUse = true; // allows up to max number of concurrent threads. int64_t maxNumConcurrentThreads = 8; diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp index 827e8cd05781..d6cb74bcde8c 100644 --- a/torch/nativert/executor/memory/LayoutManager.cpp +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -105,6 +105,7 @@ void LayoutManager::ensure_managed_storages(bool allocate) { auto* tensor = planned_tensors_[i]; at::StorageImpl& storage = *tensor->storage().unsafeGetStorageImpl(); + at::TensorImpl& tensor_impl = *tensor->unsafeGetTensorImpl(); if (C10_UNLIKELY(allocate)) { // from: https://fburl.com/code/4it00yph @@ -120,7 +121,7 @@ void LayoutManager::ensure_managed_storages(bool allocate) { // // For more information, see the doc comment for // intrusive_ptr::unsafe_adapt_non_heap_allocated. - tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( + tensor_impl.set_storage_keep_dtype(at::Storage( c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( &storage_impl_buffer_.to_managed(storage), 1))); } else if ( @@ -130,12 +131,16 @@ void LayoutManager::ensure_managed_storages(bool allocate) { &storage_buf [i]) /* managed storage was replaced for some reason */) { storage.reset(); - tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( + tensor_impl.set_storage_keep_dtype(at::Storage( c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) &storage_buf[i], 1))); } + + // resize to zero so that we ensure that we don't access out-of-bounds + // addr's in the next iteration + tensor_impl.set_sizes_contiguous({0}); } } diff --git a/torch/nativert/graph/GraphUtils.cpp b/torch/nativert/graph/GraphUtils.cpp new file mode 100644 index 000000000000..ebe2d68cc0e7 --- /dev/null +++ b/torch/nativert/graph/GraphUtils.cpp @@ -0,0 +1,80 @@ +#include + +#include + +#include + +namespace torch::nativert { + +bool areAllIOTensorsAttributesOnCpu(const Node& node) { + const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta(); + + // Check inputs + for (auto& input : node.inputs()) { + if (input.value->type() == Type::Kind::Tensor) { + if (auto it = tensorValuesMeta.find(std::string{input.value->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } else if (input.value->type() == Type::Kind::TensorList) { + for (const auto& el : input.value->getListElements()) { + if (auto it = tensorValuesMeta.find(std::string{el->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } + } else { + // other input types doesn't affect if the node is on CPU or not + } + } + + // Check outputs + for (auto& output : node.outputs()) { + if (!output) { + // When a node's output is a Constant, its Value* is nullptr + // TODO: this is breaking the invariant of all nodes outputs are non-null + // in the graph. We should fix this. + continue; + } + if (output->type() == Type::Kind::Tensor) { + if (auto it = tensorValuesMeta.find(std::string{output->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } else if (output->type() == Type::Kind::TensorList) { + for (const auto& el : output->getListElements()) { + if (auto it = tensorValuesMeta.find(std::string{el->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } + } else { + // other output types doesn't affect if the node is on CPU or not + } + } + + // Check attributes + for (auto& attribute : node.attributes()) { + if (std::holds_alternative(attribute.value)) { + auto device = std::get(attribute.value); + if (!device.is_cpu()) { + return false; + } + } + } + return true; +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/GraphUtils.h b/torch/nativert/graph/GraphUtils.h new file mode 100644 index 000000000000..593317ebb29b --- /dev/null +++ b/torch/nativert/graph/GraphUtils.h @@ -0,0 +1,22 @@ +#pragma once + +namespace torch::nativert { + +class Node; + +/** + * Utility functions for working with Graph nodes and values. + */ + +/** + * Check if all input/output tensors are on CPU and all device-type attributes + * have the value of 'cpu'. This is a util function to check if a Node can use + * static dispatch CPU kernels. + * + * @param node The node to check + * @return true if all I/O tensors and device attributes are on CPU, false + * otherwise + */ +bool areAllIOTensorsAttributesOnCpu(const Node& node); + +} // namespace torch::nativert diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8eb962f8a308..1f26a4d90a4a 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -841,6 +841,46 @@ def _softmax_default(func, *args, **kwargs): return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) +@register_jagged_func( + torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any" +) +def _log_softmax_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + if isinstance(new_kwargs["dim"], tuple): + raise RuntimeError( + "log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor" + ) + + inp = new_kwargs.pop("input") + + ( + new_kwargs["dim"], + reduce_on_batch, + reduce_on_ragged, + _reduce_on_non_batch, + ) = _wrap_jagged_dims( + inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx + ) + + if reduce_on_batch: + raise RuntimeError( + "log_softmax(): not supported when reducing across the batch dimension for NestedTensor" + ) + + if reduce_on_ragged: + raise RuntimeError( + "log_softmax(): not supported when reducing along the ragged dimension for NestedTensor" + ) + + # torch.log_softmax takes in the reduction dimension as an integer + new_kwargs["dim"] = new_kwargs["dim"][0] + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + @register_jagged_func( torch.ops.aten._softmax_backward_data.default, "grad_output: jt, output: jt, dim: any, input_dtype: any", diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 6b61c3a5799d..c3219644fee8 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3472,6 +3472,7 @@ def binary_cross_entropy( size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean", + label_smoothing: float = 0.0, ) -> Tensor: r"""Compute Binary Cross Entropy between the target and input probabilities. @@ -3490,9 +3491,11 @@ def binary_cross_entropy( elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` - + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Examples:: - >>> input = torch.randn(3, 2, requires_grad=True) >>> target = torch.rand(3, 2, requires_grad=False) >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) @@ -3508,6 +3511,7 @@ def binary_cross_entropy( size_average=size_average, reduce=reduce, reduction=reduction, + label_smoothing=label_smoothing, ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) @@ -3523,6 +3527,13 @@ def binary_cross_entropy( new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) + assert 0 <= label_smoothing <= 1, ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) + + if label_smoothing > 0: + target = target * (1 - label_smoothing) + (1 - target) * label_smoothing + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) @@ -3534,6 +3545,7 @@ def binary_cross_entropy_with_logits( reduce: Optional[bool] = None, reduction: str = "mean", pos_weight: Optional[Tensor] = None, + label_smoothing: float = 0.0, ) -> Tensor: r"""Compute Binary Cross Entropy between target and input logits. @@ -3560,9 +3572,11 @@ def binary_cross_entropy_with_logits( [C, H, W] the same pos_weights across the batch. To apply the same positive weight along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. Default: ``None`` - + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Examples:: - >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> loss = F.binary_cross_entropy_with_logits(input, target) @@ -3579,6 +3593,7 @@ def binary_cross_entropy_with_logits( reduce=reduce, reduction=reduction, pos_weight=pos_weight, + label_smoothing=label_smoothing, ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) @@ -3590,6 +3605,13 @@ def binary_cross_entropy_with_logits( f"Target size ({target.size()}) must be the same as input size ({input.size()})" ) + assert 0 <= label_smoothing <= 1, ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) + + if label_smoothing > 0: + target = target * (1 - label_smoothing) + (1 - target) * label_smoothing + return torch.binary_cross_entropy_with_logits( input, target, weight, pos_weight, reduction_enum ) diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index d0b64447e900..580a768e4d9f 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -134,6 +134,7 @@ def binary_cross_entropy_with_logits( reduce: bool | None = ..., reduction: str = ..., pos_weight: Tensor | None = ..., + label_smoothing: float = ..., ) -> Tensor: ... __all__ += ["binary_cross_entropy_with_logits"] @@ -145,6 +146,7 @@ def binary_cross_entropy( size_average: bool | None = ..., reduce: bool | None = ..., reduction: str = ..., + label_smoothing: float = ..., ) -> Tensor: ... __all__ += ["binary_cross_entropy"] diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 6fa0d53c8a44..0b9468797d4c 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -692,6 +692,10 @@ class BCELoss(_WeightedLoss): elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. @@ -717,15 +721,21 @@ def __init__( size_average=None, reduce=None, reduction: str = "mean", + label_smoothing: float = 0.0, ) -> None: super().__init__(weight, size_average, reduce, reduction) + self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: """ Runs the forward pass. """ return F.binary_cross_entropy( - input, target, weight=self.weight, reduction=self.reduction + input, + target, + weight=self.weight, + reduction=self.reduction, + label_smoothing=self.label_smoothing, ) @@ -815,6 +825,10 @@ class BCEWithLogitsLoss(_Loss): [C, H, W] the same pos_weights across the batch. To apply the same positive weight along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. Default: ``None`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. @@ -838,12 +852,14 @@ def __init__( reduce=None, reduction: str = "mean", pos_weight: Optional[Tensor] = None, + label_smoothing: float = 0.0, ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) self.register_buffer("pos_weight", pos_weight) self.weight: Optional[Tensor] self.pos_weight: Optional[Tensor] + self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: """Runs the forward pass.""" @@ -853,6 +869,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: self.weight, pos_weight=self.pos_weight, reduction=self.reduction, + label_smoothing=self.label_smoothing, ) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 2764dc27bb3e..f0c4914782f3 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -568,7 +568,9 @@ def register_buffer( raise KeyError('buffer name can\'t be empty string ""') elif hasattr(self, name) and name not in self._buffers: raise KeyError(f"attribute '{name}' already exists") - elif tensor is not None and not isinstance(tensor, torch.Tensor): + elif tensor is not None and not ( + isinstance(tensor, torch.Tensor) or hasattr(tensor, "__torch_function__") + ): raise TypeError( f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " "(torch Tensor or None required)" @@ -2024,7 +2026,10 @@ def remove_from(*dicts_or_sets) -> None: else: buffers = self.__dict__.get("_buffers") if isinstance(value, Buffer) or buffers is not None and name in buffers: - if value is not None and not isinstance(value, torch.Tensor): + if value is not None and not ( + isinstance(value, torch.Tensor) + or hasattr(value, "__torch_function__") + ): raise TypeError( f"cannot assign '{torch.typename(value)}' as buffer '{name}' " "(torch.nn.Buffer, torch.Tensor or None expected)" diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_cycle b/torch/numa/__init__.py similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_cycle rename to torch/numa/__init__.py diff --git a/torch/distributed/numa/binding.py b/torch/numa/binding.py similarity index 74% rename from torch/distributed/numa/binding.py rename to torch/numa/binding.py index 51876583ec56..7e4cc40aad5b 100644 --- a/torch/distributed/numa/binding.py +++ b/torch/numa/binding.py @@ -1,28 +1,31 @@ import os import shutil +import stat import subprocess import traceback from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass from enum import Enum +from logging import getLogger +from subprocess import run +from tempfile import mkstemp from typing import Callable, Optional, TypeVar import torch from torch._utils_internal import signpost_event -from torch.distributed.elastic.utils.logging import get_logger __all__ = [ - "maybe_wrap_with_numa_bindings", "AffinityMode", + "maybe_get_temporary_python_executable_with_numa_bindings", + "maybe_wrap_command_with_numa_bindings", "NumaOptions", ] - _NUMACTL_COMMAND = "numactl" -logger = get_logger(__file__) +logger = getLogger(__name__) class AffinityMode(str, Enum): @@ -40,10 +43,10 @@ class AffinityMode(str, Enum): @dataclass(frozen=True) class NumaOptions: affinity_mode: AffinityMode + """ - If true, we will silently return the original command if any of the following occur: - - An exception is raised as we compute the wrapped command. - - During a dry run of the wrapped command, numactl fails for any reason. + If true, we will fall back to using the original command/entrypoint if we fail to compute + or apply NUMA bindings. You should avoid using this option! It is only intended as a safety mechanism for facilitating mass rollouts of numa binding. @@ -51,135 +54,201 @@ class NumaOptions: should_fall_back_if_binding_fails: bool = False -def maybe_wrap_with_numa_bindings( - *, - entrypoint: str, - local_rank_to_args: dict[int, tuple], - numa_options: Optional[NumaOptions], -) -> tuple[str, dict[int, tuple]]: +def maybe_get_temporary_python_executable_with_numa_bindings( + *, python_executable_path: str, gpu_index: int, numa_options: Optional[NumaOptions] +) -> Optional[str]: """ Args: - entrypoint: The entrypoint to the program, such as might be input to Popen. - Example: "python" - local_rank_to_args: A mapping from local rank to args for the entrypoint. - Example: {0: ("trainer.py",)} - numa_options: See NumaOptions for details. - + python_executable_path: E.g., "/usr/local/bin/python" Returns: - A tuple of (entrypoint, local_rank_to_args), basically transforming the inputs, - where the entrypoint and args may now involve numa binding. - Example: ("numactl", {"0": ("--cpunodebind=0", "--preferred=0", "python", "trainer.py")}) + Path to a temporary file. This file can be executed just like the original python + executable, except it will first apply NUMA bindings. """ if numa_options is None: - return (entrypoint, local_rank_to_args) - - wrapped_local_rank_to_args = {} - for local_rank, args in local_rank_to_args.items(): - try: - numactl_command_options = _maybe_get_numactl_options( - command_args=(entrypoint, *[str(arg) for arg in args]), - gpu_index=local_rank, - numa_options=numa_options, - ) - except Exception: - if numa_options.should_fall_back_if_binding_fails: - # NOTE: If any element of the batch fails to apply NUMA bindings - # for any reason, we do not apply NUMA bindings to any element of the batch, - # for maximum safety. This only applies if fallback is enabled. - return (entrypoint, local_rank_to_args) - raise - wrapped_local_rank_to_args[local_rank] = ( - *numactl_command_options, - entrypoint, - *args, - ) - return (_NUMACTL_COMMAND, wrapped_local_rank_to_args) + logger.info("Received numa_options=None, not creating numa executable.") + return None + + if isinstance(python_executable_path, bytes): + python_executable_path = python_executable_path.decode() + + full_numactl_command = maybe_wrap_command_with_numa_bindings( + # "$@", i.e. pass through any args the python executable would have + # received. + command_args=(python_executable_path, '"$@"'), + gpu_index=gpu_index, + numa_options=numa_options, + ) + if full_numactl_command is None: + return None + + executable_path = _get_temporary_executable_for_command( + command_args=full_numactl_command + ) + logger.info("Returning python executable with NUMA bindings %s", executable_path) -def _maybe_get_numactl_options( + return executable_path + + +def maybe_wrap_command_with_numa_bindings( *, command_args: tuple[str, ...], gpu_index: int, - numa_options: NumaOptions, -) -> tuple[str, ...]: + numa_options: Optional[NumaOptions], +) -> Optional[tuple[str, ...]]: """ Args: - command_args: The args for a command, such as might be input to Popen. - Example: ("python", "trainer.py") - gpu_index: The index of the GPU that will be used by the subprocess which executes command_args. - Example: 0 - numa_options: See NumaOptions for details. + command_args: Full shell command, like ("/usr/local/bin/python", "train.py") + gpu_index: The index of the GPU which command_args should bind to Returns: - Depending on numa_options, something like - ("--cpunodebind=0", "--preferred=0") + command_args, but wrapped so that it runs with NUMA bindings corresponding to + gpu_index and numa_options. + E.g., ("numactl", "--cpunodebind=0", "/usr/local/bin/python", "train.py") """ + if not numa_options: + logger.info("Received numa_options=None, not applying bindings.") + return None + + kwargs = { + "command_args": command_args, + "gpu_index": gpu_index, + "numa_options": numa_options, + } + logger.info("Attempting to wrap command with NUMA bindings, given input %r", kwargs) + try: _raise_if_numactl_not_available() - if numa_options.affinity_mode == AffinityMode.NODE: - numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index) - elif numa_options.affinity_mode == AffinityMode.SOCKET: - numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index) - elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE: - numactl_command_options = _get_exclusive_numactl_options( - gpu_index=gpu_index - ) - elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX: - numactl_command_options = _get_core_complex_numactl_options( - gpu_index=gpu_index - ) - else: - raise ValueError( - f"Affinity mode {numa_options.affinity_mode} not supported." - ) - if numa_options.should_fall_back_if_binding_fails: - _raise_if_numactl_fails_dry_run(numactl_options=numactl_command_options) + numactl_options = _get_numactl_cli_options( + command_args=command_args, gpu_index=gpu_index, numa_options=numa_options + ) + logger.info("Computed numactl_options=%r", numactl_options) + + _raise_if_numactl_fails_dry_run(numactl_options=numactl_options) + logger.info("Validated numactl_options=%r", numactl_options) + + full_numactl_command = _get_assembled_command_from_pieces( + command_args=command_args, numactl_options=numactl_options + ) + logger.info( + "Successfully wrapped command with numa_bindings. Returning %r", + full_numactl_command, + ) signpost_event( category="numa_binding", name="wrap_command_success", - parameters={ - "original_command_args": command_args, - "gpu_index": gpu_index, - "numa_options": numa_options, - "numactl_command_options": numactl_command_options, - }, + parameters={**kwargs, "result": full_numactl_command}, ) - return numactl_command_options + return full_numactl_command except Exception: signpost_event( category="numa_binding", name="wrap_command_exception", parameters={ + **kwargs, "traceback": traceback.format_exc(), - "original_command_args": command_args, - "gpu_index": gpu_index, - "numa_options": numa_options, }, ) logger.exception( - """Failed to wrap command with NUMA bindings. - Input: - command_args=%r, - gpu_index=%d, - numa_options=%r, - """, - command_args, - gpu_index, - numa_options, + "Failed to wrap command with NUMA bindings for input = %r", kwargs ) + if numa_options.should_fall_back_if_binding_fails: + logger.warning("Falling back to original command without NUMA bindings.") + return None raise +def _get_temporary_executable_for_command( + *, + command_args: tuple[str, ...], +) -> str: + """ + Returns: + Path to a temporary file which executes the specified command. The executable + deletes itself the first time it runs, so do not try to run it multiple times. + """ + fd, path = mkstemp( + prefix="pytorch-numa-bind", + suffix=".sh", + ) + + # We do rm first to guarantee the file deletes itself. The rest of the file + # will still run as intended. + contents = f"""#!/bin/bash + +# If this file is more than a few minutes old and still exists on your machine, +# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such +# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163. + +rm -- "$0" +{" ".join(command_args)} +""" + + with os.fdopen(fd, "w") as file: + file.write(contents) + + # Ensure the file is fully synced, in order to avoid race condition + # from trying to execute it too early. + file.flush() + os.fsync(fd) + + # Make the script executable + os.chmod(path, stat.S_IRWXU) + + logger.info( + "Created temporary executable at path %s, with contents\n%s", path, contents + ) + + return path + + +def _get_numactl_cli_options( + *, + command_args: tuple[str, ...], + gpu_index: int, + numa_options: NumaOptions, +) -> tuple[str, ...]: + """ + Args: + command_args: The args for a command, such as might be input to Popen. + Example: ("python", "trainer.py") + gpu_index: The index of the GPU that will be used by the subprocess which executes command_args. + Example: 0 + numa_options: See NumaOptions for details. + + Returns: + Depending on numa_options, something like + ("--cpunodebind=0") + """ + if numa_options.affinity_mode == AffinityMode.NODE: + numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index) + elif numa_options.affinity_mode == AffinityMode.SOCKET: + numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index) + elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE: + numactl_command_options = _get_exclusive_numactl_options(gpu_index=gpu_index) + elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX: + numactl_command_options = _get_core_complex_numactl_options(gpu_index=gpu_index) + else: + raise ValueError(f"Affinity mode {numa_options.affinity_mode} not supported.") + + return numactl_command_options + + def _raise_if_numactl_fails_dry_run(*, numactl_options: tuple[str, ...]) -> None: noop_args = _get_assembled_command_from_pieces( # Execute arbitrary noop command_args=("true",), numactl_options=numactl_options, ) + + temporary_executable_path = _get_temporary_executable_for_command( + command_args=noop_args + ) + try: - subprocess.run( - noop_args, + run( + (temporary_executable_path,), stdout=subprocess.DEVNULL, # These allow us to capture the stderr as text stderr=subprocess.PIPE, @@ -219,14 +288,11 @@ def _get_node_numactl_options(*, gpu_index: int) -> tuple[str, ...]: Core logic of 'node' numa strategy. Returns options to be used with numactl. E.g., - ("--cpunodebind=0", "--preferred=0"). + ("--cpunodebind=0"). """ numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index) - return ( - f"--cpunodebind={numa_node_index}", - f"--preferred={numa_node_index}", - ) + return (f"--cpunodebind={numa_node_index}",) def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]: @@ -242,14 +308,7 @@ def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]: ) numa_node_indices_str = _get_ranges_str_from_ints(numa_node_indices) - return ( - f"--cpunodebind={numa_node_indices_str}", - ( - f"--preferred-many={numa_node_indices_str}" - if len(numa_node_indices) > 1 - else f"--preferred={numa_node_indices_str}" - ), - ) + return (f"--cpunodebind={numa_node_indices_str}",) def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]: @@ -321,7 +380,6 @@ def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]: return ( f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}", - f"--preferred={numa_node_index}", ) @@ -371,7 +429,6 @@ def _get_core_complex_numactl_options(*, gpu_index: int) -> tuple[str, ...]: return ( f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}", - f"--preferred={numa_node_index}", ) diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index a4e3eea2e1d2..85aa513c6d02 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -726,6 +726,12 @@ def _handle_output_node( # node.args[0] can be a tuple with more than one elements. This happens when, # for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs for output in node.args[0]: # type: ignore[index,union-attr] + if output is None: + logger.warning( + "Output node %s has None output. The output is ignored in the exported graph. Please ensure the graph output order is expected", + node.name, + ) + continue output_value_name = output.name # type: ignore[union-attr] assert isinstance(output_value_name, str), ( f"Bug: Expected {output_value_name!r} to be a string" diff --git a/torch/onnx/_internal/exporter/_testing.py b/torch/onnx/_internal/exporter/_testing.py index 58f18d0cc923..c34c2f1a38c3 100644 --- a/torch/onnx/_internal/exporter/_testing.py +++ b/torch/onnx/_internal/exporter/_testing.py @@ -71,6 +71,9 @@ class names like "TorchExportNonStrictStrategy". # ONNX outputs are always real, so we need to convert torch complex outputs to real representations torch_outputs_adapted = [] for output in torch_outputs: + # ONNX graph does not support None outputs, so we skip them + if output is None: + continue if not isinstance(output, torch.Tensor): torch_outputs_adapted.append(torch.tensor(output)) elif torch.is_complex(output): diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 8bc6f0f9f4d2..80743c6a4912 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -177,6 +177,7 @@ def scaled_dot_product_attention( if symbolic_helper._is_none(attn_mask): mul_qk_add = mul_qk + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) elif ( _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL @@ -186,19 +187,24 @@ def scaled_dot_product_attention( const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values + # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. + # This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match + # the behavior of PyTorch with boolean masks. + attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight) elif _type_utils.JitScalarType.from_value(attn_mask) in ( _type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF, _type_utils.JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) else: raise ValueError( f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" ) - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) - if dropout_p != 0: attn_weight = g.op( "Dropout", diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 6f9f6f1a3cf0..58ad582bebb9 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1344,7 +1344,7 @@ def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[over warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch - if self.is_better(current, self.best): + if self._is_better(current, self.best): self.best = current self.num_bad_epochs = 0 else: @@ -1386,7 +1386,7 @@ def _reduce_lr(self, epoch): def in_cooldown(self): # noqa: D102 return self.cooldown_counter > 0 - def is_better(self, a, best): # noqa: D102 + def _is_better(self, a, best): # noqa: D102 if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold return a < best * rel_epsilon @@ -1686,6 +1686,15 @@ def get_lr(self) -> list[float]: @override def state_dict(self) -> dict[str, Any]: # noqa: D102 + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ state = super().state_dict() # We are dropping the `_scale_fn_ref` attribute because it is a # `weakref.WeakMethod` and can't be pickled. diff --git a/torch/overrides.py b/torch/overrides.py index fe7af6bc4ff0..3304cfab5e19 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -488,7 +488,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.bernoulli: lambda input, generator=None, out=None: -1, torch.bilinear: lambda input1, input2, weight, bias: -1, torch.binary_cross_entropy_with_logits: ( - lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None, label_smoothing=0.0: -1 # noqa: B950 ), torch.bincount: lambda input, weights=None, minlength=0: -1, torch.binomial: lambda count, prob, generator=None: -1, @@ -851,10 +851,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: ), torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, torch.nn.functional.binary_cross_entropy: ( - lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1 + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", label_smoothing=0.0: -1 ), torch.nn.functional.binary_cross_entropy_with_logits: ( - lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None, label_smoothing=0.0: -1 # noqa: B950 ), torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, torch.nn.functional.cosine_embedding_loss: ( diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 09d7901c2d6c..08b0560f7932 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -2,6 +2,7 @@ """Import mangling. See mangling.md for details. """ + import re diff --git a/torch/package/importer.py b/torch/package/importer.py index 49b4512f79a6..8cfc1e336a45 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import importlib +import logging from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] _getattribute, @@ -13,6 +14,7 @@ __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] +log = logging.getLogger(__name__) class ObjNotFoundError(Exception): @@ -204,6 +206,20 @@ def _is_torchpackage_dummy(self, module): return True return module.__file__ is None + def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: + for importer in self._importers: + try: + return importer.get_name(obj, name) + except (ObjNotFoundError, ObjMismatchError) as e: + warning_message = ( + f"Tried to call get_name with obj {obj}, " + f"and name {name} on {importer} and got {e}" + ) + log.warning(warning_message) + raise ObjNotFoundError( + f"Could not find obj {obj} and name {name} in any of the importers {self._importers}" + ) + def import_module(self, module_name: str) -> ModuleType: last_err = None for importer in self._importers: diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 21446c626b9a..6118e8ce8096 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -605,9 +605,9 @@ def save_pickle( dependencies (bool, optional): If ``True``, we scan the source for dependencies. """ - assert (pickle_protocol == 4) or ( - pickle_protocol == 3 - ), "torch.package only supports pickle protocols 3 and 4" + assert (pickle_protocol == 4) or (pickle_protocol == 3), ( + "torch.package only supports pickle protocols 3 and 4" + ) filename = self._filename(package, resource) # Write the pickle data for `obj` diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index a97cf475b350..7291227e42ae 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -423,7 +423,12 @@ def _load_module(self, name: str, parent: str): module.__dict__.setdefault(old_name, new_name) return module - return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] + return self._make_module( + name, + cur.source_file, # type: ignore[attr-defined] + isinstance(cur, _PackageNode), + parent, + ) def _compile_source(self, fullpath: str, mangled_filename: str): source = self.zip_reader.get_record(fullpath) diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index a90a371130e7..153d4560e264 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -7,6 +7,7 @@ An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. """ + import os from typing import Any from typing_extensions import TypeVarTuple, Unpack diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 7ad917d1e86b..d9f3a917c152 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -239,10 +239,12 @@ def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> tuple[Optional[bool], .. def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: signature = tuple( # Tensor - TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + TensorKey.from_tensor(i) + if isinstance(i, _TensorMetadata) # # TensorList - else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) + else [TensorKey.from_tensor(j) for j in i] + if isinstance(i, list) # # Scalar and uncaptured inputs. else i diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index b1160324cb90..5b631ef743c6 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -124,9 +124,9 @@ def compute_self_time(self) -> None: for child_event in curr_event.children: self_time -= child_event.duration_time_ns stack.append(child_event) - assert ( - EventKey(curr_event) not in self.metrics - ), f"Duplicate id: {curr_event.id}, {curr_event.name}" + assert EventKey(curr_event) not in self.metrics, ( + f"Duplicate id: {curr_event.id}, {curr_event.name}" + ) self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) self.metrics[ EventKey(curr_event) @@ -227,8 +227,7 @@ def new_old_event_comparator(event): while ( current_kernel_index < len(cuda_kernel_events) - and (cuda_kernel_events[current_kernel_index].start_ns()) - <= start_time # type: ignore[possibly-undefined] + and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined] ): current_kernel_index += 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1 @@ -352,11 +351,11 @@ def get_optimizable_events(self, length: int = 1, print_enable: bool = True): output += "\n".join( [ - f"""{'-' * 80} + f"""{"-" * 80} Event: {event} Source code location: {source_code_location(event.event)} Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% -{'-' * 80}""" +{"-" * 80}""" for event in event_list ] ) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f7be416cfaa7..d88d6c5cad72 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -624,8 +624,7 @@ class profile(_KinetoProfile): ] ) as p: code_to_profile() - print(p.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1)) + print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: @@ -635,16 +634,17 @@ class profile(_KinetoProfile): # on different iterations of the training loop; # trace_handler is called every time a new trace becomes available def trace_handler(prof): - print(prof.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1)) + print( + prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1) + ) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], - # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record @@ -652,20 +652,15 @@ def trace_handler(prof): # after which the trace will become available # and on_trace_ready (when set) is called; # the cycle repeats starting with the next step - - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2, - repeat=1), - on_trace_ready=trace_handler + schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=trace_handler, # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') # used when outputting for tensorboard - ) as p: - for iter in range(N): - code_iteration_to_profile(iter) - # send a signal to the profiler that the next iteration has started - p.step() + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index cfb13ac96271..5a68fbf02015 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -6,6 +6,7 @@ `torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement here. """ + from torch.ao.quantization.fuser_method_mappings import ( _DEFAULT_OP_LIST_TO_FUSER_METHOD, fuse_conv_bn, diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index 7acea4f84a2a..d6b8611d4a76 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx._equalize import ( _convert_equalization_ref, _InputEqualizationObserver, diff --git a/torch/quantization/fx/convert.py b/torch/quantization/fx/convert.py index 9d6ac350602b..30a661da41e5 100644 --- a/torch/quantization/fx/convert.py +++ b/torch/quantization/fx/convert.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.convert import convert diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 67527080304f..22ad750e9f87 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.fuse import fuse diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index e29337b3f861..982d919655f3 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/torch/quantization/fx/graph_module.py b/torch/quantization/fx/graph_module.py index a71e980a57ba..74b63903d740 100644 --- a/torch/quantization/fx/graph_module.py +++ b/torch/quantization/fx/graph_module.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.graph_module import ( _is_observed_module, _is_observed_standalone_module, diff --git a/torch/quantization/fx/match_utils.py b/torch/quantization/fx/match_utils.py index 8b49f7c645d8..8585a21ad445 100644 --- a/torch/quantization/fx/match_utils.py +++ b/torch/quantization/fx/match_utils.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.match_utils import ( _find_matches, _is_match, diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index 2a83e180fc4d..fa601d1eb619 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.pattern_utils import ( _register_fusion_pattern, _register_quant_pattern, diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index ca65dcc04dd0..a6007ef242af 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.prepare import prepare diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 20d8cc52ee4f..89f8d4406e91 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.quantize_handler import ( BatchNormQuantizeHandler, BinaryOpQuantizeHandler, diff --git a/torch/quantization/fx/quantization_types.py b/torch/quantization/fx/quantization_types.py index a422cdd3142e..0820ea057078 100644 --- a/torch/quantization/fx/quantization_types.py +++ b/torch/quantization/fx/quantization_types.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index ef35559884b7..e45c82b8fb6f 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.utils import ( all_node_args_have_no_tensors, assert_and_get_unique_device, diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 6e6c7c1917c8..2163e2717b06 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -6,6 +6,7 @@ `torch/ao/quantization/observer.py`, while adding an import statement here. """ + from torch.ao.quantization.observer import ( _is_activation_post_process, _is_per_channel_script_obs_instance, diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 6bb7e14110cb..a02ff7d6f738 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -6,6 +6,7 @@ `torch/ao/quantization/qconfig.py`, while adding an import statement here. """ + from torch.ao.quantization.qconfig import ( _add_module_to_qconfig_obs_ctr, _assert_valid_qconfig, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 8b44a980ce82..faa24d391d31 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -6,6 +6,7 @@ `torch/ao/quantization/quantization_mappings.py`, while adding an import statement here. """ + from torch.ao.quantization.quantization_mappings import ( _get_special_act_post_process, _has_special_act_post_process, diff --git a/torch/serialization.py b/torch/serialization.py index 61a4acf68415..a6eb314fc1a8 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1426,7 +1426,7 @@ def _get_wo_message(message: str) -> str: "Please file an issue with the following so that we can make " "`weights_only=True` compatible with your use case: WeightsUnpickler error: " ) - updated_message += message + updated_message += "\n\n" + message return updated_message + DOCS_MESSAGE weights_only_not_set = weights_only is None diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 7d67de3f8384..e68c202f03e8 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -128,9 +128,7 @@ def _window_function_checks( >>> # Generates a periodic exponential window and decay factor equal to .5 >>> torch.signal.windows.exponential(10, sym=False,tau=.5) tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) - """.format( - **window_common_args - ), + """.format(**window_common_args), ) def exponential( M: int, @@ -452,9 +450,7 @@ def kaiser( >>> # Generates a periodic Hamming window. >>> torch.signal.windows.hamming(10, sym=False) tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def hamming( M: int, @@ -508,9 +504,7 @@ def hamming( >>> # Generates a periodic Hann window. >>> torch.signal.windows.hann(10, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def hann( M: int, @@ -564,9 +558,7 @@ def hann( >>> # Generates a periodic Blackman window. >>> torch.signal.windows.blackman(5, sym=False) tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def blackman( M: int, @@ -627,9 +619,7 @@ def blackman( >>> # Generates a periodic Bartlett window. >>> torch.signal.windows.bartlett(10, sym=False) tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def bartlett( M: int, @@ -704,9 +694,7 @@ def bartlett( >>> # Generates a periodic general cosine window with 2 coefficients. >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def general_cosine( M, @@ -799,9 +787,7 @@ def general_cosine( >>> # Generates a periodic Hann window with the general Hamming window. >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def general_hamming( M, @@ -866,9 +852,7 @@ def general_hamming( >>> # Generates a periodic Nuttall window. >>> torch.signal.windows.general_hamming(5, sym=False) tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def nuttall( M: int, diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 39d78e8c26ab..31299314a85f 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck): For example: >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) - >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) + >>> x = ( + ... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64) + ... .to_sparse_coo() + ... .requires_grad_(True) + ... ) >>> gradcheck(lambda x: x.to_sparse_csr(), x) True """ @@ -667,7 +671,7 @@ def restore_from_strided_representation(args): ) else: raise NotImplementedError( - f'conversion of {d["layout"]} strided representation to tensor' + f"conversion of {d['layout']} strided representation to tensor" ) new_args.append(a) return tuple(new_args) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index a5e802084c28..ea36264d8f82 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): for b in range(nbatches): for i, r in enumerate(r_offsets): r0, r1 = divmod(r, N) - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] - for g in range(c_indices[i], c_indices[i+1]): + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for g in range(c_indices[i], c_indices[i + 1]): p = p_offsets[g] q0, q1 = divmod(q_offsets[g], N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. @@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): n = (r % N) // Ns r0, r1 = divmod(r, N) c0, c1 = c_indices[m], c_indices[m + 1] - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for i, p in enumerate(range(c0, c1)): q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] q0, q1 = divmod(q, N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 762874077c7a..89245246395a 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -97,6 +97,7 @@ kernel parameters for addmm-based operations. """ + __all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] import inspect @@ -432,9 +433,9 @@ def from_key(key, parameters): def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): - assert ( - sparsity <= 1.0 and sparsity >= 0.0 - ), "sparsity should be a value between 0 and 1" + assert sparsity <= 1.0 and sparsity >= 0.0, ( + "sparsity should be a value between 0 and 1" + ) assert M % blocksize[0] == 0 assert N % blocksize[1] == 0 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 721f25512794..b225eaabb320 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -465,14 +465,26 @@ def prune_dense_static_sort( The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUTLASS - from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + from torch.sparse._semi_structured_conversions import ( + _sparse_semi_structured_tile, + _compute_compressed_swizzled_bitmask, + ) pruned = _sparse_semi_structured_tile(dense) packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) - packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass( + pruned.t().contiguous() + ) bitmask = _compute_compressed_swizzled_bitmask(pruned) - SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) + SparseSemiStructuredTensorCUTLASS( + dense.shape, + packed_cutlass, + meta_cutlass, + packed_t_cutlass, + meta_t_cutlass, + bitmask, + ) ``` """ # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. @@ -583,14 +595,19 @@ def prune_dense_static_sort( The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + from torch.sparse._semi_structured_conversions import ( + _sparse_semi_structured_tile, + _compute_compressed_swizzled_bitmask, + ) pruned = _sparse_semi_structured_tile(dense) packed_cusparselt = torch._cslt_compress(pruned) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) bitmask = _compute_compressed_swizzled_bitmask(pruned) - SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) + SparseSemiStructuredTensorCUSPARSELT( + dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask + ) ``` """ ( diff --git a/torch/special/__init__.py b/torch/special/__init__.py index be027caa94cb..dbc9314ad208 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -134,9 +134,7 @@ >>> torch.special.digamma(a) tensor([-0.5772, -1.9635]) -""".format( - **common_args - ), +""".format(**common_args), ) gammaln = _add_docstr( @@ -162,9 +160,7 @@ >>> torch.special.gammaln(a) tensor([ 0.5724, 0.0000, -0.1208]) -""".format( - **common_args - ), +""".format(**common_args), ) polygamma = _add_docstr( @@ -200,9 +196,7 @@ tensor([ 6.4939, 97.4091]) >>> torch.special.polygamma(4, a) tensor([ -24.8863, -771.4742]) -""".format( - **common_args - ), +""".format(**common_args), ) erf = _add_docstr( @@ -226,9 +220,7 @@ >>> torch.special.erf(torch.tensor([0, -1., 10.])) tensor([ 0.0000, -0.8427, 1.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) erfc = _add_docstr( @@ -253,9 +245,7 @@ >>> torch.special.erfc(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 1.8427, 0.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) erfcx = _add_docstr( @@ -283,9 +273,7 @@ >>> torch.special.erfcx(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 5.0090, 0.0561]) -""".format( - **common_args - ), +""".format(**common_args), ) erfinv = _add_docstr( @@ -311,9 +299,7 @@ >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) tensor([ 0.0000, 0.4769, -inf]) -""".format( - **common_args - ), +""".format(**common_args), ) logit = _add_docstr( @@ -351,9 +337,7 @@ tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) >>> torch.special.logit(a, eps=1e-6) tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) -""".format( - **common_args - ), +""".format(**common_args), ) logsumexp = _add_docstr( @@ -362,9 +346,7 @@ logsumexp(input, dim, keepdim=False, *, out=None) Alias for :func:`torch.logsumexp`. -""".format( - **multi_dim_common - ), +""".format(**multi_dim_common), ) expit = _add_docstr( @@ -391,9 +373,7 @@ tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) >>> torch.special.expit(t) tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) -""".format( - **common_args - ), +""".format(**common_args), ) exp2 = _add_docstr( @@ -418,9 +398,7 @@ >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) tensor([ 1., 2., 8., 16.]) -""".format( - **common_args - ), +""".format(**common_args), ) expm1 = _add_docstr( @@ -448,9 +426,7 @@ >>> torch.special.expm1(torch.tensor([0, math.log(2.)])) tensor([ 0., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) xlog1py = _add_docstr( @@ -495,9 +471,7 @@ tensor([1.6094, 3.2189, 4.8283]) >>> torch.special.xlog1py(2, y) tensor([2.7726, 2.1972, 1.3863]) -""".format( - **common_args - ), +""".format(**common_args), ) xlogy = _add_docstr( @@ -542,9 +516,7 @@ tensor([1.3863, 2.7726, 4.1589]) >>> torch.special.xlogy(2, y) tensor([2.1972, 1.3863, 0.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) i0 = _add_docstr( @@ -570,9 +542,7 @@ >>> torch.i0(torch.arange(5, dtype=torch.float32)) tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) -""".format( - **common_args - ), +""".format(**common_args), ) i0e = _add_docstr( @@ -597,9 +567,7 @@ >>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) -""".format( - **common_args - ), +""".format(**common_args), ) i1 = _add_docstr( @@ -624,9 +592,7 @@ >>> torch.special.i1(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) -""".format( - **common_args - ), +""".format(**common_args), ) i1e = _add_docstr( @@ -652,9 +618,7 @@ >>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) -""".format( - **common_args - ), +""".format(**common_args), ) ndtr = _add_docstr( @@ -679,9 +643,7 @@ >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) -""".format( - **common_args - ), +""".format(**common_args), ) ndtri = _add_docstr( @@ -709,9 +671,7 @@ >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) -""".format( - **common_args - ), +""".format(**common_args), ) log_ndtr = _add_docstr( @@ -736,9 +696,7 @@ >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) -""".format( - **common_args - ), +""".format(**common_args), ) log1p = _add_docstr( @@ -779,9 +737,7 @@ tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) >>> torch.special.sinc(t) tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) -""".format( - **common_args - ), +""".format(**common_args), ) round = _add_docstr( @@ -886,9 +842,7 @@ tensor([1.6449, 0.0823]) >>> torch.special.zeta(2, torch.tensor([1., 2.])) tensor([1.6449, 0.6449]) -""".format( - **common_args - ), +""".format(**common_args), ) multigammaln = _add_docstr( @@ -925,9 +879,7 @@ >>> torch.special.multigammaln(a, 2) tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]]) -""".format( - **common_args - ), +""".format(**common_args), ) gammainc = _add_docstr( @@ -976,9 +928,7 @@ >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) gammaincc = _add_docstr( @@ -1026,9 +976,7 @@ >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) airy_ai = _add_docstr( @@ -1045,9 +993,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_j0 = _add_docstr( @@ -1064,9 +1010,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_j1 = _add_docstr( @@ -1083,9 +1027,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_y0 = _add_docstr( @@ -1102,9 +1044,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_y1 = _add_docstr( @@ -1121,9 +1061,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_t = _add_docstr( @@ -1154,9 +1092,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_u = _add_docstr( @@ -1188,9 +1124,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_v = _add_docstr( @@ -1208,9 +1142,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_w = _add_docstr( @@ -1228,9 +1160,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) hermite_polynomial_h = _add_docstr( @@ -1256,9 +1186,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) hermite_polynomial_he = _add_docstr( @@ -1284,9 +1212,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) laguerre_polynomial_l = _add_docstr( @@ -1312,9 +1238,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) legendre_polynomial_p = _add_docstr( @@ -1340,9 +1264,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_i0 = _add_docstr( @@ -1359,9 +1281,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_i1 = _add_docstr( @@ -1378,9 +1298,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_k0 = _add_docstr( @@ -1397,9 +1315,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_k1 = _add_docstr( @@ -1416,9 +1332,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) scaled_modified_bessel_k0 = _add_docstr( @@ -1435,9 +1349,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) scaled_modified_bessel_k1 = _add_docstr( @@ -1454,9 +1366,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_t = _add_docstr( @@ -1474,9 +1384,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_u = _add_docstr( @@ -1494,9 +1402,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_v = _add_docstr( @@ -1514,9 +1420,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_w = _add_docstr( @@ -1534,9 +1438,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) spherical_bessel_j0 = _add_docstr( @@ -1553,7 +1455,5 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 228c04cd312f..eff07c413deb 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1538,7 +1538,9 @@ def assert_close( >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. - >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") + >>> torch.testing.assert_close( + ... actual, expected, msg="Argh, the tensors are not close!" + ... ) Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index e513b8d85603..23d80d6ceae4 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -115,11 +115,11 @@ def make_tensor( >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) - >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + >>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1) >>> # xdoctest: +SKIP tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA - >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) + >>> make_tensor((2, 2), device="cuda", dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0') """ diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 0e95db1fdf37..dca0275f3887 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -291,7 +291,7 @@ def _get_torch_rocm_version(): if not TEST_WITH_ROCM or torch.version.hip is None: return (0, 0) rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha return tuple(int(x) for x in rocm_version.split(".")) def _check_cusparse_generic_available(): @@ -304,7 +304,7 @@ def _check_hipsparse_generic_available(): return False rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 01499280da8f..528497ba5457 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo intersect = set(except_for if except_for else []) & set( only_for if only_for else [] ) - assert ( - not intersect - ), f"device ({intersect}) appeared in both except_for and only_for" + assert not intersect, ( + f"device ({intersect}) appeared in both except_for and only_for" + ) # Replace your privateuse1 backend name with 'privateuse1' if is_privateuse1_backend_available(): @@ -1407,9 +1407,9 @@ def __init__(self, num_required_devices): self.num_required_devices = num_required_devices def __call__(self, fn): - assert not hasattr( - fn, "num_required_devices" - ), f"deviceCountAtLeast redefinition for {fn.__name__}" + assert not hasattr(fn, "num_required_devices"), ( + f"deviceCountAtLeast redefinition for {fn.__name__}" + ) fn.num_required_devices = self.num_required_devices @wraps(fn) @@ -1474,13 +1474,13 @@ def only_fn(self, *args, **kwargs): # self.precision *2, max(1, self.precision)). class precisionOverride: def __init__(self, d): - assert isinstance( - d, dict - ), "precisionOverride not given a dtype : precision dict!" + assert isinstance(d, dict), ( + "precisionOverride not given a dtype : precision dict!" + ) for dtype in d.keys(): - assert isinstance( - dtype, torch.dtype - ), f"precisionOverride given unknown dtype {dtype}" + assert isinstance(dtype, torch.dtype), ( + f"precisionOverride given unknown dtype {dtype}" + ) self.d = d @@ -1513,12 +1513,12 @@ class toleranceOverride: def __init__(self, d): assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" for dtype, prec in d.items(): - assert isinstance( - dtype, torch.dtype - ), f"toleranceOverride given unknown dtype {dtype}" - assert isinstance( - prec, tol - ), "toleranceOverride not given a dtype : tol dict!" + assert isinstance(dtype, torch.dtype), ( + f"toleranceOverride given unknown dtype {dtype}" + ) + assert isinstance(prec, tol), ( + "toleranceOverride not given a dtype : tol dict!" + ) self.d = d @@ -1546,13 +1546,13 @@ def __init__(self, *args, device_type="all"): "all dtype variants must be. " f"Received non-list non-tuple dtype {str(arg)}" ) - assert all( - isinstance(dtype, torch.dtype) for dtype in arg - ), f"Unknown dtype in {str(arg)}" + assert all(isinstance(dtype, torch.dtype) for dtype in arg), ( + f"Unknown dtype in {str(arg)}" + ) else: - assert all( - isinstance(arg, torch.dtype) for arg in args - ), f"Unknown dtype in {str(args)}" + assert all(isinstance(arg, torch.dtype) for arg in args), ( + f"Unknown dtype in {str(args)}" + ) self.args = args self.device_type = device_type diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index af1aafd3871a..d4cc6cde3cc5 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr): if err_substr.find("\nException raised from ") == -1 else err_substr.split("\nException raised from ")[0] ) - assert ( - actual in logging_err - ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" + assert actual in logging_err, ( + f"Did not find expected {actual} in ddp logging data error: {logging_err}" + ) def with_nccl_blocking_wait(func): @@ -294,9 +294,9 @@ def wrapper(*args, **kwargs): finally: # restore old values. if cached_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = cached_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + cached_nccl_async_error_handling + ) if cached_nccl_blocking_wait is not None: os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait @@ -812,7 +812,7 @@ def run_test(self, test_name: str, parent_pipe) -> None: sys.exit(TEST_SKIPS["generic"].exit_code) except Exception: logger.error( - "Caught exception: \n%s exiting " "process %s with exit code: %s", + "Caught exception: \n%s exiting process %s with exit code: %s", traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, @@ -1605,7 +1605,7 @@ def _init_pg(cls, rank, world_size, rdvz_file): @classmethod def _run_test_given_id(cls, test_id: str, **kwargs) -> None: # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' - test_name = test_id.split(".")[-1] + test_name = test_id.rsplit(".", maxsplit=1)[-1] # Get the test function from the test class self = cls(test_name) self.rank = cls.rank @@ -1689,9 +1689,7 @@ def _spawn_processes(cls, world_size) -> None: cls.processes.append(process) cls.task_queues.append(task_queue) cls.completion_queues.append(completion_queue) - logger.info( - "Started process %s with pid %s", rank, process.pid - ) # noqa: UP031 + logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031 @classmethod def setUpClass(cls): diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 774ce179f33e..474bb689f0ad 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -121,6 +121,19 @@ def all_types_and_half(): return _all_types_and_half +_all_mps_types = ( + _dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types +) + + +def all_mps_types(): + return _all_mps_types + + +def all_mps_types_and(*dtypes): + return _all_mps_types + _validate_dtypes(*dtypes) + + _float8_types = _dispatch_dtypes( ( torch.float8_e4m3fn, diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index a9e24eb90ef8..0e50762893d7 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1285,10 +1285,10 @@ def _train_for_several_steps( loss = sharded_grad_scaler.scale(loss) if not mixed_precision and not use_pure_fp16: - assert ( - loss.dtype == torch.float32 - ), "loss data type should be float32, as the original \ + assert loss.dtype == torch.float32, ( + "loss data type should be float32, as the original \ parameter data type is float32." + ) else: if use_pure_fp16: self.assertEqual(loss.dtype, torch.float16) @@ -1354,9 +1354,9 @@ def _test_fsdp_parity( wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ - assert ( - fsdp_init_mode != FSDPInitMode.NO_FSDP - ), "Expects an FSDP init mode that wraps with FSDP" + assert fsdp_init_mode != FSDPInitMode.NO_FSDP, ( + "Expects an FSDP init mode that wraps with FSDP" + ) if init_kwargs is None: init_kwargs = {} lr = 1e-2 diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 41bb2b96bd93..506bf5488f3c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1161,8 +1161,8 @@ def make_arg_conj(size): def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): - alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) - beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6 if dtype.is_floating_point else 2) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2 if dtype.is_floating_point else 3) tests_list = [ ((2, 3), (2, 2), (2, 3), False), ((3, 3), (3, 3), (3, 3), False), diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 479c7119a5bf..f42ae06e7b30 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1463,9 +1463,14 @@ def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, tr ('reduction_mean', {'reduction': 'mean'}), ('reduction_none', {'reduction': 'none'}), ('weights', {'weight': make_weight((10,))}), + ('label_smoothing', {'label_smoothing': 0.15}), ] - def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): + def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None, label_smoothing=0.0): + assert 0 <= label_smoothing <= 1 + if label_smoothing > 0: + t = t * (1 - label_smoothing) + (1 - t) * label_smoothing + result = -(t * i.log() + (1 - t) * (1 - i).log()) if weight is not None: @@ -1511,10 +1516,15 @@ def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, require ('reduction_mean', {'reduction': 'mean'}), ('reduction_none', {'reduction': 'none'}), ('weights', {'weight': make_weight((10,))}), - ('scalar_weights', {'weight': make_weight(())}) + ('scalar_weights', {'weight': make_weight(())}), + ('label_smoothing', {'label_smoothing': 0.15}), ] - def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None): + def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None, label_smoothing=0.0): + assert 0 <= label_smoothing <= 1 + if label_smoothing > 0: + t = t * (1 - label_smoothing) + (1 - t) * label_smoothing + # TODO: add pos_weight to the definition here and corresponding SampleInputs max_val = (-i).clamp(min=0) result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_()) @@ -4064,11 +4074,6 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.LocalResponseNorm, module_inputs_func=module_inputs_torch_nn_LocalResponseNorm, - skips=( - # uses avg_pool3d which is not supported on MPS backend - DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'), - DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'), - DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'),) ), ModuleInfo(torch.nn.LayerNorm, module_inputs_func=module_inputs_torch_nn_LayerNorm, diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index a11f2cd3974b..0391a314568a 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -25,6 +25,7 @@ def mps_ops_modifier( "__rsub__", "__getitem__", "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", "abs", "add", "alias_copy", @@ -75,6 +76,7 @@ def mps_ops_modifier( "imag", "index_copy", "index_select", + "index_put", "isfinite", "isinf", "isreal", @@ -284,85 +286,6 @@ def mps_ops_modifier( "where", "byte", } - # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 - MACOS_BEFORE_13_3_XFAILLIST = { - # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ - "cdist": [torch.float32], - # CPU Error: cpu not giving nan for x/0.0 - "atan2": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - # test blow pass on macOS 12 as it falls back to cpu - # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') - # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') - # Elements from index 30 and 5133 are both equal. - # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. - "argsort": [torch.float16, torch.int8, torch.uint8, torch.bool], - # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. - # The values of the sorted tensor match the CPU, - # but in case of the returned indices this results in undefined behaviour. - "sort": [torch.int8, torch.uint8, torch.bool, torch.float16], - # Unsupported dtypes - "cumsum": [torch.int64], - "cumprod": [torch.int64], - "cumulative_trapezoid": [torch.int64], - "masked.cumsum": [torch.int64], - "masked.cumprod": [torch.int64], - "linalg.vander": [torch.int64], - # Fail with `Expected 1.0 but got nan.` for empty tensors - # Caused by sample input at index 23: SampleInput( - # input=Tensor[size=(), device="mps:0", dtype=torch.float32], - # args=(0), - # kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'}, - # broadcasts_input=False, name='') - "masked.softmin": [torch.float32, torch.float16], - "masked.softmax": [torch.float32, torch.float16], - "masked.log_softmax": [torch.float32, torch.float16], - } - - MACOS_AFTER_13_1_XFAILLIST = { - # before macOS 13.2 it falls back to cpu and pass the forward pass - "grid_sampler_2d": [ - torch.float32, - torch.float16, - torch.bfloat16, - ], # Unsupported Border padding mode - } - - MACOS_13_3_XFAILLIST = { - # Failure due to precision issue for fp16 - # on both cpu and mps there are test cases that might produce inf result - # 'nn.functional.pairwise_distance': [torch.float16], - # test blow pass on macOS 12 as it falls back to cpu - # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') - # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') - # Elements from index 30 and 5133 are both equal. - # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. - "argsort": [ - torch.float16, - torch.int8, - torch.uint8, - torch.bool, - torch.bfloat16, - ], - # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. - # The values of the sorted tensor match the CPU, - # but in case of the returned indices this results in undefined behaviour. - "sort": [ - torch.int8, - torch.uint8, - torch.bool, - torch.float16, - torch.bfloat16, - ], - } MACOS_BEFORE_14_4_XFAILLIST = { # These ops work fine in 14.4 but fail in 14.2 or 13.x @@ -424,14 +347,8 @@ def mps_ops_modifier( "nn.functional.adaptive_max_pool3d": None, "nn.functional.interpolatearea": None, "nn.functional.interpolatebicubic": [torch.uint8], - "nn.functional.max_unpool1dgrad": None, - "nn.functional.max_unpool2dgrad": None, - "nn.functional.max_unpool3dgrad": None, "nn.functional.ctc_loss": None, "nn.functional.embedding_bag": None, - "nn.functional.max_unpool1d": None, - "nn.functional.max_unpool2d": None, - "nn.functional.max_unpool3d": None, "nn.functional.multi_margin_loss": None, "nn.functional.multilabel_margin_loss": None, "nn.functional.pdist": None, @@ -501,7 +418,6 @@ def mps_ops_modifier( torch.float16, ], # Unsupported dtypes - "dot": [torch.int64] if MACOS_VERSION < 14.0 else [], "histc": [torch.float16, torch.bfloat16], "index_add": [torch.int64], # GEMM on MPS is not supported for integral types @@ -512,19 +428,9 @@ def mps_ops_modifier( torch.uint8, torch.int8, ], - "addmmdecomposed": [ - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [], - "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [], # returned output on CPU is float64 "bincount": [ torch.int16, @@ -629,6 +535,38 @@ def mps_ops_modifier( "linalg.matrix_rank": None, # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")` "arange": [torch.uint8], + # before macOS 13.2 it falls back to cpu and pass the forward pass + "grid_sampler_2d": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], # Unsupported Border padding mode + # Failure due to precision issue for fp16 + # on both cpu and mps there are test cases that might produce inf result + # 'nn.functional.pairwise_distance': [torch.float16], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [ + torch.float16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + ], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [ + torch.int8, + torch.uint8, + torch.bool, + torch.float16, + torch.bfloat16, + ], } EMPTY_OPS_SKIPLIST = { @@ -696,43 +634,6 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: ), ) - if ( - key in MACOS_BEFORE_13_3_XFAILLIST - and key not in xfail_exclusion - and (torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, - dtypes=MACOS_BEFORE_13_3_XFAILLIST[key], - ), - ) - - if ( - key in MACOS_AFTER_13_1_XFAILLIST - and key not in xfail_exclusion - and torch.backends.mps.is_macos13_or_newer(2) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, dtypes=MACOS_AFTER_13_1_XFAILLIST[key] - ), - ) - - if ( - key in MACOS_13_3_XFAILLIST - and key not in xfail_exclusion - and (MACOS_VERSION >= 13.3) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST[key] - ), - ) - # If ops is not supported for complex types, expect it to fail if key not in SUPPORTED_COMPLEX_OPS: addDecorator( @@ -823,7 +724,6 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "round": [torch.float16], # topk fails with duplicate indices "topk": [torch.float16], - "nn.functional.avg_pool3d": [torch.float32], } SKIPLIST_GRAD = { diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 780514e67439..96bab4a084c4 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs( trivial. That said, we sometimes want to test for all possible configs on an optimizer including all supported flags, so this helper returns all optim inputs. """ - assert all( - x in ["foreach", "fused", "differentiable"] for x in skip - ), "skip must be a subset of ['foreach', 'fused', 'differentiable']" + assert all(x in ["foreach", "fused", "differentiable"] for x in skip), ( + "skip must be a subset of ['foreach', 'fused', 'differentiable']" + ) optim_inputs = optim_info.optim_inputs_func(device) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 211b282c4fc4..f8671379950e 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -611,7 +611,7 @@ def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): - # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py + # source: https://github.com/meta-pytorch/gpt-fast/blob/main/quantize.py # default setup for affine quantization of activations x_dtype = x.dtype x = x.float() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 384db57e92ec..bfc568bc1464 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -297,7 +297,7 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) -NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name()) +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name()) # used for managing devices testing for torch profiler UTs # for now cpu, cuda and xpu are added for testing torch profiler UTs @@ -329,9 +329,10 @@ def extract_test_fn() -> Optional[Callable]: self_val = frame.f_locals["self"] if isinstance(self_val, unittest.TestCase): test_id = self_val.id() - test_name = test_id.split('.')[2] - test_fn = getattr(self_val, test_name).__func__ - return test_fn + *_, cls_name, test_name = test_id.rsplit('.', 2) + if cls_name == type(self_val).__name__ and test_name.startswith("test"): + test_fn = getattr(self_val, test_name).__func__ + return test_fn except Exception: pass return None @@ -2016,7 +2017,7 @@ def dec_fn(fn): def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): reason = f"ROCm {rocm_version_tuple} is available but {version} required" diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 94bfead8a0c0..4eb6677a035e 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -31,6 +31,7 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, MultiProcessTestCase, MultiThreadedTestCase, run_subtests, @@ -41,6 +42,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +DEVICE_COUNT: int + if TEST_CUDA: DEVICE_TYPE = "cuda" PG_BACKEND = "nccl" @@ -334,6 +337,21 @@ def skip_unless_torch_gpu(method: T) -> T: return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) +class DTensorContinuousTestBase(MultiProcContinousTest): + @classmethod + def device_type(cls) -> str: + # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU + if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < cls.world_size: + return "cpu" + else: + return DEVICE_TYPE + + @classmethod + def backend_str(cls) -> str: + backend = dist.get_default_backend_for_device(DEVICE_TYPE) + return backend + + class DTensorTestBase(MultiProcessTestCase): @property def world_size(self) -> int: @@ -355,22 +373,26 @@ def backend(self) -> str: def build_device_mesh(self) -> DeviceMesh: return init_device_mesh(self.device_type, (self.world_size,)) - def init_pg(self, eager_init) -> None: + def init_pg(self, eager_init, backend: Optional[str] = None) -> None: if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - if self.backend not in [ + if backend is None: + backend = self.backend + + if backend not in [ "nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl", "xccl", + "fake", ]: - raise RuntimeError(f"Backend {self.backend} not supported!") + raise RuntimeError(f"Backend {backend} not supported!") device_id = None - if "nccl" in self.backend or "xccl" in self.backend: + if "nccl" in backend or "xccl" in backend: # set device for nccl pg for collectives torch.accelerator.set_device_index(self.rank) # we only need to set device_id for nccl backend with eager init @@ -381,7 +403,7 @@ def init_pg(self, eager_init) -> None: # so the nccl communicator is immediately formed and we can use `ncclCommSplit` # for form subgroup to avoid unnecesssary overhead. dist.init_process_group( - backend=self.backend, + backend=backend, world_size=self.world_size, rank=self.rank, # pyre-ignore[16] init_method=f"file://{self.file_name}", # pyre-ignore[16] @@ -449,13 +471,17 @@ def run_subtests(self, *args, **kwargs): # wrapper to initialize comms (processgroup) -def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc: - def decorator(func, eager_init: bool = False): +def with_comms( + eager_init: Union[TestFunc, bool] = False, backend: Optional[str] = None +) -> TestFunc: + def decorator(func, eager_init: bool = False, backend: Optional[str] = None): @wraps(func) # pyre-ignore[6] def wrapper( - self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + self, + *args: tuple[object], + **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: - self.init_pg(eager_init) + self.init_pg(eager_init, backend) try: func(self, *args, **kwargs) # type: ignore[misc] @@ -470,7 +496,7 @@ def wrapper( return ( decorator(func=eager_init) if callable(eager_init) - else partial(decorator, eager_init=eager_init) + else partial(decorator, eager_init=eager_init, backend=backend) ) diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 1ac9252d498e..61c21be3ca07 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -253,7 +253,11 @@ def train_batch( else: input_batches = batches - with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): + with ( + self.hybrid_module.join() + if simulate_uneven_inputs + else contextlib.nullcontext() + ): for b in input_batches: with dist_autograd.context() as context_id: output = self.hybrid_module.forward(b) @@ -261,8 +265,7 @@ def train_batch( dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) gLogger.info( - "Loss is %s for mini batch: %s. " - "Grads dict has %s entries: %s", + "Loss is %s for mini batch: %s. Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), diff --git a/torch/testing/_internal/distributed/fake_pg.py b/torch/testing/_internal/distributed/fake_pg.py index a34ee75cf600..0a2814c24645 100644 --- a/torch/testing/_internal/distributed/fake_pg.py +++ b/torch/testing/_internal/distributed/fake_pg.py @@ -11,7 +11,7 @@ class FakeStore(dist.Store): """ -def _create_fake_pg(prefix_store, rank, world_size, timeout): +def _create_fake_pg(common_opts, backend_opts): """ A fake process group (not related to FakeTensor) is a process group which doesn't actually do any communication, it just hallucinates some @@ -22,7 +22,11 @@ def _create_fake_pg(prefix_store, rank, world_size, timeout): for every collective. It should be used as a convenient tool when playing with distributed but don't care about the actual data. """ - return FakeProcessGroup(rank, world_size) + return FakeProcessGroup( + common_opts.group_rank, common_opts.group_size, backend_opts + ) -dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda", "hpu"]) +dist.Backend.register_backend( + "fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu"] +) diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 8a521d56f5f8..f1cf62aa64bd 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -69,13 +69,13 @@ def test_cpu(): TRITON_HAS_CPU = False -HAS_CUDA = torch.cuda.is_available() and HAS_TRITON +HAS_CUDA_AND_TRITON = torch.cuda.is_available() and HAS_TRITON -HAS_XPU = torch.xpu.is_available() and HAS_TRITON +HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON HAS_MPS = torch.mps.is_available() -HAS_GPU = HAS_CUDA or HAS_XPU +HAS_GPU = HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON GPU_TYPE = get_gpu_type() @@ -163,16 +163,16 @@ def inner(fn): skipCPUIf = functools.partial(skipDeviceIf, device="cpu") IS_A100 = LazyVal( - lambda: HAS_CUDA + lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 166912 ) IS_H100 = LazyVal( - lambda: HAS_CUDA + lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 232448 ) -IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +IS_BIG_GPU = LazyVal(lambda: HAS_CUDA_AND_TRITON and is_big_gpu()) def dummy_graph() -> GraphLowering: """ diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 5cd248792dcb..97dee3c7c0f4 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -162,9 +162,7 @@ def __init__( # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as # SampleInput(input, *args, **kwargs) but not to mix the two forms if args is not None or kwargs is not None: - assert ( - not var_args and not var_kwargs - ), """ + assert not var_args and not var_kwargs, """ A SampleInput can be constructed "naturally" with *args and **kwargs or by explicitly setting the "args" and "kwargs" parameters, but the two methods of construction cannot be mixed!""" @@ -226,7 +224,7 @@ def _repr_helper(self, formatter): f"name={repr(self.name)}", ] - return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + return f"SampleInput({', '.join(a for a in arguments if a is not None)})" def __repr__(self): return self._repr_helper(lambda x: x) @@ -1601,13 +1599,11 @@ def __post_init__(self): # returns a string identifier of the rule type @abstractmethod - def type(self) -> str: - ... + def type(self) -> str: ... # returns an appropriate context that handles the xfail, skips, etc. @abstractmethod - def get_context(self, test_case): - ... + def get_context(self, test_case): ... # useful for specifying xfails @@ -1791,8 +1787,10 @@ def __init__( # kwargs to use when calling the op. This is required for operators that # have other required parameters besides the input tensor. generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( - yield (), - {}, + yield ( + (), + {}, + ) ), # Options from the OpInfo base class **kwargs, @@ -2476,9 +2474,9 @@ def __init__( self.supports_one_python_scalar = True if self.supports_one_python_scalar: - assert ( - supports_rhs_python_scalar - ), "Can't support lhs and rhs Python scalars but not rhs scalars!" + assert supports_rhs_python_scalar, ( + "Can't support lhs and rhs Python scalars but not rhs scalars!" + ) # The following functions and classes are for testing elementwise unary operators. diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index e05299632d04..c5d08073803b 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs): op_info, device, dtype, requires_grad, **kwargs ): sample_input_args, sample_input_kwargs = ( - ord, - ) + sample_input.args, sample_input.kwargs.copy() + (ord,) + sample_input.args, + sample_input.kwargs.copy(), + ) yield SampleInput( sample_input.input.clone().requires_grad_(requires_grad), args=sample_input_args, @@ -276,8 +278,9 @@ def masked_samples(): for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs) ): if type(mask) != torch.Tensor: continue - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) if "keepdim" in sample_input_kwargs: sample_input_kwargs.pop("keepdim") diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 69b260d2833b..40687995470b 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -2,11 +2,13 @@ import unittest -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON, HAS_GPU from torch.utils._triton import has_triton -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda_and_triton = unittest.skipUnless( + HAS_CUDA_AND_TRITON, "requires cuda and triton" +) requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") if has_triton(): diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 4ec4e5b59159..811b45fd1d69 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -112,7 +112,7 @@ def __init__( @staticmethod def string_or_list_of_string_to_list( - val: Optional[Union[str, list[str]]] + val: Optional[Union[str, list[str]]], ) -> Optional[list[str]]: if val is None: return None @@ -135,8 +135,7 @@ def Config( env_name_force: Optional[Union[str, list[str]]] = None, value_type: Optional[type] = None, alias: Optional[str] = None, - ) -> T: - ... + ) -> T: ... else: @@ -323,9 +322,9 @@ def __init__(self, config: _Config): # Ensure justknobs and envvars are allowlisted types if self.justknob is not None and self.default is not None: - assert isinstance( - self.default, bool - ), f"justknobs only support booleans, {self.default} is not a boolean" + assert isinstance(self.default, bool), ( + f"justknobs only support booleans, {self.default} is not a boolean" + ) if self.value_type is not None and ( config.env_name_default is not None or config.env_name_force is not None ): @@ -334,7 +333,9 @@ def __init__(self, config: _Config): str, Optional[bool], Optional[str], - ), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + ), ( + f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + ) class ConfigModule(ModuleType): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 24c73061b716..5ddda2c7edb6 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -282,9 +282,9 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) False Args: @@ -586,29 +586,28 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only( + type_or_types_or_pred: Type3[T, S, U], / +) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only( + type_or_types_or_pred: Callable[[Any], bool], / +) -> MapOnlyFn[FnAny[Any]]: ... def map_only( @@ -664,8 +663,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -675,8 +673,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -686,8 +683,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -697,8 +693,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -708,8 +703,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -729,8 +723,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -740,8 +733,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -751,8 +743,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -762,8 +753,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -773,8 +763,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -812,8 +801,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -823,8 +811,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -834,8 +821,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -856,8 +842,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -867,8 +852,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -878,8 +862,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py index 40ffd8f80a9e..0b555ffc27f9 100644 --- a/torch/utils/_functools.py +++ b/torch/utils/_functools.py @@ -12,7 +12,7 @@ def cache_method( - f: Callable[Concatenate[_C, _P], _T] + f: Callable[Concatenate[_C, _P], _T], ) -> Callable[Concatenate[_C, _P], _T]: """ Like `@functools.cache` but for methods. diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index 2bead0e00b12..b2a69fc0ff34 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import ( + Hashable, Iterable, Iterator, MutableSet, @@ -10,8 +11,8 @@ from typing import Any, cast, Optional, TypeVar -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T", bound=Hashable) +T_co = TypeVar("T_co", bound=Hashable, covariant=True) __all__ = ["OrderedSet"] diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 664994e6fe38..84353fbbebf7 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -302,14 +302,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # Subtypes which have __tensor_flatten__ and __tensor_unflatten__. class TensorWithFlatten(Protocol): - def __tensor_flatten__(self) -> tuple[Sequence[str], object]: - ... + def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ... @staticmethod def __tensor_unflatten__( inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... # It would be really nice to be able to say that the return of # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, @@ -318,26 +316,20 @@ def __tensor_unflatten__( shape: torch._C.Size @overload - def stride(self, dim: None = None) -> tuple[int, ...]: - ... + def stride(self, dim: None = None) -> tuple[int, ...]: ... @overload - def stride(self, dim: int) -> int: - ... + def stride(self, dim: int) -> int: ... @overload - def size(self, dim: None = None) -> tuple[int, ...]: - ... + def size(self, dim: None = None) -> tuple[int, ...]: ... @overload - def size(self, dim: int) -> int: - ... + def size(self, dim: int) -> int: ... - def storage_offset(self) -> int: - ... + def storage_offset(self) -> int: ... - def dim(self) -> int: - ... + def dim(self) -> int: ... @overload def to( @@ -347,8 +339,7 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @overload def to( @@ -359,8 +350,7 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @overload def to( @@ -370,8 +360,7 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 3e7cadc6dc7a..773e9f00e3d1 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -99,17 +99,13 @@ class KeyEntry(Protocol): - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: - ... + def __eq__(self, other: object) -> bool: ... - def __str__(self) -> str: - ... + def __str__(self) -> str: ... - def get(self, parent: Any) -> Any: - ... + def get(self, parent: Any) -> Any: ... class EnumEncoder(json.JSONEncoder): @@ -374,7 +370,7 @@ def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] - return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names _private_register_pytree_node( cls, @@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: def _tuple_flatten_with_keys( - d: tuple[T, ...] + d: tuple[T, ...], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _tuple_flatten(d) return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]: def _dict_flatten_with_keys( - d: dict[Any, T] + d: dict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]: def _ordereddict_flatten_with_keys( - d: OrderedDict[Any, T] + d: OrderedDict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]: def _defaultdict_flatten_with_keys( - d: defaultdict[Any, T] + d: defaultdict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context @@ -1035,9 +1031,9 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) False """ if is_leaf is not None and is_leaf(tree): @@ -1346,9 +1342,9 @@ def tree_map( See also :func:`tree_map_`. - >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)}) {'x': 8, 'y': (43, 65)} - >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None}) {'x': False, 'y': (False, False), 'z': True} If multiple inputs are given, the structure of the tree is taken from the first input; @@ -1432,29 +1428,28 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only( + type_or_types_or_pred: Type3[T, S, U], / +) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only( + type_or_types_or_pred: Callable[[Any], bool], / +) -> MapOnlyFn[FnAny[Any]]: ... def map_only( @@ -1510,8 +1505,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1521,8 +1515,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1532,8 +1525,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1543,8 +1535,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1554,8 +1545,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -1575,8 +1565,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1586,8 +1575,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1597,8 +1585,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1608,8 +1595,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1619,8 +1605,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -1658,8 +1643,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1669,8 +1653,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1680,8 +1663,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -1702,8 +1684,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1713,8 +1694,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1724,8 +1704,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( @@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: raise NotImplementedError( - f'Deserializing {json_schema["type"]} in pytree is not registered.', + f"Deserializing {json_schema['type']} in pytree is not registered.", ) typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 39e981a78ac5..9b94a7b7a484 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -301,7 +301,7 @@ def strobelight( profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( - work_function: Callable[_P, _R] + work_function: Callable[_P, _R], ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 4f8d045e5554..2b6c159f5c3a 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: def _keep_float( - f: Callable[[Unpack[_Ts]], _T] + f: Callable[[Unpack[_Ts]], _T], ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: @functools.wraps(f) def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: @@ -455,6 +455,12 @@ def _eval_is_nonnegative(self) -> Optional[bool]: def _eval_is_nonpositive(self) -> Optional[bool]: return True if self.args[1].is_negative else None # type: ignore[attr-defined] + def _ccode(self, printer): + p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) + q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) + abs_q = str(q) if self.args[1].is_positive else f"abs({q})" + return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}" + # Generic modulus: only defined on non-negative arguments class Mod(sympy.Function): @@ -920,10 +926,12 @@ def _find_localzeros(cls, values, **options): _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 - i.is_antihermitian for i in s.args # noqa: E731 + i.is_antihermitian + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_commutative = lambda s: _torf( # noqa: E731 - i.is_commutative for i in s.args # noqa: E731 + i.is_commutative + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 @@ -937,10 +945,12 @@ def _find_localzeros(cls, values, **options): _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 _eval_is_nonnegative = lambda s: _torf( # noqa: E731 - i.is_nonnegative for i in s.args # noqa: E731 + i.is_nonnegative + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonpositive = lambda s: _torf( # noqa: E731 - i.is_nonpositive for i in s.args # noqa: E731 + i.is_nonpositive + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 @@ -950,10 +960,12 @@ def _find_localzeros(cls, values, **options): _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 _eval_is_extended_real = lambda s: _torf( # noqa: E731 - i.is_extended_real for i in s.args # noqa: E731 + i.is_extended_real + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_transcendental = lambda s: _torf( # noqa: E731 - i.is_transcendental for i in s.args # noqa: E731 + i.is_transcendental + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 1b360337a53b..e02e049cc36d 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -144,16 +144,14 @@ def __init__( self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn, - ) -> None: - ... + ) -> None: ... @overload def __init__( # type: ignore[misc] self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn, - ) -> None: - ... + ) -> None: ... def __init__(self, lower: AllIn, upper: AllIn) -> None: lower = simple_sympify(lower) @@ -240,15 +238,13 @@ def tighten(self, other) -> ValueRanges: def __and__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], - ) -> ValueRanges[sympy.Expr]: - ... + ) -> ValueRanges[sympy.Expr]: ... @overload def __and__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], - ) -> ValueRanges[SympyBoolean]: - ... + ) -> ValueRanges[SympyBoolean]: ... def __and__(self: AllVR, other: AllVR) -> AllVR: if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): @@ -272,15 +268,13 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: def __or__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], - ) -> ValueRanges[sympy.Expr]: - ... + ) -> ValueRanges[sympy.Expr]: ... @overload def __or__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], - ) -> ValueRanges[SympyBoolean]: - ... + ) -> ValueRanges[SympyBoolean]: ... def __or__(self: AllVR, other: AllVR) -> AllVR: if ValueRanges.unknown() in (self, other): @@ -343,8 +337,7 @@ def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: @overload @staticmethod - def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: - ... + def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ... @overload @staticmethod @@ -384,8 +377,7 @@ def coordinatewise_increasing_map( x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2, - ) -> ExprVR: - ... + ) -> ExprVR: ... @overload @staticmethod @@ -393,8 +385,7 @@ def coordinatewise_increasing_map( # type: ignore[misc] x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2, - ) -> BoolVR: - ... + ) -> BoolVR: ... @staticmethod def coordinatewise_increasing_map( diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index e11a7afc09d8..5a83aede8d46 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -426,9 +426,9 @@ def func_name(*args, **kwargs): it is marked as private. It is a convenience function for backend implementers to more easily call the hooks into their backend extensions. """ - assert isinstance( - func_name, str - ), f"func_name must be `str`, but got `{type(func_name)}`." + assert isinstance(func_name, str), ( + f"func_name must be `str`, but got `{type(func_name)}`." + ) backend_name = _get_privateuse1_backend_name() custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 68a4da0731c0..3b291b1e60a4 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -44,7 +44,7 @@ def default_convert(data): >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) @@ -366,13 +366,13 @@ def default_collate(batch): >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: - >>> default_collate(['a', 'b', 'c']) + >>> default_collate(["a", "b", "c"]) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: - >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index 0efe39854fb0..b53c7aef9596 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -21,16 +21,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): torch.set_num_threads(1) torch.multiprocessing._set_thread_name("pt_data_pin") - - if device == "cuda": - torch.cuda.set_device(device_id) - elif device == "xpu": - torch.xpu.set_device(device_id) # type: ignore[attr-defined] - elif device == torch._C._get_privateuse1_backend_name(): - custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) - custom_device_mod.set_device(device_id) - elif device is None: - torch.accelerator.set_device_index(device_id) + torch.accelerator.set_device_index(device_id) def do_one_step(): try: @@ -78,7 +69,9 @@ def pin_memory(data, device=None): ) return clone else: - return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] + return type(data)( + {k: pin_memory(sample, device) for k, sample in data.items()} + ) # type: ignore[call-arg] except TypeError: # The mapping type may not support `copy()` / `update(mapping)` # or `__init__(iterable)`. diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index a275e2e86b6f..97c7243e78ef 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing static methods. diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 48fac6b51656..991b4f00eb85 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ functions to be run in multiprocessing. E.g., the data loading worker loop is in `./_utils/worker.py`. """ + from __future__ import annotations import functools @@ -190,9 +191,8 @@ class DataLoader(Generic[_T_co]): persistent_workers (bool, optional): If ``True``, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``) - pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is - ``True``. If not given, the current :ref:`accelerator` will be the - default. This argument is discouraged and subject to deprecated. + pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator` + will be used as the device if ``pin_memory=True``. in_order (bool, optional): If ``False``, the data loader will not enforce that batches are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) @@ -654,45 +654,43 @@ def __init__(self, loader: DataLoader) -> None: ws, rank = _get_distributed_settings() self._world_size = ws self._rank = rank - # If pin_memory_device not set, default behaviour is current accelerator. - # If pin_memory_device is set but pin_memory is not set, the default - # behaviour false. - if len(loader.pin_memory_device) == 0: - if loader.pin_memory and not torch.accelerator.is_available(): - warn_msg = ( - "'pin_memory' argument is set as true but no accelerator is found, " - "then device pinned memory won't be used." - ) - warnings.warn(warn_msg) - self._pin_memory = loader.pin_memory and torch.accelerator.is_available() - self._pin_memory_device = None - # Currently, pin_memory would raise error on the MPS backend (see - # https://github.com/pytorch/pytorch/issues/86060), so forcibly - # disable pin_memory on MPS. Remove this restriction once pinned - # memory allocation for MPS is fixed. - if ( - self._pin_memory - and (acc := torch.accelerator.current_accelerator()) is not None - and acc.type == "mps" - ): - self._pin_memory = False - warn_msg = ( - "'pin_memory' argument is set as true but not supported on MPS now, " - "then device pinned memory won't be used." - ) - warnings.warn(warn_msg) - else: - if not loader.pin_memory: - warn_msg = ( - "'pin_memory_device' is set but 'pin_memory' argument is not set, " - "then device pinned memory won't be used." - "please set 'pin_memory' to true, if you need to use the device pin memory" - ) - warnings.warn(warn_msg) + if loader.pin_memory and loader.pin_memory_device: + warnings.warn( + "pin_memory_device is deprecated, the current accelerator will be used as the device," + f"ignore pin_memory_device='{loader.pin_memory_device}'." + ) + if loader.pin_memory and not torch.accelerator.is_available(): + warn_msg = ( + "'pin_memory' argument is set as true but no accelerator is found, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + + # Enabling pin_memory in _BaseDataLoaderIter to support identical + # behavior in forked implementations using _BaseDataLoaderIter. + self._pin_memory = loader.pin_memory and torch.accelerator.is_available() + + # Set pin memory device based on the current accelerator. + self._pin_memory_device = ( + acc.type + if self._pin_memory + and (acc := torch.accelerator.current_accelerator()) is not None + else None + ) + + # Currently, pin_memory would raise error on the MPS backend (see + # https://github.com/pytorch/pytorch/issues/86060), so forcibly + # disable pin_memory on MPS. Remove this restriction once pinned + # memory allocation for MPS is fixed. + if self._pin_memory_device == "mps": + self._pin_memory = False + warn_msg = ( + "'pin_memory' argument is set as true but not supported on MPS now, " + "device pinned memory won't be used." + ) + warnings.warn(warn_msg) - self._pin_memory = loader.pin_memory - self._pin_memory_device = loader.pin_memory_device self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) @@ -1178,24 +1176,13 @@ def __init__(self, loader): # Queue is not type-annotated self._data_queue = queue.Queue() # type: ignore[var-annotated] - current_device = -1 - if self._pin_memory_device == "cuda": - current_device = torch.cuda.current_device() - elif self._pin_memory_device == "xpu": - current_device = torch.xpu.current_device() - elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): - custom_device_mod = getattr( - torch, torch._C._get_privateuse1_backend_name() - ) - current_device = custom_device_mod.current_device() - elif self._pin_memory_device is None: - current_device = torch.accelerator.current_device_index() + current_device_id = torch.accelerator.current_device_index() pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=( self._worker_result_queue, self._data_queue, - current_device, + current_device_id, self._pin_memory_thread_done_event, self._pin_memory_device, ), @@ -1222,7 +1209,10 @@ def __init__(self, loader): atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) - _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_worker_pids( + id(self), + tuple(w.pid for w in self._workers), # type: ignore[misc] + ) _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 13e28a19d626..0833f8fdf759 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -109,8 +109,7 @@ def __call__(self, *args, **kwargs): # Decorate with a functional argument if not ( - isinstance(args[0], type) - and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] + isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] ): raise TypeError( f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index d3eeee0ebfdd..506f642c411d 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> from torchdata.datapipes.iter import IterableWrapper, Mapper >>> dp = IterableWrapper(range(10)) >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor - >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) + >>> map_dp_2 = dp.map( + ... lambda x: x + 1 + ... ) # Using functional form (recommended) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> list(map_dp_2) @@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> list(it1) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> it1 = iter(source_dp) - >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` + >>> it2 = iter( + ... source_dp + ... ) # The creation of a new iterator invalidates `it1` >>> next(it2) 0 >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 718e728c9389..41c6bb362af2 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]): >>> def add_one(x): ... return x + 1 >>> dp = IterableWrapper(range(10)) - >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred + >>> # Invocation via functional form is preferred + ... map_dp_1 = dp.map(add_one) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` @@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe): >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): ... def __init__(self, start, end): ... super(MyIterDataPipe).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe): ... ... def __len__(self): ... return self.end - self.start - ... >>> ds = MyIterDataPipe(start=3, end=7) >>> print(list(ds)) [3, 4, 5, 6] >>> def collate_fn(batch): ... return torch.tensor(batch, dtype=torch.float) - ... >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> print(list(collated_ds)) [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 4c602ce4eeda..f92edd6b7b39 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -38,15 +38,17 @@ def __init__( sampler_args: Optional[tuple] = None, sampler_kwargs: Optional[dict] = None, ) -> None: - assert isinstance( - datapipe, Sized - ), "Sampler class requires input datapipe implemented `__len__`" + assert isinstance(datapipe, Sized), ( + "Sampler class requires input datapipe implemented `__len__`" + ) super().__init__() self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs # https://github.com/python/mypy/pull/9629 will solve - self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] + self.sampler = sampler( + *self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs + ) # type: ignore[misc] def __iter__(self) -> Iterator[_T_co]: return iter(self.sampler) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index deaca079c68c..8c6abc506210 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -116,16 +116,13 @@ class _ContainerTemplate(ABC): r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" @abstractmethod - def get_next_element_by_instance(self, instance_id: int): - ... + def get_next_element_by_instance(self, instance_id: int): ... @abstractmethod - def is_every_instance_exhausted(self) -> bool: - ... + def is_every_instance_exhausted(self) -> bool: ... @abstractmethod - def reset(self) -> None: - ... + def reset(self) -> None: ... @abstractmethod def get_length_by_instance(self, instance_id: int): @@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe): >>> # It can also filter out any element that gets `None` from the `classifier_fn` >>> def odd_or_even_no_zero(n): ... return n % 2 if n != 0 else None - >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) + >>> dp1, dp2 = source_dp.demux( + ... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True + ... ) >>> list(dp1) [2, 4] >>> list(dp2) @@ -428,7 +427,9 @@ def __new__( # When num_instances == 1, demux can be replaced by filter, # but keep it as Demultiplexer for the sake of consistency # like throwing Error when classification result is out of o range - container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] + container = _DemultiplexerIterDataPipe( + datapipe, num_instances, classifier_fn, drop_none, buffer_size + ) # type: ignore[abstract] return [_ChildDataPipe(container, i) for i in range(num_instances)] @@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(3)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) >>> list(dp1.mux(dp2, dp3)) [0, 10, 20, 1, 11, 21, 2, 12, 22] """ def __init__(self, *datapipes): self.datapipes = datapipes - self.buffer: list = ( - [] - ) # Store values to be yielded only when every iterator provides one + self.buffer: list = [] # Store values to be yielded only when every iterator provides one def __iter__(self): iterators = [iter(x) for x in self.datapipes] @@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(5)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) >>> list(dp1.zip(dp2, dp3)) [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] """ diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 2542c89773bd..3025b809e12d 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]): Example: >>> # xdoctest: +SKIP - >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader - >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) + >>> from torchdata.datapipes.iter import ( + ... FileLister, + ... FileOpener, + ... StreamReader, + ... ) + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt")) >>> dp = FileOpener(dp) >>> dp = StreamReader(dp) >>> list(dp) diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 08d124fdc608..055d9c28b09b 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> from torchdata.datapipes.iter import IterableWrapper >>> def group_fn(file): ... return os.path.basename(file).split(".")[0] - >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) + >>> source_dp = IterableWrapper( + ... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"] + ... ) >>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> list(dp0) [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] @@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> list(dp1) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` - >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) + >>> dp2 = source_dp.groupby( + ... group_key_fn=group_fn, + ... buffer_size=3, + ... group_size=3, + ... guaranteed_group_size=2, + ... ) >>> list(dp2) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] """ diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 02865e8064f8..e1290df32372 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]): >>> dp = SequenceWrapper(range(10)) >>> list(dp) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) - >>> dp['a'] + >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400}) + >>> dp["a"] 100 """ diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index ee5bee8f1528..9db7309bdc52 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -45,8 +45,8 @@ def basichandlers(extension: str, data): Example: >>> import pickle - >>> data = pickle.dumps('some data') - >>> new_data = basichandlers('pickle', data) + >>> data = pickle.dumps("some data") + >>> new_data = basichandlers("pickle", data) >>> new_data some data @@ -169,9 +169,9 @@ class ImageHandler: """ def __init__(self, imagespec): - assert imagespec in list( - imagespecs.keys() - ), f"unknown image specification: {imagespec}" + assert imagespec in list(imagespecs.keys()), ( + f"unknown image specification: {imagespec}" + ) self.imagespec = imagespec.lower() def __call__(self, extension, data): @@ -205,18 +205,18 @@ def __call__(self, extension, data): return img elif atype == "numpy": result = np.asarray(img) - assert ( - result.dtype == np.uint8 - ), f"numpy image array should be type uint8, but got {result.dtype}" + assert result.dtype == np.uint8, ( + f"numpy image array should be type uint8, but got {result.dtype}" + ) if etype == "uint8": return result else: return result.astype("f") / 255.0 elif atype == "torch": result = np.asarray(img) - assert ( - result.dtype == np.uint8 - ), f"numpy image array should be type uint8, but got {result.dtype}" + assert result.dtype == np.uint8, ( + f"numpy image array should be type uint8, but got {result.dtype}" + ) if etype == "uint8": result = np.array(result.transpose(2, 0, 1)) diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index d0234c553ce6..e8164e015a66 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]): tensors: tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: - assert all( - tensors[0].size(0) == tensor.size(0) for tensor in tensors - ), "Size mismatch between tensors" + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), ( + "Size mismatch between tensors" + ) self.tensors = tensors def __getitem__(self, index): @@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]): >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) - >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + >>> dict_stack[0] == {"image": images[0], "text": texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. @@ -323,9 +323,9 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] for d in self.datasets: - assert not isinstance( - d, IterableDataset - ), "ConcatDataset does not support IterableDataset" + assert not isinstance(d, IterableDataset), ( + "ConcatDataset does not support IterableDataset" + ) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): @@ -371,17 +371,17 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: def __iter__(self): for d in self.datasets: - assert isinstance( - d, IterableDataset - ), "ChainDataset only supports IterableDataset" + assert isinstance(d, IterableDataset), ( + "ChainDataset only supports IterableDataset" + ) yield from d def __len__(self): total = 0 for d in self.datasets: - assert isinstance( - d, IterableDataset - ), "ChainDataset only supports IterableDataset" + assert isinstance(d, IterableDataset), ( + "ChainDataset only supports IterableDataset" + ) total += len(d) # type: ignore[arg-type] return total diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index c92bdbb00e10..6c2e6dcaf2f4 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]): Example: >>> # xdoctest: +IGNORE_WANT("non-deterministic") - >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + >>> list( + ... WeightedRandomSampler( + ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True + ... ) + ... ) [4, 4, 1, 4, 5] - >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + >>> list( + ... WeightedRandomSampler( + ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False + ... ) + ... ) [0, 1, 4, 3, 2] """ @@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]): its size would be less than ``batch_size`` Example: - >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + >>> list( + ... BatchSampler( + ... SequentialSampler(range(10)), batch_size=3, drop_last=False + ... ) + ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] - >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + >>> list( + ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) + ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 8ac97f2e2e82..4c7dec048152 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -49,6 +49,7 @@ class ModuleTracker: def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias + torch.nn.functional.linear = my_linear mod(torch.rand(2, 2)) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 23e3a25c90f5..9a4ade5e71ea 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -6,6 +6,7 @@ This package is lazily initialized, so you can always import it, and use :func:`is_available()` to determine if your system supports XPU. """ + import threading import traceback from functools import lru_cache @@ -292,6 +293,7 @@ class StreamContext: ``None``. .. note:: Streams are per-device. """ + cur_stream: Optional["torch.xpu.Stream"] def __init__(self, stream: Optional["torch.xpu.Stream"]): @@ -438,7 +440,7 @@ def get_gencode_flags() -> str: arch_list = get_arch_list() if len(arch_list) == 0: return "" - return f'-device {",".join(arch for arch in arch_list)}' + return f"-device {','.join(arch for arch in arch_list)}" def _get_generator(device: torch.device) -> torch._C.Generator: diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 3ff40412898a..be00c49d7b1f 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -183,4 +183,6 @@ # The same BC rules apply as inductor_fallback_ops. aten_shimified_ops: dict[str, dict[str, list[str]]] = { "aten.fill_.Scalar": {}, + "aten.pad.default": {}, + "aten.narrow.default": {}, } diff --git a/torchgen/gen.py b/torchgen/gen.py index 7d1413827f35..b8290d6b8684 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -2849,14 +2849,13 @@ def main() -> None: # TODO: stop generating CUDA kernels for non-CUDA builds ignore_keys = set() + MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS} if options.mps or options.update_aoti_c_shim: - functions_keys.add(DispatchKey.MPS) + functions_keys.update(MPS_KEYS) aoti_backends.add(DispatchKey.MPS) else: - ignore_keys.add(DispatchKey.MPS) - - if DispatchKey.MPS in dispatch_keys: - del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] + ignore_keys.update(MPS_KEYS) + dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS] if options.xpu or options.update_aoti_c_shim: functions_keys.add(DispatchKey.XPU) diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 655f2bd65b02..36db26bb5ea6 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -744,7 +744,7 @@ def headers_for_aoti() -> str: f"c_shim_{device_name}.cpp", lambda: gen_aoti_c_shim( fallback_native_functions, - inductor_fallback_ops, + fallback_ops_dict, structured_func_group_dict, dispatch_key, backend_indices, diff --git a/torchgen/selective_build/operator.py b/torchgen/selective_build/operator.py index 0cb92dfc09e2..8047f033e3d2 100644 --- a/torchgen/selective_build/operator.py +++ b/torchgen/selective_build/operator.py @@ -168,4 +168,4 @@ def merge_operator_dicts( def strip_operator_overload_name(op_name: str) -> str: - return op_name.split(".")[0] + return op_name.split(".", maxsplit=1)[0]