more comments to TypeConverter & update warpTileSize
This commit is contained in:
@@ -177,9 +177,6 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|||||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||||
);
|
);
|
||||||
// auto newDot = rewriter.create<triton::DotOp>(op.getLoc(), retType,
|
|
||||||
// a, b, adaptor.c(), adaptor.allowTF32());
|
|
||||||
// rewriter.replaceOp(op, {newDot});
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -2,7 +2,6 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -23,6 +22,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
int64_t rank = tensorType.getRank();
|
int64_t rank = tensorType.getRank();
|
||||||
int64_t numElements = tensorType.getNumElements();
|
int64_t numElements = tensorType.getNumElements();
|
||||||
|
|
||||||
|
// TODO: are there any better ways to raise this error?
|
||||||
if (!(numElements >= numThreads)) {
|
if (!(numElements >= numThreads)) {
|
||||||
SmallVector<char> buffer;
|
SmallVector<char> buffer;
|
||||||
llvm::raw_svector_ostream os(buffer);
|
llvm::raw_svector_ostream os(buffer);
|
||||||
@@ -38,16 +38,18 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
// Now we assume:
|
// Now we assume:
|
||||||
// contiguous = 1, order = 0, 1, 2, ...,
|
// contiguous = 1, order = 0, 1, 2, ...,
|
||||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||||
// TODO: compute warpTileSize.
|
|
||||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||||
llvm::SmallVector<unsigned> order(rank);
|
llvm::SmallVector<unsigned> order(rank);
|
||||||
int remainingThreads = numThreads;
|
int remainingThreads = numThreads;
|
||||||
|
int remainingLanes = /*warp size*/32;
|
||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||||
|
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
|
||||||
order[dim] = dim;
|
order[dim] = dim;
|
||||||
|
|
||||||
remainingThreads /= blockTileSize[dim];
|
remainingThreads /= blockTileSize[dim];
|
||||||
|
remainingLanes /= warpTileSize[dim];
|
||||||
// TODO: will we need repetition?
|
// TODO: will we need repetition?
|
||||||
}
|
}
|
||||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||||
@@ -55,17 +57,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
return RankedTensorType::get(shape, elementType, encoding);
|
return RankedTensorType::get(shape, elementType, encoding);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
//
|
||||||
// materailizations
|
// materailizations
|
||||||
|
//
|
||||||
|
// This will be called when (newArgType != origArgType)
|
||||||
|
// This will create newArg, and map(origArg, newArg)
|
||||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// If the origValue still has live user(s), use this to
|
||||||
|
// convert origValue to newValue
|
||||||
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// This will be called when (desiredType != newOperandType)
|
||||||
|
// where, desiredType = typeConverter->convertType(origType)
|
||||||
|
// NOTE: only for remapped values.
|
||||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
@@ -104,9 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||||
return true;
|
return true;
|
||||||
// TODO: we should delete this
|
// // TODO: we should delete this
|
||||||
if (this->typeConverter.isLegal(dotOp))
|
// if (this->typeConverter.isLegal(dotOp))
|
||||||
return true;
|
// return true;
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user