Replace MlirType with mlir::Type
This commit is contained in:
@@ -4,11 +4,8 @@
|
||||
#include "triton/driver/llvm.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
// #include "mlir/IR/BuiltinOps.h"
|
||||
// #include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
@@ -652,13 +649,9 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_context", &mlir::ModuleOp::getContext)
|
||||
;
|
||||
|
||||
py::class_<MlirType>(m, "type")
|
||||
.def("is_integer", [](MlirType &self) -> bool {
|
||||
return mlirTypeIsAInteger(self);
|
||||
})
|
||||
.def("is_fp16", [](MlirType &self) -> bool {
|
||||
return mlirTypeIsABF16(self);
|
||||
})
|
||||
py::class_<mlir::Type>(m, "type")
|
||||
.def("is_integer", &mlir::Type::isInteger)
|
||||
.def("is_fp16", &mlir::Type::isF16)
|
||||
;
|
||||
|
||||
py::class_<mlir::Value>(m, "value")
|
||||
@@ -782,74 +775,77 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(v));
|
||||
})
|
||||
.def("get_null_value", [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (type.isa<mlir::FloatType>())
|
||||
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(0.0));
|
||||
else
|
||||
throw std::runtime_error("Not implemented");
|
||||
})
|
||||
|
||||
// Types
|
||||
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
|
||||
return wrap(self.getNoneType());
|
||||
.def("get_void_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getNoneType();
|
||||
})
|
||||
.def("get_int1_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI1Type());
|
||||
.def("get_int1_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getI1Type();
|
||||
}) // or ret::copy?
|
||||
.def("get_int8_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI8Type());
|
||||
.def("get_int8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getI8Type();
|
||||
})
|
||||
.def("get_int16_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getType<mlir::IntegerType>(16));
|
||||
.def("get_int16_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::IntegerType>(16);
|
||||
})
|
||||
.def("get_int32_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI32Type());
|
||||
.def("get_int32_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getI32Type();
|
||||
})
|
||||
.def("get_int64_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI64Type());
|
||||
.def("get_int64_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getI64Type();
|
||||
})
|
||||
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getType<mlir::triton::Float8Type>());
|
||||
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::Float8Type>();
|
||||
})
|
||||
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getType<mlir::triton::BFloat8Type>());
|
||||
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::BFloat8Type>();
|
||||
})
|
||||
.def("get_half_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF16Type());
|
||||
.def("get_half_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getF16Type();
|
||||
})
|
||||
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getBF16Type());
|
||||
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getBF16Type();
|
||||
})
|
||||
.def("get_float_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF32Type());
|
||||
.def("get_float_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getF32Type();
|
||||
})
|
||||
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF64Type());
|
||||
.def("get_double_ty", [](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getF64Type();
|
||||
})
|
||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType {
|
||||
return wrap(
|
||||
mlir::triton::PointerType::get(unwrap(type), addrSpace)
|
||||
);
|
||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, mlir::Type &type, int addrSpace) -> mlir::Type {
|
||||
return mlir::triton::PointerType::get(type, addrSpace);
|
||||
})
|
||||
.def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType,
|
||||
std::vector<int64_t> &shape) -> MlirType {
|
||||
return wrap(
|
||||
mlir::RankedTensorType::get(shape, unwrap(elementType))
|
||||
);
|
||||
.def("get_block_ty", [](mlir::OpBuilder &self, mlir::Type &elementType,
|
||||
std::vector<int64_t> &shape) -> mlir::Type {
|
||||
return mlir::RankedTensorType::get(shape, elementType);
|
||||
})
|
||||
.def("get_function_ty", [](mlir::OpBuilder &self,
|
||||
std::vector<MlirType> inTypes,
|
||||
std::vector<MlirType> outTypes) -> MlirType {
|
||||
llvm::SmallVector<mlir::Type, 4> inputsTypeList;
|
||||
llvm::SmallVector<mlir::Type, 4> resultsTypeList;
|
||||
(void)unwrapList(inTypes.size(), inTypes.data(), inputsTypeList);
|
||||
(void)unwrapList(outTypes.size(), outTypes.data(), resultsTypeList);
|
||||
return wrap(self.getFunctionType(inputsTypeList, resultsTypeList));
|
||||
std::vector<mlir::Type> inTypes,
|
||||
std::vector<mlir::Type> outTypes) -> mlir::Type {
|
||||
return self.getFunctionType(inTypes, outTypes);
|
||||
})
|
||||
|
||||
// Ops
|
||||
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp {
|
||||
.def("create_function", [](mlir::OpBuilder &self, std::string name, mlir::Type &funcType) -> mlir::FuncOp {
|
||||
// TODO: loc
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) {
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, name, funcTy);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
}, ret::reference)
|
||||
// Structured control flow
|
||||
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
|
||||
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
|
||||
@@ -878,35 +874,35 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
|
||||
// Cast instructions
|
||||
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_cast", &ir::builder::create_cast)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::SIToFPOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::UIToFPOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::FPToSIOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::FPToUIOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::ExtFOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
||||
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), src);
|
||||
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_int_cast", &ir::builder::create_int_cast)
|
||||
// .def("create_downcast", &ir::builder::create_downcast)
|
||||
|
@@ -44,11 +44,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
'getattr': getattr,
|
||||
}
|
||||
# SSA-construction
|
||||
# [name, bb] => triton.language.tensor
|
||||
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
|
||||
# bb => {name => phi}
|
||||
self.incomplete_phis = {}
|
||||
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
|
||||
# name => triton.language.tensor
|
||||
self.local_defs: Dict[str, triton.language.tensor] = {}
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
@@ -67,10 +64,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
print(self.lscope)
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if self.is_triton_tensor(ret):
|
||||
return self._get_tensor(name, self.builder.get_insertion_block())
|
||||
return ret
|
||||
|
||||
def set_value(self, name: str,
|
||||
@@ -81,99 +75,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
self.lscope[name] = value
|
||||
if isinstance(value, triton.language.tensor):
|
||||
self._set_value(name, self.builder.get_insertion_block(), value)
|
||||
|
||||
#
|
||||
# SSA-construction
|
||||
#
|
||||
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
if not bb:
|
||||
bb = self.builder.get_insertion_block()
|
||||
# local value numbering
|
||||
if (name, bb) in self.lvalues:
|
||||
return self.lvalues[(name, bb)]
|
||||
# param. FIXME: should delete this
|
||||
if (name, None) in self.lvalues:
|
||||
return self.lvalues[(name, None)]
|
||||
print(self.lvalues)
|
||||
assert False, f'Cannot find {name} in {bb}'
|
||||
# global value numbering
|
||||
saved_insert_point = self.builder.get_insert_point()
|
||||
result = self._get_tensor_recursive(name, bb)
|
||||
self.builder.set_insert_point(saved_insert_point)
|
||||
return result
|
||||
|
||||
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
preds = bb.get_predecessors()
|
||||
type = self.lscope[name].type
|
||||
# some preds haven't been filled, create a phi as a proxy of the value
|
||||
if bb not in self.sealed_blocks:
|
||||
result = self._make_phi(type, len(preds), bb)
|
||||
if bb in self.incomplete_phis:
|
||||
self.incomplete_phis[bb][name] = result
|
||||
else:
|
||||
self.incomplete_phis[bb] = {name: result}
|
||||
elif len(preds) == 1:
|
||||
# one predecessor: no phi needed, try get value from pred
|
||||
result = self._get_tensor(name, preds[0])
|
||||
elif len(preds) == 0:
|
||||
result = self._get_tensor(name, None)
|
||||
else: # multiple preds
|
||||
phi = self._make_phi(type, len(preds), bb)
|
||||
self._set_value(name, bb, phi)
|
||||
result = self._add_phi_operands(name, phi)
|
||||
self._set_value(name, bb, result)
|
||||
return result
|
||||
|
||||
# returns a new phi tensor, which encausulate an ir.phi_node
|
||||
def _make_phi(self,
|
||||
type: triton.language.dtype,
|
||||
num_values: int,
|
||||
bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
instr = bb.get_first_non_phi()
|
||||
self.builder.set_insert_point((bb, instr))
|
||||
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
|
||||
if instr:
|
||||
self.builder.set_insert_block(bb)
|
||||
return triton.language.tensor(ir_phi, type)
|
||||
|
||||
# complete a phi node. (TODO: rename this as _complete_phis?)
|
||||
# Note: since we try to remove tryival phi, the return tensor might not be a phi
|
||||
def _add_phi_operands(self, name: str,
|
||||
phi: triton.language.tensor) -> triton.language.tensor:
|
||||
bb = phi.handle.get_parent()
|
||||
for pred in bb.get_predecessors():
|
||||
v = self._get_tensor(name, pred)
|
||||
phi.handle.add_incoming(v.handle, pred)
|
||||
phi = self._try_remove_trivial_phi(phi)
|
||||
return phi
|
||||
|
||||
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
||||
self.lvalues[(name, bb)] = value
|
||||
# # TODO: why we need this?
|
||||
# self.module.set_instr_metadata(name, value.handle)
|
||||
|
||||
def _seal_block(self, bb: _triton.ir.basic_block):
|
||||
# complete all incomplete phis
|
||||
if bb in self.incomplete_phis:
|
||||
for name, phi in self.incomplete_phis[bb].items():
|
||||
result = self._add_phi_operands(name, phi)
|
||||
# it's possible that this phi is trivial
|
||||
if self._get_tensor(name, bb).handle == phi.handle:
|
||||
self._set_value(name, bb, result)
|
||||
del self.incomplete_phis[bb]
|
||||
self.sealed_blocks.add(bb)
|
||||
|
||||
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
|
||||
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
|
||||
if len(unique_handles) != 1: # non-trivial phi
|
||||
return phi
|
||||
v = unique_handles.pop()
|
||||
phi.handle.replace_all_uses_with(v)
|
||||
phi.handle.erase_from_parent()
|
||||
# TODO: remove trivial phis recursively
|
||||
return triton.language.tensor(v, phi.type)
|
||||
self.local_defs[name] = value
|
||||
|
||||
def is_triton_tensor(self, value):
|
||||
return isinstance(value, triton.language.tensor)
|
||||
@@ -229,7 +131,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
self._seal_block(entry)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -294,6 +195,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(_names) == 1
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
print(f'visit_Assign({names}, {values})')
|
||||
if not isinstance(names, tuple):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
@@ -367,9 +269,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
# self._seal_block(then_bb)
|
||||
# if else_bb:
|
||||
# self._seal_block(else_bb)
|
||||
# self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
# else:
|
||||
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
@@ -384,17 +284,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# # TODO: last statement is a terminator?
|
||||
# if not is_terminator:
|
||||
# self.builder.br(endif_bb)
|
||||
# self._seal_block(endif_bb)
|
||||
# self.builder.set_insert_block(endif_bb)
|
||||
parent_lvalues = self.lvalues.copy()
|
||||
parent_values = self.lscope.copy()
|
||||
self.visit_compound_statement(node.body)
|
||||
then_lvalues = self.lvalues.copy()
|
||||
then_values = self.lvalues.copy()
|
||||
assert node.orelse
|
||||
self.lvalues = parent_lvalues
|
||||
self.lscope = parent_values
|
||||
self.visit_compound_statement(node.orelse)
|
||||
else_lvalues = self.lvalues.copy()
|
||||
else_values = self.lscope.copy()
|
||||
|
||||
self.lvalues = join_if_lvalues(then_lvalues, else_lvalues)
|
||||
self.lvalues = join_if_lvalues(then_values, else_values)
|
||||
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
@@ -473,9 +372,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insertion_block()
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -510,47 +406,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
# # create nodes
|
||||
# st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
# ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||
# arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
||||
# arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
||||
# arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
||||
# init_node = ast.Assign(targets=[st_target], value=arg_0)
|
||||
# pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||
# neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||
# pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||
# build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
||||
# self.visit(pos_cond_node),
|
||||
# self.visit(neg_cond_node),
|
||||
# _builder=self.builder)
|
||||
# # cond_node = neg_cond_node
|
||||
# step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# # code generation
|
||||
# current_bb = self.builder.get_insertion_block()
|
||||
# loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
# next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
# def continue_fn():
|
||||
# self.visit(step_node)
|
||||
# cond = build_cond()
|
||||
# return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
|
||||
# self.visit(init_node)
|
||||
# cond = build_cond()
|
||||
# self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
# self.builder.set_insert_block(loop_bb)
|
||||
# self.visit_compound_statement(node.body)
|
||||
# # TODO: handle case where body breaks control flow
|
||||
# continue_fn()
|
||||
# stop_bb = self.builder.get_insertion_block()
|
||||
# self._seal_block(stop_bb)
|
||||
# self._seal_block(loop_bb)
|
||||
# self._seal_block(next_bb)
|
||||
# self.builder.set_insert_block(next_bb)
|
||||
|
||||
# for stmt in node.orelse:
|
||||
# ast.NodeVisitor.generic_visit(self, stmt)
|
||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||
@@ -558,22 +414,32 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
loop_body = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_body)
|
||||
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = set()
|
||||
self.local_defs = {}
|
||||
|
||||
# visit loop body
|
||||
parent_lvalues = self.lvalues.copy()
|
||||
self.visit_compound_statement()
|
||||
loop_lvalues = self.lvalues.copy()
|
||||
self.visit_compound_statement(node.body)
|
||||
|
||||
# TODO: update insertion point
|
||||
# TODO: create scf.forOp
|
||||
# self.lvalues = join_loop_lvalues(parent_lvalues, loop_lvalues)
|
||||
# self.make_for_op(parent_lvalues, loop_lvalues, lb, ub, step)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [loop_init_args])
|
||||
init_args = {}
|
||||
yields = {}
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
init_args[name] = liveins[name]
|
||||
yields[name] = self.local_defs[name]
|
||||
# for_op = self.builder.create_for_op(lb, ub, step, [init_args])
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Slice(self, node):
|
||||
lower = self.visit(node.lower)
|
||||
upper = self.visit(node.upper)
|
||||
@@ -1079,17 +945,12 @@ class JITFunction:
|
||||
lscope = generator.lscope.copy()
|
||||
|
||||
# TODO: clear values other than args
|
||||
lvalues = generator.lvalues.copy()
|
||||
# types = generator.module.get_types().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
|
||||
generator.lvalues = lvalues
|
||||
# generator.module.set_types(types)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
@@ -1222,9 +1083,6 @@ class JITFunction:
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||
|
||||
print(f'wargs: {wargs}')
|
||||
print(f'constants: {constants}')
|
||||
print(f'arg_types: {arg_types}')
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
|
@@ -312,7 +312,7 @@ def minus(input: tl.tensor,
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr():
|
||||
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
||||
_0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||
return sub(_0, input, builder)
|
||||
|
||||
|
||||
@@ -442,7 +442,7 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
_0 = ir.constant.get_null_value(dtype.to_ir(builder))
|
||||
_0 = builder.get_null_value(dtype.to_ir(builder))
|
||||
ret_ty = tl.block_type(dtype, shape)
|
||||
return tl.tensor(builder.create_splat(_0, shape), ret_ty)
|
||||
|
||||
|
Reference in New Issue
Block a user