the pipeline pass now generates and accepts valid IR
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -22,11 +23,13 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t numElements = tensorType.getNumElements();
|
||||
|
||||
// TODO: we should raise exception here.
|
||||
if (!(numElements >= numThreads)) {
|
||||
llvm::errs() << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")";
|
||||
assert(false);
|
||||
SmallVector<char> buffer;
|
||||
llvm::raw_svector_ostream os(buffer);
|
||||
os << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")\n"
|
||||
<< "consider using smaller num-warps\n";
|
||||
llvm::report_fatal_error(os.str());
|
||||
}
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
@@ -35,6 +38,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
// TODO: compute warpTileSize.
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
@@ -93,17 +97,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
});
|
||||
|
||||
|
||||
// // We have requirements for the data layouts
|
||||
// addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
// return true;
|
||||
// // TODO: we should delete this
|
||||
// if (this->typeConverter.isLegal(dotOp))
|
||||
// return true;
|
||||
// return false;
|
||||
// });
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// TODO: we should delete this
|
||||
if (this->typeConverter.isLegal(dotOp))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user