[TritonGPU] Improved documentation and semantics of layout encodings (#30)

This commit is contained in:
Philippe Tillet
2022-07-31 13:59:44 -07:00
committed by GitHub
parent e02c82c765
commit d1593e6ca8
17 changed files with 399 additions and 566 deletions

View File

@@ -218,6 +218,28 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
}
};
struct TritonBroadcastPattern
: public OpConversionPattern<triton::BroadcastOp> {
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
// This creates a tensor with the new shape but the argument's layout
LogicalResult
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
auto srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
auto opType = op.getType().cast<RankedTensorType>();
Type retType = RankedTensorType::get(opType.getShape(),
opType.getElementType(), srcEncoding);
// Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
adaptor.getOperands());
return success();
}
};
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
@@ -234,12 +256,12 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::SplatOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
}
//
@@ -317,9 +339,8 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
int numThreads = numWarps * 32;
// type converter
TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUTypeConverter typeConverter(context, numWarps);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
@@ -335,8 +356,8 @@ public:
// update layouts
// broadcast src => multicast, dst => broadcasted
if (failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure();
// if (failed(target.refineLayouts(mod, numWarps)))
// return signalPassFailure();
}
};