Fix visit_While issues
This commit is contained in:
@@ -32,7 +32,7 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
Triton PointerType
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||
|
@@ -293,6 +293,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.lscope = liveins
|
||||
self.local_defs = {}
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
@@ -303,7 +304,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
# TODO: what if this is constexpr?
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
else_defs[name] = liveins[name]
|
||||
@@ -452,16 +452,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Not implemented"
|
||||
|
@@ -13,9 +13,25 @@ def generic_while(lb, value):
|
||||
while c <= 0:
|
||||
c += 1
|
||||
|
||||
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
mod_atomic.dump()
|
||||
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
# mod_atomic.dump()
|
||||
|
||||
mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||
mod_generic_while.dump()
|
||||
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||
# mod_generic_while.dump()
|
||||
|
||||
@triton.jit
|
||||
def nested_cf(X, lb, ub, Z):
|
||||
a = 0.0
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
# a += 2.0
|
||||
else:
|
||||
# a *= 2.0
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
assert mod.verify(), mod.str()
|
||||
|
@@ -146,7 +146,17 @@ def test_nested():
|
||||
}
|
||||
scf.yield %10 : f32
|
||||
} else {
|
||||
scf.yield %arg5 : f32
|
||||
%7 = scf.while (%arg6 = %arg5) : (f32) -> f32 {
|
||||
%cst_1 = arith.constant 1.200000e+00 : f32
|
||||
%8 = arith.cmpf olt, %arg6, %cst_1 : f32
|
||||
scf.condition(%8) %arg6 : f32
|
||||
} do {
|
||||
^bb0(%arg6: f32):
|
||||
%cst_1 = arith.constant 2.000000e+00 : f32
|
||||
%8 = arith.mulf %arg6, %cst_1 : f32
|
||||
scf.yield %8 : f32
|
||||
}
|
||||
scf.yield %7 : f32
|
||||
}
|
||||
scf.yield %6 : f32
|
||||
}
|
||||
@@ -163,11 +173,14 @@ def test_nested():
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
else:
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify()
|
||||
assert mod.verify(), generated_ir
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_matmul():
|
||||
|
Reference in New Issue
Block a user