Fix visit_While issues

This commit is contained in:
Yan Da
2022-04-10 16:16:13 +08:00
parent 19f81b7dea
commit fcbbb3c10e
4 changed files with 47 additions and 18 deletions

View File

@@ -32,7 +32,7 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
TODO
Triton PointerType
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);

View File

@@ -293,6 +293,7 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
self.local_defs = {}
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(else_block)
@@ -303,7 +304,6 @@ class CodeGenerator(ast.NodeVisitor):
else_defs = {}
for name in then_defs:
if name in liveins:
# TODO: what if this is constexpr?
assert self.is_triton_tensor(then_defs[name])
assert self.is_triton_tensor(liveins[name])
else_defs[name] = liveins[name]
@@ -452,16 +452,16 @@ class CodeGenerator(ast.NodeVisitor):
self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields])
# update global uses in while_op
for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
# update global uses in while_op
for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
for stmt in node.orelse:
assert False, "Not implemented"

View File

@@ -13,9 +13,25 @@ def generic_while(lb, value):
while c <= 0:
c += 1
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
mod_atomic.dump()
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
# mod_atomic.dump()
mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
mod_generic_while.dump()
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
# mod_generic_while.dump()
@triton.jit
def nested_cf(X, lb, ub, Z):
a = 0.0
if lb < ub:
for z in range(0, Z):
a += 2.0
# a += 2.0
else:
# a *= 2.0
while a < 1.2:
a *= 2.0
a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
assert mod.verify(), mod.str()

View File

@@ -146,7 +146,17 @@ def test_nested():
}
scf.yield %10 : f32
} else {
scf.yield %arg5 : f32
%7 = scf.while (%arg6 = %arg5) : (f32) -> f32 {
%cst_1 = arith.constant 1.200000e+00 : f32
%8 = arith.cmpf olt, %arg6, %cst_1 : f32
scf.condition(%8) %arg6 : f32
} do {
^bb0(%arg6: f32):
%cst_1 = arith.constant 2.000000e+00 : f32
%8 = arith.mulf %arg6, %cst_1 : f32
scf.yield %8 : f32
}
scf.yield %7 : f32
}
scf.yield %6 : f32
}
@@ -163,11 +173,14 @@ def test_nested():
if lb < ub:
for z in range(0, Z):
a += 2.0
else:
while a < 1.2:
a *= 2.0
a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
generated_ir = mod.str()
assert mod.verify()
assert mod.verify(), generated_ir
assert ref_ir == generated_ir
def test_matmul():