[TritonIR] Make mask operand optional (#74)

This commit is contained in:
Shintaro Iwasaki
2022-08-22 22:00:17 -07:00
committed by GitHub
parent de2dd04c8a
commit 0ebef11c77
14 changed files with 113 additions and 102 deletions

View File

@@ -22,7 +22,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
//
// Use cast ops in arith:
// bitcast
// fptoui, fptosi, uitofp, sitofp,
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
@@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameVariadicOperandSize,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from result type",
"result", "mask",
"getI1SameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none",
"result", "other",
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> {
"result", "other", "$_self",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
let builders = [
// for args with default values
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
let assemblyFormat = "operands attr-dict `:` type($result)";
@@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask",
"getI1SameShape($_self)">]> {
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
let assemblyFormat = "operands attr-dict `:` type($value)";
}
def TT_GEPOp : TT_Op<"getelementptr",
@@ -257,7 +259,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
let summary = "atomic cas";
let description = [{
compare $cmp with data $old at location $ptr,
compare $cmp with data $old at location $ptr,
if $old == $cmp, store $val to $ptr,

View File

@@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
[MemoryEffects<[MemRead, MemWrite]>,
SameVariadicOperandSize,
TypesMatchWith<"infer mask type from ptr type",
"ptr", "mask",
"getI1SameShape($_self)">,
"ptr", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from ptr type",
"ptr", "other",
"getPointeeType($_self)">]> {
"ptr", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "copy async";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
let results = (outs TT_Type:$result);
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];

View File

@@ -226,10 +226,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
adaptor.evict(), adaptor.isVolatile());
return success();
}
};

View File

@@ -49,58 +49,43 @@ namespace mlir {
namespace triton {
//-- StoreOp --
// Default mask
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
StoreOp::build(builder, state, ptr, value, mlir::Value());
}
//-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr);
state.addOperands(mask);
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr);
state.addOperands(mask);
if (mask) {
state.addOperands(mask);
if (other) {
state.addOperands(other);
}
}
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));

View File

@@ -52,6 +52,46 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
} // anonymous namespace
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
{triton::LoadOp::getOperationName()}) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
if (!selectOp)
return mlir::failure();
mlir::Value trueValue = selectOp.getTrueValue();
mlir::Value falseValue = selectOp.getFalseValue();
auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
if (!loadOp)
return mlir::failure();
mlir::Value mask = loadOp.mask();
if (!mask)
return mlir::failure();
auto *broadcastOpCandidate = mask.getDefiningOp();
auto broadcastOp =
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
if (!broadcastOp)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
return mlir::success();
}
};
#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"

View File

@@ -37,12 +37,6 @@ def CombineGEPPattern : Pat<
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
def CombineSelectMaskedLoadPattern : Pat<
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
def CombineBroadcastConstantPattern : Pat<

View File

@@ -391,14 +391,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
}
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),

View File

@@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) {
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = ptrType.getShape();
mlir::Type elementType = ptrType.getElementType()
.dyn_cast<mlir::triton::PointerType>()
.getPointeeType();
mlir::Type resultType =
mlir::RankedTensorType::get(shape, elementType);
if (other.has_value()) {
return self.create<mlir::triton::LoadOp>(
loc, resultType, ptrs, mask, other.value(), cacheModifier,
evictionPolicy, isVolatile);
} else {
return self.create<mlir::triton::LoadOp>(
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
}
return self.create<mlir::triton::LoadOp>(
loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier,
evictionPolicy, isVolatile);
})
.def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,

View File

@@ -47,6 +47,6 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst, : tensor<128x128xf32>
tt.store %19, %20, %cst : tensor<128x128xf32>
return
}
}

View File

@@ -31,7 +31,7 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
tt.store %ptrs, %vs, %mask : tensor<128xf32>
return
}

View File

@@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return
}
}

View File

@@ -38,7 +38,7 @@ module {
}
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
tt.store %17, %15#0, %6, : tensor<256xf32>
tt.store %17, %15#0, %6 : tensor<256xf32>
return
}
}
@@ -124,7 +124,7 @@ module {
// }
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// return
// }
// }

View File

@@ -3,7 +3,7 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
@@ -15,9 +15,9 @@
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
@@ -41,6 +41,6 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -86,7 +86,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
@@ -117,7 +117,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4>
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
return
}
@@ -170,6 +170,6 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1>
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
return
}
}