@@ -105,7 +105,7 @@ def TT_AddPtrOp : TT_Op<"addptr",
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize,
|
||||
AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr",
|
||||
|
@@ -242,17 +242,10 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (op.getNumOperands() == 2) { // ptr & mask
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.getOperands()[1], adaptor.other(), adaptor.cache(),
|
||||
adaptor.evict(), adaptor.isVolatile());
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
|
||||
adaptor.isVolatile());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
|
||||
adaptor.isVolatile());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@@ -50,21 +50,34 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
|
||||
if (allOperands.size() >= 2)
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 2) {
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
if (allOperands.size() >= 3)
|
||||
hasMask = 1;
|
||||
}
|
||||
if (allOperands.size() >= 3) {
|
||||
operandTypes.push_back(resultTypes[0]); // other
|
||||
hasOther = 1;
|
||||
}
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
auto operand_segment_sizesAttrName =
|
||||
LoadOp::operand_segment_sizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
operand_segment_sizesAttrName,
|
||||
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
|
||||
return success();
|
||||
}
|
||||
|
||||
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
||||
printer << " ";
|
||||
printer << loadOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(),
|
||||
{loadOp.operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(loadOp.result().getType());
|
||||
}
|
||||
@@ -148,6 +161,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
state.addOperands(other);
|
||||
}
|
||||
}
|
||||
state.addAttribute(
|
||||
operand_segment_sizesAttrName(state.name),
|
||||
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
|
@@ -1,10 +1,27 @@
|
||||
// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
func @ops() {
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if LoadOp is lowered properly (see #771)
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%mask = arith.constant dense<true> : tensor<128xi1>
|
||||
%other = arith.constant dense<0.0e+0> : tensor<128xf32>
|
||||
// CHECK: %{{.*}} = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
|
||||
%a = tt.load %ptrs {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
|
||||
// CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
|
||||
%b = tt.load %ptrs, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
|
||||
// CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
|
||||
%c = tt.load %ptrs, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
|
||||
tt.store %ptrs, %a : tensor<128xf32>
|
||||
tt.store %ptrs, %b : tensor<128xf32>
|
||||
tt.store %ptrs, %c : tensor<128xf32>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user