now kernel functions return nothing (instead of none)
This commit is contained in:
@@ -24,7 +24,7 @@ def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
|
||||
// IntegerType
|
||||
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
|
||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
def TT_I1Tensor : TensorOf<[I1]>;
|
||||
|
||||
@@ -194,7 +194,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
|
@@ -964,6 +964,11 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
// .def("create_int_cast", &ir::builder::create_int_cast)
|
||||
// .def("create_downcast", &ir::builder::create_downcast)
|
||||
.def("create_to_index", [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input, self.getIndexType());
|
||||
})
|
||||
|
||||
.def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::MulFOp>(loc, lhs, rhs);
|
||||
|
@@ -502,6 +502,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
|
||||
block = self.builder.create_block()
|
||||
@@ -1210,8 +1214,8 @@ class JITFunction:
|
||||
context.load_triton()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
ret_types = []
|
||||
prototype = triton.language.function_type(ret_types, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
|
Reference in New Issue
Block a user