more comments to TypeConverter & update warpTileSize

This commit is contained in:
Yan Da
2022-06-08 16:20:07 +08:00
parent 49d1821149
commit 117a402c1b
2 changed files with 18 additions and 8 deletions

View File

@@ -2,7 +2,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
@@ -23,6 +22,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int64_t rank = tensorType.getRank();
int64_t numElements = tensorType.getNumElements();
// TODO: are there any better ways to raise this error?
if (!(numElements >= numThreads)) {
SmallVector<char> buffer;
llvm::raw_svector_ostream os(buffer);
@@ -38,16 +38,18 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
// TODO: compute warpTileSize.
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
order[dim] = dim;
remainingThreads /= blockTileSize[dim];
remainingLanes /= warpTileSize[dim];
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
@@ -55,17 +57,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
return RankedTensorType::get(shape, elementType, encoding);
});
//
// materailizations
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
// This will be called when (desiredType != newOperandType)
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
@@ -104,9 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// TODO: we should delete this
if (this->typeConverter.isLegal(dotOp))
return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
return false;
});