[Triton-MLIR] Fix side effects (#906)
Try to add proper side effects for triton operations. The CSE pass could fail, hang, or output incorrect IRs for unknown reasons, if side effects are not defined properly. For instance, suppose we have two shared memory tensors: ``` %a = triton_gpu.alloc_tensor shape0, share_encoding0 %b = triton_gpu.alloc_tensor shape0, share_encoding0 ``` The CSE pass will consider `%a` and `%b` are the same thing and eliminate one of them, resulting in mysterious outcomes.
This commit is contained in:
@@ -56,8 +56,12 @@ private:
|
|||||||
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
|
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
|
||||||
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
||||||
allocation) ||
|
allocation) ||
|
||||||
/*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers,
|
/*WAR*/
|
||||||
allocation);
|
isIntersected(syncReadBuffers, other.syncWriteBuffers,
|
||||||
|
allocation) ||
|
||||||
|
/*WAW*/
|
||||||
|
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
|
||||||
|
allocation);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clears the buffers because a barrier is inserted.
|
/// Clears the buffers because a barrier is inserted.
|
||||||
|
@@ -187,6 +187,7 @@ def TT_StoreOp : TT_Op<"store",
|
|||||||
//
|
//
|
||||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
|
MemoryEffects<[MemRead]>,
|
||||||
MemoryEffects<[MemWrite]>,
|
MemoryEffects<[MemWrite]>,
|
||||||
TypesMatchWith<"infer ptr type from value type",
|
TypesMatchWith<"infer ptr type from value type",
|
||||||
"val", "ptr",
|
"val", "ptr",
|
||||||
@@ -208,7 +209,9 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
|||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
|
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||||
|
MemoryEffects<[MemWrite]>,
|
||||||
|
SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding]> {
|
SameOperandsAndResultEncoding]> {
|
||||||
let summary = "atomic cas";
|
let summary = "atomic cas";
|
||||||
|
|
||||||
|
@@ -79,8 +79,7 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
|
|||||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||||
[AttrSizedOperandSegments,
|
[AttrSizedOperandSegments,
|
||||||
ResultsAreSharedEncoding,
|
ResultsAreSharedEncoding,
|
||||||
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
|
MemoryEffects<[MemRead]>,
|
||||||
NoSideEffect,
|
|
||||||
TypesMatchWith<"infer mask type from src type",
|
TypesMatchWith<"infer mask type from src type",
|
||||||
"src", "mask", "getI1SameShape($_self)",
|
"src", "mask", "getI1SameShape($_self)",
|
||||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||||
@@ -158,7 +157,8 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
|||||||
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> {
|
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
|
||||||
|
ResultsAreSharedEncoding]> {
|
||||||
let summary = "allocate tensor";
|
let summary = "allocate tensor";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@@ -172,6 +172,8 @@ def get_proper_err(a, b, golden):
|
|||||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||||
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
||||||
# K-Forloop
|
# K-Forloop
|
||||||
|
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||||
|
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
||||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||||
[128, 16, 128, 4, 128, 16, 32, False, False],
|
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||||
[32, 16, 128, 4, 32, 16, 32, False, False],
|
[32, 16, 128, 4, 32, 16, 32, False, False],
|
||||||
|
Reference in New Issue
Block a user