now kernel functions return nothing (instead of none)
This commit is contained in:
@@ -24,7 +24,7 @@ def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
|||||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||||
|
|
||||||
// IntegerType
|
// IntegerType
|
||||||
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
|
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||||
def TT_I1Tensor : TensorOf<[I1]>;
|
def TT_I1Tensor : TensorOf<[I1]>;
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
|||||||
let results = (outs I32:$result);
|
let results = (outs I32:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
|
def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
|
||||||
let summary = "dot";
|
let summary = "dot";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@@ -964,6 +964,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
})
|
})
|
||||||
// .def("create_int_cast", &ir::builder::create_int_cast)
|
// .def("create_int_cast", &ir::builder::create_int_cast)
|
||||||
// .def("create_downcast", &ir::builder::create_downcast)
|
// .def("create_downcast", &ir::builder::create_downcast)
|
||||||
|
.def("create_to_index", [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::arith::IndexCastOp>(loc, input, self.getIndexType());
|
||||||
|
})
|
||||||
|
|
||||||
.def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
.def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::arith::MulFOp>(loc, lhs, rhs);
|
return self.create<mlir::arith::MulFOp>(loc, lhs, rhs);
|
||||||
|
@@ -502,6 +502,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||||
|
|
||||||
|
lb = self.builder.create_to_index(lb)
|
||||||
|
ub = self.builder.create_to_index(ub)
|
||||||
|
step = self.builder.create_to_index(step)
|
||||||
|
|
||||||
insert_block = self.builder.get_insertion_block()
|
insert_block = self.builder.get_insertion_block()
|
||||||
|
|
||||||
block = self.builder.create_block()
|
block = self.builder.create_block()
|
||||||
@@ -1210,8 +1214,8 @@ class JITFunction:
|
|||||||
context.load_triton()
|
context.load_triton()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||||
ret_type = triton.language.void
|
ret_types = []
|
||||||
prototype = triton.language.function_type([ret_type], arg_types)
|
prototype = triton.language.function_type(ret_types, arg_types)
|
||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self into code-generator object
|
# export symbols visible from self into code-generator object
|
||||||
gscope = self.__globals__
|
gscope = self.__globals__
|
||||||
|
Reference in New Issue
Block a user