More progress on WhileOp codegen
This commit is contained in:
@@ -638,17 +638,6 @@ void init_triton_ir(py::module &&m) {
|
||||
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
// // .def("get", &ir::undef_value::get, ret::reference);
|
||||
|
||||
py::class_<mlir::ModuleOp>(m, "module")
|
||||
// .def("set_attr")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("get_context", &mlir::ModuleOp::getContext)
|
||||
;
|
||||
|
||||
py::class_<mlir::Type>(m, "type")
|
||||
.def("is_integer", &mlir::Type::isInteger)
|
||||
.def("is_fp16", &mlir::Type::isF16)
|
||||
@@ -753,13 +742,27 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
|
||||
;
|
||||
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
|
||||
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
|
||||
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
|
||||
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
|
||||
// .def("set_attr")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
|
||||
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
// // getters
|
||||
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
|
||||
.def_property_readonly("context", &mlir::OpBuilder::getContext, ret::reference)
|
||||
.def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::ModuleOp>(loc);
|
||||
@@ -883,6 +886,9 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
}, ret::reference)
|
||||
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* {
|
||||
return self.createBlock(&parent);
|
||||
})
|
||||
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
return new mlir::Block();
|
||||
}, ret::reference)
|
||||
@@ -900,7 +906,16 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::YieldOp>(loc, yields);
|
||||
})
|
||||
// // .def("create_while")
|
||||
.def("create_while_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes,
|
||||
std::vector<mlir::Value> &initArgs) -> mlir::scf::WhileOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
|
||||
})
|
||||
.def("create_condtion_op", [](mlir::OpBuilder &self, mlir::Value &cond,
|
||||
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
|
||||
})
|
||||
|
||||
// miscellious
|
||||
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
|
||||
|
@@ -46,6 +46,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# SSA-construction
|
||||
# name => triton.language.tensor
|
||||
self.local_defs: Dict[str, triton.language.tensor] = {}
|
||||
self.global_uses: Dict[str, triton.language.tensor] = {}
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
@@ -57,6 +58,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
if name not in self.local_defs:
|
||||
self.global_uses[name] = ret
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
@@ -263,27 +266,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
# cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
# current_bb = self.builder.get_insertion_block()
|
||||
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
# if else_bb:
|
||||
# self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
# else:
|
||||
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
# self.builder.set_insert_block(then_bb)
|
||||
# is_terminator = self.visit_compound_statement(node.body)
|
||||
# # TODO: last statement is a terminator?
|
||||
# if not is_terminator:
|
||||
# self.builder.br(endif_bb)
|
||||
# if else_bb:
|
||||
# self.builder.set_insert_block(else_bb)
|
||||
# is_terminator = self.visit_compound_statement(node.orelse)
|
||||
# # TODO: last statement is a terminator?
|
||||
# if not is_terminator:
|
||||
# self.builder.br(endif_bb)
|
||||
# self.builder.set_insert_block(endif_bb)
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
liveins = self.lscope.copy()
|
||||
parent_defs = self.local_defs.copy()
|
||||
@@ -413,22 +395,64 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insertion_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
|
||||
def continue_fn():
|
||||
cond = self.visit(node.test)
|
||||
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
|
||||
continue_fn()
|
||||
self.builder.set_insert_block(loop_bb)
|
||||
# condtion (the before region)
|
||||
cond_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(cond_block)
|
||||
cond = self.visit(node.test)
|
||||
|
||||
# loop body (the after region)
|
||||
loop_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insertion_block()
|
||||
self.builder.set_insert_block(next_bb)
|
||||
loop_defs = self.local_defs
|
||||
|
||||
# collect loop-carried values
|
||||
names = []
|
||||
ret_types = []
|
||||
init_args = []
|
||||
yields = []
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr (?)
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type.to_ir(self.builder))
|
||||
init_args.append(liveins[name])
|
||||
yields.append(loop_defs[name])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
while_op = self.builder.create_while_op(ret_types, init_args)
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before())
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
self.builder.create_condtion_op(cond.handle, [])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after())
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Not implemented"
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
|
43
rewrite-test/jit/vecadd.py
Normal file
43
rewrite-test/jit/vecadd.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
# BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# # NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * 256
|
||||
offsets = block_start + tl.arange(0, 256)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extra elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
||||
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
size = 1024
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
z = torch.empty_like(x)
|
||||
# add_kernel[(1,)](x, y, z, size, 256)
|
||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, grid=(1,))
|
||||
mod.get_context()
|
||||
mod.dump()
|
||||
# print(mod)
|
18
rewrite-test/jit/while.py
Normal file
18
rewrite-test/jit/while.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
@triton.jit
|
||||
def atomic(lock):
|
||||
while tl.atomic_cas(lock, 0, 1) == 1:
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def generic_while(lb, value):
|
||||
c = -1
|
||||
while c <= 0:
|
||||
c += 1
|
||||
|
||||
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
mod_atomic.dump()
|
Reference in New Issue
Block a user