From c7ad928e6011aef3a3177d5a95018cc7e69207f5 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Tue, 5 Apr 2022 15:55:48 +0800 Subject: [PATCH] More progress on WhileOp codegen --- python/src/triton.cc | 41 ++++++--- python/triton/code_gen.py | 88 ++++++++++++------- .../if-else/{vecadd-cond.py => if-else.py} | 0 rewrite-test/jit/vecadd.py | 43 +++++++++ rewrite-test/jit/while.py | 18 ++++ 5 files changed, 145 insertions(+), 45 deletions(-) rename rewrite-test/jit/if-else/{vecadd-cond.py => if-else.py} (100%) create mode 100644 rewrite-test/jit/vecadd.py create mode 100644 rewrite-test/jit/while.py diff --git a/python/src/triton.cc b/python/src/triton.cc index 953c62a40..c0bc61693 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -638,17 +638,6 @@ void init_triton_ir(py::module &&m) { // // py::class_(m, "undef") // // .def("get", &ir::undef_value::get, ret::reference); - py::class_(m, "module") - // .def("set_attr") - .def("dump", [](mlir::ModuleOp &self) -> void { - self.dump(); - }) - .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { - self.push_back(funcOp); - }) - .def("get_context", &mlir::ModuleOp::getContext) - ; - py::class_(m, "type") .def("is_integer", &mlir::Type::isInteger) .def("is_fp16", &mlir::Type::isF16) @@ -753,13 +742,27 @@ void init_triton_ir(py::module &&m) { .def("get_else_yield", &mlir::scf::IfOp::elseYield) ; py::class_(m, "YieldOp"); + py::class_(m, "WhileOp") + .def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference) + .def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference); + py::class_(m, "CondtionOp"); + + py::class_(m, "module") + // .def("set_attr") + .def("dump", [](mlir::ModuleOp &self) -> void { + self.dump(); + }) + .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + ; py::class_(m, "InsertPoint"); py::class_(m, "builder", py::dynamic_attr()) .def(py::init()) // // getters - // .def_property_readonly("context", &ir::builder::get_context, ret::reference); + .def_property_readonly("context", &mlir::OpBuilder::getContext, ret::reference) .def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp { auto loc = self.getUnknownLoc(); return self.create(loc); @@ -883,6 +886,9 @@ void init_triton_ir(py::module &&m) { mlir::Region *parent = self.getBlock()->getParent(); return self.createBlock(parent); }, ret::reference) + .def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* { + return self.createBlock(&parent); + }) .def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* { return new mlir::Block(); }, ret::reference) @@ -900,7 +906,16 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, yields); }) - // // .def("create_while") + .def("create_while_op", [](mlir::OpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> mlir::scf::WhileOp { + auto loc = self.getUnknownLoc(); + return self.create(loc, retTypes, initArgs); + }) + .def("create_condtion_op", [](mlir::OpBuilder &self, mlir::Value &cond, + std::vector &args) -> mlir::scf::ConditionOp { + auto loc = self.getUnknownLoc(); + return self.create(loc, cond, args); + }) // miscellious .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value { diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index e65a21b5c..3313c5328 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -46,6 +46,7 @@ class CodeGenerator(ast.NodeVisitor): # SSA-construction # name => triton.language.tensor self.local_defs: Dict[str, triton.language.tensor] = {} + self.global_uses: Dict[str, triton.language.tensor] = {} def get_value(self, name): ''' This function: @@ -57,6 +58,8 @@ class CodeGenerator(ast.NodeVisitor): ret = None if name in self.lscope: ret = self.lscope[name] + if name not in self.local_defs: + self.global_uses[name] = ret # search node.id in global scope elif name in self.gscope: ret = self.gscope[name] @@ -263,27 +266,6 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) if isinstance(cond, triton.language.tensor): - # cond = cond.to(triton.language.int1, _builder=self.builder) - # current_bb = self.builder.get_insertion_block() - # then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) - # else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None - # endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - # if else_bb: - # self.builder.cond_br(cond.handle, then_bb, else_bb) - # else: - # self.builder.cond_br(cond.handle, then_bb, endif_bb) - # self.builder.set_insert_block(then_bb) - # is_terminator = self.visit_compound_statement(node.body) - # # TODO: last statement is a terminator? - # if not is_terminator: - # self.builder.br(endif_bb) - # if else_bb: - # self.builder.set_insert_block(else_bb) - # is_terminator = self.visit_compound_statement(node.orelse) - # # TODO: last statement is a terminator? - # if not is_terminator: - # self.builder.br(endif_bb) - # self.builder.set_insert_block(endif_bb) cond = cond.to(triton.language.int1, _builder=self.builder) liveins = self.lscope.copy() parent_defs = self.local_defs.copy() @@ -413,22 +395,64 @@ class CodeGenerator(ast.NodeVisitor): return getattr(op, fn)() def visit_While(self, node): - current_bb = self.builder.get_insertion_block() - loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) + liveins = self.lscope.copy() + prev_defs = self.local_defs.copy() + self.local_defs = {} - def continue_fn(): - cond = self.visit(node.test) - return self.builder.cond_br(cond.handle, loop_bb, next_bb) + insert_block = self.builder.get_insertion_block() - continue_fn() - self.builder.set_insert_block(loop_bb) + # condtion (the before region) + cond_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(cond_block) + cond = self.visit(node.test) + + # loop body (the after region) + loop_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(loop_block) self.visit_compound_statement(node.body) - continue_fn() - stop_bb = self.builder.get_insertion_block() - self.builder.set_insert_block(next_bb) + loop_defs = self.local_defs + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + yields = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr (?) + assert self.is_triton_tensor(loop_defs[name]) + assert self.is_triton_tensor(liveins[name]) + if loop_defs[name].type == liveins[name].type: + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type.to_ir(self.builder)) + init_args.append(liveins[name]) + yields.append(loop_defs[name]) + + self.builder.set_insertion_point_to_end(insert_block) + while_op = self.builder.create_while_op(ret_types, init_args) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before()) + cond_block.merge_block_before(before_block) + self.builder.set_insertion_point_to_end(before_block) + self.builder.create_condtion_op(cond.handle, []) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after()) + loop_block.merge_block_before(after_block) + self.builder.set_insertion_point_to_end(after_block) + self.builder.create_yield_op([y.handle for y in yields]) + + self.builder.set_insertion_point_to_end(insert_block) + self.lscope = liveins + self.local_defs = prev_defs + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def for stmt in node.orelse: + assert False, "Not implemented" ast.NodeVisitor.generic_visit(self, stmt) def visit_Subscript(self, node): diff --git a/rewrite-test/jit/if-else/vecadd-cond.py b/rewrite-test/jit/if-else/if-else.py similarity index 100% rename from rewrite-test/jit/if-else/vecadd-cond.py rename to rewrite-test/jit/if-else/if-else.py diff --git a/rewrite-test/jit/vecadd.py b/rewrite-test/jit/vecadd.py new file mode 100644 index 000000000..e3d6f3f9a --- /dev/null +++ b/rewrite-test/jit/vecadd.py @@ -0,0 +1,43 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector + y_ptr, # *Pointer* to second input vector + output_ptr, # *Pointer* to output vector + n_elements, # Size of the vector + # BLOCK_SIZE: tl.constexpr, # Number of elements each program should process + # # NOTE: `constexpr` so it can be used as a shape value +): + # There are multiple 'program's processing different data. We identify which program + # we are here + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 + # This program will process inputs that are offset from the initial data. + # for instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers + block_start = pid * 256 + offsets = block_start + tl.arange(0, 256) + # Create a mask to guard memory operations against out-of-bounds accesses + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + output = x + y + # Write x + y back to DRAM + tl.store(output_ptr + offsets, output, mask=mask) + +size = 1024 +x = torch.rand(size, device='cuda') +y = torch.rand(size, device='cuda') +z = torch.empty_like(x) +# add_kernel[(1,)](x, y, z, size, 256) +# print(add_kernel[(1,)].kernel.compile_to_ttir()) +mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, grid=(1,)) +mod.get_context() +mod.dump() +# print(mod) diff --git a/rewrite-test/jit/while.py b/rewrite-test/jit/while.py new file mode 100644 index 000000000..25f41e86d --- /dev/null +++ b/rewrite-test/jit/while.py @@ -0,0 +1,18 @@ +import triton +import triton.language as tl +import torch + +@triton.jit +def atomic(lock): + while tl.atomic_cas(lock, 0, 1) == 1: + pass + +@triton.jit +def generic_while(lb, value): + c = -1 + while c <= 0: + c += 1 + +locks = torch.zeros(32, dtype=torch.int32, device='cuda') +mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,)) +mod_atomic.dump()