From 62f772123c22589e5cbb96b4cb4375e3c667c030 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 7 Apr 2022 20:22:17 +0800 Subject: [PATCH] now kernel functions return nothing (instead of none) --- include/triton/ir/TritonOps.td | 4 ++-- python/src/triton.cc | 5 +++++ python/triton/code_gen.py | 8 ++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index e2ce750af..0b4516ff5 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -24,7 +24,7 @@ def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : TensorOf<[TT_Float]>; // IntegerType -def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">; +def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; def TT_IntegerTensor : TensorOf<[TT_Int]>; def TT_I1Tensor : TensorOf<[I1]>; @@ -194,7 +194,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { let results = (outs I32:$result); } -def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> { +def TT_DotOp : TT_Op<"dot", [NoSideEffect]> { let summary = "dot"; let description = [{ diff --git a/python/src/triton.cc b/python/src/triton.cc index bf7299429..57c0e875e 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -964,6 +964,11 @@ void init_triton_ir(py::module &&m) { }) // .def("create_int_cast", &ir::builder::create_int_cast) // .def("create_downcast", &ir::builder::create_downcast) + .def("create_to_index", [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, input, self.getIndexType()); + }) + .def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); return self.create(loc, lhs, rhs); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index f8a1bbef9..1f99a155a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -502,6 +502,10 @@ class CodeGenerator(ast.NodeVisitor): ub = triton.language.core._to_tensor(ub, self.builder).handle step = triton.language.core._to_tensor(step, self.builder).handle + lb = self.builder.create_to_index(lb) + ub = self.builder.create_to_index(ub) + step = self.builder.create_to_index(step) + insert_block = self.builder.get_insertion_block() block = self.builder.create_block() @@ -1210,8 +1214,8 @@ class JITFunction: context.load_triton() # get just-in-time proto-type of kernel arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type([ret_type], arg_types) + ret_types = [] + prototype = triton.language.function_type(ret_types, arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__