#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "mlir/IR/BlockAndValueMapping.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include #include using namespace mlir; using namespace mlir::triton::gpu; // // TypeConverter // TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, int numWarps) : context(context), numWarps(numWarps) { // TODO: how does MLIR pick the right conversion? addConversion([](Type type) { return type; }); addConversion([this](RankedTensorType tensorType) -> RankedTensorType { // types with encoding are already in the right format // TODO: check for layout encodings specifically if (tensorType.getEncoding()) return tensorType; // pessimistic values for attributes: // - 1 element per thread // - order = arange(rank) ArrayRef shape = tensorType.getShape(); int rank = shape.size(); llvm::SmallVector order(rank); std::iota(order.begin(), order.end(), 0); llvm::SmallVector sizePerThread(rank, 1); Attribute encoding = triton::gpu::BlockedEncodingAttr::get( this->context, shape, sizePerThread, order, this->numWarps); return RankedTensorType::get(shape, tensorType.getElementType(), encoding); }); // // Materializations // // This will be called when (newArgType != origArgType) // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Argument rematerialization not implemented"); return llvm::None; }); // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { llvm_unreachable("Source rematerialization not implemented"); return llvm::None; }); // This will be called when (desiredType != newOperandType) // where, desiredType = typeConverter->convertType(origType) // NOTE: only for remapped values. addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { auto cast = builder.create(loc, tensorType, inputs); return Optional(cast.getResult()); // return Optional(cast.getResult(0)); // llvm_unreachable("Not implemented"); // return llvm::None; }); } // // TritonGPUConversion // TritonGPUConversionTarget::TritonGPUConversionTarget( MLIRContext &context, TritonGPUTypeConverter &typeConverter) : ConversionTarget(context) { // TODO: we should also verify ops of TritonGPUDialect addLegalDialect(); // Some ops from SCF are illegal addIllegalOp(); addDynamicallyLegalDialect([&](Operation *op) { if (typeConverter.isLegal(op)) return true; return false; }); // We have requirements for the data layouts addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); if (aEncoding && aEncoding.isa() && bEncoding && bEncoding.isa()) return true; return false; }); }