Skip to content

Commit 20c44e3

Browse files
committed
[MLIR][Python] enable ptr dialect bindings
1 parent 996639d commit 20c44e3

File tree

8 files changed

+208
-10
lines changed

8 files changed

+208
-10
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===- PtrDialect.h - C interface for the Ptr dialect -------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_C_DIALECT_PTR_H
10+
#define MLIR_C_DIALECT_PTR_H
11+
12+
#include "mlir-c/IR.h"
13+
14+
#ifdef __cplusplus
15+
extern "C" {
16+
#endif
17+
18+
//===----------------------------------------------------------------------===//
19+
// Dialect API.
20+
//===----------------------------------------------------------------------===//
21+
22+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Ptr, ptr);
23+
24+
//===----------------------------------------------------------------------===//
25+
// MemorySpaceAttrInterface API.
26+
//===----------------------------------------------------------------------===//
27+
28+
//===----------------------------------------------------------------------===//
29+
// Type API.
30+
//===----------------------------------------------------------------------===//
31+
32+
/// Checks if the given type is a Ptr type.
33+
MLIR_CAPI_EXPORTED bool mlirPtrTypeIsAPtrType(MlirType type);
34+
35+
MLIR_CAPI_EXPORTED MlirType mlirPtrGetPtrType(MlirAttribute memorySpace);
36+
37+
#ifdef __cplusplus
38+
}
39+
#endif
40+
41+
#endif // MLIR_C_DIALECT_PTR_H
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===- DialectPtr.cpp - Pybind module for Ptr dialect API support ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "NanobindUtils.h"
10+
11+
#include "mlir-c/Dialect/PtrDialect.h"
12+
#include "mlir-c/IR.h"
13+
#include "mlir-c/Support.h"
14+
#include "mlir/Bindings/Python/Diagnostics.h"
15+
#include "mlir/Bindings/Python/Nanobind.h"
16+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
17+
18+
namespace nb = nanobind;
19+
20+
using namespace nanobind::literals;
21+
22+
using namespace mlir;
23+
using namespace mlir::python;
24+
using namespace mlir::python::nanobind_adaptors;
25+
26+
static void populateDialectPTRSubmodule(nanobind::module_ &m) {
27+
mlir_type_subclass(m, "PtrType", mlirPtrTypeIsAPtrType)
28+
.def_classmethod(
29+
"get",
30+
[](const nb::object &cls, MlirAttribute memorySpace) {
31+
return cls(mlirPtrGetPtrType(memorySpace));
32+
},
33+
"Gets an instance of PtrType with memory_space in the same context",
34+
nb::arg("cls"), nb::arg("memory_space"));
35+
}
36+
37+
NB_MODULE(_mlirDialectsPTR, m) {
38+
m.doc() = "MLIR PTR Dialect";
39+
40+
populateDialectPTRSubmodule(m);
41+
}

mlir/lib/CAPI/Dialect/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,12 @@ add_mlir_upstream_c_api_library(MLIRCAPISMT
278278
MLIRCAPIIR
279279
MLIRSMT
280280
)
281+
282+
add_mlir_upstream_c_api_library(MLIRCAPIPtrDialect
283+
PtrDialect.cpp
284+
285+
PARTIAL_SOURCES_INTENDED
286+
LINK_LIBS PUBLIC
287+
MLIRCAPIIR
288+
MLIRPtrDialect
289+
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//===- PtrDialect.cpp - C interface for the Ptr dialect -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/PtrDialect.h"
10+
#include "mlir/CAPI/Registration.h"
11+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
12+
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
13+
#include "llvm/Support/Debug.h"
14+
15+
#define DEBUG_TYPE "ptr-dialect-capi"
16+
17+
using namespace mlir;
18+
using namespace ptr;
19+
20+
//===----------------------------------------------------------------------===//
21+
// Dialect API.
22+
//===----------------------------------------------------------------------===//
23+
24+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Ptr, ptr, mlir::ptr::PtrDialect)
25+
26+
bool mlirPtrTypeIsAPtrType(MlirType type) {
27+
return llvm::isa<ptr::PtrType>(unwrap(type));
28+
}
29+
30+
MlirType mlirPtrGetPtrType(MlirAttribute memorySpace) {
31+
MemorySpaceAttrInterface memorySpaceAttr =
32+
dyn_cast<MemorySpaceAttrInterface>(unwrap(memorySpace));
33+
if (!memorySpaceAttr) {
34+
LLVM_DEBUG(llvm::dbgs()
35+
<< "expected memory-space to be MemorySpaceAttrInterface");
36+
return {nullptr};
37+
}
38+
return wrap(ptr::PtrType::get(memorySpaceAttr));
39+
}

mlir/python/CMakeLists.txt

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,15 @@ declare_mlir_dialect_python_bindings(
516516
GEN_ENUM_BINDINGS
517517
)
518518

519+
declare_mlir_dialect_python_bindings(
520+
ADD_TO_PARENT MLIRPythonSources.Dialects
521+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
522+
TD_FILE dialects/PtrOps.td
523+
SOURCES dialects/ptr.py
524+
DIALECT_NAME ptr
525+
GEN_ENUM_BINDINGS
526+
)
527+
519528
################################################################################
520529
# Python extensions.
521530
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -579,7 +588,7 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
579588
MLIRCAPIRegisterEverything
580589
)
581590

582-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
591+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind
583592
MODULE_NAME _mlirDialectsLinalg
584593
ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
585594
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -593,7 +602,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
593602
MLIRCAPILinalg
594603
)
595604

596-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
605+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind
597606
MODULE_NAME _mlirDialectsGPU
598607
ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
599608
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -607,7 +616,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
607616
MLIRCAPIGPU
608617
)
609618

610-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
619+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind
611620
MODULE_NAME _mlirDialectsLLVM
612621
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
613622
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -623,7 +632,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
623632
MLIRCAPITarget
624633
)
625634

626-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
635+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind
627636
MODULE_NAME _mlirDialectsQuant
628637
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
629638
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -637,7 +646,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
637646
MLIRCAPIQuant
638647
)
639648

640-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
649+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind
641650
MODULE_NAME _mlirDialectsNVGPU
642651
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
643652
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -651,7 +660,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
651660
MLIRCAPINVGPU
652661
)
653662

654-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
663+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind
655664
MODULE_NAME _mlirDialectsPDL
656665
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
657666
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -665,7 +674,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
665674
MLIRCAPIPDL
666675
)
667676

668-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
677+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind
669678
MODULE_NAME _mlirDialectsSparseTensor
670679
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
671680
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -679,7 +688,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
679688
MLIRCAPISparseTensor
680689
)
681690

682-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
691+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
683692
MODULE_NAME _mlirDialectsTransform
684693
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
685694
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -693,7 +702,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
693702
MLIRCAPITransformDialect
694703
)
695704

696-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
705+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Nanobind
697706
MODULE_NAME _mlirDialectsIRDL
698707
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
699708
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -761,7 +770,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
761770
MLIRCAPILinalg
762771
)
763772

764-
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
773+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
765774
MODULE_NAME _mlirDialectsSMT
766775
ADD_TO_PARENT MLIRPythonSources.Dialects.smt
767776
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -778,6 +787,22 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
778787
MLIRCAPIExportSMTLIB
779788
)
780789

790+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Ptr.Nanobind
791+
MODULE_NAME _mlirDialectsPtr
792+
ADD_TO_PARENT MLIRPythonSources.Dialects.ptr
793+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
794+
PYTHON_BINDINGS_LIBRARY nanobind
795+
SOURCES
796+
DialectPtr.cpp
797+
# Headers must be included explicitly so they are installed.
798+
NanobindUtils.h
799+
PRIVATE_LINK_LIBS
800+
LLVMSupport
801+
EMBED_CAPI_LINK_LIBS
802+
MLIRCAPIIR
803+
MLIRCAPIPtrDialect
804+
)
805+
781806
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
782807
MODULE_NAME _mlirSparseTensorPasses
783808
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===- PTROps.td - Entry point for PTR bindings ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef BINDINGS_PYTHON_PTR_OPS
10+
#define BINDINGS_PYTHON_PTR_OPS
11+
12+
include "mlir/Dialect/Ptr/IR/PtrOps.td"
13+
14+
#endif // BINDINGS_PYTHON_PTR_OPS

mlir/python/mlir/dialects/ptr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._ptr_ops_gen import *
6+
from ._ptr_enum_gen import *

mlir/test/python/dialects/ptr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.dialects import ptr
4+
from mlir.ir import Context, Location, Module, InsertionPoint, Attribute
5+
6+
7+
def run(f):
8+
print("\nTEST:", f.__name__)
9+
with Context(), Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
f(module)
13+
print(module)
14+
assert module.operation.verify()
15+
16+
17+
# CHECK-LABEL: TEST: test_smoke
18+
@run
19+
def test_smoke(_module):
20+
null_ptr = Attribute.parse("#ptr.null : !ptr.ptr<#llvm.address_space<1>>")
21+
null = ptr.constant(null_ptr)
22+
# CHECK: %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>>
23+
print(null)

0 commit comments

Comments
 (0)