From 5898352f97fdb236a6809d833688f6378b133842 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Thu, 13 Oct 2022 18:53:00 -0700 Subject: [PATCH] [Triton-IR] Fix LoadOp definition (#771) (#777) --- include/triton/Dialect/Triton/IR/TritonOps.td | 2 +- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 15 ++++--------- lib/Dialect/Triton/IR/Ops.cpp | 22 ++++++++++++++++--- test/Conversion/triton_to_tritongpu.mlir | 19 +++++++++++++++- 4 files changed, 42 insertions(+), 16 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index c65d92be6..b3f5f82ec 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -105,7 +105,7 @@ def TT_AddPtrOp : TT_Op<"addptr", def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, - SameVariadicOperandSize, + AttrSizedOperandSegments, MemoryEffects<[MemRead]>, TypesMatchWith<"infer ptr type from result type", "result", "ptr", diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index b580f5971..788d20eaa 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -242,17 +242,10 @@ struct TritonLoadPattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getNumOperands() == 2) { // ptr & mask - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), adaptor.ptr(), - adaptor.getOperands()[1], adaptor.other(), adaptor.cache(), - adaptor.evict(), adaptor.isVolatile()); - } else { - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), adaptor.ptr(), - adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), - adaptor.isVolatile()); - } + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), adaptor.ptr(), + adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), + adaptor.isVolatile()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 9a47829c5..630260988 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -50,21 +50,34 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { SmallVector operandTypes; operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr - if (allOperands.size() >= 2) + int hasMask = 0, hasOther = 0; + if (allOperands.size() >= 2) { operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask - if (allOperands.size() >= 3) + hasMask = 1; + } + if (allOperands.size() >= 3) { operandTypes.push_back(resultTypes[0]); // other + hasOther = 1; + } if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, result.operands)) return failure(); + // Deduce operand_segment_sizes from the number of the operands. + auto operand_segment_sizesAttrName = + LoadOp::operand_segment_sizesAttrName(result.name); + result.addAttribute( + operand_segment_sizesAttrName, + parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther})); return success(); } void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) { printer << " "; printer << loadOp.getOperation()->getOperands(); - printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{}); + // "operand_segment_sizes" can be deduced, so we don't print it. + printer.printOptionalAttrDict(loadOp->getAttrs(), + {loadOp.operand_segment_sizesAttrName()}); printer << " : "; printer.printStrippedAttrOrType(loadOp.result().getType()); } @@ -148,6 +161,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, state.addOperands(other); } } + state.addAttribute( + operand_segment_sizesAttrName(state.name), + builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)})); state.addAttribute( cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index bd04b60dd..7beb42b69 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -1,10 +1,27 @@ // RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s func @ops() { -// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} + // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> %0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> return } + +func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // Test if LoadOp is lowered properly (see #771) + %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> + %mask = arith.constant dense : tensor<128xi1> + %other = arith.constant dense<0.0e+0> : tensor<128xf32> + // CHECK: %{{.*}} = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}} + %a = tt.load %ptrs {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32> + // CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}} + %b = tt.load %ptrs, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32> + // CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}} + %c = tt.load %ptrs, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32> + tt.store %ptrs, %a : tensor<128xf32> + tt.store %ptrs, %b : tensor<128xf32> + tt.store %ptrs, %c : tensor<128xf32> + return +}