diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 3f66b6d65..b96ae5116 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -32,7 +32,7 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { let summary = "pointer type"; let description = [{ - TODO + Triton PointerType }]; let parameters = (ins "Type":$pointeeType, "int":$addressSpace); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 32c7ce7f5..a7664a896 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -293,6 +293,7 @@ class CodeGenerator(ast.NodeVisitor): if then_defs or node.orelse: if node.orelse: + self.lscope = liveins self.local_defs = {} else_block = self.builder.create_block() self.builder.set_insertion_point_to_end(else_block) @@ -303,7 +304,6 @@ class CodeGenerator(ast.NodeVisitor): else_defs = {} for name in then_defs: if name in liveins: - # TODO: what if this is constexpr? assert self.is_triton_tensor(then_defs[name]) assert self.is_triton_tensor(liveins[name]) else_defs[name] = liveins[name] @@ -452,16 +452,16 @@ class CodeGenerator(ast.NodeVisitor): self.builder.set_insertion_point_to_end(after_block) self.builder.create_yield_op([y.handle for y in yields]) - # update global uses in while_op - for i, name in enumerate(names): - before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) - after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) + # update global uses in while_op + for i, name in enumerate(names): + before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) + after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) - # WhileOp defines new values, update the symbol table (lscope, local_defs) - for i, name in enumerate(names): - new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) - self.lscope[name] = new_def - self.local_defs[name] = new_def + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def for stmt in node.orelse: assert False, "Not implemented" diff --git a/rewrite-test/jit/while.py b/rewrite-test/jit/while.py index 15d49c75b..5e27025af 100644 --- a/rewrite-test/jit/while.py +++ b/rewrite-test/jit/while.py @@ -13,9 +13,25 @@ def generic_while(lb, value): while c <= 0: c += 1 -locks = torch.zeros(32, dtype=torch.int32, device='cuda') -mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,)) -mod_atomic.dump() +# locks = torch.zeros(32, dtype=torch.int32, device='cuda') +# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,)) +# mod_atomic.dump() -mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,)) -mod_generic_while.dump() +# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,)) +# mod_generic_while.dump() + +@triton.jit +def nested_cf(X, lb, ub, Z): + a = 0.0 + if lb < ub: + for z in range(0, Z): + a += 2.0 + # a += 2.0 + else: + # a *= 2.0 + while a < 1.2: + a *= 2.0 + a -= 1.0 + +mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,)) +assert mod.verify(), mod.str() diff --git a/rewrite-test/scf_tests.py b/rewrite-test/scf_tests.py index 751f990a6..26fdf82cf 100644 --- a/rewrite-test/scf_tests.py +++ b/rewrite-test/scf_tests.py @@ -146,7 +146,17 @@ def test_nested(): } scf.yield %10 : f32 } else { - scf.yield %arg5 : f32 + %7 = scf.while (%arg6 = %arg5) : (f32) -> f32 { + %cst_1 = arith.constant 1.200000e+00 : f32 + %8 = arith.cmpf olt, %arg6, %cst_1 : f32 + scf.condition(%8) %arg6 : f32 + } do { + ^bb0(%arg6: f32): + %cst_1 = arith.constant 2.000000e+00 : f32 + %8 = arith.mulf %arg6, %cst_1 : f32 + scf.yield %8 : f32 + } + scf.yield %7 : f32 } scf.yield %6 : f32 } @@ -163,11 +173,14 @@ def test_nested(): if lb < ub: for z in range(0, Z): a += 2.0 + else: + while a < 1.2: + a *= 2.0 a -= 1.0 mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,)) generated_ir = mod.str() - assert mod.verify() + assert mod.verify(), generated_ir assert ref_ir == generated_ir def test_matmul():