Fix visit_While issues
This commit is contained in:
@@ -32,7 +32,7 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
Triton PointerType
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||
|
@@ -293,6 +293,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.lscope = liveins
|
||||
self.local_defs = {}
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
@@ -303,7 +304,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
# TODO: what if this is constexpr?
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
else_defs[name] = liveins[name]
|
||||
|
@@ -13,9 +13,25 @@ def generic_while(lb, value):
|
||||
while c <= 0:
|
||||
c += 1
|
||||
|
||||
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
mod_atomic.dump()
|
||||
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
# mod_atomic.dump()
|
||||
|
||||
mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||
mod_generic_while.dump()
|
||||
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||
# mod_generic_while.dump()
|
||||
|
||||
@triton.jit
|
||||
def nested_cf(X, lb, ub, Z):
|
||||
a = 0.0
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
# a += 2.0
|
||||
else:
|
||||
# a *= 2.0
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
assert mod.verify(), mod.str()
|
||||
|
@@ -146,7 +146,17 @@ def test_nested():
|
||||
}
|
||||
scf.yield %10 : f32
|
||||
} else {
|
||||
scf.yield %arg5 : f32
|
||||
%7 = scf.while (%arg6 = %arg5) : (f32) -> f32 {
|
||||
%cst_1 = arith.constant 1.200000e+00 : f32
|
||||
%8 = arith.cmpf olt, %arg6, %cst_1 : f32
|
||||
scf.condition(%8) %arg6 : f32
|
||||
} do {
|
||||
^bb0(%arg6: f32):
|
||||
%cst_1 = arith.constant 2.000000e+00 : f32
|
||||
%8 = arith.mulf %arg6, %cst_1 : f32
|
||||
scf.yield %8 : f32
|
||||
}
|
||||
scf.yield %7 : f32
|
||||
}
|
||||
scf.yield %6 : f32
|
||||
}
|
||||
@@ -163,11 +173,14 @@ def test_nested():
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
else:
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify()
|
||||
assert mod.verify(), generated_ir
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_matmul():
|
||||
|
Reference in New Issue
Block a user