[FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet) Integer signedness analysis was moved from C++ to python (@daadaada) Cleaner frontend types (@daadaada) Moved SSA construction to a separate object (@ptillet) Co-authored-by: Yan Da <dyanab@connect.ust.hk>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -12,10 +11,12 @@
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include "Python.h"
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
@@ -541,87 +542,6 @@ void init_triton_codegen(py::module &&m) {
|
||||
}, py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* User-facing language features */
|
||||
/*****************************************************************************/
|
||||
|
||||
void init_triton_frontend(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
// programming model
|
||||
m.def("program_id", &ir::dispatch::program_id, ret::reference);
|
||||
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
|
||||
// binary
|
||||
m.def("add", &ir::dispatch::add, ret::reference);
|
||||
m.def("sub", &ir::dispatch::sub, ret::reference);
|
||||
m.def("mul", &ir::dispatch::mul, ret::reference);
|
||||
m.def("truediv", &ir::dispatch::truediv, ret::reference);
|
||||
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
|
||||
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
|
||||
m.def("mod", &ir::dispatch::mod, ret::reference);
|
||||
m.def("and_", &ir::dispatch::and_, ret::reference);
|
||||
m.def("or_", &ir::dispatch::or_, ret::reference);
|
||||
m.def("xor_", &ir::dispatch::xor_, ret::reference);
|
||||
m.def("lshr", &ir::dispatch::lshr, ret::reference);
|
||||
m.def("shl", &ir::dispatch::shl, ret::reference);
|
||||
// unary
|
||||
m.def("plus", &ir::dispatch::plus, ret::reference);
|
||||
m.def("minus", &ir::dispatch::minus, ret::reference);
|
||||
m.def("invert", &ir::dispatch::invert, ret::reference);
|
||||
// comparison
|
||||
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
|
||||
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
|
||||
m.def("less_than", &ir::dispatch::less_than, ret::reference);
|
||||
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
|
||||
m.def("equal", &ir::dispatch::equal, ret::reference);
|
||||
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
|
||||
// block creation
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
|
||||
m.def("cast", &ir::dispatch::cast, ret::reference);
|
||||
// memory
|
||||
m.def("load", &ir::dispatch::load, ret::reference);
|
||||
m.def("store", &ir::dispatch::store, ret::reference);
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
|
||||
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
|
||||
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
|
||||
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
|
||||
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
m.def("where", &ir::dispatch::where, ret::reference);
|
||||
// reduction
|
||||
m.def("min", &ir::dispatch::min, ret::reference);
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference);
|
||||
// math
|
||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// utilities
|
||||
m.def("clock", &ir::dispatch::clock, ret::reference);
|
||||
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
|
||||
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
@@ -631,16 +551,86 @@ void init_triton_ir(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER")
|
||||
.value("NONE", ir::load_inst::NONE)
|
||||
.value("CA", ir::load_inst::CA)
|
||||
.value("CG", ir::load_inst::CG)
|
||||
.export_values();
|
||||
|
||||
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
|
||||
.value("NORMAL", ir::load_inst::NORMAL)
|
||||
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
|
||||
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
|
||||
.value("ADD", ir::reduce_inst::ADD)
|
||||
.value("FADD", ir::reduce_inst::FADD)
|
||||
.value("MIN", ir::reduce_inst::MIN)
|
||||
.value("MAX", ir::reduce_inst::MAX)
|
||||
.value("FMIN", ir::reduce_inst::FMIN)
|
||||
.value("FMAX", ir::reduce_inst::FMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
.value("ADD", ir::atomic_rmw_op_t::Add)
|
||||
.value("FADD", ir::atomic_rmw_op_t::FAdd)
|
||||
.value("AND", ir::atomic_rmw_op_t::And)
|
||||
.value("OR", ir::atomic_rmw_op_t::Or)
|
||||
.value("XOR", ir::atomic_rmw_op_t::Xor)
|
||||
.value("XCHG", ir::atomic_rmw_op_t::Xchg)
|
||||
.value("MAX", ir::atomic_rmw_op_t::Max)
|
||||
.value("MIN", ir::atomic_rmw_op_t::Min)
|
||||
.value("UMIN", ir::atomic_rmw_op_t::UMin)
|
||||
.value("UMAX", ir::atomic_rmw_op_t::UMax);
|
||||
|
||||
py::class_<ir::context>(m, "context")
|
||||
.def(py::init<>());
|
||||
|
||||
auto value = py::class_<ir::value>(m, "value");
|
||||
value.def_property("name", &ir::value::get_name, &ir::value::set_name);
|
||||
value.def_property_readonly("type", &ir::value::get_type);
|
||||
py::class_<ir::value>(m, "value")
|
||||
.def("multiple_of", [](ir::value *self, int val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::multiple_of, val);
|
||||
} else
|
||||
throw std::runtime_error("multiple_of");
|
||||
})
|
||||
.def("max_contiguous", [](ir::value *self, int val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::max_contiguous, val);
|
||||
} else
|
||||
throw std::runtime_error("max_contiguous");
|
||||
})
|
||||
.def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) {
|
||||
if (auto *instr = dynamic_cast<ir::binary_operator*>(self))
|
||||
instr->set_fdiv_ieee_rounding(val);
|
||||
else
|
||||
throw std::runtime_error("set_fdiv_ieee_rounding");
|
||||
})
|
||||
.def("is_phi", [](ir::value *self) {
|
||||
if (auto *pn = dynamic_cast<ir::phi_node*>(self))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("ops", [](ir::value *self) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
return instr->ops();
|
||||
}
|
||||
throw std::runtime_error("cannot use ops()");
|
||||
})
|
||||
.def("replace_all_uses_with", &ir::value::replace_all_uses_with)
|
||||
.def("erase_from_parent", [](ir::value *self) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self))
|
||||
return instr->erase_from_parent();
|
||||
throw std::runtime_error("cannot use erase_from_parent");
|
||||
})
|
||||
.def_property("name", &ir::value::get_name, &ir::value::set_name)
|
||||
.def_property_readonly("type", &ir::value::get_type);
|
||||
|
||||
py::class_<ir::user, ir::value>(m, "user");
|
||||
|
||||
py::class_<ir::constant, ir::user>(m, "constant");
|
||||
py::class_<ir::constant, ir::user>(m, "constant")
|
||||
.def("get_null_value", &ir::constant::get_null_value, ret::reference)
|
||||
.def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference);
|
||||
|
||||
py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
.def("get", &ir::undef_value::get, ret::reference);
|
||||
@@ -651,18 +641,17 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("__bool__", [](ir::constant_int *self) { return self->get_value(); });
|
||||
|
||||
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value);
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value)
|
||||
.def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference);
|
||||
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction");
|
||||
py::class_<ir::phi_node, ir::user>(m, "phi_node");
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction")
|
||||
.def("get_parent", [](ir::instruction *self) {
|
||||
return self->get_parent();
|
||||
}, ret::reference);
|
||||
py::class_<ir::phi_node, ir::instruction>(m, "phi_node")
|
||||
.def("add_incoming", &ir::phi_node::add_incoming);
|
||||
|
||||
py::class_<ir::type>(m, "type")
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("get_int_width", &ir::type::get_integer_bitwidth)
|
||||
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
||||
.def("make_function", &ir::function_type::get, ret::reference)
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
@@ -677,35 +666,39 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
|
||||
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
|
||||
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
|
||||
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
|
||||
.def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference)
|
||||
|
||||
.def("get_block_shapes", &ir::type::get_block_shapes)
|
||||
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("is_struct", &ir::type::is_struct_ty)
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_bool", &ir::type::is_bool_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
.def("is_fp16", &ir::type::is_fp16_ty)
|
||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
|
||||
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
.def("is_struct", &ir::type::is_struct_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
||||
.def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty)
|
||||
|
||||
.def("repr", &ir::type::repr)
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference);
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference)
|
||||
.def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth)
|
||||
.def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits);
|
||||
|
||||
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference)
|
||||
.def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type")
|
||||
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
|
||||
@@ -723,21 +716,20 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get", &ir::struct_type::get, ret::reference)
|
||||
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||
|
||||
py::class_<ir::value_constructor>(m, "value_constructor")
|
||||
.def(py::init<ir::builder&>())
|
||||
.def("seal_block", &ir::value_constructor::seal_block)
|
||||
.def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value)
|
||||
.def("set_type", &ir::value_constructor::set_type)
|
||||
.def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference)
|
||||
.def("get_values", &ir::value_constructor::get_values, ret::reference)
|
||||
.def("set_values", &ir::value_constructor::set_values);
|
||||
|
||||
py::class_<ir::module>(m, "module")
|
||||
.def(py::init<std::string, ir::builder &>())
|
||||
.def("has_function", &ir::module::has_function)
|
||||
.def("get_function", &ir::module::get_function, ret::reference)
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||
.def("reset_ret_ty", &ir::module::reset_ret_ty)
|
||||
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
|
||||
const auto metadatas = self->get_metadatas();
|
||||
auto it = metadatas.find(name);
|
||||
if (it != metadatas.end())
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
|
||||
instr->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
})
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
@@ -768,6 +760,13 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
||||
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
|
||||
.def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
|
||||
.def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
|
||||
ir::basic_block::iterator it = self->get_first_non_phi();
|
||||
if (it == self->end())
|
||||
return nullptr;
|
||||
return *it;
|
||||
}, ret::reference)
|
||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||
|
||||
py::class_<ir::builder::iterator>(m, "bb_iterator");
|
||||
@@ -783,22 +782,168 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
.def("ret", &ir::builder::create_ret, ret::reference)
|
||||
.def("get_insert_point", &ir::builder::get_insert_point)
|
||||
.def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point)
|
||||
// insertion block/point, insert points are represented as (*bb, *instr)
|
||||
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||
.def("get_insert_point", [](ir::builder *self) {
|
||||
ir::basic_block *bb = self->get_insert_block();
|
||||
ir::basic_block::iterator it = self->get_insert_point();
|
||||
ir::instruction *instr = it == bb->end() ? nullptr : *it;
|
||||
return std::make_pair(bb, instr);
|
||||
}, ret::reference)
|
||||
.def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
|
||||
ir::basic_block *bb = pt.first;
|
||||
ir::instruction *instr = pt.second;
|
||||
if (instr) {
|
||||
if (bb != instr->get_parent())
|
||||
throw std::runtime_error("invalid insertion point, instr not in bb");
|
||||
self->set_insert_point(instr);
|
||||
} else {
|
||||
assert(bb);
|
||||
self->set_insert_point(bb);
|
||||
}
|
||||
})
|
||||
// Constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference)
|
||||
// Types
|
||||
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
|
||||
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
|
||||
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
|
||||
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
|
||||
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
|
||||
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
|
||||
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
|
||||
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
|
||||
.def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference)
|
||||
.def("get_float_ty", &ir::builder::get_float_ty, ret::reference)
|
||||
.def("get_double_ty", &ir::builder::get_double_ty, ret::reference)
|
||||
// terminator instructions
|
||||
.def("create_br", &ir::builder::create_br, ret::reference)
|
||||
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
.def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
|
||||
.def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
|
||||
.def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
|
||||
.def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
|
||||
.def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
|
||||
.def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
|
||||
.def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
|
||||
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
||||
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||
// phi
|
||||
.def("create_phi", &ir::builder::create_phi, ret::reference)
|
||||
// Binary instructions
|
||||
.def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
|
||||
.def("create_fmul", &ir::builder::create_fmul, ret::reference)
|
||||
.def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
|
||||
.def("create_frem", &ir::builder::create_frem, ret::reference)
|
||||
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
|
||||
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
|
||||
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
|
||||
.def("create_srem", &ir::builder::create_srem, ret::reference)
|
||||
.def("create_urem", &ir::builder::create_urem, ret::reference)
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sub", &ir::builder::create_sub, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_shl", &ir::builder::create_shl, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
// GEP
|
||||
.def("create_gep", &ir::builder::create_gep, ret::reference)
|
||||
// Comparison (int)
|
||||
.def("create_icmp", &ir::builder::create_icmp, ret::reference)
|
||||
.def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference)
|
||||
.def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference)
|
||||
.def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference)
|
||||
.def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference)
|
||||
.def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference)
|
||||
.def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference)
|
||||
.def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference)
|
||||
.def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference)
|
||||
.def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference)
|
||||
.def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference)
|
||||
// Comparison (float)
|
||||
.def("create_fcmp", &ir::builder::create_fcmp, ret::reference)
|
||||
.def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference)
|
||||
.def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
|
||||
.def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
|
||||
.def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference)
|
||||
.def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference)
|
||||
.def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference)
|
||||
.def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference)
|
||||
.def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference)
|
||||
.def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference)
|
||||
.def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference)
|
||||
.def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference)
|
||||
.def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference)
|
||||
// Logical
|
||||
.def("create_and", &ir::builder::create_and, ret::reference)
|
||||
.def("create_xor", &ir::builder::create_xor, ret::reference)
|
||||
.def("create_or", &ir::builder::create_or, ret::reference)
|
||||
// Input/Output
|
||||
.def("create_load", &ir::builder::create_load, ret::reference)
|
||||
.def("create_store", &ir::builder::create_store, ret::reference)
|
||||
.def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
|
||||
.def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
|
||||
// Block instruction
|
||||
.def("create_splat", &ir::builder::create_splat, ret::reference)
|
||||
.def("create_reshape", &ir::builder::create_reshape, ret::reference)
|
||||
.def("create_cat", &ir::builder::create_cat, ret::reference)
|
||||
.def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
|
||||
// atomic
|
||||
.def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
|
||||
.def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
|
||||
// Utilities
|
||||
.def("create_clock", &ir::builder::create_clock, ret::reference)
|
||||
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
|
||||
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
||||
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
||||
.def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||
.def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||
.def("create_sin", &ir::builder::create_sin, ret::reference)
|
||||
.def("create_log", &ir::builder::create_log, ret::reference)
|
||||
.def("create_dot", &ir::builder::create_dot, ret::reference)
|
||||
.def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
.def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
.def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
.def("create_select", &ir::builder::create_select, ret::reference)
|
||||
// struct
|
||||
.def("insert_value", &ir::builder::create_insert_value, ret::reference)
|
||||
.def("extract_value", &ir::builder::create_extract_value, ret::reference)
|
||||
// constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
.def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
|
||||
.def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
|
||||
.def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference)
|
||||
.def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
|
||||
.def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
@@ -806,5 +951,4 @@ void init_triton(py::module &m) {
|
||||
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_triton_frontend(std::move(subm.def_submodule("frontend")));
|
||||
}
|
||||
|
@@ -37,7 +37,7 @@ matmul_data = {
|
||||
(256, 256, 256): {'float16': 0.027},
|
||||
(512, 512, 512): {'float16': 0.158},
|
||||
(1024, 1024, 1024): {'float16': 0.466},
|
||||
(2048, 2048, 2048): {'float16': 0.680},
|
||||
(2048, 2048, 2048): {'float16': 0.695},
|
||||
(4096, 4096, 4096): {'float16': 0.831},
|
||||
(8192, 8192, 8192): {'float16': 0.849},
|
||||
# tall-skinny
|
||||
|
@@ -1,5 +1,4 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
@@ -12,7 +11,7 @@ from numpy.random import RandomState
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import TensorWrapper, reinterpret
|
||||
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
@@ -993,11 +992,17 @@ def test_noop(device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
spec_type = None
|
||||
|
||||
def cache_hook(*args, **kwargs):
|
||||
nonlocal spec_type
|
||||
spec_type = kwargs["compile"]["arg_types"][0][1]
|
||||
JITFunction.cache_hook = cache_hook
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
@@ -1006,11 +1011,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
|
||||
assert ir_value_type == value_type
|
||||
JITFunction.cache_hook = None
|
||||
assert spec_type == value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1045,13 +1047,13 @@ def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||
|
||||
|
||||
def test_dyn_par(cond=True, device='cuda'):
|
||||
n_pids = 10
|
||||
# pids = torch.arange(n_pids, device=device)
|
||||
# alpha = 2.0
|
||||
# x_ref = pids * alpha
|
||||
x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# cond = torch.tensor([cond], device=device)
|
||||
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
print(x_tri)
|
||||
# triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
# def test_dyn_par(cond=True, device='cuda'):
|
||||
# n_pids = 10
|
||||
# # pids = torch.arange(n_pids, device=device)
|
||||
# # alpha = 2.0
|
||||
# # x_ref = pids * alpha
|
||||
# x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# # cond = torch.tensor([cond], device=device)
|
||||
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
# print(x_tri)
|
||||
# # triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
@@ -102,3 +103,30 @@ def test_specialize(mode):
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs['key'].split('-')
|
||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.code_gen.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
@@ -6,7 +6,8 @@ __version__ = '2.0.0'
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
|
||||
JITFunction, Config, Autotuner, reinterpret
|
||||
from . import language
|
||||
from . import code_gen
|
||||
from . import testing
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import builtins
|
||||
import functools
|
||||
@@ -11,7 +13,7 @@ import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from typing import Dict, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -21,26 +23,26 @@ import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
def mangle_ty(type):
|
||||
if type.is_ptr():
|
||||
return 'P' + mangle_ty(type.element)
|
||||
if type.is_int():
|
||||
return 'i' + str(type.get_int_width())
|
||||
if type.is_fp8():
|
||||
def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
if ty.is_int():
|
||||
return 'i' + str(ty.int_bitwidth)
|
||||
if ty.is_fp8():
|
||||
return 'fp8'
|
||||
if type.is_fp16():
|
||||
if ty.is_fp16():
|
||||
return 'fp16'
|
||||
if type.is_bf16():
|
||||
if ty.is_bf16():
|
||||
return 'bf16'
|
||||
if type.is_fp32():
|
||||
if ty.is_fp32():
|
||||
return 'fp32'
|
||||
if type.is_fp64():
|
||||
if ty.is_fp64():
|
||||
return 'fp64'
|
||||
if type.is_void():
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
if type.is_block():
|
||||
elt = mangle_ty(type.scalar)
|
||||
shape = '_'.join(map(str, type.shape))
|
||||
if ty.is_block():
|
||||
elt = mangle_ty(ty.scalar)
|
||||
shape = '_'.join(map(str, ty.shape))
|
||||
return f'{elt}S{shape}S'
|
||||
assert False, "Unsupport type"
|
||||
|
||||
@@ -56,8 +58,38 @@ def mangle_fn(name, arg_tys, constants):
|
||||
return ret
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def is_triton_tensor(value):
|
||||
return isinstance(value, triton.language.tensor)
|
||||
|
||||
|
||||
class ValueConstructor:
|
||||
def __init__(self, module, builder, gscope) -> None:
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.builder = builder
|
||||
self.module = module
|
||||
# [name, bb] => triton.language.tensor
|
||||
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
|
||||
# bb => {name => phi}
|
||||
self.incomplete_phis = {}
|
||||
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
|
||||
#
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'print': print,
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
1. make sure `name` is defined
|
||||
2. if `name` is triton.language.tensor, get stored tensor by calling
|
||||
`self._get_tensor()`
|
||||
'''
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
@@ -70,21 +102,123 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if isinstance(ret, triton.language.block):
|
||||
handle = self.value_constructor.get_value(name)
|
||||
return triton.language.block(handle)
|
||||
if is_triton_tensor(ret):
|
||||
return self._get_tensor(name, self.builder.get_insert_block())
|
||||
return ret
|
||||
|
||||
def set_value(self, name, value):
|
||||
if isinstance(value, _triton.ir.value):
|
||||
value = triton.language.block(value)
|
||||
if isinstance(value, triton.language.block):
|
||||
self.value_constructor.set_value(name, value.handle)
|
||||
self.value_constructor.set_type(name, value.handle.type)
|
||||
def set_value(self, name: str,
|
||||
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
self.lscope[name] = value
|
||||
if isinstance(value, triton.language.tensor):
|
||||
self._set_value(name, self.builder.get_insert_block(), value)
|
||||
|
||||
def is_triton_object(self, value):
|
||||
return isinstance(value, triton.language.block)
|
||||
#
|
||||
# SSA-construction
|
||||
#
|
||||
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
# local value numbering
|
||||
if (name, bb) in self.lvalues:
|
||||
return self.lvalues[(name, bb)]
|
||||
# global value numbering
|
||||
saved_insert_point = self.builder.get_insert_point()
|
||||
result = self._get_tensor_recursive(name, bb)
|
||||
self.builder.set_insert_point(saved_insert_point)
|
||||
return result
|
||||
|
||||
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
preds = bb.get_predecessors()
|
||||
type = self.lscope[name].type
|
||||
# some preds haven't been filled, create a phi as a proxy of the value
|
||||
if bb not in self.sealed_blocks:
|
||||
result = self._make_phi(type, len(preds), bb)
|
||||
if bb in self.incomplete_phis:
|
||||
self.incomplete_phis[bb][name] = result
|
||||
else:
|
||||
self.incomplete_phis[bb] = {name: result}
|
||||
elif len(preds) == 1:
|
||||
# one predecessor: no phi needed, try get value from pred
|
||||
result = self._get_tensor(name, preds[0])
|
||||
elif len(preds) == 0:
|
||||
result = self._get_tensor(name, None)
|
||||
else: # multiple preds
|
||||
phi = self._make_phi(type, len(preds), bb)
|
||||
self._set_value(name, bb, phi)
|
||||
result = self._add_phi_operands(name, phi)
|
||||
self._set_value(name, bb, result)
|
||||
return result
|
||||
|
||||
# returns a new phi tensor, which encausulate an ir.phi_node
|
||||
def _make_phi(self,
|
||||
type: triton.language.dtype,
|
||||
num_values: int,
|
||||
bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
instr = bb.get_first_non_phi()
|
||||
self.builder.set_insert_point((bb, instr))
|
||||
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
|
||||
if instr:
|
||||
self.builder.set_insert_block(bb)
|
||||
return triton.language.tensor(ir_phi, type)
|
||||
|
||||
# complete a phi node. (TODO: rename this as _complete_phis?)
|
||||
# Note: since we try to remove tryival phi, the return tensor might not be a phi
|
||||
def _add_phi_operands(self, name: str,
|
||||
phi: triton.language.tensor) -> triton.language.tensor:
|
||||
bb = phi.handle.get_parent()
|
||||
for pred in bb.get_predecessors():
|
||||
v = self._get_tensor(name, pred)
|
||||
phi.handle.add_incoming(v.handle, pred)
|
||||
phi = self._try_remove_trivial_phi(phi)
|
||||
return phi
|
||||
|
||||
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
||||
self.lvalues[(name, bb)] = value
|
||||
# TODO: why we need this?
|
||||
self.module.set_instr_metadata(name, value.handle)
|
||||
|
||||
def _seal_block(self, bb: _triton.ir.basic_block):
|
||||
# complete all incomplete phis
|
||||
if bb in self.incomplete_phis:
|
||||
for name, phi in self.incomplete_phis[bb].items():
|
||||
result = self._add_phi_operands(name, phi)
|
||||
# it's possible that this phi is trivial
|
||||
if self._get_tensor(name, bb).handle == phi.handle:
|
||||
self._set_value(name, bb, result)
|
||||
del self.incomplete_phis[bb]
|
||||
self.sealed_blocks.add(bb)
|
||||
|
||||
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
|
||||
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
|
||||
if len(unique_handles) != 1: # non-trivial phi
|
||||
return phi
|
||||
v = unique_handles.pop()
|
||||
phi.handle.replace_all_uses_with(v)
|
||||
phi.handle.erase_from_parent()
|
||||
# TODO: remove trivial phis recursively
|
||||
return triton.language.tensor(v, phi.type)
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, prototypes=None, module=None, is_kernel=False):
|
||||
self.prototypes = dict() if prototypes is None else prototypes
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = _triton.ir.module('', self.builder) if module is None else module
|
||||
self.prototype = prototype
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.last_node = None
|
||||
self.is_kernel = is_kernel
|
||||
|
||||
self.value_constructor = ValueConstructor(self.module, self.builder, gscope)
|
||||
|
||||
#
|
||||
# AST visitor
|
||||
#
|
||||
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
@@ -93,27 +227,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.value_constructor = _triton.ir.value_constructor(self.builder)
|
||||
self.module = _triton.ir.module('', self.builder) if module is None else module
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.last_node = None
|
||||
self.is_kernel = is_kernel
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'print': print,
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
|
||||
def visit_Module(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
@@ -127,16 +240,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
ret = self.visit(node.value)
|
||||
if ret is None:
|
||||
return self.builder.ret_void()
|
||||
if isinstance(ret, _triton.ir.value):
|
||||
ret = self.builder.ret(ret)
|
||||
return ret
|
||||
if isinstance(ret, triton.language.block):
|
||||
ret = ret.handle
|
||||
if isinstance(ret, triton.language.constexpr):
|
||||
ret = triton.language.core._to_ir(ret, self.builder)
|
||||
# TODO: should return tl.block
|
||||
return self.builder.ret(ret)
|
||||
return triton.language.tensor(self.builder.ret_void(), triton.language.void)
|
||||
ret = triton.language.core._to_tensor(ret, self.builder)
|
||||
ret = triton.language.tensor(self.builder.ret(ret.handle), ret.type)
|
||||
return ret
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
@@ -152,8 +259,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants)
|
||||
fn = self.module.get_or_insert_function(fn_name, self.prototype)
|
||||
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
|
||||
self.prototypes[fn_name] = self.prototype
|
||||
fn = self.module.get_or_insert_function(fn_name, self.prototype.to_ir(self.builder))
|
||||
fn.set_is_kernel(self.is_kernel)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
@@ -171,23 +279,24 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
insert_pt = self.builder.get_insert_block()
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self.builder.set_insert_block(entry)
|
||||
self.value_constructor.seal_block(entry)
|
||||
self.value_constructor._seal_block(entry)
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.value_constructor.set_value(arg_name, arg_value)
|
||||
# visit function body
|
||||
has_ret = self.visit_compound_statement(node.body)
|
||||
# finalize
|
||||
if not has_ret:
|
||||
self.builder.ret_void()
|
||||
else:
|
||||
self.module.reset_ret_ty(fn_name, self.last_ret.type)
|
||||
# self.module.reset_ret_type(node.name)
|
||||
# a bit hacky: we only know the return type at the last moment so we update type info here
|
||||
self.module.reset_ret_ty(fn_name, self.last_ret.type.to_ir(self.builder))
|
||||
self.prototype.ret_type = self.last_ret.type
|
||||
self.builder.set_insert_block(insert_pt)
|
||||
|
||||
def visit_arguments(self, node):
|
||||
@@ -208,13 +317,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
value = self.visit(node.value)
|
||||
# constexpr
|
||||
if annotation == triton.language.constexpr:
|
||||
if target in self.lscope:
|
||||
if target in self.value_constructor.lscope:
|
||||
raise ValueError(f'{target} is already defined.'
|
||||
f' constexpr cannot be reassigned.')
|
||||
if not isinstance(value, triton.language.constexpr):
|
||||
value = triton.language.constexpr(value)
|
||||
self.lscope[target] = value
|
||||
return self.lscope[target]
|
||||
self.value_constructor.lscope[target] = value
|
||||
return self.value_constructor.lscope[target]
|
||||
# default: call visit_Assign
|
||||
return self.visit_Assign(node)
|
||||
|
||||
@@ -229,19 +338,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
if isinstance(values[0], _triton.ir.value):
|
||||
struct = values[0]
|
||||
ty = struct.type
|
||||
if ty.is_struct():
|
||||
values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)]
|
||||
if isinstance(values[0], triton.language.tensor) \
|
||||
and isinstance(values[0].type, triton.language.tuple_type):
|
||||
struct = values[0].handle
|
||||
tys = values[0].type.element_types
|
||||
values = [self.builder.extract_value(struct, i) for i in range(len(tys))]
|
||||
values = [triton.language.tensor(v, ty) for v, ty in zip(values, tys)]
|
||||
assert len(values) == len(names)
|
||||
for name, value in zip(names, values):
|
||||
# TODO: can we store constexpr here to support constant folding?
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language.core._to_ir(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
if not isinstance(value, triton.language.tensor):
|
||||
value = triton.language.core._to_tensor(value, self.builder)
|
||||
self.value_constructor.set_value(name, value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
name = node.target.id
|
||||
@@ -249,12 +360,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
rhs = ast.BinOp(lhs, node.op, node.value)
|
||||
assign = ast.Assign(targets=[node.target], value=rhs)
|
||||
self.visit(assign)
|
||||
return self.get_value(name)
|
||||
return self.value_constructor.get_value(name)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if type(node.ctx) == ast.Store:
|
||||
return node.id
|
||||
return self.get_value(node.id)
|
||||
return self.value_constructor.get_value(node.id)
|
||||
|
||||
def visit_Store(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@@ -266,23 +377,22 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = [self.visit(x) for x in node.elts]
|
||||
mode = type(args[0])
|
||||
# tuple of values -- create a struct
|
||||
if len(args) > 1 and mode == triton.language.block\
|
||||
if len(args) > 1 and mode == triton.language.tensor\
|
||||
and all([type(arg) == mode for arg in args]):
|
||||
args = [arg.handle for arg in args]
|
||||
tys = [arg.type for arg in args]
|
||||
struct_ty = _triton.ir.struct_type.get(tys, True)
|
||||
ret = _triton.ir.undef.get(struct_ty)
|
||||
tuple_ty = triton.language.tuple_type([arg.type for arg in args])
|
||||
ret = _triton.ir.undef.get(tuple_ty.to_ir(self.builder))
|
||||
for i, arg in enumerate(args):
|
||||
ret = self.builder.insert_value(ret, arg, i)
|
||||
ret = self.builder.insert_value(ret, arg.handle, i)
|
||||
ret = triton.language.tensor(ret, tuple_ty)
|
||||
return ret
|
||||
return tuple(args)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
@@ -298,9 +408,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_object(lhs):
|
||||
if is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
elif is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
@@ -308,15 +418,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.block):
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
current_bb = self.builder.get_insert_block()
|
||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
self.value_constructor.seal_block(then_bb)
|
||||
self.value_constructor._seal_block(then_bb)
|
||||
if else_bb:
|
||||
self.value_constructor.seal_block(else_bb)
|
||||
self.value_constructor._seal_block(else_bb)
|
||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
else:
|
||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
@@ -331,7 +441,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: last statement is a terminator?
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self.value_constructor.seal_block(endif_bb)
|
||||
self.value_constructor._seal_block(endif_bb)
|
||||
self.builder.set_insert_block(endif_bb)
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
@@ -356,9 +466,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
@@ -372,9 +482,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_object(lhs):
|
||||
if is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
elif is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
@@ -385,21 +495,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_object(op):
|
||||
if is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
cond = self.visit(node.test)
|
||||
@@ -410,9 +520,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.value_constructor.seal_block(stop_bb)
|
||||
self.value_constructor.seal_block(loop_bb)
|
||||
self.value_constructor.seal_block(next_bb)
|
||||
self.value_constructor._seal_block(stop_bb)
|
||||
self.value_constructor._seal_block(loop_bb)
|
||||
self.value_constructor._seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -422,7 +532,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_object(lhs):
|
||||
if is_triton_tensor(lhs):
|
||||
return lhs.__getitem__(slices, _builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
@@ -431,7 +541,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_For(self, node):
|
||||
iterator = self.visit(node.iter.func)
|
||||
if iterator != self.builtins['range']:
|
||||
if iterator != self.value_constructor.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
@@ -442,7 +552,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
for i in iterator(*iter_args):
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.value_constructor.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
@@ -465,8 +575,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
self.visit(step_node)
|
||||
@@ -481,9 +591,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: handle case where body breaks control flow
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.value_constructor.seal_block(stop_bb)
|
||||
self.value_constructor.seal_block(loop_bb)
|
||||
self.value_constructor.seal_block(next_bb)
|
||||
self.value_constructor._seal_block(stop_bb)
|
||||
self.value_constructor._seal_block(loop_bb)
|
||||
self.value_constructor._seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -514,7 +624,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.block)
|
||||
args = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
@@ -523,25 +633,24 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||
arg_types = [arg.type for arg in arg_vals]
|
||||
arg_types = [arg.type for arg in args if arg is not None]
|
||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||
# generate function def if necessary
|
||||
if not self.module.has_function(fn_name):
|
||||
ret_type = _triton.ir.type.get_void(self.builder.context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module)
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, prototypes=self.prototypes, module=self.module)
|
||||
generator.visit(fn.parse())
|
||||
symbol = self.module.get_function(fn_name)
|
||||
ret = self.builder.call(symbol, arg_vals)
|
||||
if not ret.type.is_void() and not ret.type.is_struct():
|
||||
ret = triton.language.block(ret)
|
||||
if not ret.type.is_void():
|
||||
ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type)
|
||||
return ret
|
||||
# built-in function
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
if sys.modules[fn.__module__] is triton.language.core:
|
||||
ret = fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
if fn in self.value_constructor.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
for arg in args]
|
||||
ret = fn(*args, **kws)
|
||||
@@ -698,7 +807,7 @@ class Kernel:
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
if isinstance(obj, triton.language.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if -2**31 <= obj < 2**31:
|
||||
@@ -730,34 +839,34 @@ class Kernel:
|
||||
return 'scalar', name
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
def _to_triton_ir(obj):
|
||||
which, name = obj
|
||||
type_map = {
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'L': _triton.ir.type.get_int64,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'bf16': _triton.ir.type.get_bf16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
'i1': _triton.ir.type.get_int1,
|
||||
'i8': _triton.ir.type.get_int8,
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
'u8': _triton.ir.type.get_uint8,
|
||||
'u16': _triton.ir.type.get_uint16,
|
||||
'u32': _triton.ir.type.get_uint32,
|
||||
'u64': _triton.ir.type.get_uint64,
|
||||
'I': triton.language.int32,
|
||||
'L': triton.language.int64,
|
||||
'f': triton.language.float32,
|
||||
'B': triton.language.int1,
|
||||
'f8': triton.language.float8,
|
||||
'f16': triton.language.float16,
|
||||
'bf16': triton.language.bfloat16,
|
||||
'f32': triton.language.float32,
|
||||
'f64': triton.language.float64,
|
||||
'i1': triton.language.int1,
|
||||
'i8': triton.language.int8,
|
||||
'i16': triton.language.int16,
|
||||
'i32': triton.language.int32,
|
||||
'i64': triton.language.int64,
|
||||
'u8': triton.language.uint8,
|
||||
'u16': triton.language.uint16,
|
||||
'u32': triton.language.uint32,
|
||||
'u64': triton.language.uint64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if which == 'ptr':
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
elt_ty = type_map[name]
|
||||
return triton.language.pointer_type(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
return type_map[name](context)
|
||||
return type_map[name]
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
@@ -1121,9 +1230,9 @@ class JITFunction:
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types]
|
||||
ret_type = _triton.ir.type.get_void(context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
|
File diff suppressed because it is too large
Load Diff
1052
python/triton/language/semantic.py
Normal file
1052
python/triton/language/semantic.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user