diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 7e4a49b16..5c7dd6e63 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -4,8 +4,8 @@ #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -// #include "mlir/IR/BlockAndValueMapping.h" #include "../PassDetail.h" +#include using namespace mlir; using namespace mlir::triton; @@ -155,9 +155,31 @@ struct TritonDotPattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32() + // a & b must be of smem layout + auto aType = adaptor.a().getType().cast(); + auto bType = adaptor.b().getType().cast(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.a(); + Value b = adaptor.b(); + if (!aEncoding.isa()) { + Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); + auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!bEncoding.isa()) { + Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); + auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } + auto newDot = rewriter.replaceOpWithNewOp( + op, retType, a, b, adaptor.c(), adaptor.allowTF32() ); + // auto newDot = rewriter.create(op.getLoc(), retType, + // a, b, adaptor.c(), adaptor.allowTF32()); + // rewriter.replaceOp(op, {newDot}); return success(); } }; @@ -182,7 +204,7 @@ struct TritonStorePattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( + auto newOp = rewriter.replaceOpWithNewOp( op, adaptor.ptr(), adaptor.value(), adaptor.mask() ); return success(); @@ -220,26 +242,24 @@ void populateTritonPatterns( // // SCF patterns // +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp struct SCFForPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // Ref: ConvertForOpTypes LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newResultTypes; - for (Type type : op.getResultTypes()) { - Type newType = typeConverter->convertType(type); - if (!newType) - return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); - newResultTypes.push_back(newType); - } - auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); // Now, update all the types. - // Convert the type of the entry block of the ForOp's body. + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. The + // entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), *getTypeConverter()))) { return rewriter.notifyMatchFailure(op, "could not convert body types"); @@ -248,11 +268,17 @@ struct SCFForPattern : public OpConversionPattern { // a BlockAndValueMapping, but this seems a bit more direct. newOp->setOperands(adaptor.getOperands()); // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); rewriter.replaceOp(op, newOp.getResults()); - return success(); return success(); } @@ -277,8 +303,7 @@ void populateSCFPatterns( TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns ) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c1fb8a44a..d8249ba97 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -41,7 +41,11 @@ TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) { } void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const { - llvm_unreachable("Not implemented"); + printer << "<" + // << "threadTileSize = " << getThreadTileSize() + // << ", blockTileSize = " << getBlockTileSize() + // << ", order = " << getOrder() + << ">"; } void TritonGPUDialect::initialize() { diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 1d1e35a61..622b0eeaf 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -50,9 +50,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // materailizations addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { - llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n" - << "in: \n"; - inputs[0].dyn_cast().getOwner()->getParentOp()->getParentOp()->print(llvm::errs()); llvm_unreachable("Not implemented"); return llvm::None; }); @@ -63,8 +60,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, }); addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { + assert(inputs.size() == 1); llvm_unreachable("Not implemented"); - // llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"; return llvm::None; }); } @@ -75,13 +72,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, TritonGPUConversionTarget::TritonGPUConversionTarget( MLIRContext &context, TritonGPUTypeConverter &typeConverter) : ConversionTarget(context), typeConverter(typeConverter) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + // Some ops from SCF are illegal addIllegalOp(); - + addDynamicallyLegalDialect([&](Operation *op) { if (typeConverter.isLegal(op)) @@ -89,14 +88,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( return false; }); - // // We have requirements for the data layouts - // addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { - // Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); - // Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); - // if (aEncoding && aEncoding.isa() && - // bEncoding && bEncoding.isa()) - // return true; - // return false; - // }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { + Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); + Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); + if (aEncoding && aEncoding.isa() && + bEncoding && bEncoding.isa()) + return true; + // TODO: we should delete this + if (this->typeConverter.isLegal(dotOp)) + return true; + return false; + }); }