[FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py Integer signedness now moved from C++ to python Cleaner frontend type Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -12,10 +11,12 @@
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include "Python.h"
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
@@ -541,84 +542,6 @@ void init_triton_codegen(py::module &&m) {
|
||||
}, py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* User-facing language features */
|
||||
/*****************************************************************************/
|
||||
|
||||
void init_triton_frontend(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
// programming model
|
||||
m.def("program_id", &ir::dispatch::program_id, ret::reference);
|
||||
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
|
||||
// binary
|
||||
m.def("add", &ir::dispatch::add, ret::reference);
|
||||
m.def("sub", &ir::dispatch::sub, ret::reference);
|
||||
m.def("mul", &ir::dispatch::mul, ret::reference);
|
||||
m.def("truediv", &ir::dispatch::truediv, ret::reference);
|
||||
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
|
||||
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
|
||||
m.def("mod", &ir::dispatch::mod, ret::reference);
|
||||
m.def("and_", &ir::dispatch::and_, ret::reference);
|
||||
m.def("or_", &ir::dispatch::or_, ret::reference);
|
||||
m.def("xor_", &ir::dispatch::xor_, ret::reference);
|
||||
m.def("lshr", &ir::dispatch::lshr, ret::reference);
|
||||
m.def("shl", &ir::dispatch::shl, ret::reference);
|
||||
// unary
|
||||
m.def("plus", &ir::dispatch::plus, ret::reference);
|
||||
m.def("minus", &ir::dispatch::minus, ret::reference);
|
||||
m.def("invert", &ir::dispatch::invert, ret::reference);
|
||||
// comparison
|
||||
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
|
||||
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
|
||||
m.def("less_than", &ir::dispatch::less_than, ret::reference);
|
||||
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
|
||||
m.def("equal", &ir::dispatch::equal, ret::reference);
|
||||
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
|
||||
// block creation
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
|
||||
m.def("cast", &ir::dispatch::cast, ret::reference);
|
||||
// memory
|
||||
m.def("load", &ir::dispatch::load, ret::reference);
|
||||
m.def("store", &ir::dispatch::store, ret::reference);
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
|
||||
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
|
||||
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
|
||||
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
|
||||
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
m.def("where", &ir::dispatch::where, ret::reference);
|
||||
// reduction
|
||||
m.def("min", &ir::dispatch::min, ret::reference);
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference);
|
||||
// math
|
||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
|
||||
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
@@ -628,16 +551,86 @@ void init_triton_ir(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER")
|
||||
.value("NONE", ir::load_inst::NONE)
|
||||
.value("CA", ir::load_inst::CA)
|
||||
.value("CG", ir::load_inst::CG)
|
||||
.export_values();
|
||||
|
||||
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
|
||||
.value("NORMAL", ir::load_inst::NORMAL)
|
||||
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
|
||||
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
|
||||
.value("ADD", ir::reduce_inst::ADD)
|
||||
.value("FADD", ir::reduce_inst::FADD)
|
||||
.value("MIN", ir::reduce_inst::MIN)
|
||||
.value("MAX", ir::reduce_inst::MAX)
|
||||
.value("FMIN", ir::reduce_inst::FMIN)
|
||||
.value("FMAX", ir::reduce_inst::FMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
.value("ADD", ir::atomic_rmw_op_t::Add)
|
||||
.value("FADD", ir::atomic_rmw_op_t::FAdd)
|
||||
.value("AND", ir::atomic_rmw_op_t::And)
|
||||
.value("OR", ir::atomic_rmw_op_t::Or)
|
||||
.value("XOR", ir::atomic_rmw_op_t::Xor)
|
||||
.value("XCHG", ir::atomic_rmw_op_t::Xchg)
|
||||
.value("MAX", ir::atomic_rmw_op_t::Max)
|
||||
.value("MIN", ir::atomic_rmw_op_t::Min)
|
||||
.value("UMIN", ir::atomic_rmw_op_t::UMin)
|
||||
.value("UMAX", ir::atomic_rmw_op_t::UMax);
|
||||
|
||||
py::class_<ir::context>(m, "context")
|
||||
.def(py::init<>());
|
||||
|
||||
auto value = py::class_<ir::value>(m, "value");
|
||||
value.def_property("name", &ir::value::get_name, &ir::value::set_name);
|
||||
value.def_property_readonly("type", &ir::value::get_type);
|
||||
py::class_<ir::value>(m, "value")
|
||||
.def("multiple_of", [](ir::value *self, int val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::multiple_of, val);
|
||||
} else
|
||||
throw std::runtime_error("multiple_of");
|
||||
})
|
||||
.def("max_contiguous", [](ir::value *self, int val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::max_contiguous, val);
|
||||
} else
|
||||
throw std::runtime_error("max_contiguous");
|
||||
})
|
||||
.def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) {
|
||||
if (auto *instr = dynamic_cast<ir::binary_operator*>(self))
|
||||
instr->set_fdiv_ieee_rounding(val);
|
||||
else
|
||||
throw std::runtime_error("set_fdiv_ieee_rounding");
|
||||
})
|
||||
.def("is_phi", [](ir::value *self) {
|
||||
if (auto *pn = dynamic_cast<ir::phi_node*>(self))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("ops", [](ir::value *self) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
return instr->ops();
|
||||
}
|
||||
throw std::runtime_error("cannot use ops()");
|
||||
})
|
||||
.def("replace_all_uses_with", &ir::value::replace_all_uses_with)
|
||||
.def("erase_from_parent", [](ir::value *self) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self))
|
||||
return instr->erase_from_parent();
|
||||
throw std::runtime_error("cannot use erase_from_parent");
|
||||
})
|
||||
.def_property("name", &ir::value::get_name, &ir::value::set_name)
|
||||
.def_property_readonly("type", &ir::value::get_type);
|
||||
|
||||
py::class_<ir::user, ir::value>(m, "user");
|
||||
|
||||
py::class_<ir::constant, ir::user>(m, "constant");
|
||||
py::class_<ir::constant, ir::user>(m, "constant")
|
||||
.def("get_null_value", &ir::constant::get_null_value, ret::reference)
|
||||
.def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference);
|
||||
|
||||
py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
.def("get", &ir::undef_value::get, ret::reference);
|
||||
@@ -648,16 +641,17 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("__bool__", [](ir::constant_int *self) { return self->get_value(); });
|
||||
|
||||
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value);
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value)
|
||||
.def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference);
|
||||
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction");
|
||||
py::class_<ir::phi_node, ir::user>(m, "phi_node");
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction")
|
||||
.def("get_parent", [](ir::instruction *self) {
|
||||
return self->get_parent();
|
||||
}, ret::reference);
|
||||
py::class_<ir::phi_node, ir::instruction>(m, "phi_node")
|
||||
.def("add_incoming", &ir::phi_node::add_incoming);
|
||||
|
||||
py::class_<ir::type>(m, "type")
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
||||
.def("make_function", &ir::function_type::get, ret::reference)
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
@@ -672,34 +666,38 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
|
||||
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
|
||||
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
|
||||
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
|
||||
.def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference)
|
||||
|
||||
.def("get_block_shapes", &ir::type::get_block_shapes)
|
||||
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_bool", &ir::type::is_bool_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
.def("is_fp16", &ir::type::is_fp16_ty)
|
||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
|
||||
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
||||
.def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty)
|
||||
|
||||
.def("repr", &ir::type::repr)
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference);
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference)
|
||||
.def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth)
|
||||
.def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits);
|
||||
|
||||
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference)
|
||||
.def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type");
|
||||
py::class_<ir::integer_type, ir::type>(m, "integer_type");
|
||||
@@ -709,16 +707,15 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::module>(m, "module")
|
||||
.def(py::init<std::string, ir::builder &>())
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||
.def("seal_block", &ir::module::seal_block)
|
||||
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
||||
.def("set_type", &ir::module::set_type)
|
||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||
.def("get_values", &ir::module::get_values, ret::reference)
|
||||
.def("set_values", &ir::module::set_values)
|
||||
.def("get_types", &ir::module::get_types, ret::reference)
|
||||
.def("set_types", &ir::module::set_types)
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
|
||||
const auto metadatas = self->get_metadatas();
|
||||
auto it = metadatas.find(name);
|
||||
if (it != metadatas.end())
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
|
||||
instr->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
})
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
py::enum_<eattr>(m, "attribute_kind")
|
||||
@@ -742,6 +739,13 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
||||
.def("create", &ir::basic_block::create, ret::reference)
|
||||
.def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
|
||||
.def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
|
||||
ir::basic_block::iterator it = self->get_first_non_phi();
|
||||
if (it == self->end())
|
||||
return nullptr;
|
||||
return *it;
|
||||
}, ret::reference)
|
||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||
|
||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
||||
@@ -752,17 +756,162 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("br", &ir::builder::create_br, ret::reference)
|
||||
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// insertion block/point, insert points are represented as (*bb, *instr)
|
||||
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||
// constants
|
||||
.def("get_insert_point", [](ir::builder *self) {
|
||||
ir::basic_block *bb = self->get_insert_block();
|
||||
ir::basic_block::iterator it = self->get_insert_point();
|
||||
ir::instruction *instr = it == bb->end() ? nullptr : *it;
|
||||
return std::make_pair(bb, instr);
|
||||
}, ret::reference)
|
||||
.def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
|
||||
ir::basic_block *bb = pt.first;
|
||||
ir::instruction *instr = pt.second;
|
||||
if (instr) {
|
||||
if (bb != instr->get_parent())
|
||||
throw std::runtime_error("invalid insertion point, instr not in bb");
|
||||
self->set_insert_point(instr);
|
||||
} else {
|
||||
assert(bb);
|
||||
self->set_insert_point(bb);
|
||||
}
|
||||
})
|
||||
// Constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
|
||||
.def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
.def("get_range", &ir::builder::get_range, ret::reference)
|
||||
// Types
|
||||
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
|
||||
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
|
||||
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
|
||||
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
|
||||
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
|
||||
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
|
||||
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
|
||||
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
|
||||
.def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference)
|
||||
.def("get_float_ty", &ir::builder::get_float_ty, ret::reference)
|
||||
.def("get_double_ty", &ir::builder::get_double_ty, ret::reference)
|
||||
// terminator instructions
|
||||
.def("create_br", &ir::builder::create_br, ret::reference)
|
||||
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
.def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
|
||||
.def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
|
||||
.def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
|
||||
.def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
|
||||
.def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
|
||||
.def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
|
||||
.def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
|
||||
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
||||
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||
// phi
|
||||
.def("create_phi", &ir::builder::create_phi, ret::reference)
|
||||
// Binary instructions
|
||||
.def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
|
||||
.def("create_fmul", &ir::builder::create_fmul, ret::reference)
|
||||
.def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
|
||||
.def("create_frem", &ir::builder::create_frem, ret::reference)
|
||||
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
|
||||
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
|
||||
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
|
||||
.def("create_srem", &ir::builder::create_srem, ret::reference)
|
||||
.def("create_urem", &ir::builder::create_urem, ret::reference)
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sub", &ir::builder::create_sub, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_shl", &ir::builder::create_shl, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
// GEP
|
||||
.def("create_gep", &ir::builder::create_gep, ret::reference)
|
||||
// Comparison (int)
|
||||
.def("create_icmp", &ir::builder::create_icmp, ret::reference)
|
||||
.def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference)
|
||||
.def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference)
|
||||
.def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference)
|
||||
.def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference)
|
||||
.def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference)
|
||||
.def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference)
|
||||
.def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference)
|
||||
.def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference)
|
||||
.def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference)
|
||||
.def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference)
|
||||
// Comparison (float)
|
||||
.def("create_fcmp", &ir::builder::create_fcmp, ret::reference)
|
||||
.def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference)
|
||||
.def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
|
||||
.def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
|
||||
.def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference)
|
||||
.def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference)
|
||||
.def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference)
|
||||
.def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference)
|
||||
.def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference)
|
||||
.def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference)
|
||||
.def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference)
|
||||
.def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference)
|
||||
.def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference)
|
||||
// Logical
|
||||
.def("create_and", &ir::builder::create_and, ret::reference)
|
||||
.def("create_xor", &ir::builder::create_xor, ret::reference)
|
||||
.def("create_or", &ir::builder::create_or, ret::reference)
|
||||
// Input/Output
|
||||
.def("create_load", &ir::builder::create_load, ret::reference)
|
||||
.def("create_store", &ir::builder::create_store, ret::reference)
|
||||
.def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
|
||||
.def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
|
||||
// Block instruction
|
||||
.def("create_splat", &ir::builder::create_splat, ret::reference)
|
||||
.def("create_reshape", &ir::builder::create_reshape, ret::reference)
|
||||
.def("create_cat", &ir::builder::create_cat, ret::reference)
|
||||
.def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
|
||||
// atomic
|
||||
.def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
|
||||
.def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
|
||||
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
||||
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
||||
.def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||
.def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||
.def("create_sin", &ir::builder::create_sin, ret::reference)
|
||||
.def("create_log", &ir::builder::create_log, ret::reference)
|
||||
.def("create_dot", &ir::builder::create_dot, ret::reference)
|
||||
.def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
.def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
.def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
.def("create_select", &ir::builder::create_select, ret::reference)
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
.def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
|
||||
.def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
|
||||
.def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference)
|
||||
.def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
|
||||
.def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
@@ -770,5 +919,4 @@ void init_triton(py::module &m) {
|
||||
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_triton_frontend(std::move(subm.def_submodule("frontend")));
|
||||
}
|
||||
|
Reference in New Issue
Block a user