From 040a2b6c75a35bcff68d9359e16e6e25bd5d97bd Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 7 Apr 2022 20:01:31 +0800 Subject: [PATCH] Fix OpBuilder --- include/triton/ir/TritonOps.td | 8 +++++--- lib/ir/Ops.cpp | 6 +++++- python/src/triton.cc | 20 +++++++++++++------- python/triton/code_gen.py | 13 +++++++++++-- rewrite-test/jit/vecadd.py | 14 +++++++------- 5 files changed, 41 insertions(+), 20 deletions(-) diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 620556287..e2ce750af 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -123,8 +123,10 @@ def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { let builders = [ // for args with default values - OpBuilder<(ins "Value":$ptr)>, - OpBuilder<(ins "Value":$ptr, "Value":$mask)> + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> ]; } @@ -199,7 +201,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> { $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c); + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); let results = (outs TT_FpIntTensor:$d); } diff --git a/lib/ir/Ops.cpp b/lib/ir/Ops.cpp index e5e7be3f0..3eb33cbbe 100644 --- a/lib/ir/Ops.cpp +++ b/lib/ir/Ops.cpp @@ -33,7 +33,8 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, : } //-- LoadOp -- -void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr) { +void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, + ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { TensorType ptrType = ptr.getType().dyn_cast(); Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); auto shape = ptrType.getShape(); @@ -57,6 +58,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, :: state.addOperands(ptr); state.addOperands(mask); state.addOperands(other); + state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); + state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); + state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile)); state.addTypes({resultType}); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 6e25f898a..bf7299429 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1155,9 +1155,12 @@ void init_triton_ir(py::module &&m) { return self.create(loc, lhs, rhs); }) // // Input/Output - .def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value { + .def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs, + mlir::triton::CacheModifier cacheModifer, + mlir::triton::EvictionPolicy evictionPolicy, + bool isVolatile) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, ptrs); + return self.create(loc, ptrs, cacheModifer, evictionPolicy, isVolatile); }) .def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void { auto loc = self.getUnknownLoc(); @@ -1200,8 +1203,7 @@ void init_triton_ir(py::module &&m) { }) .def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); - // TODO: should be scalar type here - auto argType = arg.getType(); + auto argType = arg.getType().dyn_cast().getElementType(); return self.create( loc, mlir::RankedTensorType::get(shape, argType), arg ); @@ -1246,9 +1248,9 @@ void init_triton_ir(py::module &&m) { loc, self.getI32Type(), self.getI32IntegerAttr(axis) ); }) - .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value { + .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c, bool allowTF32) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, c.getType(), a, b, c); + return self.create(loc, c.getType(), a, b, c, allowTF32); }) // .def("create_exp", &ir::builder::create_exp, ret::reference) // .def("create_cos", &ir::builder::create_cos, ret::reference) @@ -1257,7 +1259,11 @@ void init_triton_ir(py::module &&m) { // .def("create_trans", &ir::builder::create_trans, ret::reference) // .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) // .def("create_reduce", &ir::builder::create_reduce, ret::reference) - // .def("create_select", &ir::builder::create_select, ret::reference) + .def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition, + mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, condition, trueValue, falseValue); + }) // // Intrinsics // // These have no place in the IR, and hopefully they can be removed at some point // .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index d0dd38bf6..f8a1bbef9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1177,7 +1177,16 @@ class JITFunction: # Compile to ttir, for the propose of testing MLIR rewriting def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # TODO: share code with _compile & __call__ - + # handle arguments passed by name + kwargs = {self.arg_names.index(name): value for name, value in kwargs.items()} + wargs = list(wargs) + for i, pos in enumerate(sorted(kwargs)): + wargs.insert(pos + i, kwargs[pos]) + if len(wargs) != len(self.arg_names): + raise TypeError(f"Function takes {len(self.arg_names)} positional arguments but {len(wargs)} were given") + # handle annotations + for pos, _type in self.annotations.items(): + wargs[pos] = _type(wargs[pos]) # preparing args tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes @@ -1191,7 +1200,7 @@ class JITFunction: attributes[i] = min(Kernel.pow2_divisor(addr), Kernel.pow2_divisor(range_size)) # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] diff --git a/rewrite-test/jit/vecadd.py b/rewrite-test/jit/vecadd.py index e3d6f3f9a..34a7fc4f1 100644 --- a/rewrite-test/jit/vecadd.py +++ b/rewrite-test/jit/vecadd.py @@ -1,3 +1,4 @@ +from tarfile import BLOCKSIZE import torch import triton import triton.language as tl @@ -9,8 +10,8 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector - # BLOCK_SIZE: tl.constexpr, # Number of elements each program should process - # # NOTE: `constexpr` so it can be used as a shape value + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process + # NOTE: `constexpr` so it can be used as a shape value ): # There are multiple 'program's processing different data. We identify which program # we are here @@ -19,8 +20,8 @@ def add_kernel( # for instance, if you had a vector of length 256 and block_size of 64, the programs # would each access the elements [0:64, 64:128, 128:192, 192:256]. # Note that offsets is a list of pointers - block_start = pid * 256 - offsets = block_start + tl.arange(0, 256) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a @@ -37,7 +38,6 @@ y = torch.rand(size, device='cuda') z = torch.empty_like(x) # add_kernel[(1,)](x, y, z, size, 256) # print(add_kernel[(1,)].kernel.compile_to_ttir()) -mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, grid=(1,)) -mod.get_context() +# print(add_kernel.annotations) +mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) mod.dump() -# print(mod)