diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 06c5ec519..8763bbe6c 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -12,6 +12,7 @@ namespace mlir { namespace test { +void registerTestAliasPass(); void registerTestAlignmentPass(); void registerTestAllocationPass(); void registerTestMembarPass(); @@ -22,6 +23,7 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::registerTritonPasses(); mlir::registerTritonGPUPasses(); + mlir::test::registerTestAliasPass(); mlir::test::registerTestAlignmentPass(); mlir::test::registerTestAllocationPass(); mlir::test::registerTestMembarPass(); diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h new file mode 100644 index 000000000..fa6b906fc --- /dev/null +++ b/include/triton/Analysis/Alias.h @@ -0,0 +1,80 @@ +#ifndef TRITON_ANALYSIS_ALIAS_H +#define TRITON_ANALYSIS_ALIAS_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +class AliasInfo { +public: + AliasInfo() = default; + AliasInfo(Value value) { insert(value); } + + void insert(Value value) { allocs.insert(value); } + + const DenseSet &getAllocs() const { return allocs; } + + bool operator==(const AliasInfo &other) const { + return allocs == other.allocs; + } + + /// The pessimistic value state of a value without alias + static AliasInfo getPessimisticValueState(MLIRContext *context) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } + + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + +private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following + /// situations: + /// 1. values returned by scf.yield + /// 2. block arguments in scf.for + /// Example: + /// alloc v1 alloc v2 + /// | | + /// |--------------| |------------| + /// scf.for v3 scf.for v4 scf.for v5 + /// | + /// scf.yield v6 + /// + /// v1's alloc [v1] + /// v2's alloc [v2] + /// v3's alloc [v1] + /// v4's alloc [v1, v2] + /// v5's alloc [v2] + /// v6's alloc [v1] + /// + /// Therefore, v1's liveness range is the union of v3, v4, and v6 + /// v2's liveness range is the union of v4 and v5. + DenseSet allocs; +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Alias Analysis +//===----------------------------------------------------------------------===// +class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis { +public: + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. + AliasResult alias(Value lhs, Value rhs); + + /// Returns the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + + /// Computes if the alloc set of the results are changed. + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) override; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALIAS_H diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 64c870cd0..5f465f693 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -3,6 +3,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/raw_ostream.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -46,6 +47,8 @@ class Allocation { public: /// A unique identifier for shared memory buffers using BufferId = size_t; + using BufferIdSetT = DenseSet; + static constexpr BufferId InvalidBufferId = std::numeric_limits::max(); @@ -67,6 +70,9 @@ public: } /// Returns the buffer id of the given value. + /// This interface only returns the allocated buffer id. + /// If you want to get all the buffer ids that are associated with the given + /// value, including alias buffers, use getBufferIds. BufferId getBufferId(Value value) const { if (valueBuffer.count(value)) { return valueBuffer.lookup(value)->id; @@ -75,6 +81,19 @@ public: } } + /// Returns all the buffer ids of the given value, including alias buffers. + BufferIdSetT getBufferIds(Value value) const { + BufferIdSetT bufferIds; + auto allocBufferId = getBufferId(value); + if (allocBufferId != InvalidBufferId) + bufferIds.insert(allocBufferId); + for (auto *buffer : aliasBuffer.lookup(value)) { + if (buffer->id != InvalidBufferId) + bufferIds.insert(buffer->id); + } + return bufferIds; + } + /// Returns the scratch buffer id of the given value. BufferId getBufferId(Operation *operation) const { if (opScratch.count(operation)) { @@ -126,7 +145,9 @@ private: /// Op -> Scratch Buffer using OpScratchMapT = DenseMap; /// Value -> Explicit Buffer - using ValueBufferMapT = DenseMap; + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; /// BufferId -> Buffer using BufferSetT = DenseMap; /// Runs allocation analysis on the given top-level operation. @@ -144,10 +165,15 @@ private: } } + void addAlias(Value value, Value alloc) { + aliasBuffer[value].insert(valueBuffer[alloc]); + } + private: Operation *operation; OpScratchMapT opScratch; ValueBufferMapT valueBuffer; + AliasBufferMapT aliasBuffer; BufferSetT bufferSet; size_t sharedMemorySize = 0; diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 383c7a053..7929eea03 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -33,7 +33,7 @@ public: private: struct RegionInfo { - using BufferIdSetT = DenseSet; + using BufferIdSetT = Allocation::BufferIdSetT; BufferIdSetT syncReadBuffers; BufferIdSetT syncWriteBuffers; diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h new file mode 100644 index 000000000..31869c10f --- /dev/null +++ b/include/triton/Analysis/Utility.h @@ -0,0 +1,16 @@ +#ifndef TRITON_ANALYSIS_UTILITY_H +#define TRITON_ANALYSIS_UTILITY_H + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +namespace mlir { + +bool isSharedEncoding(Value value); + +bool maybeSharedAllocationOp(Operation *op); + +std::string getValueOperandName(Value value, AsmState &state); + +} // namespace mlir + +#endif // TRITON_ANALYSIS_UTILITY_H diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp new file mode 100644 index 000000000..3cc7968d1 --- /dev/null +++ b/lib/Analysis/Alias.cpp @@ -0,0 +1,66 @@ +#include "triton/Analysis/Alias.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + if (lhs == rhs) + return lhs; + AliasInfo ret; + for (auto value : lhs.allocs) { + ret.insert(value); + } + for (auto value : rhs.allocs) { + ret.insert(value); + } + return ret; +} + +ChangeResult SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands) { + AliasInfo aliasInfo; + bool pessimistic = true; + if (maybeSharedAllocationOp(op)) { + // These ops will allocate a new shared memory buffer. + auto result = op->getResult(0); + if (isSharedEncoding(result)) { + aliasInfo.insert(result); + pessimistic = false; + } + } else { + llvm::errs() << "op: " << op->getName() << "\n"; + } + // XXX(Keren): triton ops don't support aliasing yet. + // else if (auto viewOp = dyn_cast(op) || + // dyn_cast(op)) { + // // These ops will reate a new view of the same shared memory buffer. + // auto result = op->getResult(0); + // if (isSharedEncoding(result)) { + // aliasInfo = AliasInfo(operands[0]->getValue()); + // pessimistic = false; + // } + //} + if (pessimistic) { + return markAllPessimisticFixpoint(op->getResults()); + } + // Join all latice elements + ChangeResult result = ChangeResult::NoChange; + for (Value value : op->getResults()) { + result |= getLatticeElement(value).join(aliasInfo); + } + return result; +} + +AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { + // TODO: implement + return AliasResult::MayAlias; +} + +ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op, + Value location) { + // TODO: implement + return ModRefResult::getModAndRef(); +} + +} // namespace mlir diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 80266ff29..3ea775415 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -1,6 +1,8 @@ #include "triton/Analysis/Allocation.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Analysis/Alias.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" @@ -37,17 +39,22 @@ private: /// Initializes explicitly defined shared memory values for a given operation. void getExplicitValueSize(Operation *op) { + /// Values returned from scf.yield will not be allocated even though they + /// have the shared encoding. + /// For example: %a = scf.if -> yield + /// %a must be allocated elsewhere by other operations. + if (!maybeSharedAllocationOp(op)) { + return; + } + for (Value result : op->getResults()) { - auto type = result.getType(); - if (auto tensorType = type.dyn_cast()) { - auto encoding = tensorType.getEncoding(); - if (encoding && encoding.isa()) { - // Bytes could be a different value once we support padding or other - // allocation policies. - auto bytes = tensorType.getNumElements() * - tensorType.getElementTypeBitWidth() / 8; - allocation->addBuffer(result, bytes); - } + if (isSharedEncoding(result)) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto tensorType = result.getType().dyn_cast(); + auto bytes = tensorType.getNumElements() * + tensorType.getElementTypeBitWidth() / 8; + allocation->addBuffer(result, bytes); } } } @@ -67,12 +74,86 @@ private: } } + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + LatticeElement *latticeElement = + analysis.lookupLatticeElement(value); + if (latticeElement) { + auto &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + /// Extract all shared memory values and their sizes void getValuesAndSizes() { + // Get the alloc values operation->walk([&](Operation *op) { getExplicitValueSize(op); getScratchValueSize(op); }); + // Get the alias values + SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext()); + aliasAnalysis.run(operation); + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto valueBufferIter : allocation->valueBuffer) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value); + } + } + + /// Extends the liveness range by unioning the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto aliasBufferIter : allocation->aliasBuffer) { + auto value = aliasBufferIter.first; + auto buffers = aliasBufferIter.second; + auto range = getLiveness(value); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Range(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers + for (auto opScratchIter : allocation->opScratch) { + // Any scratch memory's live range is the current operation's live + // range. + auto *op = opScratchIter.first; + auto *buffer = opScratchIter.second; + bufferRange.insert( + {buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)}); + } } /// Resolves liveness of all values involved under the root operation. @@ -83,34 +164,27 @@ private: operation->walk( [&](Operation *op) { operationId[op] = operationId.size(); }); + // Analyze liveness of explicit buffers Liveness liveness(operation); - operation->walk([&](Operation *op) { - for (Value result : op->getResults()) { - auto liveOperations = liveness.resolveLiveness(result); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); - std::for_each(liveOperations.begin(), liveOperations.end(), - [&](Operation *liveOp) { - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if (operationId[liveOp] > maxId) { - maxId = operationId[liveOp]; - } - }); - if (allocation->valueBuffer.count(result)) { - auto *buffer = allocation->valueBuffer[result]; - bufferRange.insert({buffer, Range(minId, maxId + 1)}); - } - } - if (allocation->opScratch.count(op)) { - // Any scratch memory's live range is the current operation's live - // range. - auto *buffer = allocation->opScratch[op]; - bufferRange.insert( - {buffer, Range(operationId[op], operationId[op] + 1)}); - } - }); + auto getValueLivenessRange = [&](Value value) { + auto liveOperations = liveness.resolveLiveness(value); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Range(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); } /// Computes the shared memory offsets for all related values. diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index a4acf328f..34059e9fe 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_library(TritonAnalysis AxisInfo.cpp Allocation.cpp Membar.cpp + Alias.cpp + Utility.cpp DEPENDS TritonGPUAttrDefsIncGen diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index ab3b0d68c..b3939fc2d 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -46,6 +46,12 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, if (op->getNumResults() < 1) return; + if (dyn_cast(op) || dyn_cast(op) || + dyn_cast(op)) { + // Do not insert barriers before control flow operations. + return; + } + if (dyn_cast(op)) { // If the current op is a barrier, we sync previous reads and writes regionInfo->sync(); @@ -62,24 +68,28 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, return; } - auto addBuffer = [&](RegionInfo::BufferIdSetT &bufferSet, - Allocation::BufferId bufferId) { - if (bufferId != Allocation::InvalidBufferId) { - bufferSet.insert(bufferId); - } - }; - RegionInfo curRegionInfo; for (Value value : op->getOperands()) { // ConvertLayoutOp: shared memory -> registers - addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(value)); + // Need to consider all alias buffers + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + curRegionInfo.syncReadBuffers.insert(bufferId); + } + } } for (Value value : op->getResults()) { // ConvertLayoutOp: registers -> shared memory - addBuffer(curRegionInfo.syncWriteBuffers, allocation->getBufferId(value)); + auto bufferId = allocation->getBufferId(value); + if (bufferId != Allocation::InvalidBufferId) { + curRegionInfo.syncWriteBuffers.insert(bufferId); + } } // Scratch buffer is considered as a shared memory read - addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(op)); + auto bufferId = allocation->getBufferId(op); + if (bufferId != Allocation::InvalidBufferId) { + curRegionInfo.syncReadBuffers.insert(bufferId); + } if (regionInfo->isIntersected(curRegionInfo, allocation)) { OpBuilder::InsertionGuard g(*builder); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp new file mode 100644 index 000000000..0dc827857 --- /dev/null +++ b/lib/Analysis/Utility.cpp @@ -0,0 +1,38 @@ +#include "triton/Analysis/Utility.h" +#include "mlir/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +bool isSharedEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = type.dyn_cast()) { + auto encoding = tensorType.getEncoding(); + return encoding && encoding.isa(); + } + return false; +} + +bool maybeSharedAllocationOp(Operation *op) { + // TODO(Keren): This function can be replaced by adding + // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to + // query the memory effects of the op. + auto *dialect = op->getDialect(); + return dialect && + (dialect->getTypeID() == + mlir::TypeID::get() || + dialect->getTypeID() == mlir::TypeID::get() || + dialect->getTypeID() == + mlir::TypeID::get()); +} + +std::string getValueOperandName(Value value, AsmState &state) { + auto *op = value.getDefiningOp(); + std::string opName; + llvm::raw_string_ostream ss(opName); + value.printAsOperand(ss, state); + return std::move(opName); +} + +} // namespace mlir diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir new file mode 100644 index 000000000..643a7dfc4 --- /dev/null +++ b/test/Analysis/test-alias.mlir @@ -0,0 +1,216 @@ +// RUN: triton-opt %s --mlir-disable-threading -test-print-alias -split-input-file 2>&1 | FileCheck %s + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> + +// CHECK-LABEL: matmul_loop +func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + // CHECK: %4 -> %4 + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + // CHECK-NEXT: %6 -> %6 + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + return +} + +// CHECK-LABEL: alloc +func @alloc(%A : !tt.ptr) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + return +} + +// CHECK-LABEL: convert +func @convert(%A : !tt.ptr) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: %0 -> %0 + %cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> + return +} + +// CHECK-LABEL: copy_async +func @copy_async(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: %2 -> %2 + %a = triton_gpu.copy_async %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<16x16xf16, #A> + return +} + +// COM: Enable the following test once we support view on shared memory tensors +// COM: // CHECK-LABEL: view +// COM: func @view(%A : !tt.ptr) { +// COM: // CHECK: res0:0 -> 0 +// COM: %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> +// COM: // CHECK-NEXT: res1:0 -> 0 +// COM: %cst1 = tt.view %cst0 : (tensor<16x16xf16, #A>) -> tensor<32x8xf16, #A> +// COM: return +// COM: } + +// CHECK-LABEL: if_cat +func @if_cat(%i1 : i1) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: %cst_0 -> %cst_0 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: %0 -> %1,%1 + %cst2 = scf.if %i1 -> tensor<32x16xf16, #A> { + // CHECK: %1 -> %1 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield %a : tensor<32x16xf16, #A> + } else { + // CHECK: %1 -> %1 + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield %b : tensor<32x16xf16, #A> + } + return +} + +// CHECK-LABEL: if_alias +func @if_alias(%i1 : i1) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: %cst_0 -> %cst_0 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: %0 -> %cst,%cst_0 + %cst2 = scf.if %i1 -> tensor<16x16xf16, #A> { + scf.yield %cst0 : tensor<16x16xf16, #A> + } else { + scf.yield %cst1 : tensor<16x16xf16, #A> + } + return +} + +// CHECK-LABEL: for +func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + // CHECK: %cst -> %cst + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %cst_0 -> %cst_0 + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %cst_1 -> %cst_1 + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %arg6 -> %cst + // CHECK-NEXT: %arg7 -> %cst_0 + // CHECK-NEXT: %arg8 -> %cst_1 + // CHECK-NEXT: %0#0 -> %cst,%cst_0 + // CHECK-NEXT: %0#1 -> %cst,%cst_0 + // CHECK-NEXT: %0#2 -> %cst,%cst_0 + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + return +} + +// COM: // Enable the following test once we support view on shared memory tensors +// COM: // CHECK-LABEL: for_if +// COM: func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +// COM: // CHECK: res0:0 -> 0 +// COM: %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: // CHECK-NEXT: res1:0 -> 1 +// COM: %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: // CHECK-NEXT: res2:0 -> 2 +// COM: %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: // CHECK-NEXT: arg3:0 -> 0 +// COM: // CHECK-NEXT: arg3:1 -> 1 +// COM: // CHECK-NEXT: arg3:2 -> 2 +// COM: // CHECK-NEXT: res3:0 -> 0,1 +// COM: // CHECK-NEXT: res3:1 -> 0,1 +// COM: // CHECK-NEXT: res3:2 -> 0,1 +// COM: %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { +// COM: scf.if %i1 { +// COM: // CHECK-NEXT: res5:0 -> 0,1 +// COM: %cst0 = tt.view %a_shared : (tensor<128x32xf16, #A>) -> tensor<32x128xf16, #A> +// COM: scf.yield +// COM: } +// COM: scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +// COM: } +// COM: return +// COM: } + +// COM: // Enable the following test once we support view on shared memory tensors +// COM: // CHECK-LABEL: for_if_else +// COM: func @for_if_else(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +// COM: // CHECK: res0:0 -> 0 +// COM: %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: // CHECK-NEXT: res1:0 -> 1 +// COM: %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: // CHECK-NEXT: res2:0 -> 2 +// COM: %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: // CHECK-NEXT: arg3:0 -> 0 +// COM: // CHECK-NEXT: arg3:1 -> 1 +// COM: // CHECK-NEXT: arg3:2 -> 2 +// COM: // CHECK-NEXT: res3:0 -> 0 +// COM: // CHECK-NEXT: res3:1 -> 1 +// COM: // CHECK-NEXT: res3:2 -> 0,7 +// COM: %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { +// COM: // CHECK-NEXT: res4:0 -> 0,7 +// COM: %c_shared_next = scf.if %i1 -> tensor<128x32xf16, #A> { +// COM: // CHECK-NEXT: res5:0 -> 0 +// COM: %cst0 = tt.view %a_shared : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A> +// COM: scf.yield %cst0 : tensor<128x32xf16, #A> +// COM: } else { +// COM: // CHECK-NEXT: res7:0 -> 7 +// COM: %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> +// COM: scf.yield %cst0 : tensor<128x32xf16, #A> +// COM: } +// COM: scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +// COM: } +// COM: return +// COM: } + +// CHECK-LABEL: for_if_for +func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: %cst -> %cst + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %cst_0 -> %cst_0 + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %cst_1 -> %cst_1 + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %arg7 -> %cst + // CHECK-NEXT: %arg8 -> %cst_0 + // CHECK-NEXT: %arg9 -> %cst_1 + // CHECK-NEXT: %0#0 -> %cst + // CHECK-NEXT: %0#1 -> %cst_0 + // CHECK-NEXT: %0#2 -> %cst_2,%cst_2 + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2 + // CHECK-NEXT: %1 -> %cst_2,%cst_2 + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A>) { + // CHECK-NEXT: %2 -> %cst_2,%cst_2 + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A> { + // CHECK-NEXT: %cst_2 -> %cst_2 + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + scf.yield %cst0 : tensor<128x32xf16, #A> + } else { + // CHECK-NEXT: %cst_2 -> %cst_2 + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + scf.yield %cst0 : tensor<128x32xf16, #A> + } + scf.yield %c_shared_next_next : tensor<128x32xf16, #A> + } + scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + return +} diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index ba38ee9d7..d07d959fa 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -151,21 +151,17 @@ func @longlive(%A : !tt.ptr) { // CHECK-LABEL: scratch func @scratch() { - // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK-NEXT: offset = 1056, size = 1024 - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK-NEXT: scratch offset = 32, size = 1024 - // CHECK-NEXT: offset = 0, size = 32 - %b = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #A> -> tensor<16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: scratch offset = 0, size = 512 + %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #AL> return - // CHECK-NEXT: size = 2080 + // CHECK-NEXT: size = 512 } // B0 -> (B1) -> B0 // Memory used by B1 can be reused by B0. -// CHECK-LABEL: multi_blocks_reuse -func @multi_blocks_reuse(%i1 : i1) { +// CHECK-LABEL: if +func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> // CHECK-NEXT: offset = 512, size = 512 @@ -188,8 +184,8 @@ func @multi_blocks_reuse(%i1 : i1) { // B0 -> (B1) -> (B2) -> B0 // Memory used by B0 cannot be reused by B1 or B2. -// CHECK-LABEL: multi_blocks_noreuse -func @multi_blocks_noreuse(%i1 : i1) { +// CHECK-LABEL: if_else +func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> // CHECK-NEXT: offset = 512, size = 512 @@ -212,3 +208,51 @@ func @multi_blocks_noreuse(%i1 : i1) { return // CHECK-NEXT: size = 3072 } + +// Block arguments and yields are memory aliases that do not trigger a new +// allocation. +// CHECK-LABEL: for +func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 8192, size = 8192 + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 16384, size = 8192 + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + return + // CHECK-NEXT: size = 24576 +} + +// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2. +// So they cannot be reused by cst0 and cst1, but can be reused by cst2. +// CHECK-LABEL: for_if_for +func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 8192, size = 8192 + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 16384, size = 8192 + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A>) { + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A> { + // CHECK-NEXT: offset = 24576, size = 8192 + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + scf.yield %cst0 : tensor<128x32xf16, #A> + } else { + // CHECK-NEXT: offset = 32768, size = 8192 + %cst1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + scf.yield %cst1 : tensor<128x32xf16, #A> + } + scf.yield %c_shared_next_next : tensor<128x32xf16, #A> + } + scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + // CHECK-NEXT: offset = 0, size = 8192 + %cst2 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + return + // CHECK-NEXT: size = 40960 +} diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 23ccfcdda..238c57b07 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -176,3 +176,36 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> return } + +// CHECK-LABEL: for +func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK-NEXT: Membar 3 + %cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + return +} + +// Although a_shared and b_shared are synced before entering the loop, +// they are reassociated with aliases (c_shared) and thus require a barrier. +// CHECK-LABEL: for_alias +func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: Membar 2 + %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> + // CHECK-NEXT: Membar 6 + %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> + scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + // CHECK-NEXT: Membar 9 + %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A>, tensor<256x32xf16, #A>) -> tensor<512x32xf16, #A> + return +} diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index 992d687e0..3b21835b7 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(TritonTestAnalysis + TestAlias.cpp TestAxisInfo.cpp TestAllocation.cpp TestMembar.cpp diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp new file mode 100644 index 000000000..b8fef4e93 --- /dev/null +++ b/test/lib/Analysis/TestAlias.cpp @@ -0,0 +1,92 @@ +#include "mlir/IR/AsmState.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Alias.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; + +namespace { + +struct TestAliasPass + : public PassWrapper> { + + // LLVM15+ + // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); + static void print(StringRef name, SmallVector &vals, + raw_ostream &os) { + if (vals.empty()) + return; + os << name << " -> "; + size_t i = 0; + for (auto val : vals) { + if (i != 0) + os << ","; + os << val; + ++i; + } + os << "\n"; + } + + StringRef getArgument() const final { return "test-print-alias"; } + StringRef getDescription() const final { + return "print the result of the alias analysis pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + auto &os = llvm::errs(); + auto op_name = SymbolTable::getSymbolName(operation).getValue().str(); + os << op_name << "\n"; + + SharedMemoryAliasAnalysis analysis(&getContext()); + analysis.run(operation); + + AsmState state(operation->getParentOfType()); + // Get operation ids of value's aliases + auto getAllocOpNames = [&](Value value) { + LatticeElement *latticeElement = + analysis.lookupLatticeElement(value); + SmallVector opNames; + if (latticeElement) { + auto &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto &alias : info.getAllocs()) { + auto opName = + getValueOperandName(alias.getDefiningOp()->getResult(0), state); + opNames.push_back(std::move(opName)); + } + } + } + // Ensure deterministic output + std::sort(opNames.begin(), opNames.end()); + return opNames; + }; + + operation->walk([&](Operation *op) { + if (op->getNumResults() < 1) + return; + if (auto forOp = dyn_cast(op)) { + for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { + auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); + auto opNames = getAllocOpNames(operand); + auto argName = getValueOperandName(arg.value(), state); + print(argName, opNames, os); + } + } + for (auto result : llvm::enumerate(op->getResults())) { + auto opNames = getAllocOpNames(result.value()); + auto resultName = getValueOperandName(result.value(), state); + print(resultName, opNames, os); + } + }); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestAliasPass() { PassRegistration(); } +} // namespace test +} // namespace mlir