From 16aed94ff5cc7056e567ea0e19d4ec3175070ee3 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 9 Sep 2022 12:03:41 -0700 Subject: [PATCH] [Analysis/Allocation] Allocation passes now assumes that slices always alias (#108) This code in this branch assumes the `src` operand in `insert_slice_async` always aliases the result, which shouldn't hold for generally cases but is just a workaround to make the pipeline pass work. I'm also working on the complete analysis in another [branch](https://github.com/openai/triton-mlir/tree/keren/analyze-slice). --- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 57 ++------- .../Dialect/TritonGPU/Transforms/Passes.td | 3 - lib/Analysis/Alias.cpp | 27 ++--- lib/Analysis/Allocation.cpp | 12 +- lib/Analysis/Membar.cpp | 16 ++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 52 +------- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 8 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 4 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 2 +- test/Analysis/test-alias.mlir | 114 +++++++----------- test/Analysis/test-allocation.mlir | 54 +++++++++ test/Analysis/test-membar.mlir | 37 ++++++ test/TritonGPU/matmul.mlir | 106 ++++++++++++++++ test/lib/Analysis/TestMembar.cpp | 2 +- 14 files changed, 299 insertions(+), 195 deletions(-) create mode 100644 test/TritonGPU/matmul.mlir diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index ef5b0914e..8f9e4b513 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -38,37 +38,6 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { let assemblyFormat = "attr-dict"; } -def TTG_CopyAsyncOp : TTG_Op<"copy_async", - [MemoryEffects<[MemRead, MemWrite]>, - SameVariadicOperandSize, - TypesMatchWith<"infer mask type from ptr type", - "ptr", "mask", "getI1SameShape($_self)", - "($_op.getOperands().size() <= 1) || std::equal_to<>()">, - TypesMatchWith<"infer other type from ptr type", - "ptr", "other", "getPointeeType($_self)", - "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { - let summary = "copy async"; - - let arguments = (ins TT_PtrTensor:$ptr, Optional:$mask, Optional:$other, - TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, - BoolAttr:$isVolatile); - - let builders = [ - OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, - ]; - - let results = (outs TT_Tensor:$result); - - // let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)"; - let parser = [{ return parseCopyAsyncOp(parser, result); }]; - - let printer = [{ return printCopyAsyncOp(p, *this); }]; - - // result needs to be of shared layout - let verifier = [{ return ::verify(*this); }]; -} - // Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. // This is needed because Arith's Cmp ops don't // handle encodings @@ -110,7 +79,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", let description = [{ This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s - `$offset` argument and `$axis` attribute. + `$index` argument and `$axis` attribute. It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`. This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait. @@ -119,7 +88,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", * src: the tensor that is inserted. * dst: the tensor into which the `$src` tensor is inserted. - * offset: the offset of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into + * index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into * mask: optional tensor-rank number of boolean masks which specify which elements of the `$src` tensor are inserted into the `$dst` tensor. * other: optional tensor-rank number of other tensors which specify what @@ -136,24 +105,24 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", ``` %1 = triton_gpu.alloc_tensor : tensor<2x32xf32> - %2 = triton_gpu.insert_slice_async %0, %1, %offset { axis = 0 } : tensor<32x!tt.ptr, #AL> -> tensor<2x32xf32, #A> + %2 = triton_gpu.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr, #AL> -> tensor<2x32xf32, #A> triiton_gpu.async_wait { num = 0 : i32 } ``` }]; - let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$offset, + let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index, Optional:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile, I32Attr:$axis); let builders = [ - OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, - OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, "Value":$mask, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, - OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, @@ -163,7 +132,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", //let assemblyFormat = [{ // $src `,` $dst `` - // $offset, $mask, $other + // $index, $mask, $other // attr-dict `:` type($src) `->` type($dst) //}]; @@ -180,26 +149,26 @@ def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInter let summary = "extract slice"; let description = [{ The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as - specified by the operation's `$offset` and `$axis` arguments. + specified by the operation's `$index` and `$axis` arguments. The extract_slice operation supports the following arguments: * src: the tensor that is extracted from. - * offset: the offset at the given `$axis` from which the `$src` tensor is extracted + * index: the index at the given `$axis` from which the `$src` tensor is extracted Example: ``` // Rank-reducing extract_slice. - %1 = tensor.extract_slice %0, %offset {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32> + %1 = tensor.extract_slice %0, %index {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32> ``` }]; - let arguments = (ins TT_Tensor:$src, I32:$offset, I32Attr:$axis); + let arguments = (ins TT_Tensor:$src, I32:$index, I32Attr:$axis); let results = (outs TT_Tensor:$result); - let assemblyFormat = [{$src `,` $offset attr-dict `:` type($src) `->` type($result)}]; + let assemblyFormat = [{$src `,` $index attr-dict `:` type($src) `->` type($result)}]; let extraClassDeclaration = [{ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 22a34f43d..a9d501527 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -39,9 +39,6 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { let summary = "combine triton gpu ops"; let description = [{ - convert_layout(load(%ptr, %mask, %other), #SMEM_LAYOUT) => - copy_async(%ptr, %mask, %other), barrier - convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) => convert_layout(%src, #LAYOUT_1) diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index ab3ca544e..31287f5c0 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -22,25 +22,24 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( AliasInfo aliasInfo; bool pessimistic = true; if (maybeSharedAllocationOp(op)) { - // These ops will allocate a new shared memory buffer. + // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); if (isSharedEncoding(result)) { - aliasInfo.insert(result); + // FIXME(Keren): extract and insert are always alias for now + if (auto extractSliceOp = dyn_cast(op)) { + // extract_slice %src, %index + aliasInfo = AliasInfo(operands[0]->getValue()); + } else if (auto insertSliceOp = + dyn_cast(op)) { + // insert_slice_async %src, %dst, %index + aliasInfo = AliasInfo(operands[1]->getValue()); + } else { + aliasInfo.insert(result); + } pessimistic = false; - } else { - llvm::errs() << "op: " << op->getName() << "\n"; } } - // XXX(Keren): triton ops don't support aliasing yet. - // else if (auto viewOp = dyn_cast(op) || - // dyn_cast(op)) { - // // These ops will reate a new view of the same shared memory buffer. - // auto result = op->getResult(0); - // if (isSharedEncoding(result)) { - // aliasInfo = AliasInfo(operands[0]->getValue()); - // pessimistic = false; - // } - //} + if (pessimistic) { return markAllPessimisticFixpoint(op->getResults()); } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 7a48ae778..7d883546a 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -39,11 +39,13 @@ private: /// Initializes explicitly defined shared memory values for a given operation. void getExplicitValueSize(Operation *op) { - /// Values returned from scf.yield will not be allocated even though they - /// have the shared encoding. - /// For example: %a = scf.if -> yield - /// %a must be allocated elsewhere by other operations. - if (!maybeSharedAllocationOp(op)) { + // Values returned from scf.yield will not be allocated even though they + // have the shared encoding. + // For example: %a = scf.if -> yield + // %a must be allocated elsewhere by other operations. + // FIXME(Keren): extract and insert are always alias for now + if (!maybeSharedAllocationOp(op) || isa(op) || + isa(op)) { return; } diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index f3fb1b083..e2d386fd4 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -45,19 +45,25 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, if (op->getNumResults() < 1) return; - if (dyn_cast(op) || dyn_cast(op) || - dyn_cast(op)) { - // Do not insert barriers before control flow operations. + if (isa(op) || isa(op) || isa(op) || + isa(op) || + isa(op) || + isa(op)) { + // Do not insert barriers before control flow operations and + // alloc/extract/insert + // alloc is an allocation op without memory write. + // In contrast, arith.constant is an allocation op with memory write. + // FIXME(Keren): extract and insert are always alias for now return; } - if (dyn_cast(op)) { + if (isa(op)) { // If the current op is a barrier, we sync previous reads and writes regionInfo->sync(); return; } - if (dyn_cast(op)) { + if (isa(op)) { // If the current op is an async wait, we insert a barrier op and sync // previous reads and writes. OpBuilder::InsertionGuard g(*builder); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 35f5aad0a..d337ed662 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -292,44 +292,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { << "}>"; } -//===----------------------------------------------------------------------===// -// CopyAsyncOp -//===----------------------------------------------------------------------===// - -ParseResult parseCopyAsyncOp(OpAsmParser &parser, OperationState &result) { - SmallVector allOperands; - Type resultTypes[1], ptrType; - SMLoc allOperandLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(allOperands) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.parseCustomTypeWithFallback(ptrType) || parser.parseArrow() || - parser.parseCustomTypeWithFallback(resultTypes[0])) - return failure(); - result.addTypes(resultTypes); - - SmallVector operandTypes; - operandTypes.push_back(ptrType); // ptr - if (allOperands.size() >= 2) - operandTypes.push_back(triton::getI1SameShape(ptrType)); // mask - if (allOperands.size() >= 3) - operandTypes.push_back(triton::getPointeeType(ptrType)); // other - - if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, - result.operands)) - return failure(); - return success(); -} - -void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) { - printer << " "; - printer << copyAsyncOp.getOperation()->getOperands(); - printer.printOptionalAttrDict(copyAsyncOp->getAttrs(), /*elidedAttrs=*/{}); - printer << " : "; - printer.printStrippedAttrOrType(copyAsyncOp.ptr().getType()); - printer << " -> "; - printer.printStrippedAttrOrType(copyAsyncOp.result().getType()); -} - //===----------------------------------------------------------------------===// // InsertSliceAsyncOp //===----------------------------------------------------------------------===// @@ -350,7 +312,7 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, operandTypes.push_back(srcType); // src operandTypes.push_back(dstType); // dst operandTypes.push_back( - IntegerType::get(parser.getBuilder().getContext(), 32)); // offset + IntegerType::get(parser.getBuilder().getContext(), 32)); // index if (allOperands.size() >= 4) operandTypes.push_back(triton::getI1SameShape(srcType)); // mask if (allOperands.size() >= 5) @@ -389,6 +351,8 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes( auto axis = attributes.get("axis").cast().getInt(); if (axis < 0 || axis > srcShape.size()) return failure(); + // Since we only extract a slice from a certain index on the axis, + // the dims before the axis can be dropped. auto dstShape = srcShape.drop_front(axis + 1); auto returnType = RankedTensorType::get(dstShape, srcType.getElementType(), encoding); @@ -438,16 +402,10 @@ void TritonGPUDialect::initialize() { // Verification //===----------------------------------------------------------------------===// -static LogicalResult verify(CopyAsyncOp op) { - if (!isSharedEncoding(op.getResult())) { - return op.emitOpError("copy_async should return a shared memory tensor"); - } - return success(); -} - static LogicalResult verify(InsertSliceAsyncOp op) { if (!isSharedEncoding(op.getResult())) { - return op.emitOpError("copy_async should return a shared memory tensor"); + return op.emitOpError( + "insert_slice_async should return a shared memory tensor"); } return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 704a89734..d00c00c2d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -69,7 +69,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { // convert output types SmallVector newTypes; for (auto t : op->getResultTypes()) { - bool is_async = std::is_same::value; + bool is_async = std::is_same::value; newTypes.push_back(is_async ? t : convertType(t)); } // construct new op with the new encoding @@ -106,9 +106,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { builder.setInsertionPoint(curr); if (auto load = dyn_cast(curr)) coalesceOp(axisInfo, curr, load.ptr(), builder); - if (auto load = dyn_cast(curr)) - coalesceOp(axisInfo, curr, load.ptr(), - builder); + if (auto load = dyn_cast(curr)) + coalesceOp(axisInfo, curr, load.src(), + builder); if (auto store = dyn_cast(curr)) coalesceOp(axisInfo, curr, store.ptr(), builder); }); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 32ff7e89e..42a8f61dd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -119,7 +119,9 @@ public: return mlir::failure(); auto blacklist = [](Operation *op) { - if (isa(op)) + if (isa( + op)) return true; if (isa(op)) return true; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index ea39bff54..8021f672e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -275,7 +275,7 @@ void LoopPipeliner::emitPrologue() { loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]}; } // load => copy async - // TODO: check if the hardware supports copyasync + // TODO: check if the hardware supports async copy if (auto loadOp = llvm::dyn_cast(op)) { newOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 643a7dfc4..10f5c1e42 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -37,7 +37,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B func @alloc(%A : !tt.ptr) { // CHECK: %cst -> %cst %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK: %0 -> %0 + %cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A> return } @@ -49,25 +51,28 @@ func @convert(%A : !tt.ptr) { return } -// CHECK-LABEL: copy_async -func @copy_async(%A : !tt.ptr, %i1 : i1) { +// CHECK-LABEL: insert_slice_async +func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - // CHECK: %2 -> %2 - %a = triton_gpu.copy_async %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<16x16xf16, #A> + // CHECK: %cst_0 -> %cst_0 + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %index = arith.constant 0 : i32 + // CHECK: %2 -> %cst_0 + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> return } -// COM: Enable the following test once we support view on shared memory tensors -// COM: // CHECK-LABEL: view -// COM: func @view(%A : !tt.ptr) { -// COM: // CHECK: res0:0 -> 0 -// COM: %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> -// COM: // CHECK-NEXT: res1:0 -> 0 -// COM: %cst1 = tt.view %cst0 : (tensor<16x16xf16, #A>) -> tensor<32x8xf16, #A> -// COM: return -// COM: } +// CHECK-LABEL: extract_slice +func @extract_slice(%A : !tt.ptr) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %index = arith.constant 0 : i32 + // CHECK-NEXT: %0 -> %cst + %cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> + return +} // CHECK-LABEL: if_cat func @if_cat(%i1 : i1) { @@ -123,62 +128,31 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p return } -// COM: // Enable the following test once we support view on shared memory tensors -// COM: // CHECK-LABEL: for_if -// COM: func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { -// COM: // CHECK: res0:0 -> 0 -// COM: %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: // CHECK-NEXT: res1:0 -> 1 -// COM: %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: // CHECK-NEXT: res2:0 -> 2 -// COM: %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: // CHECK-NEXT: arg3:0 -> 0 -// COM: // CHECK-NEXT: arg3:1 -> 1 -// COM: // CHECK-NEXT: arg3:2 -> 2 -// COM: // CHECK-NEXT: res3:0 -> 0,1 -// COM: // CHECK-NEXT: res3:1 -> 0,1 -// COM: // CHECK-NEXT: res3:2 -> 0,1 -// COM: %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { -// COM: scf.if %i1 { -// COM: // CHECK-NEXT: res5:0 -> 0,1 -// COM: %cst0 = tt.view %a_shared : (tensor<128x32xf16, #A>) -> tensor<32x128xf16, #A> -// COM: scf.yield -// COM: } -// COM: scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> -// COM: } -// COM: return -// COM: } - -// COM: // Enable the following test once we support view on shared memory tensors -// COM: // CHECK-LABEL: for_if_else -// COM: func @for_if_else(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { -// COM: // CHECK: res0:0 -> 0 -// COM: %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: // CHECK-NEXT: res1:0 -> 1 -// COM: %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: // CHECK-NEXT: res2:0 -> 2 -// COM: %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: // CHECK-NEXT: arg3:0 -> 0 -// COM: // CHECK-NEXT: arg3:1 -> 1 -// COM: // CHECK-NEXT: arg3:2 -> 2 -// COM: // CHECK-NEXT: res3:0 -> 0 -// COM: // CHECK-NEXT: res3:1 -> 1 -// COM: // CHECK-NEXT: res3:2 -> 0,7 -// COM: %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { -// COM: // CHECK-NEXT: res4:0 -> 0,7 -// COM: %c_shared_next = scf.if %i1 -> tensor<128x32xf16, #A> { -// COM: // CHECK-NEXT: res5:0 -> 0 -// COM: %cst0 = tt.view %a_shared : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A> -// COM: scf.yield %cst0 : tensor<128x32xf16, #A> -// COM: } else { -// COM: // CHECK-NEXT: res7:0 -> 7 -// COM: %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> -// COM: scf.yield %cst0 : tensor<128x32xf16, #A> -// COM: } -// COM: scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> -// COM: } -// COM: return -// COM: } +// CHECK-LABEL: for_if +func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: %cst -> %cst + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %cst_0 -> %cst_0 + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %cst_1 -> %cst_1 + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: %arg7 -> %cst + // CHECK-NEXT: %arg8 -> %cst_0 + // CHECK-NEXT: %arg9 -> %cst_1 + // CHECK-NEXT: %0#0 -> %cst,%cst_0 + // CHECK-NEXT: %0#1 -> %cst,%cst_0 + // CHECK-NEXT: %0#2 -> %cst,%cst_0 + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + scf.if %i1 { + %index = arith.constant 8 : i32 + // CHECK-NEXT: %1 -> %cst,%cst_0 + %cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A> + scf.yield + } + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + return +} // CHECK-LABEL: for_if_for func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index d07d959fa..1daaceec7 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -149,6 +149,17 @@ func @longlive(%A : !tt.ptr) { // CHECK-NEXT: size = 2560 } +// CHECK-LABEL: alloc +func @alloc(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK-NEXT: offset = 0, size = 512 + %cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A> + return + // CHECK-NEXT: size = 512 +} + // CHECK-LABEL: scratch func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -158,6 +169,29 @@ func @scratch() { // CHECK-NEXT: size = 512 } +// CHECK-LABEL: insert_slice_async +func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: offset = 0, size = 512 + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %index = arith.constant 0 : i32 + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + return + // CHECK-NEXT: size = 512 +} + +// CHECK-LABEL: extract_slice +func @extract_slice(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %index = arith.constant 0 : i32 + %cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> + return + // CHECK-NEXT: size = 512 +} + // B0 -> (B1) -> B0 // Memory used by B1 can be reused by B0. // CHECK-LABEL: if @@ -226,6 +260,26 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p // CHECK-NEXT: size = 24576 } +// CHECK-LABEL: for_if_slice +func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 8192, size = 8192 + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 16384, size = 8192 + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + scf.if %i1 { + %index = arith.constant 8 : i32 + %cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A> + scf.yield + } + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + return + // CHECK-NEXT: size = 24576 +} + // a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2. // So they cannot be reused by cst0 and cst1, but can be reused by cst2. // CHECK-LABEL: for_if_for diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 238c57b07..f611b4824 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -56,6 +56,8 @@ func @war_single_block(%A : !tt.ptr) { %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> // CHECK: Membar 5 %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL> + // a2's liveness range ends here, and a3 and a2 have the same address range. + // So it makes sense to have a WAR dependency between a2 and a3. // CHECK-NEXT: Membar 7 %a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> return @@ -82,6 +84,41 @@ func @async_wait() { return } +// CHECK-LABEL: alloc +func @alloc() { + %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A> + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: Membar 2 + %b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL> + return +} + +// CHECK-LABEL: extract_slice +func @extract_slice() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %index = arith.constant 0 : i32 + %cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> + // CHECK: Membar 3 + %cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + // CHECK-NEXT: Membar 5 + %cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> + return +} + +// CHECK-LABEL: insert_slice_async +func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A> + %index = arith.constant 0 : i32 + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A> + // CHECK: Membar 7 + %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A> + return +} + // If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks func @multi_blocks(%i1 : i1) { diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir new file mode 100644 index 000000000..ed48ac371 --- /dev/null +++ b/test/TritonGPU/matmul.mlir @@ -0,0 +1,106 @@ +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s + +// CHECK: offset = 0, size = 49152 +// CHECK: offset = 49152, size = 49152 +// CHECK: size = 98304 +module { +func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { + %cst = arith.constant dense : tensor<64x64xi1> + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.cmpi slt, %8, %c8_i32 : i32 + %10 = select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c64_i32 : i32 + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %17 = tt.splat %15 : (i32) -> tensor<64xi32> + %18 = arith.addi %17, %16 : tensor<64xi32> + %19 = arith.muli %14, %c64_i32 : i32 + %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %21 = tt.splat %19 : (i32) -> tensor<64xi32> + %22 = arith.addi %21, %20 : tensor<64xi32> + %23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %24 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> + %25 = tt.splat %arg6 : (i32) -> tensor<64x1xi32> + %26 = arith.muli %24, %25 : tensor<64x1xi32> + %27 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %28 = tt.splat %arg7 : (i32) -> tensor<1x64xi32> + %29 = arith.muli %27, %28 : tensor<1x64xi32> + %30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x64xi32> + %31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %32 = arith.addi %30, %31 : tensor<64x64xi32> + %33 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x64x!tt.ptr> + %34 = tt.getelementptr %33, %32 : tensor<64x64x!tt.ptr> + %35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> + %36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32> + %37 = arith.muli %35, %36 : tensor<64x1xi32> + %38 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %39 = tt.splat %arg9 : (i32) -> tensor<1x64xi32> + %40 = arith.muli %38, %39 : tensor<1x64xi32> + %41 = tt.broadcast %37 : (tensor<64x1xi32>) -> tensor<64x64xi32> + %42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %43 = arith.addi %41, %42 : tensor<64x64xi32> + %44 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x64x!tt.ptr> + %45 = tt.getelementptr %44, %43 : tensor<64x64x!tt.ptr> + %46 = arith.index_cast %arg5 : i32 to index + %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { + %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32> + %77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32> + %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> + %79 = arith.addf %arg13, %78 : tensor<64x64xf32> + %80 = arith.muli %arg7, %c64_i32 : i32 + %81 = tt.splat %80 : (i32) -> tensor<64x64xi32> + %82 = tt.getelementptr %arg14, %81 : tensor<64x64x!tt.ptr> + %83 = arith.muli %arg8, %c64_i32 : i32 + %84 = tt.splat %83 : (i32) -> tensor<64x64xi32> + %85 = tt.getelementptr %arg15, %84 : tensor<64x64x!tt.ptr> + scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr> + } + %48 = arith.muli %12, %c64_i32 : i32 + %49 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %50 = tt.splat %48 : (i32) -> tensor<64xi32> + %51 = arith.addi %50, %49 : tensor<64xi32> + %52 = arith.muli %14, %c64_i32 : i32 + %53 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %54 = tt.splat %52 : (i32) -> tensor<64xi32> + %55 = arith.addi %54, %53 : tensor<64xi32> + %56 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> + %57 = tt.splat %arg10 : (i32) -> tensor<64x1xi32> + %58 = arith.muli %57, %56 : tensor<64x1xi32> + %59 = tt.expand_dims %55 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %60 = tt.splat %arg11 : (i32) -> tensor<1x64xi32> + %61 = arith.muli %59, %60 : tensor<1x64xi32> + %62 = tt.broadcast %58 : (tensor<64x1xi32>) -> tensor<64x64xi32> + %63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %64 = arith.addi %62, %63 : tensor<64x64xi32> + %65 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x64x!tt.ptr> + %66 = tt.getelementptr %65, %64 : tensor<64x64x!tt.ptr> + %67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> + %68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32> + %69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32> + %70 = tt.expand_dims %55 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %71 = tt.splat %arg4 : (i32) -> tensor<1x64xi32> + %72 = arith.cmpi slt, %70, %71 : tensor<1x64xi32> + %73 = tt.broadcast %69 : (tensor<64x1xi1>) -> tensor<64x64xi1> + %74 = tt.broadcast %72 : (tensor<1x64xi1>) -> tensor<64x64xi1> + %75 = arith.andi %73, %74 : tensor<64x64xi1> + tt.store %66, %47#0, %75 : tensor<64x64xf32> + return + } +} \ No newline at end of file diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index ec8c53948..6a4c7ed0a 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -29,7 +29,7 @@ struct TestMembarPass MembarAnalysis analysis(&allocation); size_t operationId = 0; operation->walk([&](Operation *op) { - if (dyn_cast(op)) { + if (isa(op)) { os << "Membar " << operationId << "\n"; } if (op->getNumRegions() == 0) {