[TritonMLIR] Disallow 0D tensor (#788)

This commit is contained in:
Shintaro Iwasaki
2022-10-19 10:34:32 -07:00
committed by GitHub
parent 4464646efb
commit 0d22d2bc03
8 changed files with 203 additions and 87 deletions

View File

@@ -108,8 +108,7 @@ def TT_LoadOp : TT_Op<"load",
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
"result", "ptr", "getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
@@ -118,7 +117,7 @@ def TT_LoadOp : TT_Op<"load",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
@@ -147,13 +146,13 @@ def TT_StoreOp : TT_Op<"store",
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
@@ -318,7 +317,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,