Fix visit_While issues
This commit is contained in:
@@ -32,7 +32,7 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
|||||||
let summary = "pointer type";
|
let summary = "pointer type";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
TODO
|
Triton PointerType
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||||
|
@@ -293,6 +293,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
if then_defs or node.orelse:
|
if then_defs or node.orelse:
|
||||||
if node.orelse:
|
if node.orelse:
|
||||||
|
self.lscope = liveins
|
||||||
self.local_defs = {}
|
self.local_defs = {}
|
||||||
else_block = self.builder.create_block()
|
else_block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_end(else_block)
|
self.builder.set_insertion_point_to_end(else_block)
|
||||||
@@ -303,7 +304,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
else_defs = {}
|
else_defs = {}
|
||||||
for name in then_defs:
|
for name in then_defs:
|
||||||
if name in liveins:
|
if name in liveins:
|
||||||
# TODO: what if this is constexpr?
|
|
||||||
assert self.is_triton_tensor(then_defs[name])
|
assert self.is_triton_tensor(then_defs[name])
|
||||||
assert self.is_triton_tensor(liveins[name])
|
assert self.is_triton_tensor(liveins[name])
|
||||||
else_defs[name] = liveins[name]
|
else_defs[name] = liveins[name]
|
||||||
@@ -452,16 +452,16 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.builder.set_insertion_point_to_end(after_block)
|
self.builder.set_insertion_point_to_end(after_block)
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
|
||||||
# update global uses in while_op
|
# update global uses in while_op
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||||
|
|
||||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||||
self.lscope[name] = new_def
|
self.lscope[name] = new_def
|
||||||
self.local_defs[name] = new_def
|
self.local_defs[name] = new_def
|
||||||
|
|
||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
assert False, "Not implemented"
|
assert False, "Not implemented"
|
||||||
|
@@ -13,9 +13,25 @@ def generic_while(lb, value):
|
|||||||
while c <= 0:
|
while c <= 0:
|
||||||
c += 1
|
c += 1
|
||||||
|
|
||||||
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||||
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||||
mod_atomic.dump()
|
# mod_atomic.dump()
|
||||||
|
|
||||||
mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||||
mod_generic_while.dump()
|
# mod_generic_while.dump()
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def nested_cf(X, lb, ub, Z):
|
||||||
|
a = 0.0
|
||||||
|
if lb < ub:
|
||||||
|
for z in range(0, Z):
|
||||||
|
a += 2.0
|
||||||
|
# a += 2.0
|
||||||
|
else:
|
||||||
|
# a *= 2.0
|
||||||
|
while a < 1.2:
|
||||||
|
a *= 2.0
|
||||||
|
a -= 1.0
|
||||||
|
|
||||||
|
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||||
|
assert mod.verify(), mod.str()
|
||||||
|
@@ -146,7 +146,17 @@ def test_nested():
|
|||||||
}
|
}
|
||||||
scf.yield %10 : f32
|
scf.yield %10 : f32
|
||||||
} else {
|
} else {
|
||||||
scf.yield %arg5 : f32
|
%7 = scf.while (%arg6 = %arg5) : (f32) -> f32 {
|
||||||
|
%cst_1 = arith.constant 1.200000e+00 : f32
|
||||||
|
%8 = arith.cmpf olt, %arg6, %cst_1 : f32
|
||||||
|
scf.condition(%8) %arg6 : f32
|
||||||
|
} do {
|
||||||
|
^bb0(%arg6: f32):
|
||||||
|
%cst_1 = arith.constant 2.000000e+00 : f32
|
||||||
|
%8 = arith.mulf %arg6, %cst_1 : f32
|
||||||
|
scf.yield %8 : f32
|
||||||
|
}
|
||||||
|
scf.yield %7 : f32
|
||||||
}
|
}
|
||||||
scf.yield %6 : f32
|
scf.yield %6 : f32
|
||||||
}
|
}
|
||||||
@@ -163,11 +173,14 @@ def test_nested():
|
|||||||
if lb < ub:
|
if lb < ub:
|
||||||
for z in range(0, Z):
|
for z in range(0, Z):
|
||||||
a += 2.0
|
a += 2.0
|
||||||
|
else:
|
||||||
|
while a < 1.2:
|
||||||
|
a *= 2.0
|
||||||
a -= 1.0
|
a -= 1.0
|
||||||
|
|
||||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||||
generated_ir = mod.str()
|
generated_ir = mod.str()
|
||||||
assert mod.verify()
|
assert mod.verify(), generated_ir
|
||||||
assert ref_ir == generated_ir
|
assert ref_ir == generated_ir
|
||||||
|
|
||||||
def test_matmul():
|
def test_matmul():
|
||||||
|
Reference in New Issue
Block a user