Add ReduceOp

This commit is contained in:
Yan Da
2022-05-25 14:15:36 +08:00
parent a2c9f919a8
commit 9b670cfb9f
6 changed files with 29 additions and 8 deletions

View File

@@ -1,7 +1,6 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"

View File

@@ -2,7 +2,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
@@ -24,7 +23,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int64_t numElements = tensorType.getNumElements();
// TODO: we should raise exception here.
assert(numElements > numThreads);
if (!(numElements >= numThreads)) {
llvm::errs() << tensorType << " has " << numElements << " numElements "
<< " smaller than numThreads (" << numThreads << ")";
assert(false);
}
assert(numElements % numThreads == 0);
// or assert no encoding?

View File

@@ -40,6 +40,7 @@ private:
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
}
Attribute cLayout;
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
llvm::SmallVector<char>{'c', 'd'})) {
Type type = std::get<0>(it);
@@ -48,8 +49,13 @@ private:
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>())
return dotOp.emitError() << name << " should be of mma layout";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;
else if (encoding != cLayout)
return dotOp.emitError() << "d & c should have the same layout";
} else
return dotOp.emitError() << name
<< "'s type should be of RankedTensorType";