the pipeline pass now generates and accepts valid IR

This commit is contained in:
Yan Da
2022-06-07 19:34:59 +08:00
parent 560e29229b
commit 7b09b5f9e9
4 changed files with 82 additions and 37 deletions

View File

@@ -163,16 +163,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
// if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
// auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
// a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
// }
// if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
// auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
// b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
// }
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
);
@@ -323,14 +324,17 @@ void populateSCFPatterns(
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
ConvertTritonToTritonGPU(int numWarps) {
this->numWarps = numWarps;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
// int numThreads = mod.getAttr();
int numThreads = numWarps * 32;
// type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32);
TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
@@ -350,6 +354,6 @@ public:
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass() {
return std::make_unique<::ConvertTritonToTritonGPU>();
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
}