[TritonIR] Make mask operand optional (#74)

This commit is contained in:
Shintaro Iwasaki
2022-08-22 22:00:17 -07:00
committed by GitHub
parent de2dd04c8a
commit 0ebef11c77
14 changed files with 113 additions and 102 deletions

View File

@@ -22,7 +22,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
//
// Use cast ops in arith:
// bitcast
// fptoui, fptosi, uitofp, sitofp,
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
@@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameVariadicOperandSize,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from result type",
"result", "mask",
"getI1SameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none",
"result", "other",
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> {
"result", "other", "$_self",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
let builders = [
// for args with default values
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
let assemblyFormat = "operands attr-dict `:` type($result)";
@@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask",
"getI1SameShape($_self)">]> {
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
let assemblyFormat = "operands attr-dict `:` type($value)";
}
def TT_GEPOp : TT_Op<"getelementptr",
@@ -257,7 +259,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
let summary = "atomic cas";
let description = [{
compare $cmp with data $old at location $ptr,
compare $cmp with data $old at location $ptr,
if $old == $cmp, store $val to $ptr,

View File

@@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
[MemoryEffects<[MemRead, MemWrite]>,
SameVariadicOperandSize,
TypesMatchWith<"infer mask type from ptr type",
"ptr", "mask",
"getI1SameShape($_self)">,
"ptr", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from ptr type",
"ptr", "other",
"getPointeeType($_self)">]> {
"ptr", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "copy async";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
let results = (outs TT_Type:$result);
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];