Fix OpBuilder

This commit is contained in:
Yan Da
2022-04-07 20:01:31 +08:00
parent 6b4da6f016
commit 040a2b6c75
5 changed files with 41 additions and 20 deletions

View File

@@ -1155,9 +1155,12 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
})
// // Input/Output
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value {
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs,
mlir::triton::CacheModifier cacheModifer,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::LoadOp>(loc, ptrs);
return self.create<mlir::triton::LoadOp>(loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
})
.def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void {
auto loc = self.getUnknownLoc();
@@ -1200,8 +1203,7 @@ void init_triton_ir(py::module &&m) {
})
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
// TODO: should be scalar type here
auto argType = arg.getType();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
return self.create<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg
);
@@ -1246,9 +1248,9 @@ void init_triton_ir(py::module &&m) {
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
);
})
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value {
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c, bool allowTF32) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c);
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c, allowTF32);
})
// .def("create_exp", &ir::builder::create_exp, ret::reference)
// .def("create_cos", &ir::builder::create_cos, ret::reference)
@@ -1257,7 +1259,11 @@ void init_triton_ir(py::module &&m) {
// .def("create_trans", &ir::builder::create_trans, ret::reference)
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
// .def("create_select", &ir::builder::create_select, ret::reference)
.def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition,
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::SelectOp>(loc, condition, trueValue, falseValue);
})
// // Intrinsics
// // These have no place in the IR, and hopefully they can be removed at some point
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)

View File

@@ -1177,7 +1177,16 @@ class JITFunction:
# Compile to ttir, for the propose of testing MLIR rewriting
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# TODO: share code with _compile & __call__
# handle arguments passed by name
kwargs = {self.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
if len(wargs) != len(self.arg_names):
raise TypeError(f"Function takes {len(self.arg_names)} positional arguments but {len(wargs)} were given")
# handle annotations
for pos, _type in self.annotations.items():
wargs[pos] = _type(wargs[pos])
# preparing args
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
@@ -1191,7 +1200,7 @@ class JITFunction:
attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]