diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index fc2c4b1c6ac1..4aa11f6b9e9c 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -2,7 +2,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@llvm-project//mlir:build_defs.bzl", "mlir_c_api_cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") load("@com_github_bazelbuild_buildtools//buildifier:def.bzl", "buildifier") package( @@ -923,3 +924,96 @@ cc_binary( "@llvm-project//mlir:MlirOptLib", ], ) + +# C API bindings +mlir_c_api_cc_library( + name = "CAPITorchRegisterEverything", + srcs = ["lib/CAPI/Registration.cpp"], + hdrs = ["include/torch-mlir-c/Registration.h"], + capi_deps = ["@llvm-project//mlir:CAPIIR"], + deps = [ + ":TorchMLIRInitAll", + "@llvm-project//mlir:AllPassesAndDialects", + ], +) + +mlir_c_api_cc_library( + name = "CAPITorch", + srcs = ["lib/CAPI/Dialects.cpp"], + hdrs = ["include/torch-mlir-c/Dialects.h"], + capi_deps = ["@llvm-project//mlir:CAPIIR"], + deps = [":TorchMLIRTorchDialect"], +) + +# These flags are needed for pybind11 to work. +PYBIND11_COPTS = [ + "-fexceptions", + "-frtti", +] + +PYBIND11_FEATURES = [ + # Cannot use header_modules (parse_headers feature fails). + "-use_header_modules", +] + +# pybind11 extension module +cc_binary( + name = "_torchMlir.so", + srcs = [ + "include/torch-mlir-c/Registration.h", + "python/TorchMLIRModule.cpp", + ], + copts = PYBIND11_COPTS, + features = PYBIND11_FEATURES, + linkshared = 1, + linkstatic = 0, + deps = [ + ":CAPITorch", + ":CAPITorchRegisterEverything", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + ], +) + +# python files +td_library( + name = "TorchOpsPyTdFiles", + srcs = [":MLIRTorchOpsIncGenTdFiles"], + includes = ["include"], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_filegroup( + name = "TorchOpsPyGen", + includes = ["include"], + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=torch", + ], + "python/torch_mlir/dialects/_torch_ops_gen.py", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "python/torch_mlir/dialects/TorchBinding.td", + deps = [ + ":TorchOpsPyTdFiles", + "@llvm-project//mlir:AttrTdFiles", + ], +) + +filegroup( + name = "TorchOpsPyFiles", + srcs = [ + ":TorchOpsPyGen", + ], +) + +filegroup( + name = "TorchPyFiles", + srcs = glob(["python/**/*.py"]), +)