[TritonIR] Make mask operand optional (#74)
This commit is contained in:
@@ -22,7 +22,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
//
|
||||
// Use cast ops in arith:
|
||||
// bitcast
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
|
||||
@@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameVariadicOperandSize,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type",
|
||||
"result", "mask",
|
||||
"getI1SameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type or none",
|
||||
"result", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from result type or none",
|
||||
"result", "other",
|
||||
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> {
|
||||
"result", "other", "$_self",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
||||
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
@@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store",
|
||||
"value", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"value", "mask",
|
||||
"getI1SameShape($_self)">]> {
|
||||
"value", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
|
||||
let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr",
|
||||
@@ -257,7 +259,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
compare $cmp with data $old at location $ptr,
|
||||
compare $cmp with data $old at location $ptr,
|
||||
|
||||
if $old == $cmp, store $val to $ptr,
|
||||
|
||||
|
@@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
|
||||
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
[MemoryEffects<[MemRead, MemWrite]>,
|
||||
SameVariadicOperandSize,
|
||||
TypesMatchWith<"infer mask type from ptr type",
|
||||
"ptr", "mask",
|
||||
"getI1SameShape($_self)">,
|
||||
"ptr", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from ptr type",
|
||||
"ptr", "other",
|
||||
"getPointeeType($_self)">]> {
|
||||
"ptr", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "copy async";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
||||
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
||||
let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
|
||||
|
||||
// result needs to be of shared layout
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
Reference in New Issue
Block a user