From 117a402c1b5bdcceeca600d8cc0e5ca81fd10ba9 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 8 Jun 2022 16:20:07 +0800 Subject: [PATCH] more comments to TypeConverter & update warpTileSize --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 3 --- .../Transforms/TritonGPUConversion.cpp | 23 +++++++++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index ee35a6eb8..591af75b5 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -177,9 +177,6 @@ struct TritonDotPattern : public OpConversionPattern { auto newDot = rewriter.replaceOpWithNewOp( op, retType, a, b, adaptor.c(), adaptor.allowTF32() ); - // auto newDot = rewriter.create(op.getLoc(), retType, - // a, b, adaptor.c(), adaptor.allowTF32()); - // rewriter.replaceOp(op, {newDot}); return success(); } }; diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index e2d6a6687..970ecbedc 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -2,7 +2,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include -#include using namespace mlir; @@ -23,6 +22,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, int64_t rank = tensorType.getRank(); int64_t numElements = tensorType.getNumElements(); + // TODO: are there any better ways to raise this error? if (!(numElements >= numThreads)) { SmallVector buffer; llvm::raw_svector_ostream os(buffer); @@ -38,16 +38,18 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // Now we assume: // contiguous = 1, order = 0, 1, 2, ..., llvm::SmallVector threadTileSize(rank, 1); // naive layout - // TODO: compute warpTileSize. llvm::SmallVector warpTileSize(rank, 1); llvm::SmallVector blockTileSize(rank); llvm::SmallVector order(rank); int remainingThreads = numThreads; + int remainingLanes = /*warp size*/32; for (int64_t dim = 0; dim < rank; ++dim) { blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim])); + warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim])); order[dim] = dim; remainingThreads /= blockTileSize[dim]; + remainingLanes /= warpTileSize[dim]; // TODO: will we need repetition? } Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( @@ -55,17 +57,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, return RankedTensorType::get(shape, elementType, encoding); }); + // // materailizations + // + // This will be called when (newArgType != origArgType) + // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { assert(inputs.size() == 1); @@ -104,9 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( if (aEncoding && aEncoding.isa() && bEncoding && bEncoding.isa()) return true; - // TODO: we should delete this - if (this->typeConverter.isLegal(dotOp)) - return true; + // // TODO: we should delete this + // if (this->typeConverter.isLegal(dotOp)) + // return true; return false; });