From 9dafa0e2e35b6e3c9fe96c2fa4a8755ce727c321 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Fri, 1 Apr 2022 20:16:07 +0800 Subject: [PATCH] Update trtion dependencies --- include/triton/ir/TritonDialect.td | 1 + python/triton/code_gen.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/include/triton/ir/TritonDialect.td b/include/triton/ir/TritonDialect.td index c23075d33..2b2a57124 100644 --- a/include/triton/ir/TritonDialect.td +++ b/include/triton/ir/TritonDialect.td @@ -28,6 +28,7 @@ def Triton_Dialect : Dialect { "arith::ArithmeticDialect", "tensor::TensorDialect", "StandardOpsDialect", + "scf::SCFDialect" // Since LLVM 15 // "cf::ControlFlowDialect", diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index faf264de2..690f6cfc4 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -410,6 +410,10 @@ class CodeGenerator(ast.NodeVisitor): lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) + # TODO: better way to do this? + lb = triton.language.core._to_tensor(lb, self.builder).handle + ub = triton.language.core._to_tensor(ub, self.builder).handle + step = triton.language.core._to_tensor(step, self.builder).handle loop_body = self.builder.create_block() self.builder.set_insertion_point_to_start(loop_body) @@ -422,16 +426,16 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) # TODO: update insertion point - init_args = {} - yields = {} + init_args = [] + yields = [] for name in self.local_defs: if name in liveins: assert self.is_triton_tensor(self.local_defs[name]) assert self.is_triton_tensor(liveins[name]) if self.local_defs[name].type == liveins[name].type: - init_args[name] = liveins[name] - yields[name] = self.local_defs[name] - # for_op = self.builder.create_for_op(lb, ub, step, [init_args]) + init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle) + yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle) + for_op = self.builder.create_for_op(lb, ub, step, init_args) self.lscope = liveins self.local_defs = prev_defs