Replace MlirType with mlir::Type

This commit is contained in:
Yan Da
2022-04-01 18:46:46 +08:00
parent 4ad432f1fc
commit bde103fab0
3 changed files with 93 additions and 239 deletions

View File

@@ -4,11 +4,8 @@
#include "triton/driver/llvm.h"
#include "mlir/IR/Builders.h"
#include "mlir-c/IR.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/CAPI/IR.h"
// #include "mlir/IR/BuiltinOps.h"
// #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
@@ -652,13 +649,9 @@ void init_triton_ir(py::module &&m) {
.def("get_context", &mlir::ModuleOp::getContext)
;
py::class_<MlirType>(m, "type")
.def("is_integer", [](MlirType &self) -> bool {
return mlirTypeIsAInteger(self);
})
.def("is_fp16", [](MlirType &self) -> bool {
return mlirTypeIsABF16(self);
})
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16)
;
py::class_<mlir::Value>(m, "value")
@@ -782,74 +775,77 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(v));
})
.def("get_null_value", [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
if (type.isa<mlir::FloatType>())
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(0.0));
else
throw std::runtime_error("Not implemented");
})
// Types
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
return wrap(self.getNoneType());
.def("get_void_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getNoneType();
})
.def("get_int1_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI1Type());
.def("get_int1_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI1Type();
}) // or ret::copy?
.def("get_int8_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI8Type());
.def("get_int8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI8Type();
})
.def("get_int16_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getType<mlir::IntegerType>(16));
.def("get_int16_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::IntegerType>(16);
})
.def("get_int32_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI32Type());
.def("get_int32_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI32Type();
})
.def("get_int64_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI64Type());
.def("get_int64_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI64Type();
})
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getType<mlir::triton::Float8Type>());
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::Float8Type>();
})
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getType<mlir::triton::BFloat8Type>());
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::BFloat8Type>();
})
.def("get_half_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF16Type());
.def("get_half_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getF16Type();
})
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getBF16Type());
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getBF16Type();
})
.def("get_float_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF32Type());
.def("get_float_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getF32Type();
})
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF64Type());
.def("get_double_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getF64Type();
})
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType {
return wrap(
mlir::triton::PointerType::get(unwrap(type), addrSpace)
);
.def("get_ptr_ty", [](mlir::OpBuilder &self, mlir::Type &type, int addrSpace) -> mlir::Type {
return mlir::triton::PointerType::get(type, addrSpace);
})
.def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType,
std::vector<int64_t> &shape) -> MlirType {
return wrap(
mlir::RankedTensorType::get(shape, unwrap(elementType))
);
.def("get_block_ty", [](mlir::OpBuilder &self, mlir::Type &elementType,
std::vector<int64_t> &shape) -> mlir::Type {
return mlir::RankedTensorType::get(shape, elementType);
})
.def("get_function_ty", [](mlir::OpBuilder &self,
std::vector<MlirType> inTypes,
std::vector<MlirType> outTypes) -> MlirType {
llvm::SmallVector<mlir::Type, 4> inputsTypeList;
llvm::SmallVector<mlir::Type, 4> resultsTypeList;
(void)unwrapList(inTypes.size(), inTypes.data(), inputsTypeList);
(void)unwrapList(outTypes.size(), outTypes.data(), resultsTypeList);
return wrap(self.getFunctionType(inputsTypeList, resultsTypeList));
std::vector<mlir::Type> inTypes,
std::vector<mlir::Type> outTypes) -> mlir::Type {
return self.getFunctionType(inTypes, outTypes);
})
// Ops
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp {
.def("create_function", [](mlir::OpBuilder &self, std::string name, mlir::Type &funcType) -> mlir::FuncOp {
// TODO: loc
auto loc = self.getUnknownLoc();
if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) {
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, name, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
// Structured control flow
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
@@ -878,35 +874,35 @@ void init_triton_ir(py::module &&m) {
})
// Cast instructions
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::SIToFPOp>(loc, dstType, src);
})
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::UIToFPOp>(loc, dstType, src);
})
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::FPToSIOp>(loc, dstType, src);
})
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::FPToUIOp>(loc, dstType, src);
})
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::ExtFOp>(loc, dstType, src);
})
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
})
// .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast)

View File

@@ -44,11 +44,8 @@ class CodeGenerator(ast.NodeVisitor):
'getattr': getattr,
}
# SSA-construction
# [name, bb] => triton.language.tensor
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
# bb => {name => phi}
self.incomplete_phis = {}
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
# name => triton.language.tensor
self.local_defs: Dict[str, triton.language.tensor] = {}
def get_value(self, name):
''' This function:
@@ -67,10 +64,7 @@ class CodeGenerator(ast.NodeVisitor):
elif name in self.builtins:
ret = self.builtins[name]
else:
print(self.lscope)
raise ValueError(f'{name} is not defined')
if self.is_triton_tensor(ret):
return self._get_tensor(name, self.builder.get_insertion_block())
return ret
def set_value(self, name: str,
@@ -81,99 +75,7 @@ class CodeGenerator(ast.NodeVisitor):
2. store tensor in self.lvalue
'''
self.lscope[name] = value
if isinstance(value, triton.language.tensor):
self._set_value(name, self.builder.get_insertion_block(), value)
#
# SSA-construction
#
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
if not bb:
bb = self.builder.get_insertion_block()
# local value numbering
if (name, bb) in self.lvalues:
return self.lvalues[(name, bb)]
# param. FIXME: should delete this
if (name, None) in self.lvalues:
return self.lvalues[(name, None)]
print(self.lvalues)
assert False, f'Cannot find {name} in {bb}'
# global value numbering
saved_insert_point = self.builder.get_insert_point()
result = self._get_tensor_recursive(name, bb)
self.builder.set_insert_point(saved_insert_point)
return result
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
preds = bb.get_predecessors()
type = self.lscope[name].type
# some preds haven't been filled, create a phi as a proxy of the value
if bb not in self.sealed_blocks:
result = self._make_phi(type, len(preds), bb)
if bb in self.incomplete_phis:
self.incomplete_phis[bb][name] = result
else:
self.incomplete_phis[bb] = {name: result}
elif len(preds) == 1:
# one predecessor: no phi needed, try get value from pred
result = self._get_tensor(name, preds[0])
elif len(preds) == 0:
result = self._get_tensor(name, None)
else: # multiple preds
phi = self._make_phi(type, len(preds), bb)
self._set_value(name, bb, phi)
result = self._add_phi_operands(name, phi)
self._set_value(name, bb, result)
return result
# returns a new phi tensor, which encausulate an ir.phi_node
def _make_phi(self,
type: triton.language.dtype,
num_values: int,
bb: _triton.ir.basic_block) -> triton.language.tensor:
instr = bb.get_first_non_phi()
self.builder.set_insert_point((bb, instr))
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
if instr:
self.builder.set_insert_block(bb)
return triton.language.tensor(ir_phi, type)
# complete a phi node. (TODO: rename this as _complete_phis?)
# Note: since we try to remove tryival phi, the return tensor might not be a phi
def _add_phi_operands(self, name: str,
phi: triton.language.tensor) -> triton.language.tensor:
bb = phi.handle.get_parent()
for pred in bb.get_predecessors():
v = self._get_tensor(name, pred)
phi.handle.add_incoming(v.handle, pred)
phi = self._try_remove_trivial_phi(phi)
return phi
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
self.lvalues[(name, bb)] = value
# # TODO: why we need this?
# self.module.set_instr_metadata(name, value.handle)
def _seal_block(self, bb: _triton.ir.basic_block):
# complete all incomplete phis
if bb in self.incomplete_phis:
for name, phi in self.incomplete_phis[bb].items():
result = self._add_phi_operands(name, phi)
# it's possible that this phi is trivial
if self._get_tensor(name, bb).handle == phi.handle:
self._set_value(name, bb, result)
del self.incomplete_phis[bb]
self.sealed_blocks.add(bb)
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
if len(unique_handles) != 1: # non-trivial phi
return phi
v = unique_handles.pop()
phi.handle.replace_all_uses_with(v)
phi.handle.erase_from_parent()
# TODO: remove trivial phis recursively
return triton.language.tensor(v, phi.type)
self.local_defs[name] = value
def is_triton_tensor(self, value):
return isinstance(value, triton.language.tensor)
@@ -229,7 +131,6 @@ class CodeGenerator(ast.NodeVisitor):
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
self._seal_block(entry)
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
@@ -294,6 +195,7 @@ class CodeGenerator(ast.NodeVisitor):
assert len(_names) == 1
names = _names[0]
values = self.visit(node.value)
print(f'visit_Assign({names}, {values})')
if not isinstance(names, tuple):
names = [names]
if not isinstance(values, tuple):
@@ -367,9 +269,7 @@ class CodeGenerator(ast.NodeVisitor):
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
# self._seal_block(then_bb)
# if else_bb:
# self._seal_block(else_bb)
# self.builder.cond_br(cond.handle, then_bb, else_bb)
# else:
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
@@ -384,17 +284,16 @@ class CodeGenerator(ast.NodeVisitor):
# # TODO: last statement is a terminator?
# if not is_terminator:
# self.builder.br(endif_bb)
# self._seal_block(endif_bb)
# self.builder.set_insert_block(endif_bb)
parent_lvalues = self.lvalues.copy()
parent_values = self.lscope.copy()
self.visit_compound_statement(node.body)
then_lvalues = self.lvalues.copy()
then_values = self.lvalues.copy()
assert node.orelse
self.lvalues = parent_lvalues
self.lscope = parent_values
self.visit_compound_statement(node.orelse)
else_lvalues = self.lvalues.copy()
else_values = self.lscope.copy()
self.lvalues = join_if_lvalues(then_lvalues, else_lvalues)
self.lvalues = join_if_lvalues(then_values, else_values)
else:
if isinstance(cond, triton.language.constexpr):
@@ -473,9 +372,6 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body)
continue_fn()
stop_bb = self.builder.get_insertion_block()
self._seal_block(stop_bb)
self._seal_block(loop_bb)
self._seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
@@ -510,47 +406,7 @@ class CodeGenerator(ast.NodeVisitor):
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# # create nodes
# st_target = ast.Name(id=node.target.id, ctx=ast.Store())
# ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
# arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
# arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
# arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
# init_node = ast.Assign(targets=[st_target], value=arg_0)
# pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
# neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
# pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
# build_cond = lambda: triton.language.where(self.visit(pos_step_node),
# self.visit(pos_cond_node),
# self.visit(neg_cond_node),
# _builder=self.builder)
# # cond_node = neg_cond_node
# step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# # code generation
# current_bb = self.builder.get_insertion_block()
# loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
# next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
# def continue_fn():
# self.visit(step_node)
# cond = build_cond()
# return self.builder.cond_br(cond.handle, loop_bb, next_bb)
# self.visit(init_node)
# cond = build_cond()
# self.builder.cond_br(cond.handle, loop_bb, next_bb)
# self.builder.set_insert_block(loop_bb)
# self.visit_compound_statement(node.body)
# # TODO: handle case where body breaks control flow
# continue_fn()
# stop_bb = self.builder.get_insertion_block()
# self._seal_block(stop_bb)
# self._seal_block(loop_bb)
# self._seal_block(next_bb)
# self.builder.set_insert_block(next_bb)
# for stmt in node.orelse:
# ast.NodeVisitor.generic_visit(self, stmt)
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
@@ -558,22 +414,32 @@ class CodeGenerator(ast.NodeVisitor):
loop_body = self.builder.create_block()
self.builder.set_insertion_point_to_start(loop_body)
liveins = self.lscope.copy()
prev_defs = self.local_defs.copy()
self.local_defs = set()
self.local_defs = {}
# visit loop body
parent_lvalues = self.lvalues.copy()
self.visit_compound_statement()
loop_lvalues = self.lvalues.copy()
self.visit_compound_statement(node.body)
# TODO: update insertion point
# TODO: create scf.forOp
# self.lvalues = join_loop_lvalues(parent_lvalues, loop_lvalues)
# self.make_for_op(parent_lvalues, loop_lvalues, lb, ub, step)
for_op = self.builder.create_for_op(lb, ub, step, [loop_init_args])
init_args = {}
yields = {}
for name in self.local_defs:
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name])
assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type:
init_args[name] = liveins[name]
yields[name] = self.local_defs[name]
# for_op = self.builder.create_for_op(lb, ub, step, [init_args])
self.lscope = liveins
self.local_defs = prev_defs
for stmt in node.orelse:
assert False
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):
lower = self.visit(node.lower)
upper = self.visit(node.upper)
@@ -1079,17 +945,12 @@ class JITFunction:
lscope = generator.lscope.copy()
# TODO: clear values other than args
lvalues = generator.lvalues.copy()
# types = generator.module.get_types().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
generator.lvalues = lvalues
# generator.module.set_types(types)
return ret
except Exception as e:
node = generator.last_node
@@ -1222,9 +1083,6 @@ class JITFunction:
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
print(f'wargs: {wargs}')
print(f'constants: {constants}')
print(f'arg_types: {arg_types}')
# create IR module
context = _triton.ir.context()
context.load_triton()

View File

@@ -312,7 +312,7 @@ def minus(input: tl.tensor,
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr():
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
_0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return sub(_0, input, builder)
@@ -442,7 +442,7 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
_0 = ir.constant.get_null_value(dtype.to_ir(builder))
_0 = builder.get_null_value(dtype.to_ir(builder))
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(_0, shape), ret_ty)