From 2e3335241932d540a8146c6aa9da794437a5bd73 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 22 Nov 2022 23:29:18 -0800 Subject: [PATCH] [Triton-MLIR] Fix side effects (#906) Try to add proper side effects for triton operations. The CSE pass could fail, hang, or output incorrect IRs for unknown reasons, if side effects are not defined properly. For instance, suppose we have two shared memory tensors: ``` %a = triton_gpu.alloc_tensor shape0, share_encoding0 %b = triton_gpu.alloc_tensor shape0, share_encoding0 ``` The CSE pass will consider `%a` and `%b` are the same thing and eliminate one of them, resulting in mysterious outcomes. --- include/triton/Analysis/Membar.h | 8 ++++++-- include/triton/Dialect/Triton/IR/TritonOps.td | 5 ++++- include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td | 6 +++--- python/tests/test_gemm.py | 2 ++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 7929eea03..242b54ecc 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -56,8 +56,12 @@ private: bool isIntersected(const RegionInfo &other, Allocation *allocation) const { return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers, allocation) || - /*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers, - allocation); + /*WAR*/ + isIntersected(syncReadBuffers, other.syncWriteBuffers, + allocation) || + /*WAW*/ + isIntersected(syncWriteBuffers, other.syncWriteBuffers, + allocation); } /// Clears the buffers because a barrier is inserted. diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d0981ce8f..97a015882 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -187,6 +187,7 @@ def TT_StoreOp : TT_Op<"store", // def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, MemoryEffects<[MemWrite]>, TypesMatchWith<"infer ptr type from value type", "val", "ptr", @@ -208,7 +209,9 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, let results = (outs TT_Type:$result); } -def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape, +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "atomic cas"; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 488c6a72d..e5b1da097 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -79,8 +79,7 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> { def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", [AttrSizedOperandSegments, ResultsAreSharedEncoding, - // MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should? - NoSideEffect, + MemoryEffects<[MemRead]>, TypesMatchWith<"infer mask type from src type", "src", "mask", "getI1SameShape($_self)", "($_op.getOperands().size() <= 3) || std::equal_to<>()">, @@ -158,7 +157,8 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; } -def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> { +def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory + ResultsAreSharedEncoding]> { let summary = "allocate tensor"; let description = [{ diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 4deff76a8..b12bd6fad 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -172,6 +172,8 @@ def get_proper_err(a, b, golden): [128, 64, 128, 4, 128, 64, 128, False, False], [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue # K-Forloop + [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding + [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k [64, 32, 128, 4, 64, 32, 64, False, False], [128, 16, 128, 4, 128, 16, 32, False, False], [32, 16, 128, 4, 32, 16, 32, False, False],