diff --git a/.dep-versions b/.dep-versions index edc108195f..241af9ce25 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,16 +1,10 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. - -############# -# We track mlir submodule versions from jax 0.4.32 for now -# These are the earliest versions with complete upstream bufferization changes -# Versions are retrieved from -# python3 .github/workflows/set_dep_versions.py 0.4.32 -############# - +# To update JAX version alongside compatible dependency tags, run the following script: +# python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.0 -mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4 -llvm=5f74671c85877e03622e8d308aee15ed73ccee7c -enzyme=v0.0.149 +mhlo=617a9361d186199480c080c9e8c474a5e30c22d1 +llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158 +enzyme=v0.0.180 # Always remove custom PL/LQ versions before release. diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 7b36859997..c1e9a49775 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -187,9 +187,10 @@ jobs: run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt - export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch - if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + pushd $GITHUB_WORKSPACE/mlir/mlir-hlo + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch + popd cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -215,6 +216,11 @@ jobs: if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + + export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp + export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch + if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \ @@ -222,7 +228,7 @@ jobs: -DCMAKE_CXX_VISIBILITY_PRESET=default \ -DCMAKE_CXX_FLAGS="-fuse-ld=lld" - cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20 + cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21 - name: Save Enzyme Build id: save-enzyme-build diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 2b8ce6a8c4..a78e9bbeff 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -210,9 +210,10 @@ jobs: run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt - export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch - if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + pushd $GITHUB_WORKSPACE/mlir/mlir-hlo + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch + popd cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -238,6 +239,11 @@ jobs: if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + + export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp + export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch + if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \ @@ -245,7 +251,7 @@ jobs: -DCMAKE_CXX_VISIBILITY_PRESET=default \ -DCMAKE_CXX_FLAGS="-fuse-ld=lld" - cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20 + cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21 - name: Save Enzyme Build id: save-enzyme-build diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 0a1279423a..8f249f2a46 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -185,9 +185,10 @@ jobs: run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt - export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch - if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + pushd $GITHUB_WORKSPACE/mlir/mlir-hlo + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch + popd cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -212,13 +213,17 @@ jobs: - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | + export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp + export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch + if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \ -DENZYME_STATIC_LIB=ON \ -DCMAKE_CXX_VISIBILITY_PRESET=default - cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20 + cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21 - name: Save Enzyme Build id: save-enzyme-build diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 1fe2832233..4d633bb298 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -146,7 +146,7 @@ jobs: sudo apt-get update sudo apt-get install -y python3 python3-pip cmake ninja-build clang lld python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} - python3 -m pip install numpy pybind11 + python3 -m pip install numpy pybind11 nanobind - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -194,7 +194,7 @@ jobs: uses: actions/cache@v4 with: path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }} + key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0 - name: Get Cached LLVM Source id: cache-llvm-source @@ -351,7 +351,7 @@ jobs: uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }} + key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0 fail-on-cache-miss: true - name: Get Cached Enzyme Source diff --git a/.github/workflows/set_dep_versions.py b/.github/workflows/set_dep_versions.py index 74df70ca06..ce9a998c18 100644 --- a/.github/workflows/set_dep_versions.py +++ b/.github/workflows/set_dep_versions.py @@ -32,11 +32,11 @@ assert os.path.isfile(dep_versions_path) assert os.path.isfile(catalyst_init_path) -url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/WORKSPACE" +url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE" response = requests.get(url) match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text) if not match: - url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/third_party/xla/workspace.bzl" + url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl" response = requests.get(url) match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text) xla_commit = match.group(1) @@ -67,21 +67,16 @@ response = requests.get(url).json() hlo_commit = response["items"][0]["sha"] -existing_text = open(dep_versions_path, "r", encoding="UTF-8").read() -match = re.search(r"enzyme=([a-zA-Z0-9]*)", existing_text) -enzyme_commit = match.group(1) - -with open(dep_versions_path, "w", encoding="UTF-8") as f: - f.write( - f"""\ -jax={jax_version} -mhlo={hlo_commit} -llvm={llvm_commit} -enzyme={enzyme_commit} -""" - ) - quote = '"' -cmd = f"sed -i 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}" -res = os.system(cmd) -assert res == 0 +# Update each version using sed +cmds = [ + f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}", + f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}", + f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}", + # Update jaxlib version in __init__.py + rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}", +] + +for cmd in cmds: + res = os.system(cmd) + assert res == 0 diff --git a/doc/dev/transforms.rst b/doc/dev/transforms.rst index 51a0686883..bf61912413 100644 --- a/doc/dev/transforms.rst +++ b/doc/dev/transforms.rst @@ -252,8 +252,8 @@ Note how the value ``%q2`` links the two operations together from definition ``( across several other instructions. As seen in the `pattern rewriter documentation `_, -a new rewrite pattern can be defined as a C++ class as follows, where we will focus on the ``match`` -and ``rewrite`` methods (refer to the link for the full class and up to date information): +a new rewrite pattern can be defined as a C++ class as follows, where we will focus on the +``matchAndRewrite`` method (refer to the link for the full class and up to date information): .. code-block:: cpp @@ -261,14 +261,13 @@ and ``rewrite`` methods (refer to the link for the full class and up to date inf { ... - LogicalResult match(QubitUnitaryOp op) const override { - // The ``match`` method returns ``success()`` if the pattern is a match, failure - // otherwise. - } - - void rewrite(QubitUnitaryOp op, PatternRewriter &rewriter) { - // The ``rewrite`` method performs mutations on the IR rooted at ``op`` using - // the provided rewriter. All mutations must go through the provided rewriter. + LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override { + // The `matchAndRewrite` method performs both the pattern matching and the mutation + // on the IR rooted at `op` using the provided rewriter. + // All mutations must go through the provided rewriter and IR mutation should only + // take place after the match is deemed successful. + // matchAndRewrite must return "success" if and only if the IR was modified. + // The root operation is required to either be: updated in-place, replaced, or erased. } ... @@ -286,11 +285,11 @@ the second is a list of qubits): QubitUnitary(*, QubitUnitary(*, *)) -Let's implement it in C++: +Let's add the pattern-matching logic to the ``matchAndRewrite`` method: .. code-block:: cpp - LogicalResult match(QubitUnitaryOp op) const override + LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override { ValueRange qbs = op.getInQubits(); Operation *parent = qbs[0].getDefiningOp(); @@ -314,6 +313,9 @@ Let's implement it in C++: return failure(); } + // Rewrite logic + // ... We have matched the pattern, now rewrite the IR here + return success(); } @@ -351,8 +353,8 @@ MLIR will automatically generate canonical ``get*`` methods for attributes like ``out_qubits``, and ``matrix``. When in doubt it's best to have a look at the generated C++ files in the build folder, named ``QuantumOps.h.inc`` and ``QuantumOps.cpp.inc`` in this instance. -Alright, now that we have the matching part, let's implement the actual transformation via the -``rewrite`` method. All we need to do is replace the original pattern with the following: +Alright, now that we have the matching part, let's add the actual transformation to the +``matchAndRewrite`` method. All we need to do is replace the original pattern with the following: .. code-block:: @@ -362,8 +364,13 @@ In C++ it will look as follows: .. code-block:: cpp - void rewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override { + + // Pattern matching logic + // ... match the pattern + + // Rewrite logic ValueRange qbs = op.getInQubits(); QubitUnitaryOp parentOp = cast(qbs[0].getDefiningOp()); @@ -410,11 +417,13 @@ In C++ it will look as follows: // The second unitary is not needed anymore // Whoever uses the second unitary, use the first one instead! op.replaceAllUsesWith(parentOp); + + return success(); } When writing transformations, the rewriter is the most important tool we have. It can create new operations for us, delete others, or change the place in the IR where we are choosing to make -changes (also called the insertion point). Let's have look at some of these elements: +changes (also called the insertion point). Let's have a look at some of these elements: - **Constructing new operations**: @@ -512,7 +521,7 @@ and other function operations, which themselves can contain other operations, an quantumPatterns.add(ctx); // Apply patterns in an iterative and greedy manner. - if (failed(applyPatternsAndFoldGreedily(op, std::move(quantumPatterns)))) { + if (failed(applyPatternsGreedily(op, std::move(quantumPatterns)))) { return signalPassFailure(); } } @@ -520,7 +529,7 @@ and other function operations, which themselves can contain other operations, an To apply patterns we need a `pattern applicator `_. There a few in MLIR but typically you can just use the greedy pattern rewrite driver -(``applyPatternsAndFoldGreedily``), which will iterative over the IR and apply patterns until a +(``applyPatternsGreedily``), which will iterative over the IR and apply patterns until a fixed point is reached. .. note:: @@ -565,12 +574,16 @@ gradient ops that specify the finite-difference method, indicated via the ``"fd" .. code-block:: cpp - LogicalResult FiniteDiffLowering::match(GradOp op) + LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter) { - if (op.getMethod() == "fd") - return success(); + // Pattern matching logic + if (op.getMethod() != "fd") + return failure(); - return failure(); + // Rewrite logic + // ... + + return success(); } For the rewriting part we'll want to introduce a few new elements, such as looking up symbols @@ -578,8 +591,13 @@ For the rewriting part we'll want to introduce a few new elements, such as looki .. code-block:: cpp - void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter) + LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter) { + // Pattern matching logic + if (op.getMethod() != "fd") + return failure(); + + // Rewrite logic // First let's find the function the grad operation is referencing. func::FuncOp callee = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); @@ -609,6 +627,8 @@ For the rewriting part we'll want to introduce a few new elements, such as looki // Populate the function body. populateFiniteDiffMethod(rewriter, op, gradFn); } + + return success(); } Symbols are string references to IR objects, which rather than containing a physical reference or @@ -711,18 +731,20 @@ Alright, our function should now look something like this: func.return %dx, %dy, %dz : f64, f64, f64 } -Finally, we have to amend our rewrite function to invoke the new function we created and delete the +Finally, we have to amend our ``matchAndRewrite`` function to invoke the new function we created and delete the ``GradOp`` from the IR: .. code-block:: cpp - void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter) + LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter) { ... populateFiniteDiffMethod(rewriter, op, gradFn); } rewriter.replaceOpWithNewOp(op, gradFn, op.getArgOperands()); + + return success(); } Note how we can create a new operation, take its results, and use those to replace another operation diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f8b7311b9e..f486fb6824 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -175,6 +175,14 @@ [(#1671)](https://github.com/PennyLaneAI/catalyst/pull/1671) [(#1681)](https://github.com/PennyLaneAI/catalyst/pull/1681) +* (Compiler developers only) The version of LLVM, mlir-hlo and Enzyme used by Catalyst is + updated to track those in jax 0.6.0. + [(#1752)](https://github.com/PennyLaneAI/catalyst/pull/1752) + + The LLVM version is updated to [commit 179d30f8c3fddd3c85056fd2b8e877a4a8513158](https://github.com/llvm/llvm-project/tree/179d30f8c3fddd3c85056fd2b8e877a4a8513158). + The mlir-hlo version is updated to [commit 617a9361d186199480c080c9e8c474a5e30c22d1](https://github.com/tensorflow/mlir-hlo/tree/617a9361d186199480c080c9e8c474a5e30c22d1). + The Enzyme version is updated to [v0.0.180](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.180). + * The clang-format and clang-tidy versions used by Catalyst have been updated to v20. [(#1721)](https://github.com/PennyLaneAI/catalyst/pull/1721) @@ -257,7 +265,7 @@ * Improved the definition of `YieldOp` in the quantum dialect by removing `AnyTypeOf` [(#1696)](https://github.com/PennyLaneAI/catalyst/pull/1696) -* The assembly format of `MeasureOp` in the `Quantum` dialect and `MeasureInBasisOp` in the `MBQC` dialect now contains the `postselect` attribute. +* The assembly format of `MeasureOp` in the `Quantum` dialect and `MeasureInBasisOp` in the `MBQC` dialect now contains the `postselect` attribute. [(#1732)](https://github.com/PennyLaneAI/catalyst/pull/1732) * The bufferization of custom catalyst dialects has been migrated to the new one-shot diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 04f6c4178c..67a6e4ef45 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -240,6 +240,7 @@ def get_bufferization_stage(options: CompileOptions) -> List[str]: "func.func(buffer-hoisting)", "func.func(buffer-loop-hoisting)", "func.func(promote-buffers-to-stack)", + # TODO: migrate to new buffer deallocation "buffer-deallocation-pipeline" "func.func(buffer-deallocation)", "convert-arraylist-to-memref", "convert-bufferization-to-memref", diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 1951f2c3a2..f624a09fc7 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -47,7 +47,6 @@ set(ALL_MHLO_PASSES StablehloPasses MhloToArithmeticConversion MhloToMemrefConversion - MhloToStandard HloToLinalgUtils MhloToLinalg MhloToStablehlo diff --git a/mlir/Enzyme b/mlir/Enzyme index ed3ae59ec3..db0181320d 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit ed3ae59ec3c4ed082ba035e65488ef29e5d41ae0 +Subproject commit db0181320d6e425ee963bd496ed0d8dbb615be18 diff --git a/mlir/Makefile b/mlir/Makefile index 496dc26ee5..8c129d5b20 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -55,15 +55,8 @@ help: all: llvm mhlo enzyme dialects plugin .PHONY: llvm -llvm: TARGET_FILE := $(MK_DIR)/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp -llvm: PATCH_FILE := $(MK_DIR)/patches/mlir-buffer-deallocation.patch llvm: @echo "build LLVM and MLIR enabling Python bindings" - # Patch in MLIR buffer deallocation bugfix - # TODO: remove once https://github.com/llvm/llvm-project/pull/121582 is merged & the dep updated - @if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \ - patch -p1 $(TARGET_FILE) $(PATCH_FILE); \ - fi cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -91,14 +84,17 @@ llvm: LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS) .PHONY: mhlo -mhlo: TARGET_FILE := $(MK_DIR)/mlir-hlo/mhlo/transforms/CMakeLists.txt -mhlo: PATCH_FILE := $(MK_DIR)/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch mhlo: @echo "build MLIR-HLO" - # Patch MHLO cmake dependency - # TODO: remove once https://github.com/openxla/xla/pull/15446 is merged - @if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \ - patch -p1 $(TARGET_FILE) $(PATCH_FILE); \ + + # Patch MHLO shardy dependency + @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-remove-shardy.patch; then \ + git apply $(MK_DIR)/patches/mhlo-remove-shardy.patch; \ + fi + + # Patch MHLO passes removed from upstream + @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; then \ + git apply $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; \ fi cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ @@ -119,8 +115,14 @@ mhlo: LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $(MHLO_BUILD_DIR) --target check-mlir-hlo .PHONY: enzyme +enzyme: TARGET_FILE := $(MK_DIR)/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +enzyme: PATCH_FILE := $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch enzyme: @echo "build enzyme" + # Patch enzyme's dependency on nvidia fabs llvm intrinsics + @if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \ + patch -p1 $(TARGET_FILE) $(PATCH_FILE); \ + fi cmake -G Ninja -S Enzyme/enzyme -B $(ENZYME_BUILD_DIR) \ -DENZYME_STATIC_LIB=ON \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ @@ -133,7 +135,7 @@ enzyme: -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \ -DCMAKE_POLICY_DEFAULT_CMP0116=NEW - cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-20 + cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21 .PHONY: plugin plugin: @@ -204,6 +206,7 @@ clean-llvm: clean-mhlo: @echo "clean HLO dialect build files" rm -rf $(MHLO_BUILD_DIR) + cd mlir-hlo; git clean -fd; git checkout . clean-enzyme: @echo "clean enzyme build files" diff --git a/mlir/include/Catalyst/IR/CatalystOps.td b/mlir/include/Catalyst/IR/CatalystOps.td index 9a58872b84..69ab0f54e5 100644 --- a/mlir/include/Catalyst/IR/CatalystOps.td +++ b/mlir/include/Catalyst/IR/CatalystOps.td @@ -170,7 +170,9 @@ def CallbackCallOp : Catalyst_Op<"callback_call", let arguments = (ins FlatSymbolRefAttr:$callee, - Variadic]>>:$inputs + Variadic]>>:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic); @@ -188,7 +190,9 @@ def LaunchKernelOp : Catalyst_Op<"launch_kernel", let arguments = (ins SymbolRefAttr:$callee, - Variadic]>>:$inputs + Variadic]>>:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic); diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index 7fb7122e58..f7872961ad 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -23,6 +23,7 @@ namespace catalyst { std::unique_ptr createAddExceptionHandlingPass(); std::unique_ptr createApplyTransformSequencePass(); std::unique_ptr createArrayListToMemRefPass(); +std::unique_ptr createBufferDeallocationPass(); std::unique_ptr createCatalystBufferizationPass(); std::unique_ptr createCatalystConversionPass(); std::unique_ptr createDetensorizeSCFPass(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index d22246e3c8..e1512e00e0 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -198,4 +198,88 @@ def InlineNestedModulePass : Pass<"inline-nested-module"> { ]; } + +// Legacy buffer deallocation pass. +// +// This pass has been modified from its original form in the LLVM project at +// https://github.com/llvm/llvm-project released under the Apache License, Version 2.0, +// with the following copyright notice: +// +// * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// * See https://llvm.org/LICENSE.txt for license information. +// * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> { + let summary = "Adds all required dealloc operations for all allocations in " + "the input program"; + let description = [{ + This pass implements an algorithm to automatically introduce all required + deallocation operations for all buffers in the input program. This ensures + that the resulting program does not have any memory leaks. + + + Input + + ```mlir + #map0 = affine_map<(d0) -> (d0)> + module { + func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cf.cond_br %arg0, ^bb1, ^bb2 + ^bb1: + cf.br ^bb3(%arg1 : memref<2xf32>) + ^bb2: + %0 = memref.alloc() : memref<2xf32> + linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} %arg1, %0 { + ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): + %tmp1 = exp %gen1_arg0 : f32 + linalg.yield %tmp1 : f32 + }: memref<2xf32>, memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) + ^bb3(%1: memref<2xf32>): + "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return + } + } + + ``` + + Output + + ```mlir + #map0 = affine_map<(d0) -> (d0)> + module { + func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cf.cond_br %arg0, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %0 = memref.alloc() : memref<2xf32> + memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) + ^bb2: // pred: ^bb0 + %1 = memref.alloc() : memref<2xf32> + linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} %arg1, %1 { + ^bb0(%arg3: f32, %arg4: f32): + %4 = exp %arg3 : f32 + linalg.yield %4 : f32 + }: memref<2xf32>, memref<2xf32> + %2 = memref.alloc() : memref<2xf32> + memref.copy(%1, %2) : memref<2xf32>, memref<2xf32> + dealloc %1 : memref<2xf32> + cf.br ^bb3(%2 : memref<2xf32>) + ^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2 + memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32> + dealloc %3 : memref<2xf32> + return + } + + } + ``` + + }]; + let constructor = "mlir::bufferization::createBufferDeallocationPass()"; +} + #endif // CATALYST_PASSES diff --git a/mlir/include/Gradient/IR/GradientOps.td b/mlir/include/Gradient/IR/GradientOps.td index 236aa3e4e6..8931e109e0 100644 --- a/mlir/include/Gradient/IR/GradientOps.td +++ b/mlir/include/Gradient/IR/GradientOps.td @@ -18,6 +18,7 @@ include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpBase.td" @@ -58,7 +59,9 @@ def GradOp : Gradient_Op<"grad", [ SymbolRefAttr:$callee, Variadic:$operands, OptionalAttr:$diffArgIndices, - OptionalAttr:$finiteDiffParam + OptionalAttr:$finiteDiffParam, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic]>>); @@ -82,7 +85,9 @@ def ValueAndGradOp : Gradient_Op<"value_and_grad", [ SymbolRefAttr:$callee, Variadic:$operands, OptionalAttr:$diffArgIndices, - OptionalAttr:$finiteDiffParam + OptionalAttr:$finiteDiffParam, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs @@ -184,7 +189,9 @@ def JVPOp : Gradient_Op<"jvp", [ Variadic:$params, Variadic:$tangents, OptionalAttr:$diffArgIndices, - OptionalAttr:$finiteDiffParam + OptionalAttr:$finiteDiffParam, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs @@ -217,7 +224,9 @@ def VJPOp : Gradient_Op<"vjp", [ Variadic:$params, Variadic:$cotangents, OptionalAttr:$diffArgIndices, - OptionalAttr:$finiteDiffParam + OptionalAttr:$finiteDiffParam, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs diff --git a/mlir/include/Mitigation/IR/MitigationOps.td b/mlir/include/Mitigation/IR/MitigationOps.td index 6da2848cd4..036aea40bb 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.td +++ b/mlir/include/Mitigation/IR/MitigationOps.td @@ -48,7 +48,9 @@ def ZneOp : Mitigation_Op<"zne", [DeclareOpInterfaceMethods, SymbolRefAttr:$callee, Variadic:$args, FoldingAttr:$folding, - RankedTensorOf<[AnySignlessIntegerOrIndex]>:$numFolds + RankedTensorOf<[AnySignlessIntegerOrIndex]>:$numFolds, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic]>>); diff --git a/mlir/include/Quantum/IR/QuantumDialect.h b/mlir/include/Quantum/IR/QuantumDialect.h index aad9d23a23..a4f7e32350 100644 --- a/mlir/include/Quantum/IR/QuantumDialect.h +++ b/mlir/include/Quantum/IR/QuantumDialect.h @@ -16,6 +16,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// // Quantum dialect declarations. @@ -29,3 +30,7 @@ #define GET_TYPEDEF_CLASSES #include "Quantum/IR/QuantumOpsTypes.h.inc" + +class QuantumMemory : public mlir::SideEffects::Resource::Base { + llvm::StringRef getName() final { return "QuantumMemory"; } +}; diff --git a/mlir/include/Quantum/IR/QuantumDialect.td b/mlir/include/Quantum/IR/QuantumDialect.td index d2899dd8c4..c6368eceff 100644 --- a/mlir/include/Quantum/IR/QuantumDialect.td +++ b/mlir/include/Quantum/IR/QuantumDialect.td @@ -17,6 +17,7 @@ include "mlir/IR/DialectBase.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// // Quantum dialect. @@ -72,4 +73,23 @@ def ResultType : Quantum_Type<"Result", "res"> { let summary = "A quantum measurement result."; } +//===----------------------------------------------------------------------===// +// Quantum resource abstractions. +//===----------------------------------------------------------------------===// + +def QuantumMemory : Resource<"QuantumMemory"> { + // This resource represents a generic piece of quantum memory. It can be used to + // model the resource bahavior of quantum operations, in order to help make decisions + // during various IR analyses. + // + // For an operation from a downstream dialect (e.g. the MBQC dialect) to use this resource, + // all the downstream dialect needs to do is to + // `#include "Quantum/IR/QuantumDialect.h"` + // `include "Quantum/IR/QuantumDialect.td"` + // in the downstream dialect's "Ops" cpp header tablegen file, both of which it is likely + // already doing. + // + // Note that `Resource` in tablegen does not have a `description` field. +} + #endif // QUANTUM_DIALECT diff --git a/mlir/include/Quantum/IR/QuantumOps.h b/mlir/include/Quantum/IR/QuantumOps.h index 4ccca45618..3479659cb5 100644 --- a/mlir/include/Quantum/IR/QuantumOps.h +++ b/mlir/include/Quantum/IR/QuantumOps.h @@ -14,17 +14,17 @@ #pragma once +#include + #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" -#include #include "Quantum/IR/QuantumInterfaces.h" @@ -44,10 +44,6 @@ class HermitianTrait : public TraitBase {}; } // namespace OpTrait } // namespace mlir -class QuantumMemory : public mlir::SideEffects::Resource::Base { - llvm::StringRef getName() final { return "QuantumMemory"; } -}; - //===----------------------------------------------------------------------===// // Quantum ops declarations. //===----------------------------------------------------------------------===// diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 61483e940c..f1fcf50c8e 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -17,7 +17,6 @@ include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" @@ -49,8 +48,6 @@ def NamedObservable : I32EnumAttr<"NamedObservable", def Unitary : NativeOpTrait<"UnitaryTrait">; def Hermitian : NativeOpTrait<"HermitianTrait">; -def QuantumMemory : Resource<"QuantumMemory">; - //===----------------------------------------------------------------------===// // Quantum dialect attributes. //===----------------------------------------------------------------------===// @@ -341,7 +338,7 @@ class UnitaryGate_Op traits = []> : let extraClassDeclaration = extraBaseClassDeclaration; } -def SetStateOp : Gate_Op<"set_state", []> { +def SetStateOp : Gate_Op<"set_state"> { let summary = "Set state to a complex vector."; let description = [{ This operation is useful for simulators implementing state preparation. @@ -372,7 +369,7 @@ def SetStateOp : Gate_Op<"set_state", []> { } -def SetBasisStateOp : Gate_Op<"set_basis_state", []> { +def SetBasisStateOp : Gate_Op<"set_basis_state"> { let summary = "Set basis state."; let description = [{ This operation is useful for simulators implementing set basis state. diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp index aae67c64ca..c6f157e7c2 100644 --- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp +++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp @@ -79,7 +79,7 @@ CallInterfaceCallable CallbackCallOp::getCallableForCallee() void CallbackCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); } Operation::operand_range CallbackCallOp::getArgOperands() { return getInputs(); } @@ -97,7 +97,7 @@ CallInterfaceCallable LaunchKernelOp::getCallableForCallee() void LaunchKernelOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); } Operation::operand_range LaunchKernelOp::getArgOperands() { return getInputs(); } diff --git a/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp b/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp index a1f03ba52b..a96a0bc2cb 100644 --- a/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp +++ b/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp @@ -187,7 +187,7 @@ struct LowerListInit : public OpConversionPattern { struct LowerListDealloc : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ListDeallocOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(ListDeallocOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto typeConverter = getTypeConverter(); @@ -210,7 +210,7 @@ struct LowerListDealloc : public OpConversionPattern { struct LowerListPush : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ListPushOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(ListPushOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto typeConverter = getTypeConverter(); @@ -231,7 +231,7 @@ struct LowerListPush : public OpConversionPattern { struct LowerListPop : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ListPopOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(ListPopOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto typeConverter = getTypeConverter(); @@ -252,7 +252,7 @@ struct LowerListPop : public OpConversionPattern { struct LowerListLoadData : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ListLoadDataOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(ListLoadDataOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto typeConverter = getTypeConverter(); diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp index a295156330..76cee871d4 100644 --- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp +++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp @@ -134,14 +134,16 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(ModuleOp moduleOp) auto i32Ty = IntegerType::get(ctx, 32); bool isVarArg = true; return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::personalityName, {}, i32Ty, - isVarArg); + isVarArg) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); auto voidTy = LLVM::LLVMVoidType::get(ctx); - return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::abortName, {}, voidTy); + return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::abortName, {}, voidTy) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(ModuleOp moduleOp) @@ -150,7 +152,8 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(ModuleOp moduleOp) Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy); + moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(ModuleOp moduleOp) @@ -159,7 +162,8 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(ModuleOp moduleOp) Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy); + moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(ModuleOp moduleOp) @@ -169,7 +173,8 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(ModuleOp moduleOp) Type llvmInt64Type = IntegerType::get(moduleOp.getContext(), 64); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeDropRefName, - {ptrTy, llvmInt64Type}, voidTy); + {ptrTy, llvmInt64Type}, voidTy) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp) @@ -178,7 +183,8 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleO Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy); + moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp) @@ -187,7 +193,8 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleO Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy); + moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy) + .value(); } LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(ModuleOp moduleOp) @@ -195,7 +202,8 @@ LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(ModuleOp moduleOp) MLIRContext *ctx = moduleOp.getContext(); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::unrecoverableErrorName, {}, - voidTy); + voidTy) + .value(); } std::optional AsyncUtils::getCalleeSafe(LLVM::CallOp callOp) diff --git a/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp b/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp new file mode 100644 index 0000000000..012799bec7 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp @@ -0,0 +1,743 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file has been modified from its original form in the LLVM project at +// https://github.com/llvm/llvm-project released under the Apache License, Version 2.0, +// with the following copyright notice: +// +// * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// * See https://llvm.org/LICENSE.txt for license information. +// * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +/* + * This file is a modified version of + * + * https://github.com/llvm/llvm-project/blob/9b2fc66830b2e81d95ef272ddc51c6cff9ba23a1/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp + * + * The modifications are porting the pass from the upstream bufferization namespace to + * catalyst namespace. + */ + +// +// This file implements logic for computing correct alloc and dealloc positions. +// Furthermore, buffer deallocation also adds required new clone operations to +// ensure that all buffers are deallocated. The main class is the +// BufferDeallocationPass class that implements the underlying algorithm. In +// order to put allocations and deallocations at safe positions, it is +// significantly important to put them into the correct blocks. However, the +// liveness analysis does not pay attention to aliases, which can occur due to +// branches (and their associated block arguments) in general. For this purpose, +// BufferDeallocation firstly finds all possible aliases for a single value +// (using the BufferViewFlowAnalysis class). Consider the following example: +// +// ^bb0(%arg0): +// cf.cond_br %cond, ^bb1, ^bb2 +// ^bb1: +// cf.br ^exit(%arg0) +// ^bb2: +// %new_value = ... +// cf.br ^exit(%new_value) +// ^exit(%arg1): +// return %arg1; +// +// We should place the dealloc for %new_value in exit. However, we have to free +// the buffer in the same block, because it cannot be freed in the post +// dominator. However, this requires a new clone buffer for %arg1 that will +// contain the actual contents. Using the class BufferViewFlowAnalysis, we +// will find out that %new_value has a potential alias %arg1. In order to find +// the dealloc position we have to find all potential aliases, iterate over +// their uses and find the common post-dominator block (note that additional +// clones and buffers remove potential aliases and will influence the placement +// of the deallocs). In all cases, the computed block can be safely used to free +// the %new_value buffer (may be exit or bb2) as it will die and we can use +// liveness information to determine the exact operation after which we have to +// insert the dealloc. However, the algorithm supports introducing clone buffers +// and placing deallocs in safe locations to ensure that all buffers will be +// freed in the end. +// +// TODO: +// The current implementation does not support explicit-control-flow loops and +// the resulting code will be invalid with respect to program semantics. +// However, structured control-flow loops are fully supported. Furthermore, it +// doesn't accept functions which return buffers already. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SetOperations.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/Passes.h" + +// using namespace llvm; +using namespace mlir; +using namespace catalyst; + +namespace catalyst { + +#define GEN_PASS_DEF_BUFFERDEALLOCATION +#define GEN_PASS_DECL_BUFFERDEALLOCATION +#include "Catalyst/Transforms/Passes.h.inc" + +} // namespace catalyst + +// namespace mlir { +// namespace bufferization { +// #define GEN_PASS_DEF_BUFFERDEALLOCATION +// #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" +// } // namespace bufferization +// } // namespace mlir + +using namespace mlir; +using namespace mlir::bufferization; + +/// Walks over all immediate return-like terminators in the given region. +static LogicalResult +walkReturnOperations(Region *region, + llvm::function_ref func) +{ + for (Block &block : *region) { + Operation *terminator = block.getTerminator(); + // Skip non region-return-like terminators. + if (auto regionTerminator = dyn_cast(terminator)) { + if (failed(func(regionTerminator))) + return failure(); + } + } + return success(); +} + +/// Checks if all operations that have at least one attached region implement +/// the RegionBranchOpInterface. This is not required in edge cases, where we +/// have a single attached region and the parent operation has no results. +static bool validateSupportedControlFlow(Operation *op) +{ + WalkResult result = op->walk([&](Operation *operation) { + // Only check ops that are inside a function. + if (!operation->getParentOfType()) + return WalkResult::advance(); + + auto regions = operation->getRegions(); + // Walk over all operations in a region and check if the operation has at + // least one region and implements the RegionBranchOpInterface. If there + // is an operation that does not fulfill this condition, we cannot apply + // the deallocation steps. Furthermore, we accept cases, where we have a + // region that returns no results, since, in that case, the intra-region + // control flow does not affect the transformation. + size_t size = regions.size(); + if (((size == 1 && !operation->getResults().empty()) || size > 1) && + !dyn_cast(operation)) { + operation->emitError("All operations with attached regions need to " + "implement the RegionBranchOpInterface."); + } + + return WalkResult::advance(); + }); + return !result.wasSkipped(); +} + +namespace { + +//===----------------------------------------------------------------------===// +// Backedges analysis +//===----------------------------------------------------------------------===// + +/// A straight-forward program analysis which detects loop backedges induced by +/// explicit control flow. +class Backedges { + public: + using BlockSetT = SmallPtrSet; + using BackedgeSetT = llvm::DenseSet>; + + public: + /// Constructs a new backedges analysis using the op provided. + Backedges(Operation *op) { recurse(op); } + + /// Returns the number of backedges formed by explicit control flow. + size_t size() const { return edgeSet.size(); } + + /// Returns the start iterator to loop over all backedges. + BackedgeSetT::const_iterator begin() const { return edgeSet.begin(); } + + /// Returns the end iterator to loop over all backedges. + BackedgeSetT::const_iterator end() const { return edgeSet.end(); } + + private: + /// Enters the current block and inserts a backedge into the `edgeSet` if we + /// have already visited the current block. The inserted edge links the given + /// `predecessor` with the `current` block. + bool enter(Block ¤t, Block *predecessor) + { + bool inserted = visited.insert(¤t).second; + if (!inserted) + edgeSet.insert(std::make_pair(predecessor, ¤t)); + return inserted; + } + + /// Leaves the current block. + void exit(Block ¤t) { visited.erase(¤t); } + + /// Recurses into the given operation while taking all attached regions into + /// account. + void recurse(Operation *op) + { + Block *current = op->getBlock(); + // If the current op implements the `BranchOpInterface`, there can be + // cycles in the scope of all successor blocks. + if (isa(op)) { + for (Block *succ : current->getSuccessors()) + recurse(*succ, current); + } + // Recurse into all distinct regions and check for explicit control-flow + // loops. + for (Region ®ion : op->getRegions()) { + if (!region.empty()) + recurse(region.front(), current); + } + } + + /// Recurses into explicit control-flow structures that are given by + /// the successor relation defined on the block level. + void recurse(Block &block, Block *predecessor) + { + // Try to enter the current block. If this is not possible, we are + // currently processing this block and can safely return here. + if (!enter(block, predecessor)) + return; + + // Recurse into all operations and successor blocks. + for (Operation &op : block.getOperations()) + recurse(&op); + + // Leave the current block. + exit(block); + } + + /// Stores all blocks that are currently visited and on the processing stack. + BlockSetT visited; + + /// Stores all backedges in the format (source, target). + BackedgeSetT edgeSet; +}; + +//===----------------------------------------------------------------------===// +// BufferDeallocation +//===----------------------------------------------------------------------===// + +/// The buffer deallocation transformation which ensures that all allocs in the +/// program have a corresponding de-allocation. As a side-effect, it might also +/// introduce clones that in turn leads to additional deallocations. +class BufferDeallocation : public BufferPlacementTransformationBase { + public: + using AliasAllocationMapT = llvm::DenseMap; + + BufferDeallocation(Operation *op) + : BufferPlacementTransformationBase(op), dominators(op), postDominators(op) + { + } + + /// Checks if all allocation operations either provide an already existing + /// deallocation operation or implement the AllocationOpInterface. In + /// addition, this method initializes the internal alias to + /// AllocationOpInterface mapping in order to get compatible + /// AllocationOpInterface implementations for aliases. + LogicalResult prepare() + { + for (const BufferPlacementAllocs::AllocEntry &entry : allocs) { + // Get the defining allocation operation. + Value alloc = std::get<0>(entry); + auto allocationInterface = alloc.getDefiningOp(); + // If there is no existing deallocation operation and no implementation of + // the AllocationOpInterface, we cannot apply the BufferDeallocation pass. + if (!std::get<1>(entry) && !allocationInterface) { + return alloc.getDefiningOp()->emitError( + "Allocation is not deallocated explicitly nor does the operation " + "implement the AllocationOpInterface."); + } + + // Register the current allocation interface implementation. + aliasToAllocations[alloc] = allocationInterface; + + // Get the alias information for the current allocation node. + for (Value alias : aliases.resolve(alloc)) { + // TODO: check for incompatible implementations of the + // AllocationOpInterface. This could be realized by promoting the + // AllocationOpInterface to a DialectInterface. + aliasToAllocations[alias] = allocationInterface; + } + } + return success(); + } + + /// Performs the actual placement/creation of all temporary clone and dealloc + /// nodes. + LogicalResult deallocate() + { + // Add additional clones that are required. + if (failed(introduceClones())) + return failure(); + + // Place deallocations for all allocation entries. + return placeDeallocs(); + } + + private: + /// Introduces required clone operations to avoid memory leaks. + LogicalResult introduceClones() + { + // Initialize the set of values that require a dedicated memory free + // operation since their operands cannot be safely deallocated in a post + // dominator. + SetVector valuesToFree; + llvm::SmallDenseSet> visitedValues; + SmallVector, 8> toProcess; + + // Check dominance relation for proper dominance properties. If the given + // value node does not dominate an alias, we will have to create a clone in + // order to free all buffers that can potentially leak into a post + // dominator. + auto findUnsafeValues = [&](Value source, Block *definingBlock) { + auto it = aliases.find(source); + if (it == aliases.end()) + return; + for (Value value : it->second) { + if (valuesToFree.count(value) > 0) + continue; + Block *parentBlock = value.getParentBlock(); + // Check whether we have to free this particular block argument or + // generic value. We have to free the current alias if it is either + // defined in a non-dominated block or it is defined in the same block + // but the current value is not dominated by the source value. + if (!dominators.dominates(definingBlock, parentBlock) || + (definingBlock == parentBlock && isa(value))) { + toProcess.emplace_back(value, parentBlock); + valuesToFree.insert(value); + } + else if (visitedValues.insert(std::make_tuple(value, definingBlock)).second) + toProcess.emplace_back(value, definingBlock); + } + }; + + // Detect possibly unsafe aliases starting from all allocations. + for (BufferPlacementAllocs::AllocEntry &entry : allocs) { + Value allocValue = std::get<0>(entry); + findUnsafeValues(allocValue, allocValue.getDefiningOp()->getBlock()); + } + // Try to find block arguments that require an explicit free operation + // until we reach a fix point. + while (!toProcess.empty()) { + auto current = toProcess.pop_back_val(); + findUnsafeValues(std::get<0>(current), std::get<1>(current)); + } + + // Update buffer aliases to ensure that we free all buffers and block + // arguments at the correct locations. + aliases.remove(valuesToFree); + + // Add new allocs and additional clone operations. + for (Value value : valuesToFree) { + if (!isa(value.getType())) { + continue; + } + if (failed(isa(value) ? introduceBlockArgCopy(cast(value)) + : introduceValueCopyForRegionResult(value))) + return failure(); + + // Register the value to require a final dealloc. Note that we do not have + // to assign a block here since we do not want to move the allocation node + // to another location. + allocs.registerAlloc(std::make_tuple(value, nullptr)); + } + return success(); + } + + /// Introduces temporary clones in all predecessors and copies the source + /// values into the newly allocated buffers. + LogicalResult introduceBlockArgCopy(BlockArgument blockArg) + { + // Allocate a buffer for the current block argument in the block of + // the associated value (which will be a predecessor block by + // definition). + Block *block = blockArg.getOwner(); + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + // Get the terminator and the value that will be passed to our + // argument. + Operation *terminator = (*it)->getTerminator(); + auto branchInterface = cast(terminator); + SuccessorOperands operands = + branchInterface.getSuccessorOperands(it.getSuccessorIndex()); + + // Query the associated source value. + Value sourceValue = operands[blockArg.getArgNumber()]; + if (!sourceValue) { + return failure(); + } + // Wire new clone and successor operand. + // Create a new clone at the current location of the terminator. + auto clone = introduceCloneBuffers(sourceValue, terminator); + if (failed(clone)) + return failure(); + operands.slice(blockArg.getArgNumber(), 1).assign(*clone); + } + + // Check whether the block argument has implicitly defined predecessors via + // the RegionBranchOpInterface. This can be the case if the current block + // argument belongs to the first block in a region and the parent operation + // implements the RegionBranchOpInterface. + Region *argRegion = block->getParent(); + Operation *parentOp = argRegion->getParentOp(); + RegionBranchOpInterface regionInterface; + if (&argRegion->front() != block || + !(regionInterface = dyn_cast(parentOp))) + return success(); + + if (failed(introduceClonesForRegionSuccessors( + regionInterface, argRegion->getParentOp()->getRegions(), blockArg, + [&](RegionSuccessor &successorRegion) { + // Find a predecessor of our argRegion. + return successorRegion.getSuccessor() == argRegion; + }))) + return failure(); + + // Check whether the block argument belongs to an entry region of the + // parent operation. In this case, we have to introduce an additional clone + // for buffer that is passed to the argument. + SmallVector successorRegions; + regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), + successorRegions); + auto *it = llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) { + return successorRegion.getSuccessor() == argRegion; + }); + if (it == successorRegions.end()) + return success(); + + // Determine the actual operand to introduce a clone for and rewire the + // operand to point to the clone instead. + auto operands = regionInterface.getEntrySuccessorOperands(argRegion); + size_t operandIndex = llvm::find(it->getSuccessorInputs(), blockArg).getIndex() + + operands.getBeginOperandIndex(); + Value operand = parentOp->getOperand(operandIndex); + assert(operand == operands[operandIndex - operands.getBeginOperandIndex()] && + "region interface operands don't match parentOp operands"); + auto clone = introduceCloneBuffers(operand, parentOp); + if (failed(clone)) + return failure(); + + parentOp->setOperand(operandIndex, *clone); + return success(); + } + + /// Introduces temporary clones in front of all associated nested-region + /// terminators and copies the source values into the newly allocated buffers. + LogicalResult introduceValueCopyForRegionResult(Value value) + { + // Get the actual result index in the scope of the parent terminator. + Operation *operation = value.getDefiningOp(); + auto regionInterface = cast(operation); + // Filter successors that return to the parent operation. + auto regionPredicate = [&](RegionSuccessor &successorRegion) { + // If the RegionSuccessor has no associated successor, it will return to + // its parent operation. + return !successorRegion.getSuccessor(); + }; + // Introduce a clone for all region "results" that are returned to the + // parent operation. This is required since the parent's result value has + // been considered critical. Therefore, the algorithm assumes that a clone + // of a previously allocated buffer is returned by the operation (like in + // the case of a block argument). + return introduceClonesForRegionSuccessors(regionInterface, operation->getRegions(), value, + regionPredicate); + } + + /// Introduces buffer clones for all terminators in the given regions. The + /// regionPredicate is applied to every successor region in order to restrict + /// the clones to specific regions. + template + LogicalResult introduceClonesForRegionSuccessors(RegionBranchOpInterface regionInterface, + MutableArrayRef regions, + Value argValue, + const TPredicate ®ionPredicate) + { + for (Region ®ion : regions) { + // Query the regionInterface to get all successor regions of the current + // one. + SmallVector successorRegions; + regionInterface.getSuccessorRegions(region, successorRegions); + // Try to find a matching region successor. + RegionSuccessor *regionSuccessor = llvm::find_if(successorRegions, regionPredicate); + if (regionSuccessor == successorRegions.end()) + continue; + // Get the operand index in the context of the current successor input + // bindings. + size_t operandIndex = + llvm::find(regionSuccessor->getSuccessorInputs(), argValue).getIndex(); + + // Iterate over all immediate terminator operations to introduce + // new buffer allocations. Thereby, the appropriate terminator operand + // will be adjusted to point to the newly allocated buffer instead. + if (failed(walkReturnOperations( + ®ion, [&](RegionBranchTerminatorOpInterface terminator) { + // Get the actual mutable operands for this terminator op. + auto terminatorOperands = + terminator.getMutableSuccessorOperands(*regionSuccessor); + // Extract the source value from the current terminator. + // This conversion needs to exist on a separate line due to a + // bug in GCC conversion analysis. + OperandRange immutableTerminatorOperands = terminatorOperands; + Value sourceValue = immutableTerminatorOperands[operandIndex]; + // Create a new clone at the current location of the terminator. + auto clone = introduceCloneBuffers(sourceValue, terminator); + if (failed(clone)) + return failure(); + // Wire clone and terminator operand. + terminatorOperands.slice(operandIndex, 1).assign(*clone); + return success(); + }))) + return failure(); + } + return success(); + } + + /// Creates a new memory allocation for the given source value and clones + /// its content into the newly allocated buffer. The terminator operation is + /// used to insert the clone operation at the right place. + FailureOr introduceCloneBuffers(Value sourceValue, Operation *terminator) + { + // Avoid multiple clones of the same source value. This can happen in the + // presence of loops when a branch acts as a backedge while also having + // another successor that returns to its parent operation. Note: that + // copying copied buffers can introduce memory leaks since the invariant of + // BufferDeallocation assumes that a buffer will be only cloned once into a + // temporary buffer. Hence, the construction of clone chains introduces + // additional allocations that are not tracked automatically by the + // algorithm. + if (clonedValues.contains(sourceValue)) + return sourceValue; + // Create a new clone operation that copies the contents of the old + // buffer to the new one. + auto clone = buildClone(terminator, sourceValue); + if (succeeded(clone)) { + // Remember the clone of original source value. + clonedValues.insert(*clone); + } + return clone; + } + + /// Finds correct dealloc positions according to the algorithm described at + /// the top of the file for all alloc nodes and block arguments that can be + /// handled by this analysis. + LogicalResult placeDeallocs() + { + // Move or insert deallocs using the previously computed information. + // These deallocations will be linked to their associated allocation nodes + // since they don't have any aliases that can (potentially) increase their + // liveness. + for (const BufferPlacementAllocs::AllocEntry &entry : allocs) { + Value alloc = std::get<0>(entry); + auto aliasesSet = aliases.resolve(alloc); + assert(!aliasesSet.empty() && "must contain at least one alias"); + + // Determine the actual block to place the dealloc and get liveness + // information. + Block *placementBlock = findCommonDominator(alloc, aliasesSet, postDominators); + const LivenessBlockInfo *livenessInfo = liveness.getLiveness(placementBlock); + + // We have to ensure that the dealloc will be after the last use of all + // aliases of the given value. We first assume that there are no uses in + // the placementBlock and that we can safely place the dealloc at the + // beginning. + Operation *endOperation = &placementBlock->front(); + + // Iterate over all aliases and ensure that the endOperation will point + // to the last operation of all potential aliases in the placementBlock. + for (Value alias : aliasesSet) { + // Ensure that the start operation is at least the defining operation of + // the current alias to avoid invalid placement of deallocs for aliases + // without any uses. + Operation *beforeOp = endOperation; + if (alias.getDefiningOp() && + !(beforeOp = placementBlock->findAncestorOpInBlock(*alias.getDefiningOp()))) + continue; + + Operation *aliasEndOperation = livenessInfo->getEndOperation(alias, beforeOp); + // Check whether the aliasEndOperation lies in the desired block and + // whether it is behind the current endOperation. If yes, this will be + // the new endOperation. + if (aliasEndOperation->getBlock() == placementBlock && + endOperation->isBeforeInBlock(aliasEndOperation)) + endOperation = aliasEndOperation; + } + // endOperation is the last operation behind which we can safely store + // the dealloc taking all potential aliases into account. + + // If there is an existing dealloc, move it to the right place. + Operation *deallocOperation = std::get<1>(entry); + if (deallocOperation) { + deallocOperation->moveAfter(endOperation); + } + else { + // If the Dealloc position is at the terminator operation of the + // block, then the value should escape from a deallocation. + Operation *nextOp = endOperation->getNextNode(); + if (!nextOp) + continue; + // If there is no dealloc node, insert one in the right place. + if (failed(buildDealloc(nextOp, alloc))) + return failure(); + } + } + return success(); + } + + /// Builds a deallocation operation compatible with the given allocation + /// value. If there is no registered AllocationOpInterface implementation for + /// the given value (e.g. in the case of a function parameter), this method + /// builds a memref::DeallocOp. + LogicalResult buildDealloc(Operation *op, Value alloc) + { + OpBuilder builder(op); + auto it = aliasToAllocations.find(alloc); + if (it != aliasToAllocations.end()) { + // Call the allocation op interface to build a supported and + // compatible deallocation operation. + auto dealloc = it->second.buildDealloc(builder, alloc); + if (!dealloc) + return op->emitError() << "allocations without compatible deallocations are " + "not supported"; + } + else { + // Build a "default" DeallocOp for unknown allocation sources. + builder.create(alloc.getLoc(), alloc); + } + return success(); + } + + /// Builds a clone operation compatible with the given allocation value. If + /// there is no registered AllocationOpInterface implementation for the given + /// value (e.g. in the case of a function parameter), this method builds a + /// bufferization::CloneOp. + FailureOr buildClone(Operation *op, Value alloc) + { + OpBuilder builder(op); + auto it = aliasToAllocations.find(alloc); + if (it != aliasToAllocations.end()) { + // Call the allocation op interface to build a supported and + // compatible clone operation. + auto clone = it->second.buildClone(builder, alloc); + if (clone) + return *clone; + return (LogicalResult)(op->emitError() << "allocations without compatible clone ops " + "are not supported"); + } + // Build a "default" CloneOp for unknown allocation sources. + return builder.create(alloc.getLoc(), alloc).getResult(); + } + + /// The dominator info to find the appropriate start operation to move the + /// allocs. + DominanceInfo dominators; + + /// The post dominator info to move the dependent allocs in the right + /// position. + PostDominanceInfo postDominators; + + /// Stores already cloned buffers to avoid additional clones of clones. + ValueSetT clonedValues; + + /// Maps aliases to their source allocation interfaces (inverse mapping). + AliasAllocationMapT aliasToAllocations; +}; + +//===----------------------------------------------------------------------===// +// BufferDeallocationPass +//===----------------------------------------------------------------------===// + +/// The actual buffer deallocation pass that inserts and moves dealloc nodes +/// into the right positions. Furthermore, it inserts additional clones if +/// necessary. It uses the algorithm described at the top of the file. +struct BufferDeallocationPass + : public catalyst::impl::BufferDeallocationBase { + void getDependentDialects(DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + } + + LogicalResult deallocateBuffers(Operation *op) + { + if (isa(op)) { + WalkResult result = op->walk([&](func::FuncOp funcOp) { + if (failed(deallocateBuffers(funcOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + // Ensure that there are supported loops only. + Backedges backedges(op); + if (backedges.size()) { + op->emitError("Only structured control-flow loops are supported."); + return failure(); + } + + // Check that the control flow structures are supported. + if (!validateSupportedControlFlow(op)) + return failure(); + + // Gather all required allocation nodes and prepare the deallocation phase. + BufferDeallocation deallocation(op); + + // Check for supported AllocationOpInterface implementations and prepare the + // internal deallocation pass. + if (failed(deallocation.prepare())) + return failure(); + + // Place all required temporary clone and dealloc nodes. + if (failed(deallocation.deallocate())) + return failure(); + + return success(); + } + + void runOnOperation() override + { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + if (failed(deallocateBuffers(func))) + signalPassFailure(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// BufferDeallocationPass construction +//===----------------------------------------------------------------------===// + +std::unique_ptr catalyst::createBufferDeallocationPass() +{ + return std::make_unique(); +} diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 6bd01f6615..c5b93d1602 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -321,7 +321,8 @@ struct CallbackCallOpInterface } SmallVector emptyRets; - rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs); + rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs); return success(); } diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 48d4aa362b..b4776af6a9 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ file(GLOB SRC ApplyTransformSequencePass.cpp ArrayListToMemRefPass.cpp AsyncUtils.cpp + BufferDeallocation.cpp BufferizableOpInterfaceImpl.cpp catalyst_to_llvm.cpp DetectQNodes.cpp diff --git a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp index 9a62a283e7..d516605b4c 100644 --- a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp +++ b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp @@ -113,24 +113,23 @@ LogicalResult DetectCallsInAsyncRegionsTransform::matchAndRewrite(LLVM::CallOp c struct AddExceptionHandlingTransform : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(LLVM::CallOp op) const override; - void rewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; }; /* Here we only match with calls that have the { catalyst.preInvoke } annotations. * The reason behind this separation between the previous pattern and this one, * is that this pattern can potentially be reused as long as this single annotation is present. */ -LogicalResult AddExceptionHandlingTransform::match(LLVM::CallOp callOp) const +LogicalResult AddExceptionHandlingTransform::matchAndRewrite(LLVM::CallOp callOp, + PatternRewriter &rewriter) const { // The following is a valid match // llvm.call @callee() { catalyst.preInvoke } bool validCandidate = AsyncUtils::isScheduledForTransformation(callOp); - return validCandidate ? success() : failure(); -} + if (!validCandidate) { + return failure(); + } -void AddExceptionHandlingTransform::rewrite(LLVM::CallOp callOp, PatternRewriter &rewriter) const -{ auto moduleOp = callOp->getParentOfType(); // Here, we are adding a reference to the personality declaration. // From the documentation: https://llvm.org/docs/ExceptionHandling.html#exception-tables @@ -255,6 +254,7 @@ void AddExceptionHandlingTransform::rewrite(LLVM::CallOp callOp, PatternRewriter // // llvm.func caller() attributes { catalyst.preHandleError } AsyncUtils::scheduleAnalysisForErrorHandling(caller, rewriter); + return success(); } /* The next step is to inspect callers of the previous caller. @@ -264,8 +264,7 @@ void AddExceptionHandlingTransform::rewrite(LLVM::CallOp callOp, PatternRewriter struct RemoveAbortAndPutsInsertCallTransform : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(LLVM::CallOp op) const override; - void rewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; }; // In this pattern we are looking for function calls to functions annotated @@ -276,27 +275,21 @@ struct RemoveAbortAndPutsInsertCallTransform : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(LLVM::CallOp op) const override; - void rewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; }; -LogicalResult LivenessAnalysisDropRef::match(LLVM::CallOp op) const +LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink, + PatternRewriter &rewriter) const { // We match on function calls that have the sink attribute. // llvm.call @__catalyst__host__rt__unrecoverable_error() { catalyst.sink } - return AsyncUtils::isSink(op) ? success() : failure(); -} + if (!AsyncUtils::isSink(sink)) { + return failure(); + } -void LivenessAnalysisDropRef::rewrite(LLVM::CallOp sink, PatternRewriter &rewriter) const -{ auto caller = AsyncUtils::getCaller(sink); SmallVector sources; @@ -557,6 +550,7 @@ void LivenessAnalysisDropRef::rewrite(LLVM::CallOp sink, PatternRewriter &rewrit // NEVER CALL: // cleanupSource(annotatedCalls, rewriter); AsyncUtils::cleanupSink(sink, rewriter); + return success(); } // We now can cleanup the source @@ -926,8 +920,11 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns2), config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns2), config))) { signalPassFailure(); } @@ -947,7 +944,7 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns3), config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns3), config))) { signalPassFailure(); } @@ -957,7 +954,7 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns4), config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns4), config))) { signalPassFailure(); } @@ -967,7 +964,7 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns5), config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns5), config))) { signalPassFailure(); } } diff --git a/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp b/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp index 0ba0838277..21282ffd57 100644 --- a/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp +++ b/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp @@ -352,7 +352,7 @@ struct DetensorizeSCFPass : public impl::DetensorizeSCFPassBase(context); patterns.add(context); patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Catalyst/Transforms/GEPInboundsPass.cpp b/mlir/lib/Catalyst/Transforms/GEPInboundsPass.cpp index 2543bbfcf7..eee9575d3a 100644 --- a/mlir/lib/Catalyst/Transforms/GEPInboundsPass.cpp +++ b/mlir/lib/Catalyst/Transforms/GEPInboundsPass.cpp @@ -42,7 +42,7 @@ struct GEPInboundsPass : impl::GEPInboundsPassBase { RewritePatternSet patterns(&getContext()); populateGEPInboundsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } @@ -50,4 +50,4 @@ struct GEPInboundsPass : impl::GEPInboundsPassBase { std::unique_ptr createGEPInboundsPass() { return std::make_unique(); } -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp index 958857ae52..f540222b77 100644 --- a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp +++ b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp @@ -379,18 +379,19 @@ struct AnnotateWithFullyQualifiedNamePass void runOnOperation() override { MLIRContext *context = &getContext(); - - // Do not fold to save in compile time. GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; + // TODO: Update to the following lines the next time we update llvm + // config.setStrictness(GreedyRewriteStrictness::ExistingOps); + // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); RewritePatternSet annotate(context); auto root = getOperation(); auto parent = root->getParentOp(); annotate.add(context, parent); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(annotate), config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(annotate), config))) { signalPassFailure(); } } @@ -410,7 +411,9 @@ struct InlineNestedSymbolTablePass : PassWrapper(context, &symbolTables); bool run = _stopAfterStep >= 2 || _stopAfterStep == 0; - if (run && - failed(applyPatternsAndFoldGreedily(symbolTable, std::move(renameFunctions), config))) { + if (run && failed(applyPatternsGreedily(symbolTable, std::move(renameFunctions), config))) { signalPassFailure(); } RewritePatternSet inlineNested(context); inlineNested.add(context); run = _stopAfterStep >= 3 || _stopAfterStep == 0; - if (run && - failed(applyPatternsAndFoldGreedily(symbolTable, std::move(inlineNested), config))) { + if (run && failed(applyPatternsGreedily(symbolTable, std::move(inlineNested), config))) { signalPassFailure(); } @@ -467,15 +468,14 @@ struct InlineNestedSymbolTablePass : PassWrapper( context, &old_to_new); run = _stopAfterStep >= 4 || _stopAfterStep == 0; - if (run && - failed(applyPatternsAndFoldGreedily(symbolTable, std::move(nestedToFlat), config))) { + if (run && failed(applyPatternsGreedily(symbolTable, std::move(nestedToFlat), config))) { signalPassFailure(); } RewritePatternSet cleanup(context); cleanup.add(context); run = _stopAfterStep >= 5 || _stopAfterStep == 0; - if (run && failed(applyPatternsAndFoldGreedily(symbolTable, std::move(cleanup), config))) { + if (run && failed(applyPatternsGreedily(symbolTable, std::move(cleanup), config))) { signalPassFailure(); } } diff --git a/mlir/lib/Catalyst/Transforms/MemrefCopyToLinalgCopyPass.cpp b/mlir/lib/Catalyst/Transforms/MemrefCopyToLinalgCopyPass.cpp index 4eb5745bf5..c00b10f380 100644 --- a/mlir/lib/Catalyst/Transforms/MemrefCopyToLinalgCopyPass.cpp +++ b/mlir/lib/Catalyst/Transforms/MemrefCopyToLinalgCopyPass.cpp @@ -44,7 +44,7 @@ struct MemrefCopyToLinalgCopyPass RewritePatternSet patterns(&getContext()); populateMemrefCopyToLinalgCopyPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } @@ -55,4 +55,4 @@ std::unique_ptr createMemrefCopyToLinalgCopyPass() return std::make_unique(); } -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 56f11d9ff3..304eaf8cd5 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -29,6 +29,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createAnnotateFunctionPass); mlir::registerPass(catalyst::createApplyTransformSequencePass); mlir::registerPass(catalyst::createArrayListToMemRefPass); + mlir::registerPass(catalyst::createBufferDeallocationPass); mlir::registerPass(catalyst::createCatalystConversionPass); mlir::registerPass(catalyst::createCopyGlobalMemRefPass); mlir::registerPass(catalyst::createCommuteCliffordTPPRPass); diff --git a/mlir/lib/Catalyst/Transforms/TBAAPatterns.cpp b/mlir/lib/Catalyst/Transforms/TBAAPatterns.cpp index 0adedfd6b7..1dba88e5f0 100644 --- a/mlir/lib/Catalyst/Transforms/TBAAPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/TBAAPatterns.cpp @@ -12,20 +12,87 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/PatternMatch.h" +#include +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" #include "Catalyst/Transforms/TBAAUtils.h" using namespace mlir; namespace catalyst { + +bool isFromExtractAlignedPointerAsIndexOp(Operation *op) +{ + // Returns true if the memref.store/load ops deal with memrefs coming from a + // memref.extract_aligned_pointer_as_index op + + SetVector backwardSlice; + BackwardSliceOptions options; + // Upstream slice analysis fails if encounters a block argument in an op with + // more than one region + // https://github.com/llvm/llvm-project/blob/179d30f8c3fddd3c85056fd2b8e877a4a8513158/mlir/lib/Analysis/SliceAnalysis.cpp#L109 + options.omitBlockArguments = true; + getBackwardSlice(op, &backwardSlice, options); + bool found = std::find_if(backwardSlice.begin(), backwardSlice.end(), [](const Operation *op) { + return isa(op); + }) != backwardSlice.end(); + return found; +} + +bool isMemrefArgOfDeallocHelper(Operation *op) +{ + // Some memref.store and load ops are in the `dealloc_helper` function + // generated by the --buffer--deallocation--pipeline. + // `index` types on this function's arguments are pointers, not ints + // + // See https://mlir.llvm.org/docs/OwnershipBasedBufferDeallocation/#generic-lowering + // * The generated function takes two MemRefs of indices and three MemRefs of + // * booleans as arguments: + // * The first argument A should contain the result of the + // * extract_aligned_pointer_as_index operation applied to the MemRefs to be deallocated + // * The second argument B should contain the result of the + // * extract_aligned_pointer_as_index operation applied to the MemRefs to be retained + assert(isa(op) || isa(op)); + auto parentOp = dyn_cast(op->getParentOp()); + if (!parentOp) { + return false; + } + + auto str = parentOp.getName(); + llvm::StringRef dellocRef = "dealloc_helper"; + if (!(str.compare(dellocRef) == 0)) { + return false; + } + + Value mem; + if (isa(op)) { + mem = dyn_cast(op).getMemref(); + } + else if (isa(op)) { + mem = dyn_cast(op).getMemref(); + } + + auto funcArgs = parentOp.getCallableRegion()->getArguments(); + auto memPos = std::find(funcArgs.begin(), funcArgs.end(), mem); + if (std::distance(funcArgs.begin(), memPos) >= 2) { + // Only the first two arguments to `dealloc_helper` are the memrefs we are looking for + // Note that std::find returns end iterator if not found + // so distance >=2 also covers the case for not found + return false; + } + + return true; +} + void setTag(mlir::Type baseType, catalyst::TBAATree *tree, mlir::MLIRContext *ctx, mlir::LLVM::AliasAnalysisOpInterface newOp) { @@ -71,7 +138,13 @@ struct MemrefLoadTBAARewritePattern : public ConvertOpToLLVMPatternconvertType(type.getElementType()), dataPtr, 0, false, loadOp.getNontemporal()); - if (isAnyOf(baseType)) { + // Index can be used as a pointer. + if (isa(baseType) && + (isFromExtractAlignedPointerAsIndexOp(loadOp) || isMemrefArgOfDeallocHelper(loadOp))) { + mlir::LLVM::TBAATagAttr tag = tree->getTag("any pointer"); + op.setTBAATags(ArrayAttr::get(loadOp.getContext(), tag)); + } + else if (isAnyOf(baseType)) { setTag(baseType, tree, loadOp.getContext(), op); } else { @@ -102,7 +175,13 @@ struct MemrefStoreTBAARewritePattern : public ConvertOpToLLVMPattern(storeOp, adaptor.getValue(), dataPtr, 0, false, storeOp.getNontemporal()); - if (isAnyOf(baseType)) { + // Index can be used as a pointer. + if (isa(baseType) && (isFromExtractAlignedPointerAsIndexOp(storeOp) || + isMemrefArgOfDeallocHelper(storeOp))) { + mlir::LLVM::TBAATagAttr tag = tree->getTag("any pointer"); + op.setTBAATags(ArrayAttr::get(storeOp.getContext(), tag)); + } + else if (isAnyOf(baseType)) { setTag(baseType, tree, storeOp.getContext(), op); } else { diff --git a/mlir/lib/Catalyst/Transforms/TBAATagsPass.cpp b/mlir/lib/Catalyst/Transforms/TBAATagsPass.cpp index ae92ab9782..3741e78b99 100644 --- a/mlir/lib/Catalyst/Transforms/TBAATagsPass.cpp +++ b/mlir/lib/Catalyst/Transforms/TBAATagsPass.cpp @@ -12,20 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "Catalyst/Transforms/Passes.h" -#include "Catalyst/Transforms/TBAAUtils.h" - #include "Catalyst/Transforms/Patterns.h" +#include "Catalyst/Transforms/TBAAUtils.h" #include "Gradient/IR/GradientInterfaces.h" using namespace mlir; diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index daba674513..d427d49386 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -334,8 +334,10 @@ struct CustomCallOpPattern : public OpConversionPattern { ModuleOp mod = op->getParentOfType(); rewriter.setInsertionPointToStart(mod.getBody()); - LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn( - mod, op.getCallTargetName(), {/*args=*/ptr, /*rets=*/ptr}, /*ret_type=*/voidType); + LLVM::LLVMFuncOp customCallFnOp = + mlir::LLVM::lookupOrCreateFn(mod, op.getCallTargetName(), {/*args=*/ptr, /*rets=*/ptr}, + /*ret_type=*/voidType) + .value(); customCallFnOp.setPrivate(); rewriter.restoreInsertionPoint(point); @@ -435,15 +437,14 @@ struct CustomCallOpPattern : public OpConversionPattern { struct DefineCallbackOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult match(CallbackOp op) const override + LogicalResult matchAndRewrite(CallbackOp op, CallbackOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // Only match with ops without an entry block - return !op.empty() ? failure() : success(); - } + if (!op.empty()) { + return failure(); + } - void rewrite(CallbackOp op, CallbackOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { Block *entry; rewriter.modifyOpInPlace(op, [&] { entry = op.addEntryBlock(); }); PatternRewriter::InsertionGuard guard(rewriter); @@ -465,9 +466,11 @@ struct DefineCallbackOpPattern : public OpConversionPattern { bool isVarArg = true; ModuleOp mod = op->getParentOfType(); auto typeConverter = getTypeConverter(); - LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn( - mod, "__catalyst_inactive_callback", {/*args=*/i64, i64, i64}, - /*ret_type=*/voidType, isVarArg); + LLVM::LLVMFuncOp customCallFnOp = + mlir::LLVM::lookupOrCreateFn(mod, "__catalyst_inactive_callback", + {/*args=*/i64, i64, i64}, + /*ret_type=*/voidType, isVarArg) + .value(); SmallVector passthroughs; auto keyAttr = StringAttr::get(ctx, "nofree"); passthroughs.push_back(keyAttr); @@ -483,20 +486,21 @@ struct DefineCallbackOpPattern : public OpConversionPattern { } rewriter.create(loc, customCallFnOp, callArgs); rewriter.create(loc, TypeRange{}, ValueRange{}); + return success(); } }; struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult match(CallbackOp op) const override + LogicalResult matchAndRewrite(CallbackOp op, CallbackOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // Only match with ops with an entry block - return !op.empty() ? success() : failure(); - } - void rewrite(CallbackOp op, CallbackOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + if (op.empty()) { + return failure(); + } + ModuleOp mod = op->getParentOfType(); rewriter.setInsertionPointToStart(mod.getBody()); @@ -511,6 +515,7 @@ struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern { auto typeConverter = getTypeConverter(); gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc()); rewriter.eraseOp(op); + return success(); } }; @@ -541,7 +546,8 @@ struct CallbackCallOpPattern : public OpConversionPattern { struct CustomGradOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult match(gradient::CustomGradOp op) const override + LogicalResult matchAndRewrite(gradient::CustomGradOp op, gradient::CustomGradOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // only match after all three are func.func auto callee = op.getCalleeAttr(); @@ -552,22 +558,14 @@ struct CustomGradOpPattern : public OpConversionPattern auto forwardOp = mod.lookupSymbol(forward); auto reverseOp = mod.lookupSymbol(reverse); auto ready = calleeOp && forwardOp && reverseOp; - return ready ? success() : failure(); - } + if (!ready) { + return failure(); + } - void rewrite(gradient::CustomGradOp op, gradient::CustomGradOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { auto loc = op.getLoc(); - ModuleOp mod = op->getParentOfType(); - auto callee = op.getCalleeAttr(); - auto forward = op.getForwardAttr(); - auto reverse = op.getReverseAttr(); - auto calleeOp = mod.lookupSymbol(callee); - auto forwardOp = mod.lookupSymbol(forward); - auto reverseOp = mod.lookupSymbol(reverse); gradient::insertEnzymeCustomGradient(rewriter, mod, loc, calleeOp, forwardOp, reverseOp); rewriter.eraseOp(op); + return success(); } }; diff --git a/mlir/lib/Catalyst/Transforms/disable_assertion.cpp b/mlir/lib/Catalyst/Transforms/disable_assertion.cpp index e4648f3b22..318d992838 100644 --- a/mlir/lib/Catalyst/Transforms/disable_assertion.cpp +++ b/mlir/lib/Catalyst/Transforms/disable_assertion.cpp @@ -38,7 +38,7 @@ struct DisableAssertionPass : impl::DisableAssertionPassBase RewritePatternSet patterns(&getContext()); populateScatterPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/lib/Catalyst/Utils/CallGraph.cpp b/mlir/lib/Catalyst/Utils/CallGraph.cpp index 37dc62212c..569f711fc5 100644 --- a/mlir/lib/Catalyst/Utils/CallGraph.cpp +++ b/mlir/lib/Catalyst/Utils/CallGraph.cpp @@ -31,8 +31,8 @@ void traverseCallGraph(func::FuncOp start, SymbolTableCollection *symbolTable, processFunc(callable); callable.walk([&](CallOpInterface callOp) { - if (auto nextFunc = - dyn_cast_or_null(callOp.resolveCallable(symbolTable))) { + if (auto nextFunc = dyn_cast_or_null( + mlir::call_interface_impl::resolveCallable(callOp, symbolTable))) { if (!visited.contains(nextFunc)) { visited.insert(nextFunc); frontier.push_back(nextFunc); diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index 53b7d7f353..d72ef39ebb 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -765,7 +765,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & if (runLLC && (inType == InputType::LLVMIR)) { TimingScope llcTiming = timing.nest("llc"); // Set data layout before LLVM passes or the default one is used. - std::string targetTriple = llvm::sys::getDefaultTargetTriple(); + llvm::Triple targetTriple(llvm::sys::getDefaultTargetTriple()); llvm::InitializeAllTargetInfos(); llvm::InitializeAllTargets(); diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 8293a98c97..46f6905a66 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -73,16 +73,12 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addPass(catalyst::createGradientPreprocessingPass()); pm.addPass(mlir::bufferization::createEmptyTensorEliminationPass()); /////////// - mlir::bufferization::OneShotBufferizationOptions options; + mlir::bufferization::OneShotBufferizePassOptions options; options.bufferizeFunctionBoundaries = true; options.allowReturnAllocsFromLoops = true; - options.setFunctionBoundaryTypeConversion( - mlir::bufferization::LayoutMapOption::IdentityLayoutMap); - options.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, - const mlir::bufferization::BufferizationOptions &options) { - auto tensorType = cast(value.getType()); - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); - }; + options.functionBoundaryTypeConversion = + mlir::bufferization::LayoutMapOption::IdentityLayoutMap; + options.unknownTypeConversion = mlir::bufferization::LayoutMapOption::IdentityLayoutMap; pm.addPass(mlir::bufferization::createOneShotBufferizePass(options)); ////////////// pm.addPass(mlir::createCanonicalizerPass()); @@ -90,9 +86,10 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addNestedPass(mlir::bufferization::createBufferHoistingPass()); pm.addNestedPass(mlir::bufferization::createBufferLoopHoistingPass()); pm.addNestedPass(mlir::bufferization::createPromoteBuffersToStackPass()); - pm.addNestedPass(mlir::bufferization::createBufferDeallocationPass()); + // TODO: migrate to new buffer deallocation "buffer-deallocation-pipeline" + pm.addNestedPass(catalyst::createBufferDeallocationPass()); pm.addPass(catalyst::createArrayListToMemRefPass()); - pm.addPass(mlir::createBufferizationToMemRefPass()); + pm.addPass(mlir::createConvertBufferizationToMemRefPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createCopyGlobalMemRefPass()); } @@ -102,7 +99,7 @@ void createLLVMDialectLoweringPipeline(OpPassManager &pm) pm.addPass(catalyst::createGradientConversionPass()); pm.addPass(catalyst::createMemrefCopyToLinalgCopyPass()); pm.addNestedPass(mlir::createConvertLinalgToLoopsPass()); - pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createSCFToControlFlowPass()); pm.addPass(mlir::memref::createExpandStridedMetadataPass()); pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::arith::createArithExpandOpsPass()); diff --git a/mlir/lib/Gradient/IR/GradientOps.cpp b/mlir/lib/Gradient/IR/GradientOps.cpp index 237f264bc7..6b46eeb12f 100644 --- a/mlir/lib/Gradient/IR/GradientOps.cpp +++ b/mlir/lib/Gradient/IR/GradientOps.cpp @@ -130,7 +130,7 @@ CallInterfaceCallable GradOp::getCallableForCallee() { return getCalleeAttr(); } void GradOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); }; Operation::operand_range GradOp::getArgOperands() { return getOperands(); } @@ -187,7 +187,7 @@ CallInterfaceCallable ValueAndGradOp::getCallableForCallee() { return getCalleeA void ValueAndGradOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); }; Operation::operand_range ValueAndGradOp::getArgOperands() { return getOperands(); } @@ -260,7 +260,7 @@ CallInterfaceCallable JVPOp::getCallableForCallee() { return getCalleeAttr(); } void JVPOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); }; Operation::operand_range JVPOp::getArgOperands() { return getOperands(); } @@ -367,7 +367,7 @@ CallInterfaceCallable VJPOp::getCallableForCallee() { return getCalleeAttr(); } void VJPOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); }; Operation::operand_range VJPOp::getArgOperands() { return getOperands(); } diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp index a79a4b408e..e3a472945e 100644 --- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp @@ -211,10 +211,10 @@ struct AdjointOpPattern : public ConvertOpToLLVMPattern { return op.emitOpError("adjoint can only return MemRef or tuple thereof"); } - // The callee of the adjoint op must return as a single result the quantum register. + // The callee of the adjoint op must return 2 results: the quantum register and the expval. func::FuncOp callee = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); - assert(callee && callee.getNumResults() == 1 && "invalid qfunc symbol in adjoint op"); + assert(callee && callee.getNumResults() == 2 && "invalid qfunc symbol in adjoint op"); StringRef cacheFnName = "__catalyst__rt__toggle_recorder"; StringRef gradFnName = "__catalyst__qis__Gradient"; @@ -319,8 +319,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { LowerToLLVMOptions options = getTypeConverter()->getOptions(); if (options.useGenericFunctions) { LLVM::LLVMFuncOp allocFn = - LLVM::lookupOrCreateGenericAllocFn(moduleOp, getTypeConverter()->getIndexType()); - LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(moduleOp); + LLVM::lookupOrCreateGenericAllocFn(moduleOp, getTypeConverter()->getIndexType()) + .value(); + LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(moduleOp).value(); // Register the previous functions as llvm globals (for Enzyme) // With the following piece of metadata, shadow memory is allocated with @@ -824,10 +825,8 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { struct ForwardOpPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult match(ForwardOp op) const override { return success(); } - - void rewrite(ForwardOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(ForwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // convert all arguments to pointers... auto typeConverter = getTypeConverter(); @@ -865,16 +864,15 @@ struct ForwardOpPattern : public ConvertOpToLLVMPattern { rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc()); rewriter.eraseOp(op); + return success(); } }; struct ReverseOpPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult match(ReverseOp op) const override { return success(); } - - void rewrite(ReverseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(ReverseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto argc = op.getArgc(); auto resc = op.getResc(); @@ -991,21 +989,21 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern { catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc()); rewriter.eraseOp(op); + return success(); } }; struct ReturnOpPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult match(ReturnOp op) const override { return success(); } - - void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); if (op.getEmpty()) { auto returnOp = rewriter.create(loc, ValueRange{}); rewriter.replaceOp(op, returnOp); - return; + return success(); } auto tape = adaptor.getTape(); @@ -1016,7 +1014,7 @@ struct ReturnOpPattern : public ConvertOpToLLVMPattern { Value nullPtr = rewriter.create(loc, ptrType); auto returnOp = rewriter.create(loc, nullPtr); rewriter.replaceOp(op, returnOp); - return; + return success(); } SmallVector tapeStructVals(tape); @@ -1035,6 +1033,7 @@ struct ReturnOpPattern : public ConvertOpToLLVMPattern { auto returnOp = rewriter.create(loc, result); rewriter.replaceOp(op, returnOp); + return success(); } }; diff --git a/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp b/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp index 8e20015222..e25d144d55 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp @@ -12,32 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "Adjoint.hpp" - #include #include #include #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/IRMapping.h" #include "Gradient/IR/GradientOps.h" #include "Gradient/Utils/DifferentialQNode.h" #include "Gradient/Utils/GradientShape.h" #include "Quantum/IR/QuantumOps.h" +#include "Adjoint.hpp" + namespace catalyst { namespace gradient { -LogicalResult AdjointLowering::match(func::FuncOp op) const +LogicalResult AdjointLowering::matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const { - if (getQNodeDiffMethod(op) == "adjoint" && requiresCustomGradient(op)) - return success(); - - return failure(); -} + if (!(getQNodeDiffMethod(op) == "adjoint" && requiresCustomGradient(op))) { + return failure(); + } -void AdjointLowering::rewrite(func::FuncOp op, PatternRewriter &rewriter) const -{ Location loc = op.getLoc(); rewriter.setInsertionPointAfter(op); @@ -47,15 +44,45 @@ void AdjointLowering::rewrite(func::FuncOp op, PatternRewriter &rewriter) const // Register the quantum gradient on the quantum-only split-out QNode. registerCustomGradient(op, FlatSymbolRefAttr::get(qGradFn)); + return success(); } func::FuncOp AdjointLowering::discardAndReturnReg(PatternRewriter &rewriter, Location loc, func::FuncOp callee) { + // TODO: we do not support multiple return statements (which can happen for unstructured + // control flow), i.e. our gradient functions will have just one block. + assert(callee.getBody().hasOneBlock() && + "Gradients with unstructured control flow are not supported"); + + // Since the return value is guaranteed to be discarded, then let's change the return type + // to be only the quantum register and the expval. + // + // We also need to return the expval to avoid dead code elimination downstream from + // removing the expval op in the body. + // TODO: we only support grad on expval op for now SmallVector deallocs; - for (auto op : callee.getOps()) { - deallocs.push_back(op); - }; + SmallVector expvalOps; + SmallVector deviceReleaseOps; + for (Operation &op : callee.getBody().getOps()) { + if (isa(op)) { + deallocs.push_back(cast(op)); + continue; + } + else if (isa(op)) { + if (isa(op)) { + expvalOps.push_back(cast(op)); + continue; + } + else { + callee.emitOpError() << "Adjoint gradient is only supported on expval measurements"; + return callee; + } + } + else if (isa(op)) { + deviceReleaseOps.push_back(cast(op)); + } + } // If there are no deallocs leave early then this transformation // is invalid. This is because the caller will expect a quantum register @@ -67,11 +94,20 @@ func::FuncOp AdjointLowering::discardAndReturnReg(PatternRewriter &rewriter, Loc return callee; } - // Since the return value is guaranteed to be discarded, then let's change the return type - // to be only the quantum register. + size_t numDeviceReleases = deviceReleaseOps.size(); + if (numDeviceReleases > 1) { + callee.emitOpError() << "Invalid number of device release ops: " << numDeviceReleases; + return callee; + } + + // Create clone, return type is qreg and float for the expvals std::string fnName = callee.getName().str() + ".nodealloc"; Type qregType = quantum::QuregType::get(rewriter.getContext()); - FunctionType fnType = rewriter.getFunctionType(callee.getArgumentTypes(), qregType); + Type f64Type = rewriter.getF64Type(); + SmallVector retTypes{qregType}; + std::for_each(expvalOps.begin(), expvalOps.end(), + [&](const quantum::ExpvalOp &) { retTypes.push_back(f64Type); }); + FunctionType fnType = rewriter.getFunctionType(callee.getArgumentTypes(), retTypes); StringAttr visibility = rewriter.getStringAttr("private"); func::FuncOp unallocFn = @@ -82,25 +118,32 @@ func::FuncOp AdjointLowering::discardAndReturnReg(PatternRewriter &rewriter, Loc rewriter.setInsertionPointAfter(callee); unallocFn = rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); - // clone the body. - rewriter.cloneRegionBefore(callee.getBody(), unallocFn.getBody(), unallocFn.end()); - rewriter.setInsertionPointToStart(&unallocFn.getBody().front()); - // Let's capture the qreg. - quantum::DeallocOp localDealloc = *unallocFn.getOps().begin(); + // Clone the body. + IRMapping mapper; + rewriter.cloneRegionBefore(callee.getBody(), unallocFn.getBody(), unallocFn.end(), mapper); + rewriter.setInsertionPointToStart(&unallocFn.getBody().front()); - // Let's return the qreg and erase the device release. - unallocFn.walk([&](Operation *op) { - if (isa(op)) { - rewriter.eraseOp(op); - } - else if (isa(op)) { - op->setOperands(localDealloc.getOperand()); - } + // Let's return the qreg+expval and erase the device release. + // Fine for now: only one block in body so only one dealloc and one expval + SmallVector returnVals{mapper.lookup(deallocs[0])->getOperand(0)}; + std::for_each(expvalOps.begin(), expvalOps.end(), [&](const quantum::ExpvalOp &expval) { + returnVals.push_back(mapper.lookup(expval)); }); + // Create the return + // Again, assume just one block for now + Operation *returnOp = unallocFn.getBody().front().getTerminator(); + assert(isa(returnOp) && "adjoint block must terminate with return op"); + returnOp->setOperands(returnVals); + + // Erase the device release. + for (auto op : deviceReleaseOps) { + rewriter.eraseOp(mapper.lookup(op)); + } + // Let's erase the deallocation. - rewriter.eraseOp(localDealloc); + rewriter.eraseOp(mapper.lookup(deallocs[0])); } return unallocFn; diff --git a/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.hpp b/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.hpp index 2b3b42c0eb..d986c08c8d 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.hpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.hpp @@ -25,8 +25,7 @@ namespace gradient { struct AdjointLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(func::FuncOp op) const override; - void rewrite(func::FuncOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const override; private: static func::FuncOp genQGradFunction(PatternRewriter &rewriter, Location loc, diff --git a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp index 20c59d4e0c..66fbcee25d 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp @@ -30,17 +30,12 @@ namespace catalyst { namespace gradient { -LogicalResult FiniteDiffLowering::match(GradOp op) const +LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter) const { - if (op.getMethod() == "fd") { - return success(); + if (op.getMethod() != "fd") { + return failure(); } - return failure(); -} - -void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter) const -{ Location loc = op.getLoc(); const std::vector &diffArgIndices = computeDiffArgIndices(op.getDiffArgIndices()); std::stringstream uniquer; @@ -67,6 +62,7 @@ void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter) const } rewriter.replaceOpWithNewOp(op, gradFn, op.getArgOperands()); + return success(); } void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location loc, diff --git a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.hpp b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.hpp index f5547b00c2..79be485ec7 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.hpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.hpp @@ -29,8 +29,7 @@ namespace gradient { struct FiniteDiffLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(GradOp op) const override; - void rewrite(GradOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(GradOp op, PatternRewriter &rewriter) const override; private: static void computeFiniteDiff(PatternRewriter &rewriter, Location loc, func::FuncOp gradFn, diff --git a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp index 7844fd3e76..672819d3e7 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp @@ -97,7 +97,8 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew auto gradOp = rewriter.create(loc, grad_result_types, op.getMethod(), op.getCallee(), calleeOperands, op.getDiffArgIndicesAttr(), - op.getFiniteDiffParamAttr()); + op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr); std::vector einsumResults; for (size_t nout = 0; nout < funcResultTypes.size(); nout++) { @@ -219,7 +220,8 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew auto gradOp = rewriter.create(loc, grad_result_types, op.getMethod(), op.getCallee(), calleeOperands, op.getDiffArgIndicesAttr(), - op.getFiniteDiffParamAttr()); + op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr); std::vector einsumResults; for (size_t nparam = 0; nparam < func_diff_operand_indices.size(); nparam++) { diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.cpp index 77305c7db4..6f955b1cde 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.cpp @@ -23,16 +23,13 @@ namespace catalyst { namespace gradient { -LogicalResult ParameterShiftLowering::match(func::FuncOp op) const +LogicalResult ParameterShiftLowering::matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const { - if (getQNodeDiffMethod(op) == "parameter-shift" && requiresCustomGradient(op)) { - return success(); + if (!(getQNodeDiffMethod(op) == "parameter-shift" && requiresCustomGradient(op))) { + return failure(); } - return failure(); -} -void ParameterShiftLowering::rewrite(func::FuncOp op, PatternRewriter &rewriter) const -{ Location loc = op.getLoc(); rewriter.setInsertionPointAfter(op); @@ -51,6 +48,7 @@ void ParameterShiftLowering::rewrite(func::FuncOp op, PatternRewriter &rewriter) // Register the quantum gradient on the quantum-only split-out QNode. registerCustomGradient(op, FlatSymbolRefAttr::get(qGradFn)); + return success(); } std::pair ParameterShiftLowering::analyzeFunction(func::FuncOp callee) diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.hpp b/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.hpp index b3f2a639de..d267e42886 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.hpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/ParameterShift.hpp @@ -29,8 +29,7 @@ namespace gradient { struct ParameterShiftLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(func::FuncOp op) const override; - void rewrite(func::FuncOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const override; private: static std::pair analyzeFunction(func::FuncOp callee); diff --git a/mlir/lib/Gradient/Transforms/gradient_lowering.cpp b/mlir/lib/Gradient/Transforms/gradient_lowering.cpp index 4ddbbde756..aa57621a5b 100644 --- a/mlir/lib/Gradient/Transforms/gradient_lowering.cpp +++ b/mlir/lib/Gradient/Transforms/gradient_lowering.cpp @@ -59,7 +59,7 @@ struct GradientLoweringPass : impl::GradientLoweringPassBase createGradientPostprocessingPass() return std::make_unique(); } -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp b/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp index a0a003991c..10cca45df1 100644 --- a/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp +++ b/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp @@ -42,7 +42,7 @@ struct GradientPreprocessingPass : impl::GradientPreprocessingPassBase createGradientPreprocessingPass() return std::make_unique(); } -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Mitigation/IR/MitigationOps.cpp b/mlir/lib/Mitigation/IR/MitigationOps.cpp index 79801d21f7..3f42015240 100644 --- a/mlir/lib/Mitigation/IR/MitigationOps.cpp +++ b/mlir/lib/Mitigation/IR/MitigationOps.cpp @@ -40,7 +40,7 @@ CallInterfaceCallable ZneOp::getCallableForCallee() { return getCalleeAttr(); } void ZneOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); }; Operation::operand_range ZneOp::getArgOperands() { return getOperands(); } diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index c6c7cf1782..616ab66d98 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -72,11 +72,9 @@ func::FuncOp createZneFunc(func::FuncOp funcOp, PatternRewriter &rewriter) return fnFoldedOp; } -LogicalResult ZneLowering::match(mitigation::ZneOp op) const { return success(); } - // TODO: Optimize the traversal of call graphs (currently used twice) // Also all functions exploree in the call graph get their ZNE version. -void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const +LogicalResult ZneLowering::matchAndRewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); auto moduleOp = op->getParentOfType(); @@ -226,6 +224,8 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const .getResult(0); // Replace the original results rewriter.replaceOp(op, resultValues); + + return success(); } // In *.cpp module only, to keep extraneous headers out of *.hpp diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp index 6a6bcdd6c7..d01109570b 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp @@ -28,8 +28,7 @@ namespace mitigation { struct ZneLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(mitigation::ZneOp op) const override; - void rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const override; private: static FlatSymbolRefAttr getOrInsertFoldedCircuit(Location loc, PatternRewriter &builder, diff --git a/mlir/lib/Mitigation/Transforms/mitigation_lowering.cpp b/mlir/lib/Mitigation/Transforms/mitigation_lowering.cpp index 425a2bd295..ac78e6b888 100644 --- a/mlir/lib/Mitigation/Transforms/mitigation_lowering.cpp +++ b/mlir/lib/Mitigation/Transforms/mitigation_lowering.cpp @@ -44,7 +44,7 @@ struct MitigationLoweringPass : impl::MitigationLoweringPassBase { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); ModuleOp moduleOp = op->getParentOfType(); auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetState", - {ptrTy, i64}, voidTy, isVarArg); + {ptrTy, i64}, voidTy, isVarArg) + .value(); SmallVector args; @@ -958,7 +959,8 @@ struct SetBasisStateOpPattern : public OpConversionPattern { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); ModuleOp moduleOp = op->getParentOfType(); auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetBasisState", - {ptrTy, i64}, voidTy, isVarArg); + {ptrTy, i64}, voidTy, isVarArg) + .value(); SmallVector args; diff --git a/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp b/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp index d3b12718f2..b13b6c3904 100644 --- a/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/adjoint_lowering.cpp @@ -57,7 +57,7 @@ struct AdjointLoweringPass : impl::AdjointLoweringPassBase RewritePatternSet patterns(&getContext()); populateAdjointPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/lib/Quantum/Transforms/annotate_function.cpp b/mlir/lib/Quantum/Transforms/annotate_function.cpp index 621e9d0a73..d479034ea0 100644 --- a/mlir/lib/Quantum/Transforms/annotate_function.cpp +++ b/mlir/lib/Quantum/Transforms/annotate_function.cpp @@ -83,18 +83,18 @@ void annotate(FunctionOpInterface op, PatternRewriter &rewriter, const char *att struct AnnotateFunctionPattern : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult match(FunctionOpInterface op) const override; - void rewrite(FunctionOpInterface op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(FunctionOpInterface op, PatternRewriter &rewriter) const override; }; -LogicalResult AnnotateFunctionPattern::match(FunctionOpInterface op) const +LogicalResult AnnotateFunctionPattern::matchAndRewrite(FunctionOpInterface op, + PatternRewriter &rewriter) const { - return successfulMatchLeaf(op) ? success() : failure(); -} + if (!successfulMatchLeaf(op)) { + return failure(); + } -void AnnotateFunctionPattern::rewrite(FunctionOpInterface op, PatternRewriter &rewriter) const -{ annotate(op, rewriter, hasInvalidGradientOp); + return success(); } std::optional getFuncOp(const CallGraphNode *node, CallGraph &cg) @@ -156,21 +156,21 @@ struct PropagateAnnotationPattern : public OpInterfaceRewritePattern(context); patterns.add(context, cg); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp index b461dc8a60..28ccb09c3e 100644 --- a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp +++ b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp @@ -166,23 +166,23 @@ void applyCopyGlobalMemRefTransform(func::FuncOp op, PatternRewriter &rewriter) struct CopyGlobalMemRefTransform : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(func::FuncOp op) const override; - void rewrite(func::FuncOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const override; }; -LogicalResult CopyGlobalMemRefTransform::match(func::FuncOp op) const +LogicalResult CopyGlobalMemRefTransform::matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const { bool isCandidate = hasCWrapperButNoCopyWrapperAttribute(op); - if (!isCandidate) + if (!isCandidate) { return failure(); + } + if (!hasMemRefReturnTypes(op)) { + return failure(); + } - return hasMemRefReturnTypes(op) ? success() : failure(); -} - -void CopyGlobalMemRefTransform::rewrite(func::FuncOp op, PatternRewriter &rewriter) const -{ setCopyWrapperAttribute(op, rewriter); applyCopyGlobalMemRefTransform(op, rewriter); + return success(); } } // namespace @@ -201,7 +201,7 @@ struct CopyGlobalMemRefPass : impl::CopyGlobalMemRefPassBase(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp index 386023eb04..4c91fbf67c 100644 --- a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp +++ b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp @@ -173,17 +173,16 @@ void wrapResultsAndArgsInTwoStructs(LLVM::LLVMFuncOp op, PatternRewriter &rewrit struct EmitCatalystPyInterfaceTransform : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(LLVM::LLVMFuncOp op) const override; - void rewrite(LLVM::LLVMFuncOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(LLVM::LLVMFuncOp op, PatternRewriter &rewriter) const override; }; -LogicalResult EmitCatalystPyInterfaceTransform::match(LLVM::LLVMFuncOp op) const +LogicalResult EmitCatalystPyInterfaceTransform::matchAndRewrite(LLVM::LLVMFuncOp op, + PatternRewriter &rewriter) const { - return isFunctionMLIRCWrapper(op) ? success() : failure(); -} + if (!isFunctionMLIRCWrapper(op)) { + return failure(); + } -void EmitCatalystPyInterfaceTransform::rewrite(LLVM::LLVMFuncOp op, PatternRewriter &rewriter) const -{ // Find substr after _mlir_ciface_ std::string _mlir_ciface = "_mlir_ciface_"; size_t _mlir_ciface_len = _mlir_ciface.length(); @@ -194,6 +193,7 @@ void EmitCatalystPyInterfaceTransform::rewrite(LLVM::LLVMFuncOp op, PatternRewri rewriter.modifyOpInPlace(op, [&] { op.setSymName(newName); }); wrapResultsAndArgsInTwoStructs(op, rewriter, functionNameWithoutPrefix); + return success(); } } // namespace @@ -212,16 +212,21 @@ struct EmitCatalystPyInterfacePass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); + GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; config.maxIterations = 1; + // TODO: Update to the following lines the next time we update llvm + // config.setStrictness(GreedyRewriteStrictness::ExistingOps); + // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); + // config.setMaxIterations(1); auto op = getOperation(); SmallVector targets; op->walk([&](LLVM::LLVMFuncOp func) { targets.push_back(func); }); - if (failed(applyOpPatternsAndFold(targets, std::move(patterns), config))) { + if (failed(applyOpPatternsGreedily(targets, std::move(patterns), config))) { signalPassFailure(); } } diff --git a/mlir/lib/Quantum/Transforms/ions_decompositions.cpp b/mlir/lib/Quantum/Transforms/ions_decompositions.cpp index c4d833816e..e40bb0f1b5 100644 --- a/mlir/lib/Quantum/Transforms/ions_decompositions.cpp +++ b/mlir/lib/Quantum/Transforms/ions_decompositions.cpp @@ -46,12 +46,12 @@ struct IonsDecompositionPass : impl::IonsDecompositionPassBase { &getContext()); catalyst::quantum::MultiRZOp::getCanonicalizationPatterns(patternsCanonicalization, &getContext()); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patternsCanonicalization)))) { + if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) { return signalPassFailure(); } @@ -60,7 +60,7 @@ struct MergeRotationsPass : impl::MergeRotationsPassBase { populateLoopBoundaryPatterns(patterns, 1); populateMergeRotationsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/lib/Quantum/Transforms/quantum_to_llvm.cpp b/mlir/lib/Quantum/Transforms/quantum_to_llvm.cpp index 9b85840467..e42ebd48d4 100644 --- a/mlir/lib/Quantum/Transforms/quantum_to_llvm.cpp +++ b/mlir/lib/Quantum/Transforms/quantum_to_llvm.cpp @@ -80,6 +80,7 @@ struct QuantumConversionPass : impl::QuantumConversionPassBase= 21 +- case Intrinsic::nvvm_fabs: +- case Intrinsic::nvvm_fabs_ftz: +-#else +- case Intrinsic::nvvm_fabs_f: +- case Intrinsic::nvvm_fabs_d: +- case Intrinsic::nvvm_fabs_ftz_f: +-#endif + case Intrinsic::fabs: + // No direction check as always valid + updateAnalysis( diff --git a/mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch b/mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch deleted file mode 100644 index acb1f0e834..0000000000 --- a/mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch +++ /dev/null @@ -1,24 +0,0 @@ -From 431da821402484bfa128d0559788c86e33f14dc2 Mon Sep 17 00:00:00 2001 -From: Tzung-Han Juang -Date: Fri, 26 Jul 2024 10:39:45 -0400 -Subject: [PATCH 1/1] Add PassesIncGen in transforms CMakeList - ---- - mhlo/transforms/CMakeLists.txt | 1 + - 1 file changed, 1 insertion(+) - -diff --git a/mhlo/transforms/CMakeLists.txt b/mhlo/transforms/CMakeLists.txt -index ecec1370..9e80604f 100644 ---- a/mhlo/transforms/CMakeLists.txt -+++ b/mhlo/transforms/CMakeLists.txt -@@ -195,6 +195,7 @@ add_mlir_library(ChloPasses - MLIRhlo_opsIncGen - MLIRChloLegalizeToHloIncGen - MLIRMhloPassIncGen -+ PassesIncGen - - LINK_COMPONENTS - Core --- -2.34.1 - diff --git a/mlir/patches/mhlo-add-back-necessary-passes.patch b/mlir/patches/mhlo-add-back-necessary-passes.patch new file mode 100644 index 0000000000..b56ede8dd5 --- /dev/null +++ b/mlir/patches/mhlo-add-back-necessary-passes.patch @@ -0,0 +1,1317 @@ +From b1728b65b1511cd5ef3e11650b9e416d3fad068f Mon Sep 17 00:00:00 2001 +From: paul0403 +Date: Thu, 29 May 2025 11:00:56 -0400 +Subject: [PATCH] restore the removed mhlo passes we need: + mhlo-legalize-control-flow, mhlo-legalize-to-std, hlo-legalize-sort + +--- + mhlo/transforms/CMakeLists.txt | 6 + + .../legalize_control_flow.cc | 288 +++++++++ + .../transforms/legalize_sort/legalize_sort.cc | 577 ++++++++++++++++++ + .../legalize_to_standard.cc | 243 ++++++++ + .../legalize_to_standard_patterns.td | 92 +++ + mhlo/transforms/mhlo_passes.td | 19 + + mhlo/transforms/passes.h | 4 + + 7 files changed, 1229 insertions(+) + create mode 100644 mhlo/transforms/legalize_control_flow/legalize_control_flow.cc + create mode 100644 mhlo/transforms/legalize_sort/legalize_sort.cc + create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard.cc + create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td + +diff --git a/mhlo/transforms/CMakeLists.txt b/mhlo/transforms/CMakeLists.txt +index d6848633..26d3b419 100644 +--- a/mhlo/transforms/CMakeLists.txt ++++ b/mhlo/transforms/CMakeLists.txt +@@ -26,14 +26,20 @@ set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.t + mlir_tablegen(chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc -gen-rewriters) + add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) + ++set(LLVM_TARGET_DEFINITIONS legalize_to_standard/legalize_to_standard_patterns.td) ++mlir_tablegen(legalize_to_standard/generated_legalize_to_standard.inc -gen-rewriters) ++add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) + + + add_mlir_library(MhloPasses + collapse_elementwise_map/collapse_elementwise_map.cc + convert_to_signless/convert_to_signless_pass.cc + expand_hlo_tuples/expand_hlo_tuples.cc ++ legalize_control_flow/legalize_control_flow.cc + legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc + legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc ++ legalize_sort/legalize_sort.cc ++ legalize_to_standard/legalize_to_standard.cc + legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc + legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc + materialize_broadcasts/materialize_broadcasts.cc +diff --git a/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc b/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc +new file mode 100644 +index 00000000..9d473b9a +--- /dev/null ++++ b/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc +@@ -0,0 +1,288 @@ ++/* Copyright 2019 The OpenXLA Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++// This file implements logic for lowering MHLO dialect to SCF dialect. ++#include ++#include ++#include ++ ++#include "llvm/Support/Casting.h" ++#include "mhlo/IR/hlo_ops.h" ++#include "mhlo/transforms/passes.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" ++#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project ++#include "mlir/IR/Block.h" ++#include "mlir/IR/Builders.h" ++#include "mlir/IR/BuiltinTypes.h" ++#include "mlir/IR/Diagnostics.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/TypeUtilities.h" ++#include "mlir/Pass/Pass.h" ++#include "mlir/Support/LLVM.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/DialectConversion.h" ++ ++namespace mlir { ++namespace mhlo { ++ ++#define GEN_PASS_DEF_LEGALIZECONTROLFLOWPASS ++#include "mhlo/transforms/mhlo_passes.h.inc" ++ ++namespace { ++ ++// All transformations in this file take mhlo blocks which end with ++// mhlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an ++// entire block with the only change being return -> yield. ++void inlineMhloRegionIntoSCFRegion(PatternRewriter& rewriter, Region& mhlo, ++ Region& scf) { ++ // Remove an existing block, then move the region over. ++ if (!scf.empty()) rewriter.eraseBlock(&scf.back()); ++ rewriter.inlineRegionBefore(mhlo, scf, scf.end()); ++ // Fix up the terminator. ++ PatternRewriter::InsertionGuard guard(rewriter); ++ rewriter.setInsertionPointToEnd(&scf.back()); ++ auto* terminator = scf.back().getTerminator(); ++ rewriter.replaceOpWithNewOp(terminator, ++ terminator->getOperands()); ++} ++ ++// mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor ++// or a 1 element tensor. To handle this, collapse shape before extracting the ++// scalar value when necessary. ++Value extractTensorValue(OpBuilder& b, Value tensor) { ++ auto loc = tensor.getLoc(); ++ if (mlir::cast(tensor.getType()).hasRank() && ++ mlir::cast(tensor.getType()).getRank() != 0) { ++ tensor = b.create( ++ loc, tensor, SmallVector()); ++ } ++ return b.create(loc, tensor, ValueRange()); ++} ++ ++struct ScfForBounds { ++ Value lb; ++ Value ub; ++ Value step; ++ unsigned indexArgIndex; ++}; ++ ++std::optional extractForBounds(mhlo::WhileOp op) { ++ auto& cond = op.getCond().front(); ++ auto& body = op.getBody().front(); ++ if (cond.getOperations().size() != 2) return std::nullopt; ++ ++ auto matchBbArg = [](Value v, Block& block) -> std::optional { ++ if (!mlir::isa(v) || v.getParentBlock() != &block) ++ return std::nullopt; ++ return mlir::cast(v).getArgNumber(); ++ }; ++ ++ auto compare = llvm::dyn_cast(cond.front()); ++ // If the rhs of the comapare is defined outside the block, it's a constant ++ // within the loop. ++ if (!compare || ++ compare.getComparisonDirection() != mhlo::ComparisonDirection::LT || ++ compare.getRhs().getParentBlock() == &cond || ++ !getElementTypeOrSelf(compare.getLhs().getType()) ++ .isSignlessIntOrIndex()) { ++ return std::nullopt; ++ } ++ ++ auto iterArg = matchBbArg(compare.getLhs(), cond); ++ if (!iterArg) return std::nullopt; ++ ++ auto add = llvm::dyn_cast_or_null( ++ body.getTerminator()->getOperand(*iterArg).getDefiningOp()); ++ if (!add || matchBbArg(add.getLhs(), body) != iterArg || ++ add.getRhs().getParentBlock() == &body) { ++ return std::nullopt; ++ } ++ ++ ScfForBounds bounds; ++ bounds.ub = compare.getRhs(); ++ bounds.step = add.getRhs(); ++ bounds.lb = op->getOperand(*iterArg); ++ bounds.indexArgIndex = *iterArg; ++ return bounds; ++} ++ ++// Rewrites `mhlo.while` to `scf.while` or `scf.for`. ++struct WhileOpPattern : public OpConversionPattern { ++ using OpConversionPattern::OpConversionPattern; ++ ++ LogicalResult matchAndRewrite( ++ mhlo::WhileOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter& rewriter) const override { ++ auto loc = op.getLoc(); ++ ++ if (auto bounds = extractForBounds(op)) { ++ auto newForOp = rewriter.create( ++ loc, extractTensorValue(rewriter, bounds->lb), ++ extractTensorValue(rewriter, bounds->ub), ++ extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); ++ ++ rewriter.setInsertionPointToEnd(newForOp.getBody()); ++ // Inline while body, and only replace the mhlo.return with an scf.yield. ++ inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), ++ newForOp.getRegion()); ++ auto indexArg = newForOp.getRegion().insertArgument( ++ unsigned{0}, newForOp.getLowerBound().getType(), loc); ++ auto oldIndexArg = ++ newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); ++ rewriter.setInsertionPointToStart(&newForOp.getRegion().front()); ++ auto indexArgTensor = rewriter.create( ++ loc, oldIndexArg.getType(), indexArg); ++ oldIndexArg.replaceAllUsesWith(indexArgTensor); ++ ++ rewriter.replaceOp(op, newForOp.getResults()); ++ return success(); ++ } ++ ++ auto newWhileOp = rewriter.create(loc, op.getResultTypes(), ++ adaptor.getOperands()); ++ ++ // Inline while condition. The block is the same, except the boolean result ++ // needs to be extracted and used with an scf.condition. ++ rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), ++ newWhileOp.getBefore().end()); ++ auto conditionReturn = ++ cast(newWhileOp.getBefore().front().getTerminator()); ++ rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); ++ Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); ++ rewriter.replaceOpWithNewOp( ++ conditionReturn, i1, newWhileOp.getBeforeArguments()); ++ ++ // Inline while body, and only replace the mhlo.return with an scf.yield. ++ inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), ++ newWhileOp.getAfter()); ++ ++ rewriter.replaceOp(op, newWhileOp.getResults()); ++ return success(); ++ } ++}; ++ ++// Rewrites `mhlo.if` to `scf.if`. ++struct IfOpPattern : public OpConversionPattern { ++ using OpConversionPattern::OpConversionPattern; ++ ++ LogicalResult matchAndRewrite( ++ mhlo::IfOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter& rewriter) const override { ++ auto scfIf = rewriter.create( ++ op.getLoc(), op.getResultTypes(), ++ extractTensorValue(rewriter, adaptor.getPred()), ++ /*withElseRegion=*/true); ++ inlineMhloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), ++ scfIf.getThenRegion()); ++ inlineMhloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), ++ scfIf.getElseRegion()); ++ rewriter.replaceOp(op, scfIf.getResults()); ++ return success(); ++ } ++}; ++ ++// Rewrites `mhlo.case` to a nested `scf.if`. ++struct CaseOpPattern : public OpConversionPattern { ++ using OpConversionPattern::OpConversionPattern; ++ ++ // Recursively create if/else ops to handle each possible value in a case op. ++ scf::IfOp createNestedCases(int currentIdx, CaseOp op, OpAdaptor adaptor, ++ PatternRewriter& outerBuilder) const { ++ Location loc = op.getLoc(); ++ Value idxValue = adaptor.getIndex(); ++ auto finalIdx = op.getBranches().size() - 2; ++ ++ // Determine if the current index matches the case index. ++ auto scalarType = idxValue.getType(); ++ auto shapedType = mlir::cast(scalarType); ++ auto constAttr = DenseElementsAttr::get( ++ shapedType, {mlir::cast( ++ outerBuilder.getI32IntegerAttr(currentIdx))}); ++ Value currentIdxVal = outerBuilder.create( ++ loc, idxValue.getType(), constAttr); ++ ++ auto scfIf = outerBuilder.create( ++ loc, op.getResultTypes(), ++ extractTensorValue(outerBuilder, outerBuilder.create( ++ loc, idxValue, currentIdxVal, ++ ComparisonDirection::EQ)), ++ /*withElseRegion=*/true); ++ inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], ++ scfIf.getThenRegion()); ++ int nextIdx = currentIdx + 1; ++ // Don't recurse for the final default block. ++ if (currentIdx == static_cast(finalIdx)) { ++ inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], ++ scfIf.getElseRegion()); ++ } else { ++ PatternRewriter::InsertionGuard guard(outerBuilder); ++ outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back()); ++ auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder); ++ outerBuilder.create(op.getLoc(), innerIf.getResults()); ++ } ++ return scfIf; ++ } ++ ++ LogicalResult matchAndRewrite( ++ mhlo::CaseOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter& rewriter) const override { ++ // Inline the op if there is only a default block. ++ if (op.getBranches().size() == 1) { ++ Block& block = op.getBranches().front().front(); ++ auto results = block.getTerminator()->getOperands(); ++ // Remove the mhlo.return terminator, then inline the block. ++ rewriter.eraseOp(block.getTerminator()); ++ rewriter.inlineBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(), ++ /*argValues=*/{}); ++ rewriter.replaceOp(op, results); ++ return success(); ++ } ++ ++ // Begin recursion with case 0. ++ rewriter.replaceOp( ++ op, createNestedCases(0, op, adaptor, rewriter).getResults()); ++ return success(); ++ } ++}; ++ ++struct LegalizeControlFlowPass ++ : public impl::LegalizeControlFlowPassBase { ++ // Perform the lowering to MLIR control flow. ++ void runOnOperation() override { ++ func::FuncOp f = getOperation(); ++ MLIRContext* ctx = f.getContext(); ++ ++ RewritePatternSet patterns(&getContext()); ++ patterns.add(&getContext()); ++ ++ mlir::ConversionTarget target(*ctx); ++ target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); ++ target.addIllegalOp(); ++ ++ if (failed(applyPartialConversion(f, target, std::move(patterns)))) { ++ signalPassFailure(); ++ } ++ } ++}; ++ ++} // namespace ++} // namespace mhlo ++} // namespace mlir ++ ++std::unique_ptr> ++mlir::mhlo::createLegalizeControlFlowPass() { ++ return std::make_unique(); ++} +diff --git a/mhlo/transforms/legalize_sort/legalize_sort.cc b/mhlo/transforms/legalize_sort/legalize_sort.cc +new file mode 100644 +index 00000000..8ba9de9a +--- /dev/null ++++ b/mhlo/transforms/legalize_sort/legalize_sort.cc +@@ -0,0 +1,577 @@ ++/* Copyright 2019 The OpenXLA Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++// This file implements logic for lowering mhlo.sort to the SCF dialect. ++#include ++#include ++#include ++ ++#include "llvm/ADT/STLExtras.h" ++#include "mhlo/IR/hlo_ops.h" ++#include "mhlo/transforms/passes.h" ++#include "mlir/Dialect/Arith/IR/Arith.h" ++#include "mlir/Dialect/Arith/Utils/Utils.h" ++#include "mlir/Dialect/Bufferization/IR/Bufferization.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/Dialect/MemRef/IR/MemRef.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" ++#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project ++#include "mlir/IR/Block.h" ++#include "mlir/IR/Builders.h" ++#include "mlir/IR/BuiltinTypes.h" ++#include "mlir/IR/IRMapping.h" ++#include "mlir/IR/ImplicitLocOpBuilder.h" ++#include "mlir/IR/Location.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/TypeRange.h" ++#include "mlir/IR/ValueRange.h" ++#include "mlir/Pass/Pass.h" ++#include "mlir/Support/LLVM.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/DialectConversion.h" ++ ++namespace mlir { ++namespace mhlo { ++ ++#define GEN_PASS_DEF_HLOLEGALIZESORTPASS ++#include "mhlo/transforms/mhlo_passes.h.inc" ++ ++namespace { ++ ++using ::mlir::arith::AddIOp; ++using ::mlir::arith::MinSIOp; ++using ::mlir::arith::SelectOp; ++ ++constexpr int64_t kInsertionSortSize = 16; ++ ++// Inlines the `comparator` region (without terminator) at the current insertion ++// point, replacing the arguments with the given values from `lhs` and `rhs`. ++Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, ++ SmallVector& rhs, Region& comparator) { ++ assert(comparator.hasOneBlock() && "Comparator must have only one block."); ++ Block& block = comparator.front(); ++ assert(block.getTerminator()->getOperands().size() == 1 && ++ "Comparator must return a single value"); ++ ++ IRMapping mapping; ++ for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { ++ Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; ++ Type type = RankedTensorType::get({}, value.getType()); ++ mapping.map(arg, b.create(type, value)); ++ } ++ ++ for (Operation& op : block.without_terminator()) b.clone(op, mapping); ++ Value result = mapping.lookup(block.getTerminator()->getOperands().front()); ++ ++ return b.create(result, ValueRange()); ++} ++ ++// Emits a binary search of `pivots` in `arrayMemrefs` (all rank 1) in the range ++// [`left`;`right`). `arrayMemrefs` must be sorted according to `comparator`. ++Value emitBinarySearch(ImplicitLocOpBuilder& b, Value leftInit, Value rightInit, ++ SmallVector& pivots, ValueRange arrayMemrefs, ++ Region& comparator) { ++ SmallVector types{leftInit.getType(), rightInit.getType()}; ++ ArithBuilder arith(b, b.getLoc()); ++ ++ // while ( ++ auto whileOp = ++ b.create(types, SmallVector{leftInit, rightInit}); ++ OpBuilder::InsertionGuard guard(b); ++ ++ // left < right) { ++ Block* before = b.createBlock(&whileOp.getBefore(), {}, types, ++ {whileOp.getLoc(), whileOp.getLoc()}); ++ { ++ Value left = before->getArgument(0), right = before->getArgument(1); ++ b.setInsertionPointToEnd(before); ++ b.create(arith.slt(left, right), before->getArguments()); ++ } ++ ++ Block* after = b.createBlock(&whileOp.getAfter(), {}, types, ++ {whileOp.getLoc(), whileOp.getLoc()}); ++ { ++ Value left = after->getArgument(0), right = after->getArgument(1); ++ b.setInsertionPointToEnd(after); ++ // int mid = (left + right) >> 1; ++ Value one = b.create(1); ++ Value mid = b.create(arith.add(left, right), one); ++ Value midPlusOne = b.create(mid, one); ++ ++ auto arraysAtMid = llvm::to_vector( ++ llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { ++ return b.create(arrayMemref, mid); ++ })); ++ Value cond = emitComparison(b, pivots, arraysAtMid, comparator); ++ // if (comparator(pivot, array[mid])) ++ // right = mid; ++ // else ++ // left = mid + 1; ++ Value newLeft = arith.select(cond, left, midPlusOne); ++ Value newRight = arith.select(cond, mid, right); ++ ++ // } ++ b.create(ValueRange{newLeft, newRight}); ++ } ++ ++ return whileOp.getResult(0); ++} ++ ++SmallVector loadTensorElements(ImplicitLocOpBuilder& b, ++ ValueRange tensors, Value index) { ++ return llvm::to_vector(llvm::map_range(tensors, [&](Value tensor) -> Value { ++ return b.create(tensor, index); ++ })); ++} ++ ++SmallVector loadMemrefElements(ImplicitLocOpBuilder& b, ++ ValueRange memrefs, Value index) { ++ return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { ++ Type type = mlir::cast(memref.getType()).getElementType(); ++ return b.create(type, memref, index); ++ })); ++} ++ ++void storeMemrefElements(ImplicitLocOpBuilder& b, ValueRange memrefs, ++ Value index, ValueRange values) { ++ for (auto [value, memref] : llvm::zip(values, memrefs)) { ++ b.create(value, memref, index); ++ } ++} ++ ++// Insertion sorts `inputTensors` in the range [`lo`; `hi`), storing the results ++// in `outputMemrefs`. `inputTensors` and `outputMemrefs` must all be rank 1 and ++// of identical size. ++void emitInsertionSort(ImplicitLocOpBuilder& b, Value lo, Value hi, ++ ValueRange inputTensors, ValueRange outputMemrefs, ++ mlir::Region& comparator) { ++ ArithBuilder arith(b, b.getLoc()); ++ Value zero = b.create(0); ++ Value one = b.create(1); ++ ++ // array[lo] = tensors[lo]; ++ storeMemrefElements(b, outputMemrefs, lo, ++ loadTensorElements(b, inputTensors, lo)); ++ ++ // for (int start = lo + 1; start < hi; ++start) ++ { ++ auto forOp = b.create(arith.add(lo, one), hi, one); ++ OpBuilder::InsertionGuard outerGuard(b); ++ b.setInsertionPointToStart(forOp.getBody()); ++ Value start = forOp.getInductionVar(); ++ ++ // T pivot = tensors[start]; ++ auto pivots = loadTensorElements(b, inputTensors, start); ++ ++ // int index = binarySearch(lo, start, pivot, array, comparator); ++ auto index = ++ emitBinarySearch(b, lo, start, pivots, outputMemrefs, comparator); ++ ++ // int n = start - index; // The number of elements to move ++ Value n = arith.sub(start, index); ++ ++ // memmove(&array[index + 1], &array[index], n * sizeof(T)) ++ // memref::CopyOp would be nice to use here, but: ++ // 1. It lowers to a quite inefficient library call in the general case ++ // (strides != 1). ++ // 2. It implements memcpy semantics, but we need memmove here. ++ // So we go with a loop instead. ++ auto copyForOp = b.create(zero, n, one); ++ { ++ OpBuilder::InsertionGuard innerGuard(b); ++ b.setInsertionPointToStart(copyForOp.getBody()); ++ Value copyLoopIndex = copyForOp.getBody()->getArgument(0); ++ ++ Value dstIndex = arith.sub(start, copyLoopIndex); ++ Value srcIndex = arith.sub(dstIndex, one); ++ storeMemrefElements(b, outputMemrefs, dstIndex, ++ loadMemrefElements(b, outputMemrefs, srcIndex)); ++ } ++ // array[index] = pivot; ++ storeMemrefElements(b, outputMemrefs, index, pivots); ++ } ++} ++ ++void emitMerge(ImplicitLocOpBuilder& b, Value lo, Value mid, Value hi, ++ ValueRange readBufs, ValueRange writeBufs, ++ mlir::Region& comparator) { ++ ArithBuilder arith(b, b.getLoc()); ++ // The while loop runs until we reach the end of either interval. It has three ++ // loop-carried variables: ++ // 1. current output index ++ // 2. current read index for interval 1 ++ // 3. current read index for interval 2 ++ SmallVector whileArgTypes{lo.getType(), lo.getType(), mid.getType()}; ++ SmallVector whileInitArgs{lo, lo, mid}; ++ SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); ++ ++ // while( ++ auto whileOp = b.create(whileArgTypes, whileInitArgs); ++ { ++ OpBuilder::InsertionGuard guard(b); ++ { ++ Block* before = ++ b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); ++ Value i0 = before->getArgument(1), i1 = before->getArgument(2); ++ b.setInsertionPointToEnd(before); ++ ++ // i0 < mid && i1 < hi) { ++ Value inbounds0 = arith.slt(i0, mid); ++ Value inbounds1 = arith.slt(i1, hi); ++ ++ b.create(arith._and(inbounds0, inbounds1), ++ before->getArguments()); ++ } ++ ++ { ++ Block* after = ++ b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); ++ Value iOut = after->getArgument(0), i0 = after->getArgument(1), ++ i1 = after->getArgument(2); ++ b.setInsertionPointToEnd(after); ++ ++ // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; ++ SmallVector vals0 = loadMemrefElements(b, readBufs, i0); ++ SmallVector vals1 = loadMemrefElements(b, readBufs, i1); ++ ++ // writeBufs[iOut] = comparator(vals1, vals0) ++ // ? readBufs[i1++] : readBufs[i0++]; ++ Value cmp = emitComparison(b, vals1, vals0, comparator); ++ SmallVector pickedVals; ++ for (auto [val0, val1] : llvm::zip(vals0, vals1)) { ++ pickedVals.push_back(b.create(cmp, val1, val0)); ++ } ++ storeMemrefElements(b, writeBufs, iOut, pickedVals); ++ ++ Value one = b.create(1); ++ Value nexti0 = b.create(cmp, i0, arith.add(i0, one)); ++ Value nexti1 = b.create(cmp, arith.add(i1, one), i1); ++ // ++iOut; ++ Value nextIOut = b.create(iOut, one); ++ b.create(ValueRange{nextIOut, nexti0, nexti1}); ++ } ++ } ++ ++ // At this point, exactly one of the input ranges will have leftover elements. ++ Value iOut = whileOp->getResult(0); ++ Value i0 = whileOp->getResult(1); ++ Value i1 = whileOp->getResult(2); ++ ++ // We could use memref::CopyOp here, but typically, there aren't many leftover ++ // elements for randomly shuffled inputs. ++ Value leftoverIn0 = arith.slt(i0, mid); ++ Value start = arith.select(leftoverIn0, i0, i1); ++ Value end = arith.select(leftoverIn0, mid, hi); ++ Value n = arith.sub(end, start); ++ ++ Value zero = b.create(0); ++ Value one = b.create(1); ++ auto forOp = b.create(zero, n, one); ++ b.setInsertionPointToStart(forOp.getBody()); ++ Value copyIndex = forOp.getBody()->getArgument(0); ++ ++ Value srcIndex = arith.add(start, copyIndex); ++ Value dstIndex = arith.add(iOut, copyIndex); ++ storeMemrefElements(b, writeBufs, dstIndex, ++ loadMemrefElements(b, readBufs, srcIndex)); ++} ++ ++// Emits a bottom up merge sort of `inputTensors` in the range [`lo`; `hi`), and ++// writes the results to either `outputs0` or `outputs1`. ++// Returns 0 if the results are in `outputs0`, 1 if they are in `outputs1`. ++// TODO(jreiffers): Consider implementing top-down merge sort. ++Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, ++ int64_t staticSortDimSize, ValueRange inputTensors, ++ ValueRange outputs0, ValueRange outputs1, ++ mlir::Region& comparator) { ++ ArithBuilder arith(b, b.getLoc()); ++ Value size = arith.sub(hi, lo); ++ ++ Value zero = b.create(0); ++ Value insertionSortSize = ++ b.create(kInsertionSortSize); ++ ++ // Run insertion sort on blocks of size kInsertionSortSize. ++ // for (int start = 0; start < size; start += kInsertionSortSize) { ++ { ++ auto forOp = b.create(zero, size, insertionSortSize); ++ OpBuilder::InsertionGuard guard(b); ++ b.setInsertionPointToStart(forOp.getBody()); ++ Value start = forOp.getBody()->getArgument(0); ++ Value end = arith.add( ++ b.create(arith.add(start, insertionSortSize), size), lo); ++ emitInsertionSort(b, start, end, inputTensors, outputs0, comparator); ++ } ++ ++ Value initParity = b.create(0, 1); ++ if (staticSortDimSize >= 0 && staticSortDimSize < kInsertionSortSize) { ++ return initParity; ++ } ++ ++ // The while arguments are: ++ // 1. the current size ++ // 2. the original index of the buffers we're currently reading from ++ // 3. the buffers we're currently reading from ++ // 4. the buffers we're currently writing to. ++ // ++ // 1 gets doubled each iteration, 2 gets negated, 3 and 4 are swapped. ++ // int currentSize = 16; ++ SmallVector whileInitArgs{insertionSortSize, initParity}; ++ // First we read from `outputs0` (initialized by the insertion sort above). ++ llvm::copy(outputs0, std::back_inserter(whileInitArgs)); ++ llvm::copy(outputs1, std::back_inserter(whileInitArgs)); ++ ++ SmallVector whileArgTypes; ++ for (auto val : whileInitArgs) whileArgTypes.push_back(val.getType()); ++ ++ SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); ++ ++ // while ( ++ auto whileOp = b.create(whileArgTypes, whileInitArgs); ++ OpBuilder::InsertionGuard guard(b); ++ ++ // currentSize < totalSize) ++ { ++ Block* before = ++ b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); ++ Value currentSize = before->getArgument(0); ++ b.setInsertionPointToEnd(before); ++ b.create(arith.slt(currentSize, size), ++ before->getArguments()); ++ } ++ ++ size_t numArgs = inputTensors.size(); ++ // { ++ { ++ Block* after = ++ b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); ++ ++ Value currentSize = after->getArgument(0); ++ Value parity = after->getArgument(1); ++ auto readBufs = after->getArguments().drop_front(2).take_front(numArgs); ++ auto writeBufs = after->getArguments().take_back(numArgs); ++ ++ Value twoCurrentSize = arith.add(currentSize, currentSize); ++ ++ // for (int start = 0; start < size; start += 2*currentSize) { ++ { ++ auto forOp = b.create(zero, size, twoCurrentSize); ++ b.setInsertionPointToStart(forOp.getBody()); ++ Value start = forOp.getBody()->getArgument(0); ++ ++ Value mid = b.create(size, arith.add(start, currentSize)); ++ Value end = b.create(size, arith.add(start, twoCurrentSize)); ++ emitMerge(b, start, mid, end, readBufs, writeBufs, comparator); ++ b.setInsertionPointAfter(forOp); ++ } ++ // } ++ ++ // parity = !parity; ++ Value one = b.create(1, 1); ++ Value notParity = arith.sub(one, parity); ++ // currentSize *= 2; ++ SmallVector nextWhileArgs{twoCurrentSize, notParity}; ++ llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); ++ llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); ++ b.create(nextWhileArgs); ++ } ++ // } ++ ++ // The result is the parity bit. ++ return whileOp.getResults().drop_front(1).front(); ++} ++ ++// Helper struct for extracting 1d slices from tensors and memrefs. ++struct Slicer { ++ Slicer(OpBuilder& b, uint64_t sortDim, Value sortDimSize, ValueRange ivs) ++ : sizes(ivs.size() + 1, b.getI64IntegerAttr(1)), ++ strides(ivs.size() + 1, b.getI64IntegerAttr(1)) { ++ sizes[sortDim] = sortDimSize; ++ for (size_t i = 0; i < ivs.size() + 1; ++i) { ++ if (i == sortDim) { ++ offsets.push_back(b.getI64IntegerAttr(0)); ++ } else { ++ offsets.push_back(ivs[i - static_cast(i > sortDim)]); ++ } ++ } ++ } ++ ++ RankedTensorType toSlicedType(RankedTensorType sourceType) { ++ return tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( ++ /*resultRank=*/1, sourceType, offsets, sizes, strides); ++ } ++ ++ MemRefType toSlicedType(MemRefType sourceType) { ++ return mlir::cast(memref::SubViewOp::inferRankReducedResultType( ++ {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, ++ strides)); ++ } ++ ++ template ++ Value slice(ImplicitLocOpBuilder& b, Value input) { ++ Ty ty = mlir::cast(input.getType()); ++ return b.create(toSlicedType(ty), input, offsets, sizes, strides) ++ .getResult(); ++ } ++ ++ Value apply(ImplicitLocOpBuilder& b, Value input) { ++ Type inTy = input.getType(); ++ if (mlir::isa(inTy)) { ++ return slice(b, input); ++ } ++ assert(mlir::isa(inTy)); ++ return slice(b, input); ++ } ++ ++ SmallVector offsets; ++ SmallVector sizes; ++ SmallVector strides; ++}; ++ ++SmallVector sliceMemrefsOrTensors(ImplicitLocOpBuilder& b, ++ SmallVector& ivs, ++ Value sortDimSize, ++ ValueRange memrefsOrTensors, ++ SortOp op) { ++ if (ivs.empty()) return memrefsOrTensors; ++ ++ SmallVector outputs; ++ Slicer slicer(b, op.getDimension(), sortDimSize, ivs); ++ // Create subviews/slices. ++ for (Value out : memrefsOrTensors) { ++ outputs.push_back(slicer.apply(b, out)); ++ } ++ ++ return outputs; ++} ++ ++struct SortOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(SortOp op, ++ PatternRewriter& rewriter) const override { ++ ImplicitLocOpBuilder b(op.getLoc(), rewriter); ++ ++ // Note: the output memrefs aren't necessarily the ones that we return, ++ SmallVector outputMemrefs; ++ SmallVector scratchMemrefs; ++ ++ Value firstOperand = op.getOperands().front(); ++ auto firstOperandType = mlir::cast(firstOperand.getType()); ++ int64_t inputRank = firstOperandType.getRank(); ++ ++ Value sortDimSize = b.createOrFold( ++ firstOperand, b.create(op.getDimension())); ++ int64_t staticSortDimSize = firstOperandType.getDimSize(op.getDimension()); ++ ++ SmallVector dynamicDims; ++ for (int i = 0; i < inputRank; ++i) { ++ if (!firstOperandType.isDynamicDim(i)) continue; ++ Value index = b.create(i); ++ Value dimOp = b.create(firstOperand, index); ++ dynamicDims.push_back(dimOp); ++ } ++ ++ // Allocate output and scratch memrefs. If the size of the sort dimension is ++ // statically known to be <= kInsertionSortSize, `scratchMemrefs` are unused ++ // and will be cleaned up later. ++ for (auto input : op.getOperands()) { ++ auto inputType = mlir::cast(input.getType()); ++ auto memRefType = ++ MemRefType::get(inputType.getShape(), inputType.getElementType()); ++ ++ outputMemrefs.push_back( ++ b.create(memRefType, dynamicDims)); ++ scratchMemrefs.push_back( ++ b.create(memRefType, dynamicDims)); ++ } ++ ++ b.setInsertionPoint(op); ++ Value zero = b.create(0); ++ Value one = b.create(1); ++ ++ Value forInitArg = b.create(0, 1); ++ SmallVector forOps; ++ SmallVector ivs; ++ forOps.reserve(inputRank - 1); ++ ivs.reserve(inputRank - 1); ++ for (int64_t i = 0; i < inputRank; ++i) { ++ if (i != static_cast(op.getDimension())) { ++ Value dim = b.create(i); ++ Value ub = b.create(firstOperand, dim); ++ scf::ForOp& forOp = forOps.emplace_back( ++ b.create(zero, ub, one, ValueRange{forInitArg})); ++ ivs.push_back(forOp.getInductionVar()); ++ b.setInsertionPointToStart(&forOp.getRegion().front()); ++ } ++ } ++ SmallVector inputs = ++ sliceMemrefsOrTensors(b, ivs, sortDimSize, op.getOperands(), op); ++ SmallVector outputs = ++ sliceMemrefsOrTensors(b, ivs, sortDimSize, outputMemrefs, op); ++ SmallVector scratches = ++ sliceMemrefsOrTensors(b, ivs, sortDimSize, scratchMemrefs, op); ++ ++ Value parity = ++ emitBottomUpMergeSort(b, zero, sortDimSize, staticSortDimSize, inputs, ++ outputs, scratches, op.getRegion()); ++ ++ // Pass the parity bit through the for loops. ++ for (auto i = static_cast(forOps.size() - 1); i >= 0; --i) { ++ b.setInsertionPointToEnd(&forOps[i].getRegion().front()); ++ b.create(ValueRange{parity}); ++ parity = forOps[i]->getResult(0); ++ } ++ b.setInsertionPoint(op); ++ ++ SmallVector outputTensors; ++ for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) { ++ outputTensors.push_back(b.create( ++ b.create(parity, out1, out0), /*restrict=*/true)); ++ } ++ ++ rewriter.replaceOp(op, outputTensors); ++ return success(); ++ } ++}; ++ ++struct LegalizeSortPass ++ : public impl::HloLegalizeSortPassBase { ++ // Perform the lowering to MLIR control flow. ++ void runOnOperation() override { ++ func::FuncOp f = getOperation(); ++ MLIRContext* ctx = f.getContext(); ++ ++ RewritePatternSet patterns(ctx); ++ patterns.add(ctx); ++ ++ mlir::ConversionTarget target(*ctx); ++ target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); ++ target.addIllegalOp(); ++ ++ if (failed(applyPartialConversion(f, target, std::move(patterns)))) { ++ signalPassFailure(); ++ } ++ } ++}; ++ ++} // namespace ++} // namespace mhlo ++} // namespace mlir ++ ++std::unique_ptr> ++mlir::mhlo::createLegalizeSortPass() { ++ return std::make_unique(); ++} +diff --git a/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc b/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc +new file mode 100644 +index 00000000..be752397 +--- /dev/null ++++ b/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc +@@ -0,0 +1,243 @@ ++/* Copyright 2019 The OpenXLA Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++// This file implements logic for lowering MHLO dialect to Standard dialect. ++ ++#include ++#include ++#include ++ ++#include "mhlo/IR/hlo_ops.h" ++#include "mhlo/transforms/passes.h" ++#include "mhlo/transforms/rewriters.h" ++#include "mlir/Dialect/Arith/IR/Arith.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/Dialect/Math/IR/Math.h" ++#include "mlir/IR/BuiltinOps.h" ++#include "mlir/Pass/Pass.h" ++#include "mlir/Support/LLVM.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++ ++namespace mlir { ++namespace { ++#include "legalize_to_standard/generated_legalize_to_standard.inc" ++} // end anonymous namespace ++namespace mhlo { ++ ++#define GEN_PASS_DEF_LEGALIZETOSTANDARDPASS ++#include "mhlo/transforms/mhlo_passes.h.inc" ++ ++namespace { ++ ++class CompareIConvert : public OpRewritePattern { ++ public: ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(mhlo::CompareOp op, ++ PatternRewriter &rewriter) const override { ++ auto lhs = op.getLhs(); ++ auto rhs = op.getRhs(); ++ auto lhsType = mlir::cast(lhs.getType()); ++ auto rhsType = mlir::cast(rhs.getType()); ++ ++ // Broadcasting not supported by this rewrite. ++ if (lhsType.getShape() != rhsType.getShape()) return failure(); ++ ++ if (!lhsType.getElementType().isSignlessInteger() || ++ !rhsType.getElementType().isSignlessInteger()) ++ return failure(); ++ ++ std::optional comparePredicate = std::nullopt; ++ switch (op.getComparisonDirection()) { ++ case ComparisonDirection::EQ: ++ comparePredicate = arith::CmpIPredicate::eq; ++ break; ++ case ComparisonDirection::NE: ++ comparePredicate = arith::CmpIPredicate::ne; ++ break; ++ case ComparisonDirection::LT: ++ comparePredicate = arith::CmpIPredicate::slt; ++ break; ++ case ComparisonDirection::LE: ++ comparePredicate = arith::CmpIPredicate::sle; ++ break; ++ case ComparisonDirection::GT: ++ comparePredicate = arith::CmpIPredicate::sgt; ++ break; ++ case ComparisonDirection::GE: ++ comparePredicate = arith::CmpIPredicate::sge; ++ break; ++ } ++ ++ if (!comparePredicate.has_value()) return failure(); ++ ++ rewriter.replaceOpWithNewOp(op, comparePredicate.value(), ++ lhs, rhs); ++ return success(); ++ } ++}; ++ ++class CompareFConvert : public OpRewritePattern { ++ public: ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(mhlo::CompareOp op, ++ PatternRewriter &rewriter) const override { ++ auto lhs = op.getLhs(); ++ auto rhs = op.getRhs(); ++ auto lhsType = mlir::cast(lhs.getType()); ++ auto rhsType = mlir::cast(rhs.getType()); ++ ++ // Broadcasting not supported by this rewrite. ++ if (lhsType.getShape() != rhsType.getShape()) return failure(); ++ ++ if (!mlir::isa(lhsType.getElementType()) || ++ !mlir::isa(rhsType.getElementType())) ++ return failure(); ++ ++ std::optional comparePredicate = std::nullopt; ++ switch (op.getComparisonDirection()) { ++ case ComparisonDirection::EQ: ++ comparePredicate = arith::CmpFPredicate::OEQ; ++ break; ++ case ComparisonDirection::NE: ++ comparePredicate = arith::CmpFPredicate::UNE; ++ break; ++ case ComparisonDirection::LT: ++ comparePredicate = arith::CmpFPredicate::OLT; ++ break; ++ case ComparisonDirection::LE: ++ comparePredicate = arith::CmpFPredicate::OLE; ++ break; ++ case ComparisonDirection::GT: ++ comparePredicate = arith::CmpFPredicate::OGT; ++ break; ++ case ComparisonDirection::GE: ++ comparePredicate = arith::CmpFPredicate::OGE; ++ break; ++ } ++ ++ if (!comparePredicate.has_value()) return failure(); ++ ++ rewriter.replaceOpWithNewOp(op, comparePredicate.value(), ++ lhs, rhs); ++ return success(); ++ } ++}; ++ ++// Replace IotaOp with an integer constant. A ConvertOp is added to ++// convert the integer constant to iota result type. For complex types, the real ++// part is replaced with the generated constant and the imaginary part is ++// replaced with zero tensor. ++class ConvertIotaOp : public OpRewritePattern { ++ public: ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(mhlo::IotaOp op, ++ PatternRewriter &rewriter) const override { ++ auto outputType = mlir::cast(op.getType()); ++ auto outputSize = outputType.getNumElements(); ++ auto dimension = op.getIotaDimension(); ++ auto maxDimSize = outputType.getDimSize(dimension); ++ ++ auto elementType = outputType.getElementType(); ++ int bitwidth; ++ ++ auto complexTy = mlir::dyn_cast(elementType); ++ Type intOrFloatTy = elementType; ++ if (complexTy) intOrFloatTy = complexTy.getElementType(); ++ ++ bitwidth = intOrFloatTy.getIntOrFloatBitWidth(); ++ llvm::SmallVector values; ++ values.reserve(outputSize); ++ ++ int64_t increaseStride = outputSize; ++ for (uint64_t i = 0; i <= dimension; i++) { ++ increaseStride /= outputType.getDimSize(i); ++ } ++ ++ int64_t currentValue = 0; ++ for (int i = 0; i < outputSize; i++) { ++ int64_t value = (currentValue / increaseStride) % maxDimSize; ++ values.push_back(APInt(bitwidth, value)); ++ ++currentValue; ++ } ++ ++ auto intShapeType = RankedTensorType::get( ++ outputType.getShape(), ++ IntegerType::get(rewriter.getContext(), bitwidth)); ++ auto loc = op.getLoc(); ++ auto integerConst = rewriter.create( ++ loc, DenseIntElementsAttr::get(intShapeType, values)); ++ ++ auto intOrFloatShapeTy = ++ RankedTensorType::get(outputType.getShape(), intOrFloatTy); ++ ++ auto iotaConst = ++ rewriter.create(loc, intOrFloatShapeTy, integerConst); ++ ++ // For int/float types we are done, replace op and return. ++ if (!complexTy) { ++ rewriter.replaceOp(op, iotaConst.getResult()); ++ return success(); ++ } ++ ++ // For complex types, generate a constant tensor of zeroes for the imaginary ++ // part and use iota_const for real part. ++ auto zeroes = rewriter.create( ++ loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); ++ auto imagZeroes = ++ rewriter.create(loc, intOrFloatShapeTy, zeroes); ++ rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); ++ return success(); ++ } ++}; ++ ++} // end anonymous namespace ++ ++namespace { ++struct LegalizeToStandardPass ++ : public impl::LegalizeToStandardPassBase { ++ void getDependentDialects(DialectRegistry ®istry) const override { ++ registry ++ .insert(); ++ } ++ ++ /// Perform the lowering to Standard dialect. ++ void runOnOperation() override; ++}; ++} // end anonymous namespace ++ ++std::unique_ptr> ++createLegalizeToStdPass() { ++ return std::make_unique(); ++} ++ ++void populateMhloToStdPatterns(RewritePatternSet *patterns, ++ mlir::MLIRContext *ctx) { ++ mlir::populateWithGenerated(*patterns); ++ patterns->add(ctx); ++} ++ ++/// Perform the lowering to standard dialect. ++void LegalizeToStandardPass::runOnOperation() { ++ RewritePatternSet patterns(&getContext()); ++ mlir::mhlo::populateMhloToStdPatterns(&patterns, &getContext()); ++ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) ++ return signalPassFailure(); ++} ++ ++} // end namespace mhlo ++} // end namespace mlir +diff --git a/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td b/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td +new file mode 100644 +index 00000000..f4d24608 +--- /dev/null ++++ b/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td +@@ -0,0 +1,92 @@ ++/* Copyright 2019 The OpenXLA Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++// This is the legalization pattern definition file for MHLO to StandardOps. ++ ++include "mlir/IR/OpBase.td" ++include "mlir/Dialect/Arith/IR/ArithOps.td" ++include "mlir/Dialect/Math/IR/MathOps.td" ++include "mlir/Dialect/Func/IR/FuncOps.td" ++include "mhlo/IR/hlo_ops.td" ++ ++//===----------------------------------------------------------------------===// ++// Nullary op patterns. ++//===----------------------------------------------------------------------===// ++ ++def : Pat<(MHLO_ConstantOp ElementsAttr:$value), ++ (Arith_ConstantOp $value)>; ++ ++//===----------------------------------------------------------------------===// ++// Binary op patterns. ++//===----------------------------------------------------------------------===// ++ ++def IsSameSizePred : CPred< ++ "cast($0.getType()).getShape() " ++ "== cast($1.getType()).getShape()">; ++def IsSameSizeConstraint : Constraint; ++def createFastMathNone : NativeCodeCall< ++ "::mlir::arith::FastMathFlagsAttr::get(" ++ "$_builder.getContext(), ::mlir::arith::FastMathFlags::none" ++ ")">; ++def createOverflowNone : NativeCodeCall< ++ "::mlir::arith::IntegerOverflowFlagsAttr::get(" ++ "$_builder.getContext(), ::mlir::arith::IntegerOverflowFlags::none" ++ ")">; ++ ++ ++// Unary Lowering Patterns. ++def : Pat<(MHLO_CeilOp MHLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; ++ ++// Binary Lowering Patterns. ++def : Pat<(MHLO_AndOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_AndIOp $l, $r), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_OrOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_OrIOp $l, $r), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_AddOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), ++ (Arith_AddFOp $l, $r, (createFastMathNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_SubtractOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), ++ (Arith_SubFOp $l, $r, (createFastMathNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_MulOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), ++ (Arith_MulFOp $l, $r, (createFastMathNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_DivOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), ++ (Arith_DivFOp $l, $r, (createFastMathNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_RemOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), ++ (Arith_RemFOp $l, $r, (createFastMathNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_AddOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_AddIOp $l, $r, (createOverflowNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_SubtractOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_SubIOp $l, $r, (createOverflowNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_MulOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_MulIOp $l, $r, (createOverflowNone )), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_DivOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_DivSIOp $l, $r), ++ [(IsSameSizeConstraint $l, $r)]>; ++def : Pat<(MHLO_RemOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), ++ (Arith_RemSIOp $l, $r), ++ [(IsSameSizeConstraint $l, $r)]>; ++ ++def : Pat<(MHLO_SelectOp $pred, $tv, $fv), ++ (SelectOp $pred, $tv, $fv), ++ [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; +diff --git a/mhlo/transforms/mhlo_passes.td b/mhlo/transforms/mhlo_passes.td +index 853531c1..378f8944 100644 +--- a/mhlo/transforms/mhlo_passes.td ++++ b/mhlo/transforms/mhlo_passes.td +@@ -15,6 +15,25 @@ limitations under the License. + + include "mlir/Pass/PassBase.td" + ++def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "func::FuncOp"> { ++ let summary = "Legalize from MHLO control flow to SCF control flow."; ++ let constructor = "createLegalizeControlFlowPass()"; ++ let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"]; ++} ++ ++def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "func::FuncOp"> { ++ let summary = "Legalize from MHLO dialect to standard dialect."; ++ let constructor = "createLegalizeToStdPass()"; ++} ++ ++def HloLegalizeSortPass : Pass<"hlo-legalize-sort", "func::FuncOp"> { ++ let summary = "Legalize from MHLO sort to SCF control flow."; ++ let constructor = "createLegalizeSortPass()"; ++ let dependentDialects = ["arith::ArithDialect", ++ "bufferization::BufferizationDialect", ++ "scf::SCFDialect", "tensor::TensorDialect"]; ++} ++ + def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "func::FuncOp"> { + let summary = "Legalize CHLO's with XLA counterparts, like TopK and Erf."; + let description = [{ +diff --git a/mhlo/transforms/passes.h b/mhlo/transforms/passes.h +index 3d2aa3b3..3f03b2df 100644 +--- a/mhlo/transforms/passes.h ++++ b/mhlo/transforms/passes.h +@@ -37,6 +37,10 @@ namespace mhlo { + #define GEN_PASS_DECL + #include "mhlo/transforms/mhlo_passes.h.inc" + ++std::unique_ptr> createLegalizeControlFlowPass(); ++std::unique_ptr> createLegalizeSortPass(); ++std::unique_ptr> createLegalizeToStdPass(); ++ + /// Lowers from HLO dialect to Arithmetic dialect. + std::unique_ptr> createLegalizeToArithmeticPass(); + +-- +2.34.1 + diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch new file mode 100644 index 0000000000..f78200bdab --- /dev/null +++ b/mlir/patches/mhlo-remove-shardy.patch @@ -0,0 +1,132 @@ +From 70172e8399383d6c1964d73a2d20cba3c55a3279 Mon Sep 17 00:00:00 2001 +From: paul0403 +Date: Thu, 29 May 2025 10:06:35 -0400 +Subject: [PATCH] remove shardy dependency + +--- + bindings/c/CMakeLists.txt | 1 - + stablehlo_ext/CMakeLists.txt | 1 + + stablehlo_ext/analysis/CMakeLists.txt | 3 ++- + stablehlo_ext/transforms/CMakeLists.txt | 7 ++++++- + stablehlo_ext/transforms/stablehlo_refine_shapes.cpp | 3 --- + tests/lit.cfg.py | 1 + + tools/mlir-hlo-opt/mlir-hlo-opt.cc | 2 -- + 7 files changed, 10 insertions(+), 8 deletions(-) + +diff --git a/bindings/c/CMakeLists.txt b/bindings/c/CMakeLists.txt +index fd2a5c2c..53d916d5 100644 +--- a/bindings/c/CMakeLists.txt ++++ b/bindings/c/CMakeLists.txt +@@ -10,7 +10,6 @@ add_mlir_public_c_api_library(MLIRHLOCAPIDialects + MhloPasses + MhloToArithmeticConversion + MhloToMemrefConversion +- MhloToStandard + MhloToLinalg + MhloToStablehlo + StablehloToMhlo +diff --git a/stablehlo_ext/CMakeLists.txt b/stablehlo_ext/CMakeLists.txt +index 3e55a89d..e8d318f1 100644 +--- a/stablehlo_ext/CMakeLists.txt ++++ b/stablehlo_ext/CMakeLists.txt +@@ -12,5 +12,6 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + ++add_subdirectory(analysis) + add_subdirectory(IR) + add_subdirectory(transforms) +diff --git a/stablehlo_ext/analysis/CMakeLists.txt b/stablehlo_ext/analysis/CMakeLists.txt +index 726d340d..0c0259b8 100644 +--- a/stablehlo_ext/analysis/CMakeLists.txt ++++ b/stablehlo_ext/analysis/CMakeLists.txt +@@ -1,5 +1,6 @@ + add_mlir_library(MhloAnalysis +- shape_component_analysis.cc ++ shape_component_analysis.cpp ++ PARTIAL_SOURCES_INTENDED + + DEPENDS + mlir-headers +diff --git a/stablehlo_ext/transforms/CMakeLists.txt b/stablehlo_ext/transforms/CMakeLists.txt +index ee58f490..2d7cc22c 100644 +--- a/stablehlo_ext/transforms/CMakeLists.txt ++++ b/stablehlo_ext/transforms/CMakeLists.txt +@@ -20,9 +20,14 @@ add_mlir_dialect_library(StablehloExtensionPasses + PARTIAL_SOURCES_INTENDED + chlo_recompose_ops.cpp + chlo_preserve_high_level_ops.cpp ++ sink_constants_to_control_flow.cpp ++ stablehlo_add_quant_dequant_conv.cpp + stablehlo_canonicalize_dynamism.cpp ++ stablehlo_canonicalize_from_hlo_import.cpp ++ stablehlo_legalize_quant_composite.cpp ++ stablehlo_prepare_for_hlo_export.cpp + stablehlo_refine_shapes.cpp +- sdy_refine_shapes.cpp ++ symbolic_shape_optimization.cpp + + DEPENDS + StablehloExtensionPassesIncGen +diff --git a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp +index cabd6a9f..2e64b4ed 100644 +--- a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp ++++ b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp +@@ -34,7 +34,6 @@ limitations under the License. + #include "stablehlo_ext/IR/base.h" + #include "stablehlo_ext/IR/stablehlo_ops.h" + #include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc +-#include "stablehlo_ext/transforms/sdy_refine_shapes.h" + + namespace mlir { + namespace stablehlo_ext { +@@ -154,7 +153,6 @@ struct StablehloRefineShapesPass + patterns->add(context); + patterns->add(context); + patterns->add(context); +- populateSdyShapeRefinementPatterns(patterns, context); + }; + + if (failed(stablehlo::refineEntryFunction(*context, func, +@@ -172,7 +170,6 @@ void populateStablehloExtRefineShapesPatterns(RewritePatternSet *patterns, + patterns->add(context); + patterns->add(context); + patterns->add(context); +- populateSdyShapeRefinementPatterns(patterns, context); + } + + } // namespace stablehlo_ext +diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py +index ab20fbb5..6c61aec5 100644 +--- a/tests/lit.cfg.py ++++ b/tests/lit.cfg.py +@@ -32,6 +32,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + + # suffixes: A list of file extensions to treat as test files. + config.suffixes = ['.mlir'] ++config.excludes = ['sdy_refine_shapes.mlir'] + + # test_source_root: The root path where tests are located. + config.test_source_root = os.path.dirname(__file__) +diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/tools/mlir-hlo-opt/mlir-hlo-opt.cc +index f018cbdc..b4474850 100644 +--- a/tools/mlir-hlo-opt/mlir-hlo-opt.cc ++++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cc +@@ -20,7 +20,6 @@ limitations under the License. + #include "mlir/InitAllExtensions.h" + #include "mlir/InitAllPasses.h" + #include "mlir/Tools/mlir-opt/MlirOptMain.h" +-#include "shardy/dialect/sdy/ir/dialect.h" + #include "stablehlo/dialect/Register.h" + #include "stablehlo_ext/transforms/passes.h" + #include "transforms/gpu_passes.h" +@@ -41,6 +40,5 @@ int main(int argc, char** argv) { + registerAllExtensions(registry); + mhlo::registerAllMhloDialects(registry); + stablehlo::registerAllDialects(registry); +- registry.insert(); + return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); + } +-- +2.34.1 + diff --git a/mlir/patches/mlir-buffer-deallocation.patch b/mlir/patches/mlir-buffer-deallocation.patch deleted file mode 100644 index 852e6e84c5..0000000000 --- a/mlir/patches/mlir-buffer-deallocation.patch +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/mlir/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp -index a0a81d4add..7b7be9e577 100644 ---- a/mlir/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp -+++ b/mlir/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp -@@ -308,6 +308,9 @@ private: - - // Add new allocs and additional clone operations. - for (Value value : valuesToFree) { -+ if (!isa(value.getType())) { -+ continue; -+ } - if (failed(isa(value) - ? introduceBlockArgCopy(cast(value)) - : introduceValueCopyForRegionResult(value))) diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index 23c3138bf4..effc229a64 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -106,7 +106,7 @@ module @test1 { // CHECK-LABEL: @foo( // CHECK-SAME: [[arg0:%.+]]: tensor) func.func private @foo(%arg0: tensor) -> tensor { - // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref + // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : tensor to memref // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> () %1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor) diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir index 29ab3dfbe9..20ce7c6bc7 100644 --- a/mlir/test/Catalyst/ConversionTest.mlir +++ b/mlir/test/Catalyst/ConversionTest.mlir @@ -160,13 +160,13 @@ module @test1 { // call @callback_1([[ptr0]], [[ptr1]]) - %0 = bufferization.to_memref %arg0 : memref + %0 = bufferization.to_memref %arg0 : tensor to memref %1 = bufferization.alloc_tensor() {memory_space = 0 : i64} : tensor - %2 = bufferization.to_memref %1 : memref + %2 = bufferization.to_memref %1 : tensor to memref catalyst.callback_call @callback_1(%0, %2) : (memref, memref) -> () - %3 = bufferization.to_tensor %2 : memref + %3 = bufferization.to_tensor %2 : memref to tensor return %3 : tensor } } diff --git a/mlir/test/Gradient/BufferizationTest.mlir b/mlir/test/Gradient/BufferizationTest.mlir index 1b72897b65..4a8f9a246e 100644 --- a/mlir/test/Gradient/BufferizationTest.mlir +++ b/mlir/test/Gradient/BufferizationTest.mlir @@ -63,7 +63,7 @@ func.func private @circuit(%arg0: tensor<2xf64>) // CHECK-LABEL: @adjoint_with_tensor_arg func.func @adjoint_with_tensor_arg(%arg0: tensor<2xf64>, %arg1: index) { - // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : memref<2xf64> + // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64> // CHECK: [[alloc:%.+]] = memref.alloc(%arg1) : memref // CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc]] : memref) : (memref<2xf64>) -> () %grad = gradient.adjoint @circuit(%arg0) size(%arg1) : (tensor<2xf64>) -> tensor @@ -77,7 +77,7 @@ func.func private @circuit(%arg0: tensor<2xf64>) // CHECK-LABEL: @adjoint_with_multiple_results func.func @adjoint_with_multiple_results(%arg0: tensor<2xf64>, %arg1: index) { - // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : memref<2xf64> + // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64> // CHECK: [[alloc0:%.+]] = memref.alloc(%arg1) : memref // CHECK: [[alloc1:%.+]] = memref.alloc(%arg1) : memref // CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc0]], [[alloc1]] @@ -93,7 +93,7 @@ func.func private @circuit(%arg0: f64) // CHECK-LABEL: @backprop_scalar_in func.func @backprop_scalar_in(%arg0: f64, %arg1: tensor) { - // CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : memref + // CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref // CHECK: [[dim1:%.+]] = memref.dim [[cotangentSource]] // CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim1]]) {alignment = 64 : i64} : memref // CHECK: memref.copy [[cotangentSource]], [[cotangentRes]] @@ -115,8 +115,8 @@ func.func private @circuit(%arg0: tensor) // CHECK-LABEL: @backprop_tensor_in func.func @backprop_tensor_in(%arg0: tensor, %arg1: tensor) { - // CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : memref - // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : memref + // CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : tensor to memref + // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref // CHECK: [[dim2:%.+]] = memref.dim [[cotangentSource]] // CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim2]]) {alignment = 64 : i64} : memref // CHECK: memref.copy [[cotangentSource]], [[cotangentRes]] @@ -141,8 +141,8 @@ func.func private @circuit(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>) // CHECK-LABEL: @backprop_multiple_tensors_in func.func @backprop_multiple_tensors_in(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>, %arg2: tensor) { - // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : memref<10xf64> - // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : memref<2xf64> + // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : tensor<10xf64> to memref<10xf64> + // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : tensor<2xf64> to memref<2xf64> // CHECK: memref.alloc // CHECK: memref.copy // CHECK: [[argShadow1:%.+]] = memref.alloc() : memref<10xf64> @@ -171,10 +171,9 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: tensor<2xf64>) -> (tensor, ten // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) - // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : memref - // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : memref<2xf64> + // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref + // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64> // CHECK: gradient.return {empty = false} [[res0]], [[res1]] : memref, memref<2xf64> - // CHECK: } %0:2 = func.call @callback_fn_fwd(%arg0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) gradient.return {empty = false} %0#0, %0#1 : tensor, tensor<2xf64> @@ -193,9 +192,8 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: tensor, %arg1: tensor<2xf64>) // CHECK: [[in1:%.+]] = bufferization.to_tensor %arg1 : memref<2xf64> // CHECK: [[in0:%.+]] = bufferization.to_tensor %arg0 : memref // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[in1]], [[in0]]) : (tensor<2xf64>, tensor) -> tensor<2xf64> - // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : memref<2xf64> + // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64> // CHECK: gradient.return {empty = true} [[res]] : memref<2xf64> - // CHECK: } %0 = func.call @callback_fn_vjp(%arg1, %arg0) : (tensor<2xf64>, tensor) -> tensor<2xf64> gradient.return {empty = true} %0 : tensor<2xf64> diff --git a/mlir/test/Gradient/ConversionTest.mlir b/mlir/test/Gradient/ConversionTest.mlir index 04fca8c5d3..1beed4bcaf 100644 --- a/mlir/test/Gradient/ConversionTest.mlir +++ b/mlir/test/Gradient/ConversionTest.mlir @@ -18,7 +18,7 @@ // Native Gradients // ////////////////////// -func.func private @circuit.nodealloc(%arg0: f32) -> (!quantum.reg) +func.func private @circuit.nodealloc(%arg0: f32) -> (!quantum.reg, f64) // CHECK-DAG: llvm.func @__catalyst__rt__toggle_recorder(i1) // CHECK-DAG: llvm.func @__catalyst__qis__Gradient(i64, ...) @@ -29,7 +29,7 @@ func.func @adjoint(%arg0: f32, %arg1 : index) -> (memref, memref) // CHECK-DAG: [[F:%.+]] = llvm.mlir.constant(false) : i1 // CHECK: llvm.call @__catalyst__rt__toggle_recorder([[T]]) : (i1) -> () - // CHECK: [[QREG:%.+]] = call @circuit.nodealloc(%arg0) + // CHECK: [[QREG_and_expval:%.+]]:2 = call @circuit.nodealloc(%arg0) // CHECK: llvm.call @__catalyst__rt__toggle_recorder([[F]]) // CHECK-DAG: [[C1:%.+]] = llvm.mlir.constant(1 : i64) : i64 @@ -38,7 +38,7 @@ func.func @adjoint(%arg0: f32, %arg1 : index) -> (memref, memref) // CHECK: [[GRAD2:%.+]] = llvm.alloca [[C1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: llvm.call @__catalyst__qis__Gradient([[C2]], [[GRAD1]], [[GRAD2]]) - // CHECK: quantum.dealloc [[QREG]] + // CHECK: quantum.dealloc [[QREG_and_expval]]#0 %alloc0 = memref.alloc(%arg1) : memref %alloc1 = memref.alloc(%arg1) : memref gradient.adjoint @circuit.nodealloc(%arg0) size(%arg1) in(%alloc0, %alloc1 : memref, memref) : (f32) -> () diff --git a/mlir/test/Gradient/PostProcessingTest.mlir b/mlir/test/Gradient/PostProcessingTest.mlir index 764b5912b5..2403372410 100644 --- a/mlir/test/Gradient/PostProcessingTest.mlir +++ b/mlir/test/Gradient/PostProcessingTest.mlir @@ -23,17 +23,17 @@ func.func private @callback_fn_fwd(tensor<2xf64>) -> (tensor, tensor<2xf64> // CHECK-SAME: { gradient.forward @callback_fn_fwd.fwd(%arg0: memref<2xf64>) -> (memref, memref<2xf64>) attributes {argc = 1 : i64, implementation = @callback_fn_fwd, resc = 1 : i64, tape = 1 : i64} { - // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> + // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64> // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) - // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : memref - // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : memref<2xf64> + // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref + // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64> // CHECK: memref.copy [[res0]], %arg2 : memref to memref // CHECK: gradient.return {empty = false} [[res1]] : memref<2xf64> - %0 = bufferization.to_tensor %arg0 : memref<2xf64> + %0 = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64> %1:2 = func.call @callback_fn_fwd(%0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) - %2 = bufferization.to_memref %1#0 : memref - %3 = bufferization.to_memref %1#1 : memref<2xf64> + %2 = bufferization.to_memref %1#0 : tensor to memref + %3 = bufferization.to_memref %1#1 : tensor<2xf64> to memref<2xf64> gradient.return {empty = false} %2, %3 : memref, memref<2xf64> } @@ -47,16 +47,16 @@ func.func private @callback_fn_vjp(tensor<2xf64>, tensor) -> tensor<2xf64> // CHECK-SAME: { gradient.reverse @callback_fn_vjp.rev(%arg0: memref, %arg1: memref<2xf64>) -> memref<2xf64> attributes {argc = 1 : i64, implementation = @callback_fn_vjp, resc = 1 : i64, tape = 1 : i64} { - // CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64> - // CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref + // CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64> to tensor<2xf64> + // CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref to tensor // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[tape]], [[cotan]]) : (tensor<2xf64>, tensor) -> tensor<2xf64> - // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : memref<2xf64> + // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64> // CHECK: memref.copy [[res]], %arg1 : memref<2xf64> to memref<2xf64> // CHECK: gradient.return {empty = true} - %0 = bufferization.to_tensor %arg1 : memref<2xf64> - %1 = bufferization.to_tensor %arg0 : memref + %0 = bufferization.to_tensor %arg1 : memref<2xf64> to tensor<2xf64> + %1 = bufferization.to_tensor %arg0 : memref to tensor %2 = func.call @callback_fn_vjp(%0, %1) : (tensor<2xf64>, tensor) -> tensor<2xf64> - %3 = bufferization.to_memref %2 : memref<2xf64> + %3 = bufferization.to_memref %2 : tensor<2xf64> to memref<2xf64> gradient.return {empty = true} %3 : memref<2xf64> } diff --git a/mlir/test/Quantum/BufferizationTest.mlir b/mlir/test/Quantum/BufferizationTest.mlir index fb53b96d32..7b94860c44 100644 --- a/mlir/test/Quantum/BufferizationTest.mlir +++ b/mlir/test/Quantum/BufferizationTest.mlir @@ -15,7 +15,7 @@ // RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { - // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex> + // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<2x2xcomplex> to memref<2x2xcomplex> // CHECK: {{%.+}} = quantum.unitary([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.bit %out_qubits = quantum.unitary(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.bit @@ -25,7 +25,7 @@ func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { // ----- func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { - // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex> + // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<2x2xcomplex> to memref<2x2xcomplex> // CHECK: {{%.+}} = quantum.hermitian([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.obs %obs = quantum.hermitian(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.obs @@ -35,7 +35,7 @@ func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { // ----- func.func @hamiltonian(%obs: !quantum.obs, %coeffs: tensor<1xf64>){ - // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<1xf64> + // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<1xf64> to memref<1xf64> // CHECK: {{%.+}} = quantum.hamiltonian([[memref]] : memref<1xf64>) %arg0 : !quantum.obs %hamil = quantum.hamiltonian(%coeffs: tensor<1xf64>) %obs : !quantum.obs diff --git a/mlir/test/Quantum/ConversionTest.mlir b/mlir/test/Quantum/ConversionTest.mlir index fb5b46a372..035f0147d2 100644 --- a/mlir/test/Quantum/ConversionTest.mlir +++ b/mlir/test/Quantum/ConversionTest.mlir @@ -221,9 +221,9 @@ module @custom_gate { // CHECK: llvm.func @__catalyst__qis__RX(f64, !llvm.ptr, !llvm.ptr) // CHECK-LABEL: @test func.func @test(%q0: !quantum.bit, %p: f64) -> () { + // CHECK: [[nullptr:%.+]] = llvm.mlir.zero // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[alloca:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(i1, i64, ptr, ptr)> - // CHECK: [[nullptr:%.+]] = llvm.mlir.zero // CHECK: [[true:%.+]] = llvm.mlir.constant(true) // CHECK: [[off0:%.+]] = llvm.getelementptr inbounds [[alloca]][0, 0] // CHECK: [[off1:%.+]] = llvm.getelementptr inbounds [[alloca]][0, 1] @@ -374,16 +374,14 @@ func.func @tensor(%obs : !quantum.obs) { // CHECK: llvm.func @__catalyst__qis__HamiltonianObs(!llvm.ptr, i64, ...) -> i64 // CHECK-LABEL: @hamiltonian func.func @hamiltonian(%obs : !quantum.obs, %p1 : memref<1xf64>, %p2 : memref<3xf64>) { - // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) - // CHECK: [[alloca:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[memrefvar:%.+]] = llvm.mlir.undef + // CHECK: [[memrefvar:%.+]] = llvm.mlir.poison // CHECK: [[memrefvar0:%.+]] = llvm.insertvalue %arg1, [[memrefvar]][0] // CHECK: [[memrefvar1:%.+]] = llvm.insertvalue %arg2, [[memrefvar0]][1] // CHECK: [[memrefvar2:%.+]] = llvm.insertvalue %arg3, [[memrefvar1]][2] // CHECK: [[memrefvar3:%.+]] = llvm.insertvalue %arg4, [[memrefvar2]][3, 0] // CHECK: [[memrefvar4:%.+]] = llvm.insertvalue %arg5, [[memrefvar3]][4, 0] - // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast [[memrefvar4]] - // CHECK: [[memrefvar4:%.+]] = builtin.unrealized_conversion_cast [[cast]] + // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) + // CHECK: [[alloca:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: llvm.store [[memrefvar4]], [[alloca]] // CHECK: llvm.call @__catalyst__qis__HamiltonianObs([[alloca]], [[c1]], %arg0) @@ -397,16 +395,14 @@ func.func @hamiltonian(%obs : !quantum.obs, %p1 : memref<1xf64>, %p2 : memref<3x // CHECK-LABEL: @hamiltonian func.func @hamiltonian(%obs : !quantum.obs, %p1 : memref<1xf64>, %p2 : memref<3xf64>) { - // CHECK: [[memrefvar:%.+]] = llvm.mlir.undef + // CHECK: [[memrefvar:%.+]] = llvm.mlir.poison // CHECK: [[memrefvar0:%.+]] = llvm.insertvalue %arg6, [[memrefvar]][0] // CHECK: [[memrefvar1:%.+]] = llvm.insertvalue %arg7, [[memrefvar0]][1] // CHECK: [[memrefvar2:%.+]] = llvm.insertvalue %arg8, [[memrefvar1]][2] // CHECK: [[memrefvar3:%.+]] = llvm.insertvalue %arg9, [[memrefvar2]][3, 0] // CHECK: [[memrefvar4:%.+]] = llvm.insertvalue %arg10, [[memrefvar3]][4, 0] - // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast [[memrefvar4]] // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[alloca:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[memrefvar4:%.+]] = builtin.unrealized_conversion_cast [[cast]] // CHECK: [[c3:%.+]] = llvm.mlir.constant(3 : i64) // CHECK: llvm.store [[memrefvar4]], [[alloca]] // CHECK: llvm.call @__catalyst__qis__HamiltonianObs([[alloca]], [[c3]], %arg0, %arg0, %arg0) @@ -588,7 +584,6 @@ func.func @state(%q : !quantum.bit) { // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[qb:%.+]] = builtin.unrealized_conversion_cast %arg0 // CHECK: [[c0:%.+]] = llvm.mlir.constant(0 : i64) // CHECK: llvm.call @__catalyst__qis__State([[ptr]], [[c0]]) %alloc1 = memref.alloc() : memref<2xcomplex> @@ -621,13 +616,13 @@ func.func @controlled_circuit(%1 : !quantum.bit, %2 : !quantum.bit, %3 : !quantu %cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64 %cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64 + // CHECK: [[true:%.+]] = llvm.mlir.constant(true) // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[alloca0:%.+]] = llvm.alloca [[c1]] x i1 // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[alloca1:%.+]] = llvm.alloca [[c1]] x !llvm.ptr // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[mod:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(i1, i64, ptr, ptr)> - // CHECK: [[true:%.+]] = llvm.mlir.constant(true) // CHECK-DAG: [[cst6:%.+]] = llvm.mlir.constant(6.0 @@ -661,13 +656,13 @@ func.func @controlled_circuit(%1 : !quantum.bit, %2 : !quantum.bit, %3 : !quantu %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 %true = llvm.mlir.constant (1 : i1) :i1 + // CHECK: [[cst6:%.+]] = llvm.mlir.constant(6.0 // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[alloca0:%.+]] = llvm.alloca [[c1]] x i1 // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[alloca1:%.+]] = llvm.alloca [[c1]] x !llvm.ptr // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[mod:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(i1, i64, ptr, ptr)> - // CHECK: [[cst6:%.+]] = llvm.mlir.constant(6.0 // CHECK: [[true:%.+]] = llvm.mlir.constant(true) // CHECK: [[offset0:%.+]] = llvm.getelementptr inbounds [[mod]][0, 0] diff --git a/requirements.txt b/requirements.txt index 7d6c019116..97747179e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,9 @@ pip>=22.3 # Build dependencies for non-Python components # Do not allow NumPy 2.0.0 due to a bug with their C API that blocks the usage of the Stable ABI; # this bug was fixed in 2.0.1 (https://github.com/numpy/numpy/pull/26995) -nanobind numpy!=2.0.0 +# llvm requires nanobind 2.4 or higher +nanobind>=2.4 pybind11>=2.12.0 PyYAML