now kernel functions return nothing (instead of none)

This commit is contained in:
Yan Da
2022-04-07 20:22:17 +08:00
parent 040a2b6c75
commit 62f772123c
3 changed files with 13 additions and 4 deletions

View File

@@ -24,7 +24,7 @@ def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
// IntegerType
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_I1Tensor : TensorOf<[I1]>;
@@ -194,7 +194,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
let results = (outs I32:$result);
}
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
let summary = "dot";
let description = [{

View File

@@ -964,6 +964,11 @@ void init_triton_ir(py::module &&m) {
})
// .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast)
.def("create_to_index", [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::IndexCastOp>(loc, input, self.getIndexType());
})
.def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::MulFOp>(loc, lhs, rhs);

View File

@@ -502,6 +502,10 @@ class CodeGenerator(ast.NodeVisitor):
ub = triton.language.core._to_tensor(ub, self.builder).handle
step = triton.language.core._to_tensor(step, self.builder).handle
lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
insert_block = self.builder.get_insertion_block()
block = self.builder.create_block()
@@ -1210,8 +1214,8 @@ class JITFunction:
context.load_triton()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
ret_types = []
prototype = triton.language.function_type(ret_types, arg_types)
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__