[Triton-MLIR] Fix side effects (#906)

Try to add proper side effects for triton operations. 

The CSE pass could fail, hang, or output incorrect IRs for unknown
reasons, if side effects are not defined properly.

For instance, suppose we have two shared memory tensors:

```
%a = triton_gpu.alloc_tensor shape0, share_encoding0
%b = triton_gpu.alloc_tensor shape0, share_encoding0
```

The CSE pass will consider `%a` and `%b` are the same thing and
eliminate one of them, resulting in mysterious outcomes.
This commit is contained in:
Keren Zhou
2022-11-22 23:29:18 -08:00
committed by GitHub
parent 037f9efa95
commit 2e33352419
4 changed files with 15 additions and 6 deletions

View File

@@ -56,8 +56,12 @@ private:
bool isIntersected(const RegionInfo &other, Allocation *allocation) const { bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers, return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) || allocation) ||
/*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers, /*WAR*/
allocation); isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation) ||
/*WAW*/
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
allocation);
} }
/// Clears the buffers because a barrier is inserted. /// Clears the buffers because a barrier is inserted.

View File

@@ -187,6 +187,7 @@ def TT_StoreOp : TT_Op<"store",
// //
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding, SameOperandsAndResultEncoding,
MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>, MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type", TypesMatchWith<"infer ptr type from value type",
"val", "ptr", "val", "ptr",
@@ -208,7 +209,9 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);
} }
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape, def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> { SameOperandsAndResultEncoding]> {
let summary = "atomic cas"; let summary = "atomic cas";

View File

@@ -79,8 +79,7 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments, [AttrSizedOperandSegments,
ResultsAreSharedEncoding, ResultsAreSharedEncoding,
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should? MemoryEffects<[MemRead]>,
NoSideEffect,
TypesMatchWith<"infer mask type from src type", TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)", "src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">, "($_op.getOperands().size() <= 3) || std::equal_to<>()">,
@@ -158,7 +157,8 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
} }
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> { def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
ResultsAreSharedEncoding]> {
let summary = "allocate tensor"; let summary = "allocate tensor";
let description = [{ let description = [{

View File

@@ -172,6 +172,8 @@ def get_proper_err(a, b, golden):
[128, 64, 128, 4, 128, 64, 128, False, False], [128, 64, 128, 4, 128, 64, 128, False, False],
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
# K-Forloop # K-Forloop
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
[64, 32, 128, 4, 64, 32, 64, False, False], [64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32, False, False], [128, 16, 128, 4, 128, 16, 32, False, False],
[32, 16, 128, 4, 32, 16, 32, False, False], [32, 16, 128, 4, 32, 16, 32, False, False],