From 9b670cfb9f2872afa5e48e79c43f84587a1cb141 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 25 May 2022 14:15:36 +0800 Subject: [PATCH] Add ReduceOp --- include/triton/Dialect/Triton/IR/TritonOps.td | 8 +++++++- lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp | 1 - lib/Dialect/TritonGPU/IR/Dialect.cpp | 1 - .../TritonGPU/Transforms/TritonGPUConversion.cpp | 7 +++++-- lib/Dialect/TritonGPU/Transforms/Verifier.cpp | 10 ++++++++-- python/src/triton.cc | 10 +++++++++- 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index f137bf1ef..445d2f4f5 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -205,7 +205,13 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, def TT_ReduceOp : TT_Op<"reduce"> { let summary = "reduce"; - let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis); + let arguments = (ins TT_RedOpAttr:$redOp, TT_Type:$operand, I32Attr:$axis); + + let results = (outs TT_Type:$result); + + // let builders = [ + // OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>, + // ]; } def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 5c7dd6e63..224ddd22a 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -5,7 +5,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "../PassDetail.h" -#include using namespace mlir; using namespace mlir::triton; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 0bb7ece43..140bce5ce 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,7 +1,6 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" -#include #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 622b0eeaf..b1de660f6 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -2,7 +2,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include -#include using namespace mlir; @@ -24,7 +23,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, int64_t numElements = tensorType.getNumElements(); // TODO: we should raise exception here. - assert(numElements > numThreads); + if (!(numElements >= numThreads)) { + llvm::errs() << tensorType << " has " << numElements << " numElements " + << " smaller than numThreads (" << numThreads << ")"; + assert(false); + } assert(numElements % numThreads == 0); // or assert no encoding? diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp index 139d3f3ae..3749de00e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -40,6 +40,7 @@ private: return dotOp.emitError() << name << "'s type should be of RankedTensorType"; } + Attribute cLayout; for (auto it : llvm::zip(llvm::SmallVector{cType, dType}, llvm::SmallVector{'c', 'd'})) { Type type = std::get<0>(it); @@ -48,8 +49,13 @@ private: Attribute encoding = tensorType.getEncoding(); if (!encoding) return dotOp.emitError() << name << " should have encoding"; - if (!encoding.isa()) - return dotOp.emitError() << name << " should be of mma layout"; + if (!encoding.isa() && + !encoding.isa()) + return dotOp.emitError() << name << " should be of distributed layout"; + if (name == 'c') + cLayout = encoding; + else if (encoding != cLayout) + return dotOp.emitError() << "d & c should have the same layout"; } else return dotOp.emitError() << name << "'s type should be of RankedTensorType"; diff --git a/python/src/triton.cc b/python/src/triton.cc index 42d7fd312..b583be726 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1312,7 +1312,15 @@ void init_triton_ir(py::module &&m) { // .def("create_log", &ir::builder::create_log, ret::reference) // .def("create_trans", &ir::builder::create_trans, ret::reference) // .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) - // .def("create_reduce", &ir::builder::create_reduce, ret::reference) + .def("create_reduce", [](mlir::OpBuilder &self, mlir::Value &operand, + mlir::triton::RedOp redOp, int axis) -> mlir::Value { + auto loc = self.getUnknownLoc(); + auto inputTensorType = operand.getType().dyn_cast(); + std::vector shape = inputTensorType.getShape(); + shape.erase(shape.begin() + axis); + auto resType = mlir::RankedTensorType::get(shape, inputTensorType.getElementType()); + return self.create(loc, resType, redOp, operand, axis); + }) .def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition, mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { auto loc = self.getUnknownLoc();