[TritonGPU] Improved documentation and semantics of layout encodings (#30)
This commit is contained in:
@@ -218,6 +218,28 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonBroadcastPattern
|
||||
: public OpConversionPattern<triton::BroadcastOp> {
|
||||
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
||||
|
||||
// This creates a tensor with the new shape but the argument's layout
|
||||
LogicalResult
|
||||
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
auto opType = op.getType().cast<RankedTensorType>();
|
||||
Type retType = RankedTensorType::get(opType.getShape(),
|
||||
opType.getElementType(), srcEncoding);
|
||||
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
@@ -234,12 +256,12 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::SplatOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -317,9 +339,8 @@ public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
int numThreads = numWarps * 32;
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, numThreads);
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
@@ -335,8 +356,8 @@ public:
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
if (failed(target.refineLayouts(mod, numWarps)))
|
||||
return signalPassFailure();
|
||||
// if (failed(target.refineLayouts(mod, numWarps)))
|
||||
// return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user