From f2ab318614c2e038272a3ca230a859d346e088db Mon Sep 17 00:00:00 2001 From: Yan Da Date: Tue, 22 Mar 2022 21:53:22 +0800 Subject: [PATCH] New python binding --- CMakeLists.txt | 7 +- include/triton/ir/TritonOps.td | 35 +- lib/ir/CMakeLists.txt | 14 - python/src/triton.cc | 950 ++++++++++++++++++--------------- python/triton/language/core.py | 48 +- 5 files changed, 593 insertions(+), 461 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e7bd9c9b..140b64dc8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -174,10 +174,15 @@ add_subdirectory(lib) add_library(triton SHARED ${PYTHON_SRC}) +find_package(PythonLibs REQUIRED) + target_link_libraries(triton TritonIR TritonDriver - TritonCodeGen + # TritonCodeGen + + MLIRCAPIIR + ${PYTHON_LIBRARIES} ) target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index dca22ff4f..aaf44d949 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -91,6 +91,25 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, // // Load/Store Ops // +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + ]> { + let cppNamespace = "::mlir::triton"; +} +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { let summary = "load"; @@ -157,10 +176,13 @@ def TT_RedOpAttr : I32EnumAttr< /*name*/"RedOp", /*summary*/"", /*case*/ [ - I32EnumAttrCase, + I32EnumAttrCase, I32EnumAttrCase<"MAX", 2, "max">, I32EnumAttrCase<"MIN", 3, "min">, - I32EnumAttrCase<"XOR_SUM", 4, "xor_sum"> + I32EnumAttrCase<"FADD", 4, "fadd">, + I32EnumAttrCase<"FMAX", 5, "fmax">, + I32EnumAttrCase<"FMIN", 6, "fmin">, + I32EnumAttrCase<"XOR", 7, "xor"> ]> { let cppNamespace = "::mlir::triton"; } @@ -179,10 +201,11 @@ def TT_AtomicRMWAttr : I32EnumAttr< I32EnumAttrCase<"OR", 2, "or">, I32EnumAttrCase<"XOR", 3, "xor">, I32EnumAttrCase<"ADD", 4, "add">, - I32EnumAttrCase<"MAX", 5, "max">, - I32EnumAttrCase<"MIN", 6, "min">, - I32EnumAttrCase<"UMAX", 7, "umax">, - I32EnumAttrCase<"UMIN", 8, "umin"> + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin"> ]> { let cppNamespace = "::mlir::triton"; } diff --git a/lib/ir/CMakeLists.txt b/lib/ir/CMakeLists.txt index 4a6ee9d69..59155ddc5 100644 --- a/lib/ir/CMakeLists.txt +++ b/lib/ir/CMakeLists.txt @@ -18,17 +18,3 @@ add_mlir_dialect_library(TritonIR MLIRTensor ) - -# add_library(TritonIR -# Dialect.cpp -# Ops.cpp -# Types.cpp -# ) - -# target_link_libraries(TritonIR PUBLIC -# MLIRIR -# MLIRArithmetic -# MLIRControlFlow -# MLIRFunc -# MLIRTensor -# ) diff --git a/python/src/triton.cc b/python/src/triton.cc index b47f45796..73b81b8d2 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,16 +1,23 @@ -#include "triton/codegen/pass.h" -#include "triton/codegen/target.h" +// #include "triton/codegen/pass.h" +// #include "triton/codegen/target.h" #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" +#include "mlir-c/IR.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir/CAPI/IR.h" +// #include "mlir/IR/BuiltinOps.h" +// #include "mlir/IR/MLIRContext.h" + +#include "triton/ir/Dialect.h" +#include "triton/ir/Types.h" #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" +#include #include #include #include @@ -24,7 +31,7 @@ #include namespace py = pybind11; -namespace ir = triton::ir; +// namespace ir = triton::ir; namespace drv = triton::driver; @@ -464,84 +471,84 @@ std::tuple hip_load_binary(const std::string& name, asm_map_ // Compile Triton-IR to assembly // --------------------------------------- -// CUDA -std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ +// // CUDA +// std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, +// uint64_t device, int num_warps, int num_stages, +// asm_map_t &asm_map){ - int n_shared_bytes; - py::gil_scoped_release allow_threads; - llvm::LLVMContext ctx; - // device properties - CUdevice dev = (CUdevice)device; - size_t major = cuGetInfo(dev); - size_t minor = cuGetInfo(dev); - size_t cc = major*10 + minor; - int version; - std::string ptxas_path = drv::path_to_ptxas(version); - // Triton-IR -> NVPTX LLVM-IR - triton::codegen::nvidia_cu_target target(cc); - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); - std::string tmp; - llvm::raw_string_ostream llir(tmp); - llir << *llvm; - llir.flush(); - asm_map["llir"] = py::cast(tmp); - // LLVM-IR -> PTX - std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); - asm_map["ptx"] = py::cast(ptx); - // PTX -> Binary - std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); - if(!cubin.empty()){ - py::bytes bytes(cubin); - asm_map["cubin"] = bytes; - } - return std::make_tuple(name, asm_map, n_shared_bytes); -} +// int n_shared_bytes; +// py::gil_scoped_release allow_threads; +// llvm::LLVMContext ctx; +// // device properties +// CUdevice dev = (CUdevice)device; +// size_t major = cuGetInfo(dev); +// size_t minor = cuGetInfo(dev); +// size_t cc = major*10 + minor; +// int version; +// std::string ptxas_path = drv::path_to_ptxas(version); +// // Triton-IR -> NVPTX LLVM-IR +// triton::codegen::nvidia_cu_target target(cc); +// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); +// std::string tmp; +// llvm::raw_string_ostream llir(tmp); +// llir << *llvm; +// llir.flush(); +// asm_map["llir"] = py::cast(tmp); +// // LLVM-IR -> PTX +// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); +// asm_map["ptx"] = py::cast(ptx); +// // PTX -> Binary +// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); +// if(!cubin.empty()){ +// py::bytes bytes(cubin); +// asm_map["cubin"] = bytes; +// } +// return std::make_tuple(name, asm_map, n_shared_bytes); +// } -// HIP -std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ - llvm::LLVMContext ctx; - // Triton-IR -> NVPTX LLVM-IR - triton::codegen::amd_cl_target target; - int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes); - std::string tmp; - llvm::raw_string_ostream llir(tmp); - llir << *llvm; - llir.flush(); - asm_map["llir"] = py::cast(tmp); - // LLVM-IR -> HSA-CO - std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908"); - asm_map["hsaco"] = py::cast(path); - return std::make_tuple(name, asm_map, n_shared_bytes); -} +// // HIP +// std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, +// uint64_t device, int num_warps, int num_stages, +// asm_map_t &asm_map){ +// llvm::LLVMContext ctx; +// // Triton-IR -> NVPTX LLVM-IR +// triton::codegen::amd_cl_target target; +// int n_shared_bytes; +// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes); +// std::string tmp; +// llvm::raw_string_ostream llir(tmp); +// llir << *llvm; +// llir.flush(); +// asm_map["llir"] = py::cast(tmp); +// // LLVM-IR -> HSA-CO +// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908"); +// asm_map["hsaco"] = py::cast(path); +// return std::make_tuple(name, asm_map, n_shared_bytes); +// } -void init_triton_codegen(py::module &&m) { - m.def( - "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) { - std::string name = ir.get_function_list()[0]->get_name(); - // record asm as we generate - asm_map_t asm_map; - std::ostringstream ttir; - ir.print(ttir); - asm_map["ttir"] = py::cast(ttir.str()); - llvm::LLVMContext ctx; - if(backend == CUDA) - return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); - if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); - }, py::return_value_policy::take_ownership); - m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ - py::gil_scoped_release allow_threads; - if(backend == CUDA) - return cu_load_binary(name, asm_map, n_shared_bytes, dev); - if(backend == ROCM) - return hip_load_binary(name, asm_map, n_shared_bytes, dev); - }, py::return_value_policy::take_ownership); -} +// void init_triton_codegen(py::module &&m) { +// m.def( +// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) { +// std::string name = ir.get_function_list()[0]->get_name(); +// // record asm as we generate +// asm_map_t asm_map; +// std::ostringstream ttir; +// ir.print(ttir); +// asm_map["ttir"] = py::cast(ttir.str()); +// llvm::LLVMContext ctx; +// if(backend == CUDA) +// return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); +// if(backend == ROCM) +// return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); +// }, py::return_value_policy::take_ownership); +// m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ +// py::gil_scoped_release allow_threads; +// if(backend == CUDA) +// return cu_load_binary(name, asm_map, n_shared_bytes, dev); +// if(backend == ROCM) +// return hip_load_binary(name, asm_map, n_shared_bytes, dev); +// }, py::return_value_policy::take_ownership); +// } /*****************************************************************************/ @@ -552,367 +559,478 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; - py::enum_(m, "CACHE_MODIFIER") - .value("NONE", ir::load_inst::NONE) - .value("CA", ir::load_inst::CA) - .value("CG", ir::load_inst::CG) + py::enum_(m, "CACHE_MODIFIER") + .value("NONE", mlir::triton::CacheModifier::NONE) + .value("CA", mlir::triton::CacheModifier::CA) + .value("CG", mlir::triton::CacheModifier::CG) .export_values(); - py::enum_(m, "EVICTION_POLICY") - .value("NORMAL", ir::load_inst::NORMAL) - .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) - .value("EVICT_LAST", ir::load_inst::EVICT_LAST) + py::enum_(m, "EVICTION_POLICY") + .value("NORMAL", mlir::triton::EvictionPolicy::NORMAL) + .value("EVICT_FIRST", mlir::triton::EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST) .export_values(); - py::enum_(m, "REDUCE_OP") - .value("ADD", ir::reduce_inst::ADD) - .value("FADD", ir::reduce_inst::FADD) - .value("MIN", ir::reduce_inst::MIN) - .value("MAX", ir::reduce_inst::MAX) - .value("FMIN", ir::reduce_inst::FMIN) - .value("FMAX", ir::reduce_inst::FMAX) - .value("XOR", ir::reduce_inst::XOR); + py::enum_(m, "REDUCE_OP") + .value("ADD", mlir::triton::RedOp::ADD) + .value("FADD", mlir::triton::RedOp::FADD) + .value("MIN", mlir::triton::RedOp::MIN) + .value("MAX", mlir::triton::RedOp::MAX) + .value("FMIN", mlir::triton::RedOp::FMIN) + .value("FMAX", mlir::triton::RedOp::FMAX) + .value("XOR", mlir::triton::RedOp::XOR); - py::enum_(m, "ATOMIC_OP") - .value("ADD", ir::atomic_rmw_op_t::Add) - .value("FADD", ir::atomic_rmw_op_t::FAdd) - .value("AND", ir::atomic_rmw_op_t::And) - .value("OR", ir::atomic_rmw_op_t::Or) - .value("XOR", ir::atomic_rmw_op_t::Xor) - .value("XCHG", ir::atomic_rmw_op_t::Xchg) - .value("MAX", ir::atomic_rmw_op_t::Max) - .value("MIN", ir::atomic_rmw_op_t::Min) - .value("UMIN", ir::atomic_rmw_op_t::UMin) - .value("UMAX", ir::atomic_rmw_op_t::UMax); + py::enum_(m, "ATOMIC_OP") + .value("ADD", mlir::triton::RMWOp::ADD) + .value("FADD", mlir::triton::RMWOp::FADD) + .value("AND", mlir::triton::RMWOp::AND) + .value("OR", mlir::triton::RMWOp::OR) + .value("XOR", mlir::triton::RMWOp::XOR) + // .value("XCHG", mlir::triton::RMWOp::Xchg) + .value("MAX", mlir::triton::RMWOp::MAX) + .value("MIN", mlir::triton::RMWOp::MIN) + .value("UMIN", mlir::triton::RMWOp::UMIN) + .value("UMAX", mlir::triton::RMWOp::UMAX); - py::class_(m, "context") - .def(py::init<>()); + py::class_(m, "context") + .def(py::init<>()) + .def("load_triton", [](mlir::MLIRContext &self) { + self.getOrLoadDialect(); + }); + // .def(py::init([](){ + // mlir::MLIRContext context; + // context.getOrLoadDialect(); + // // TODO: should we return a (raw/unique) pointer here? + // return context; + // })); - py::class_(m, "value") - .def("multiple_of", [](ir::value *self, int val) { - if (auto *instr = dynamic_cast(self)) { - instr->set_metadata(ir::metadata::multiple_of, val); - } else - throw std::runtime_error("multiple_of"); + // py::class_(m, "value") + // .def("multiple_of", [](ir::value *self, int val) { + // if (auto *instr = dynamic_cast(self)) { + // instr->set_metadata(ir::metadata::multiple_of, val); + // } else + // throw std::runtime_error("multiple_of"); + // }) + // .def("max_contiguous", [](ir::value *self, int val) { + // if (auto *instr = dynamic_cast(self)) { + // instr->set_metadata(ir::metadata::max_contiguous, val); + // } else + // throw std::runtime_error("max_contiguous"); + // }) + // .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { + // if (auto *instr = dynamic_cast(self)) + // instr->set_fdiv_ieee_rounding(val); + // else + // throw std::runtime_error("set_fdiv_ieee_rounding"); + // }) + // .def("ops", [](ir::value *self) { + // if (auto *instr = dynamic_cast(self)) { + // return instr->ops(); + // } + // throw std::runtime_error("cannot use ops()"); + // }) + // .def("replace_all_uses_with", &ir::value::replace_all_uses_with) + // .def("erase_from_parent", [](ir::value *self) { + // if (auto *instr = dynamic_cast(self)) + // return instr->erase_from_parent(); + // throw std::runtime_error("cannot use erase_from_parent"); + // }) + // .def_property("name", &ir::value::get_name, &ir::value::set_name) + // .def_property_readonly("type", &ir::value::get_type); + + // // // Do we need under in TritonIR ? + // // py::class_(m, "undef") + // // .def("get", &ir::undef_value::get, ret::reference); + + py::class_(m, "type") + .def("is_integer", [](MlirType &self) { + return mlirTypeIsAInteger(self); }) - .def("max_contiguous", [](ir::value *self, int val) { - if (auto *instr = dynamic_cast(self)) { - instr->set_metadata(ir::metadata::max_contiguous, val); - } else - throw std::runtime_error("max_contiguous"); + .def("is_fp16", [](MlirType &self) { + return mlirTypeIsABF16(self); }) - .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { - if (auto *instr = dynamic_cast(self)) - instr->set_fdiv_ieee_rounding(val); - else - throw std::runtime_error("set_fdiv_ieee_rounding"); - }) - .def("is_phi", [](ir::value *self) { - if (auto *pn = dynamic_cast(self)) - return true; - return false; - }) - .def("ops", [](ir::value *self) { - if (auto *instr = dynamic_cast(self)) { - return instr->ops(); - } - throw std::runtime_error("cannot use ops()"); - }) - .def("replace_all_uses_with", &ir::value::replace_all_uses_with) - .def("erase_from_parent", [](ir::value *self) { - if (auto *instr = dynamic_cast(self)) - return instr->erase_from_parent(); - throw std::runtime_error("cannot use erase_from_parent"); - }) - .def_property("name", &ir::value::get_name, &ir::value::set_name) - .def_property_readonly("type", &ir::value::get_type); + ; - py::class_(m, "user"); - - py::class_(m, "constant") - .def("get_null_value", &ir::constant::get_null_value, ret::reference) - .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference); - - py::class_(m, "undef") - .def("get", &ir::undef_value::get, ret::reference); - - py::class_(m, "constant_int") - .def_property_readonly("value", &ir::constant_int::get_value) - .def("__int__", [](ir::constant_int *self) { return self->get_value(); }) - .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); - - py::class_(m, "constant_float") - .def_property_readonly("value", &ir::constant_fp::get_value) - .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); - - py::class_(m, "instruction") - .def("get_parent", [](ir::instruction *self) { - return self->get_parent(); - }, ret::reference); - py::class_(m, "phi_node") - .def("add_incoming", &ir::phi_node::add_incoming); - - py::class_(m, "type") - .def("make_ptr", &ir::pointer_type::get, ret::reference) - .def("make_function", &ir::function_type::get, ret::reference) - .def("make_block", &ir::block_type::get, ret::reference) - .def("get_void", &ir::type::get_void_ty, ret::reference) - .def("get_fp8", &ir::type::get_fp8_ty, ret::reference) - .def("get_fp16", &ir::type::get_fp16_ty, ret::reference) - .def("get_bf16", &ir::type::get_bf16_ty, ret::reference) - .def("get_fp32", &ir::type::get_fp32_ty, ret::reference) - .def("get_fp64", &ir::type::get_fp64_ty, ret::reference) - .def("get_int1", &ir::type::get_int1_ty, ret::reference) - .def("get_int8", &ir::type::get_int8_ty, ret::reference) - .def("get_int16", &ir::type::get_int16_ty, ret::reference) - .def("get_int32", &ir::type::get_int32_ty, ret::reference) - .def("get_int64", &ir::type::get_int64_ty, ret::reference) - .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) - - .def("get_block_shapes", &ir::type::get_block_shapes) - - .def("is_ptr", &ir::type::is_pointer_ty) - .def("is_int", static_cast(&ir::type::is_integer_ty)) - .def("is_floating", &ir::type::is_floating_point_ty) - .def("is_block", &ir::type::is_block_ty) - .def("is_void", &ir::type::is_void_ty) - .def("is_bool", &ir::type::is_bool_ty) - .def("is_fp8", &ir::type::is_fp8_ty) - .def("is_fp16", &ir::type::is_fp16_ty) - .def("is_bf16", &ir::type::is_bf16_ty) - .def("is_fp32", &ir::type::is_fp32_ty) - .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) - .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) - - .def("repr", &ir::type::repr) - .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) - .def_property_readonly("scalar", &ir::type::get_scalar_ty) - .def_property_readonly("context", &ir::type::get_context, ret::reference) - .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth) - .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits); - - py::class_(m, "pointer_type") - .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) - .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference); - - py::class_(m, "function_type"); - py::class_(m, "integer_type"); - py::class_(m, "block_type") - .def_property_readonly("shape", &ir::block_type::get_shapes) - .def_property_readonly("numel", &ir::type::get_tile_num_elements); - - py::class_(m, "module") - .def(py::init()) - .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { - const auto metadatas = self->get_metadatas(); - auto it = metadatas.find(name); - if (it != metadatas.end()) - if (auto *instr = dynamic_cast(value)) { - instr->set_metadata(it->second.first, it->second.second); + py::class_(m, "operation") + .def("add_entry_block", [](MlirOperation &self) -> MlirBlock { + // if (auto FunctionOp = unwrap(self)->dyn_cast()) { + if (auto info = unwrap(self)->getRegisteredInfo()) { + if (mlir::TypeID::get() == info->getTypeID()) { + auto FunctionOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self)); + mlir::Block *entry = FunctionOp.addEntryBlock(); + return wrap(entry); } + throw std::runtime_error("Only FuncOp can call add_entry_block"); + } else + throw std::runtime_error("Unknown error"); + }, ret::reference) // this should be automatic? + .def("dump", [](MlirOperation &self) -> void { + unwrap(self)->dump(); }) - .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); + ; - using eattr = ir::attribute_kind_t; - py::enum_(m, "attribute_kind") - .value("readonly", eattr::readonly) - .value("writeonly", eattr::writeonly) - .value("noalias", eattr::noalias) - .value("aligned", eattr::aligned) - .value("multiple_of", eattr::multiple_of) - .value("retune", eattr::retune) - .value("not_implemented", eattr::not_implemented); + py::class_(m, "block") + ; - py::class_(m, "attribute") - .def(py::init()); + // py::class_(m, "float8_type") + // .def_static("get", &mlir::triton::Float8Type::get); + // py::class_(m, "bfloat8_type") + // .def_static("get", &mlir::triton::BFloat8Type::get); + // py::class_(m, "pointer_type") + // .def_static("get", &mlir::triton::PointerType::get); + // py::class_(m, "function_type") + // .def_static("get", &mlir::FunctionType::get); + // py::class_(m, "integer_type") + // .def_static("get", &mlir::IntegerType::get); + // py::class_(m, "block_type") + // .def_static("get", &mlir::RankedTensorType::get); - py::class_(m, "function") - .def_property_readonly("args", &ir::function::args) - .def_property_readonly("attrs", &ir::function::attrs) - .def("add_attr", &ir::function::add_attr); + // py::class_(m, "module") + // .def(py::init()) + // .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { + // const auto metadatas = self->get_metadatas(); + // auto it = metadatas.find(name); + // if (it != metadatas.end()) + // if (auto *instr = dynamic_cast(value)) { + // instr->set_metadata(it->second.first, it->second.second); + // } + // }) + // .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); - py::class_(m, "argument"); + // using eattr = ir::attribute_kind_t; + // py::enum_(m, "attribute_kind") + // .value("readonly", eattr::readonly) + // .value("writeonly", eattr::writeonly) + // .value("noalias", eattr::noalias) + // .value("aligned", eattr::aligned) + // .value("multiple_of", eattr::multiple_of) + // .value("retune", eattr::retune) + // .value("not_implemented", eattr::not_implemented); - py::class_(m, "basic_block") - .def("create", &ir::basic_block::create, ret::reference) - .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) - .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* { - ir::basic_block::iterator it = self->get_first_non_phi(); - if (it == self->end()) - return nullptr; - return *it; - }, ret::reference) - .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); + // py::class_(m, "attribute"); + // // .def(py::init()); + + // py::class_(m, "function") + // .def_property_readonly("args", &ir::function::args) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + + // // // We don't need to expose mlir::Block (?) + // // py::class_(m, "basic_block") + // // // .def("create", &ir::basic_block::create, ret::reference) + // // .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) + // // .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); py::class_(m, "builder", py::dynamic_attr()) - .def(py::init()) - // getters - .def_property_readonly("context", &ir::builder::get_context, ret::reference) - // control flow - .def("br", &ir::builder::create_br, ret::reference) - .def("cond_br", &ir::builder::create_cond_br, ret::reference) - .def("ret_void", &ir::builder::create_ret_void, ret::reference) - // insertion block/point, insert points are represented as (*bb, *instr) - .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) - .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) - .def("get_insert_point", [](ir::builder *self) { - ir::basic_block *bb = self->get_insert_block(); - ir::basic_block::iterator it = self->get_insert_point(); - ir::instruction *instr = it == bb->end() ? nullptr : *it; - return std::make_pair(bb, instr); - }, ret::reference) - .def("set_insert_point", [](ir::builder *self, std::pair pt) { - ir::basic_block *bb = pt.first; - ir::instruction *instr = pt.second; - if (instr) { - if (bb != instr->get_parent()) - throw std::runtime_error("invalid insertion point, instr not in bb"); - self->set_insert_point(instr); - } else { - assert(bb); - self->set_insert_point(bb); - } + .def(py::init()) + // // getters + // .def_property_readonly("context", &ir::builder::get_context, ret::reference); + // // control flow + // .def("br", &ir::builder::create_br, ret::reference) + // .def("cond_br", &ir::builder::create_cond_br, ret::reference) + // .def("ret_void", &ir::builder::create_ret_void, ret::reference) + // // insertion block/point, insert points are represented as (*bb, *instr) + .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, MlirBlock &block) -> void{ + self.setInsertionPointToStart(unwrap(block)); }) - // Constants - .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) - .def("get_uint32", &ir::builder::get_int32, ret::reference) - .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) - .def("get_uint64", &ir::builder::get_int64, ret::reference) - .def("get_float16", &ir::builder::get_float16, ret::reference) - .def("get_float32", &ir::builder::get_float32, ret::reference) - .def("get_range", &ir::builder::get_range, ret::reference) - // Types - .def("get_void_ty", &ir::builder::get_void_ty, ret::reference) - .def("get_int1_ty", &mlir::OpBuilder::getI1Type, ret::reference) - .def("get_int8_ty", &mlir::OpBuilder::getI8Type, ret::reference) - .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference) - .def("get_int32_ty", &mlir::OpBuilder::getI32Type, ret::reference) - .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference) - .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference) - .def("get_half_ty", &ir::builder::get_half_ty, ret::reference) - .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference) - .def("get_float_ty", &ir::builder::get_float_ty, ret::reference) - .def("get_double_ty", &ir::builder::get_double_ty, ret::reference) - // terminator instructions - .def("create_br", &ir::builder::create_br, ret::reference) - .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) - .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) - // Cast instructions - .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) - .def("create_cast", &ir::builder::create_cast, ret::reference) - .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) - .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) - .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) - .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) - .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) - .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) - .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) - .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) - .def("create_downcast", &ir::builder::create_downcast, ret::reference) - // phi - .def("create_phi", &ir::builder::create_phi, ret::reference) - // Binary instructions - .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) - .def("create_fmul", &ir::builder::create_fmul, ret::reference) - .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) - .def("create_frem", &ir::builder::create_frem, ret::reference) - .def("create_fadd", &ir::builder::create_fadd, ret::reference) - .def("create_fsub", &ir::builder::create_fsub, ret::reference) - .def("create_mul", &ir::builder::create_mul, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) - .def("create_udiv", &ir::builder::create_udiv, ret::reference) - .def("create_srem", &ir::builder::create_srem, ret::reference) - .def("create_urem", &ir::builder::create_urem, ret::reference) - .def("create_add", &ir::builder::create_add, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_sub", &ir::builder::create_sub, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_shl", &ir::builder::create_shl, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_lshr", &ir::builder::create_lshr, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_ashr", &ir::builder::create_ashr, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - // GEP - .def("create_gep", &ir::builder::create_gep, ret::reference) - // Comparison (int) - .def("create_icmp", &ir::builder::create_icmp, ret::reference) - .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference) - .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference) - .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference) - .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference) - .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference) - .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference) - .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference) - .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference) - .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference) - .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference) - // Comparison (float) - .def("create_fcmp", &ir::builder::create_fcmp, ret::reference) - .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference) - .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) - .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) - .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) - .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) - .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) - .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) - .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) - .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) - .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) - .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) - .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) - // Logical - .def("create_and", &ir::builder::create_and, ret::reference) - .def("create_xor", &ir::builder::create_xor, ret::reference) - .def("create_or", &ir::builder::create_or, ret::reference) - // Input/Output - .def("create_load", &ir::builder::create_load, ret::reference) - .def("create_store", &ir::builder::create_store, ret::reference) - .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) - .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) - // Block instruction - .def("create_splat", &ir::builder::create_splat, ret::reference) - .def("create_reshape", &ir::builder::create_reshape, ret::reference) - .def("create_cat", &ir::builder::create_cat, ret::reference) - .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) - // atomic - .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) - .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) + // .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) + // .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) + // .def("get_insert_point", [](ir::builder *self) { + // ir::basic_block *bb = self->get_insert_block(); + // ir::basic_block::iterator it = self->get_insert_point(); + // ir::instruction *instr = it == bb->end() ? nullptr : *it; + // return std::make_pair(bb, instr); + // }, ret::reference) + // .def("set_insert_point", [](ir::builder *self, std::pair pt) { + // ir::basic_block *bb = pt.first; + // ir::instruction *instr = pt.second; + // if (instr) { + // if (bb != instr->get_parent()) + // throw std::runtime_error("invalid insertion point, instr not in bb"); + // self->set_insert_point(instr); + // } else { + // assert(bb); + // self->set_insert_point(bb); + // } + // }) + // Use arith.ConstantOp to create constants + // // Constants + // .def("get_int1", &ir::builder::get_int1, ret::reference) + // .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) + // .def("get_uint32", &ir::builder::get_int32, ret::reference) + // .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) + // .def("get_uint64", &ir::builder::get_int64, ret::reference) + // .def("get_float16", &ir::builder::get_float16, ret::reference) + // .def("get_float32", &ir::builder::get_float32, ret::reference) + // .def("get_range", &ir::builder::get_range, ret::reference) - // Built-in instruction - .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) - .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) - .def("create_exp", &ir::builder::create_exp, ret::reference) - .def("create_cos", &ir::builder::create_cos, ret::reference) - .def("create_sin", &ir::builder::create_sin, ret::reference) - .def("create_log", &ir::builder::create_log, ret::reference) - .def("create_dot", &ir::builder::create_dot, ret::reference) - .def("create_trans", &ir::builder::create_trans, ret::reference) - .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) - .def("create_reduce", &ir::builder::create_reduce, ret::reference) - .def("create_select", &ir::builder::create_select, ret::reference) - // Intrinsics - // These have no place in the IR, and hopefully they can be removed at some point - .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) - .def("create_barrier", &ir::builder::create_barrier, ret::reference); + // Types + .def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType { + return wrap(self.getNoneType()); + }, ret::reference) + .def("get_int1_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getI1Type()); + }, ret::reference) // or ret::copy? + .def("get_int8_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getI8Type()); + }, ret::reference) + .def("get_int16_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getType(16)); + }, ret::reference) + .def("get_int32_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getI32Type()); + }, ret::reference) + .def("get_int64_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getI64Type()); + }, ret::reference) + .def("get_fp8_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getType()); + }, ret::reference) + .def("get_bf8_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getType()); + }, ret::reference) + .def("get_half_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getF16Type()); + }, ret::reference) + .def("get_bf16_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getBF16Type()); + }, ret::reference) + .def("get_float_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getF32Type()); + }, ret::reference) + .def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType { + return wrap(self.getF64Type()); + }, ret::reference) + .def("get_function_ty", [](mlir::OpBuilder &self, + std::vector inTypes, + std::vector outTypes) -> MlirType { + llvm::SmallVector inputsTypeList; + llvm::SmallVector resultsTypeList; + (void)unwrapList(inTypes.size(), inTypes.data(), inputsTypeList); + (void)unwrapList(outTypes.size(), outTypes.data(), resultsTypeList); + return wrap(self.getFunctionType(inputsTypeList, resultsTypeList)); + }, ret::reference) + + // Ops + .def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> MlirOperation { + // TODO: loc + auto loc = self.getUnknownLoc(); + if (auto funcTy = unwrap(funcType).dyn_cast()) { + return wrap(self.create(loc, name, funcTy)); + } + throw std::runtime_error("invalid function type"); + }, ret::reference) + // // Structured control flow + // .def("create_scf_for", [](mlir::OpBuilder &self) { + // return self.create(/*fill this*/); + // }) + // .def("create_scf_yield") + // .def("create_scf_if") + // .def("create_scf_while") + + // miscellious + .def("create_make_range", [](mlir::OpBuilder &self, int start, int end){ + auto loc = self.getUnknownLoc(); + auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type()); + return wrap(self.create(loc, retType, start, end).getOperation()); + }, ret::reference) + .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, self.getI32Type(), axis).getOperation()); + }) + + // // Cast instructions + // .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) + // .def("create_cast", &ir::builder::create_cast, ret::reference) + // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) + // .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) + // .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) + // .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) + // .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) + // .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) + // .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) + // .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) + // .def("create_downcast", &ir::builder::create_downcast, ret::reference) + // // Binary instructions + // .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) + // .def("create_fmul", &ir::builder::create_fmul, ret::reference) + // .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) + // .def("create_frem", &ir::builder::create_frem, ret::reference) + // .def("create_fadd", &ir::builder::create_fadd, ret::reference) + // .def("create_fsub", &ir::builder::create_fsub, ret::reference) + .def("create_mul", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + // Check lhs & rhs have single result (?) + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_sdiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_udiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_srem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_urem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_add", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_sub", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + .def("create_shl", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + }, ret::reference) + // .def("create_lshr", &ir::builder::create_lshr, ret::reference, + // py::arg("lhs"), py::arg("rhs"), + // py::arg("has_nuw")=false, py::arg("has_nsw")=false) + // .def("create_ashr", &ir::builder::create_ashr, ret::reference, + // py::arg("lhs"), py::arg("rhs"), + // py::arg("has_nuw")=false, py::arg("has_nsw")=false) + // // GEP + // .def("create_gep", [](mlir::OpBuilder &self, MlirOperation &ptr, MlirOperation &offset) -> MlirOperation { + // auto loc = self.getUnknownLoc(); + // }, ret::reference) + // Comparison (int) + .def("create_icmpSLE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::sle, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpSLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::slt, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpSGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::sge, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpSGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::sgt, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpULE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::ule, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpULT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::ult, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpUGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::uge, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpUGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::ugt, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpEQ", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::eq, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + .def("create_icmpNE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpIPredicate::ne, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + // Comparison (float) + .def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + auto loc = self.getUnknownLoc(); + return wrap(self.create( + loc, mlir::arith::CmpFPredicate::OLT, + unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) + ).getOperation()); + }, ret::reference) + // .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) + // .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) + // .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) + // .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) + // .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) + // .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) + // .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) + // .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) + // .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) + // .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) + // .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) + // // Logical + // .def("create_and", &ir::builder::create_and, ret::reference) + // .def("create_xor", &ir::builder::create_xor, ret::reference) + // .def("create_or", &ir::builder::create_or, ret::reference) + // // Input/Output + // .def("create_load", &ir::builder::create_load, ret::reference) + // .def("create_store", &ir::builder::create_store, ret::reference) + // .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) + // .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) + // // Block instruction + // .def("create_splat", &ir::builder::create_splat, ret::reference) + // .def("create_reshape", &ir::builder::create_reshape, ret::reference) + // .def("create_cat", &ir::builder::create_cat, ret::reference) + // .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) + // // atomic + // .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) + // .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) + + // // Built-in instruction + // .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) + // .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) + // .def("create_exp", &ir::builder::create_exp, ret::reference) + // .def("create_cos", &ir::builder::create_cos, ret::reference) + // .def("create_sin", &ir::builder::create_sin, ret::reference) + // .def("create_log", &ir::builder::create_log, ret::reference) + // .def("create_dot", &ir::builder::create_dot, ret::reference) + // .def("create_trans", &ir::builder::create_trans, ret::reference) + // .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) + // .def("create_reduce", &ir::builder::create_reduce, ret::reference) + // .def("create_select", &ir::builder::create_select, ret::reference) + // // Intrinsics + // // These have no place in the IR, and hopefully they can be removed at some point + // .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) + // .def("create_barrier", &ir::builder::create_barrier, ret::reference); + ; } void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); - init_triton_codegen(std::move(subm.def_submodule("code_gen"))); + // init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); } diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 81b9fe790..ea1b24940 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -395,30 +395,30 @@ class constexpr: class tensor: - # infer dtype from ir type - @staticmethod - def _to_dtype(ir_type): - # block type - if ir_type.is_block(): - scalar_ty = tensor._to_dtype(ir_type.scalar) - return block_type(scalar_ty, ir_type.get_block_shapes()) - # pointer type - if ir_type.is_ptr(): - element_ty = tensor._to_dtype(ir_type.element) - return pointer_type(element_ty) - # primitive type - if ir_type.is_void(): return void - if ir_type.is_int1(): return int1 - if ir_type.is_int8(): return int8 - if ir_type.is_int16(): return int16 - if ir_type.is_int32(): return int32 - if ir_type.is_int64(): return int64 - if ir_type.is_fp8(): return float8 - if ir_type.is_fp16(): return float16 - if ir_type.is_bf16(): return bfloat16 - if ir_type.is_fp32(): return float32 - if ir_type.is_fp64(): return float64 - raise ValueError(f"Unsupported type {ir_type.repr()}") + # # infer dtype from ir type + # @staticmethod + # def _to_dtype(ir_type): + # # block type + # if ir_type.is_block(): + # scalar_ty = tensor._to_dtype(ir_type.scalar) + # return block_type(scalar_ty, ir_type.get_block_shapes()) + # # pointer type + # if ir_type.is_ptr(): + # element_ty = tensor._to_dtype(ir_type.element) + # return pointer_type(element_ty) + # # primitive type + # if ir_type.is_void(): return void + # if ir_type.is_int1(): return int1 + # if ir_type.is_int8(): return int8 + # if ir_type.is_int16(): return int16 + # if ir_type.is_int32(): return int32 + # if ir_type.is_int64(): return int64 + # if ir_type.is_fp8(): return float8 + # if ir_type.is_fp16(): return float16 + # if ir_type.is_bf16(): return bfloat16 + # if ir_type.is_fp32(): return float32 + # if ir_type.is_fp64(): return float64 + # raise ValueError(f"Unsupported type {ir_type.repr()}") def __init__(self, handle, type: dtype): # IR handle