More progress on WhileOp codegen

This commit is contained in:
Yan Da
2022-04-05 15:55:48 +08:00
parent 76d9249724
commit c7ad928e60
5 changed files with 145 additions and 45 deletions

View File

@@ -638,17 +638,6 @@ void init_triton_ir(py::module &&m) {
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
// // .def("get", &ir::undef_value::get, ret::reference);
py::class_<mlir::ModuleOp>(m, "module")
// .def("set_attr")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("get_context", &mlir::ModuleOp::getContext)
;
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16)
@@ -753,13 +742,27 @@ void init_triton_ir(py::module &&m) {
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
;
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
// .def("set_attr")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<mlir::MLIRContext *>())
// // getters
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
.def_property_readonly("context", &mlir::OpBuilder::getContext, ret::reference)
.def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc);
@@ -883,6 +886,9 @@ void init_triton_ir(py::module &&m) {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* {
return self.createBlock(&parent);
})
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return new mlir::Block();
}, ret::reference)
@@ -900,7 +906,16 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::YieldOp>(loc, yields);
})
// // .def("create_while")
.def("create_while_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes,
std::vector<mlir::Value> &initArgs) -> mlir::scf::WhileOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
})
.def("create_condtion_op", [](mlir::OpBuilder &self, mlir::Value &cond,
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
})
// miscellious
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value {

View File

@@ -46,6 +46,7 @@ class CodeGenerator(ast.NodeVisitor):
# SSA-construction
# name => triton.language.tensor
self.local_defs: Dict[str, triton.language.tensor] = {}
self.global_uses: Dict[str, triton.language.tensor] = {}
def get_value(self, name):
''' This function:
@@ -57,6 +58,8 @@ class CodeGenerator(ast.NodeVisitor):
ret = None
if name in self.lscope:
ret = self.lscope[name]
if name not in self.local_defs:
self.global_uses[name] = ret
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
@@ -263,27 +266,6 @@ class CodeGenerator(ast.NodeVisitor):
def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
# cond = cond.to(triton.language.int1, _builder=self.builder)
# current_bb = self.builder.get_insertion_block()
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
# if else_bb:
# self.builder.cond_br(cond.handle, then_bb, else_bb)
# else:
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
# self.builder.set_insert_block(then_bb)
# is_terminator = self.visit_compound_statement(node.body)
# # TODO: last statement is a terminator?
# if not is_terminator:
# self.builder.br(endif_bb)
# if else_bb:
# self.builder.set_insert_block(else_bb)
# is_terminator = self.visit_compound_statement(node.orelse)
# # TODO: last statement is a terminator?
# if not is_terminator:
# self.builder.br(endif_bb)
# self.builder.set_insert_block(endif_bb)
cond = cond.to(triton.language.int1, _builder=self.builder)
liveins = self.lscope.copy()
parent_defs = self.local_defs.copy()
@@ -413,22 +395,64 @@ class CodeGenerator(ast.NodeVisitor):
return getattr(op, fn)()
def visit_While(self, node):
current_bb = self.builder.get_insertion_block()
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
liveins = self.lscope.copy()
prev_defs = self.local_defs.copy()
self.local_defs = {}
def continue_fn():
cond = self.visit(node.test)
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
insert_block = self.builder.get_insertion_block()
continue_fn()
self.builder.set_insert_block(loop_bb)
# condtion (the before region)
cond_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(cond_block)
cond = self.visit(node.test)
# loop body (the after region)
loop_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(loop_block)
self.visit_compound_statement(node.body)
continue_fn()
stop_bb = self.builder.get_insertion_block()
self.builder.set_insert_block(next_bb)
loop_defs = self.local_defs
# collect loop-carried values
names = []
ret_types = []
init_args = []
yields = []
for name in loop_defs:
if name in liveins:
# We should not def new constexpr (?)
assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name])
if loop_defs[name].type == liveins[name].type:
# these are loop-carried values
names.append(name)
ret_types.append(loop_defs[name].type.to_ir(self.builder))
init_args.append(liveins[name])
yields.append(loop_defs[name])
self.builder.set_insertion_point_to_end(insert_block)
while_op = self.builder.create_while_op(ret_types, init_args)
# merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before())
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
self.builder.create_condtion_op(cond.handle, [])
# merge the loop body
after_block = self.builder.create_block_with_parent(while_op.get_after())
loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields])
self.builder.set_insertion_point_to_end(insert_block)
self.lscope = liveins
self.local_defs = prev_defs
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
for stmt in node.orelse:
assert False, "Not implemented"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Subscript(self, node):

View File

@@ -0,0 +1,43 @@
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
# BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# # NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * 256
offsets = block_start + tl.arange(0, 256)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
z = torch.empty_like(x)
# add_kernel[(1,)](x, y, z, size, 256)
# print(add_kernel[(1,)].kernel.compile_to_ttir())
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, grid=(1,))
mod.get_context()
mod.dump()
# print(mod)

18
rewrite-test/jit/while.py Normal file
View File

@@ -0,0 +1,18 @@
import triton
import triton.language as tl
import torch
@triton.jit
def atomic(lock):
while tl.atomic_cas(lock, 0, 1) == 1:
pass
@triton.jit
def generic_while(lb, value):
c = -1
while c <= 0:
c += 1
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
mod_atomic.dump()