#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "mlir/IR/BlockAndValueMapping.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include using namespace mlir; using namespace mlir::triton::gpu; // // TypeConverter // TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, int numThreads) : context(context), numThreads(numThreads) { // TODO: how does MLIR pick the right conversion? addConversion([](Type type) { return type; }); addConversion([this](RankedTensorType tensorType) -> RankedTensorType { MLIRContext *context = this->context; int numThreads = this->numThreads; llvm::ArrayRef shape = tensorType.getShape(); Type elementType = tensorType.getElementType(); int64_t rank = tensorType.getRank(); int64_t numElements = tensorType.getNumElements(); // TODO: are there any better ways to raise this error? if (!(numElements >= numThreads)) { SmallVector buffer; llvm::raw_svector_ostream os(buffer); os << tensorType << " has " << numElements << " numElements " << " smaller than numThreads (" << numThreads << ")\n" << "consider using smaller num-warps\n"; llvm::report_fatal_error(os.str()); } assert(numElements % numThreads == 0); // or assert no encoding? // Now we assume: // contiguous = 1, order = 0, 1, 2, ..., llvm::SmallVector threadTileSize(rank, 1); // naive layout llvm::SmallVector warpTileSize(rank, 1); llvm::SmallVector blockTileSize(rank); llvm::SmallVector order(rank); llvm::SmallVector broadcastAxis; int remainingThreads = numThreads; int remainingLanes = /*warp size*/ 32; for (int64_t dim = 0; dim < rank; ++dim) { blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim])); warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim])); order[dim] = dim; remainingThreads /= blockTileSize[dim]; remainingLanes /= warpTileSize[dim]; // TODO: will we need repetition? } Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis); return RankedTensorType::get(shape, elementType, encoding); }); // // materailizations // // This will be called when (newArgType != origArgType) // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); // This will be called when (desiredType != newOperandType) // where, desiredType = typeConverter->convertType(origType) // NOTE: only for remapped values. addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Not implemented"); return llvm::None; }); } // // TritonGPUConversion // TritonGPUConversionTarget::TritonGPUConversionTarget( MLIRContext &context, TritonGPUTypeConverter &typeConverter) : ConversionTarget(context), typeConverter(typeConverter) { // TODO: we should also verify ops of TritonGPUDialect addLegalDialect(); // Some ops from SCF are illegal addIllegalOp(); addDynamicallyLegalDialect( [&](Operation *op) { if (typeConverter.isLegal(op)) return true; return false; }); // We have requirements for the data layouts addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); if (aEncoding && aEncoding.isa() && bEncoding && bEncoding.isa()) return true; // // TODO: we should delete this // if (this->typeConverter.isLegal(dotOp)) // return true; return false; }); } // %dst = tt.broadcast %src // => // %newSrc = convert_layout %src // %bcst = tt.broadcast %newSrc // %dst = convert_layout %bcst LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod, int numThreads) { // collect broadcasts SmallVector broadcasts; mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); }); BlockAndValueMapping mapping; for (auto broadcast : broadcasts) { OpBuilder builder(broadcast); Value src = mapping.lookupOrDefault(broadcast.src()); Type originSrcType = src.getType(); Type originDstType = broadcast.getType(); auto originDstTensorType = originDstType.dyn_cast(); unsigned dstRank = originDstTensorType.getRank(); // compute newSrcType & broadcastAxis Type newSrcType; SmallVector broadcastAxis; bool isSrcScalar = false; if (auto tensorType = originSrcType.dyn_cast()) { assert(tensorType.getRank() == dstRank && "src & dst should have same rank (verifier should catch this)"); for (unsigned ax = 0; ax < dstRank; ++ax) if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax]) broadcastAxis.push_back(ax); Attribute originSrcEnc = tensorType.getEncoding(); if (auto blockedEnc = originSrcEnc.dyn_cast()) { auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get( blockedEnc.getContext(), blockedEnc.getThreadTileSize(), blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(), blockedEnc.getOrder(), broadcastAxis); newSrcType = RankedTensorType::get( tensorType.getShape(), tensorType.getElementType(), newSrcEnc); } else llvm_unreachable("src of broadcast should have blocked encoding"); } else { for (unsigned ax = 0; ax < dstRank; ++ax) broadcastAxis.push_back(ax); newSrcType = originSrcType; isSrcScalar = true; } // create new src if (!isSrcScalar) // we don't need to convert layout for scalar values src = builder.create(src.getLoc(), newSrcType, src); // create new broadcast // compute new type (encoding) auto originDstEnc = originDstTensorType.getEncoding() .dyn_cast(); auto newEnc = TritonGPUBlockedEncodingAttr::get( originDstEnc.getContext(), originDstEnc.getThreadTileSize(), originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(), originDstEnc.getOrder(), broadcastAxis); auto newType = RankedTensorType::get(originDstTensorType.getShape(), originDstTensorType.getElementType(), newEnc); Value newBroadcast = builder.create(broadcast.getLoc(), newType, src); // we don't want to change the encoding of the result Value newDst = builder.create( broadcast.getLoc(), originDstType, newBroadcast); broadcast.replaceAllUsesWith(newDst); mapping.map(broadcast, newDst); } return success(); }