Add ReduceOp
This commit is contained in:
@@ -205,7 +205,13 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
def TT_ReduceOp : TT_Op<"reduce"> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis);
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Type:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
// let builders = [
|
||||
// OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>,
|
||||
// ];
|
||||
}
|
||||
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
|
@@ -5,7 +5,6 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
@@ -1,7 +1,6 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
|
||||
|
@@ -2,7 +2,6 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -24,7 +23,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int64_t numElements = tensorType.getNumElements();
|
||||
|
||||
// TODO: we should raise exception here.
|
||||
assert(numElements > numThreads);
|
||||
if (!(numElements >= numThreads)) {
|
||||
llvm::errs() << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")";
|
||||
assert(false);
|
||||
}
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
// or assert no encoding?
|
||||
|
@@ -40,6 +40,7 @@ private:
|
||||
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
Attribute cLayout;
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
Type type = std::get<0>(it);
|
||||
@@ -48,8 +49,13 @@ private:
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of mma layout";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
cLayout = encoding;
|
||||
else if (encoding != cLayout)
|
||||
return dotOp.emitError() << "d & c should have the same layout";
|
||||
} else
|
||||
return dotOp.emitError() << name
|
||||
<< "'s type should be of RankedTensorType";
|
||||
|
@@ -1312,7 +1312,15 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_log", &ir::builder::create_log, ret::reference)
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
.def("create_reduce", [](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto inputTensorType = operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
auto resType = mlir::RankedTensorType::get(shape, inputTensorType.getElementType());
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp, operand, axis);
|
||||
})
|
||||
.def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
Reference in New Issue
Block a user