[OPTIMIZER] Pipeline async buffer (#110)
This commit is contained in:
@@ -34,6 +34,14 @@ class LoopPipeliner {
|
|||||||
DenseMap<Value, Value> loadsMapping;
|
DenseMap<Value, Value> loadsMapping;
|
||||||
/// load => buffer
|
/// load => buffer
|
||||||
DenseMap<Value, Value> loadsBuffer;
|
DenseMap<Value, Value> loadsBuffer;
|
||||||
|
/// load => buffer at stage N
|
||||||
|
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
|
||||||
|
/// load => after extract
|
||||||
|
DenseMap<Value, Value> loadsExtract;
|
||||||
|
///
|
||||||
|
Value pipelineIterIdx;
|
||||||
|
///
|
||||||
|
Value loopIterIdx;
|
||||||
|
|
||||||
/// value (in loop) => value at stage N
|
/// value (in loop) => value at stage N
|
||||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||||
@@ -237,6 +245,7 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
|
|
||||||
// prologue from [0, numStage-1)
|
// prologue from [0, numStage-1)
|
||||||
Value iv = forOp.getLowerBound();
|
Value iv = forOp.getLowerBound();
|
||||||
|
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||||
// special handling for induction variable as the increment is implicit
|
// special handling for induction variable as the increment is implicit
|
||||||
if (stage != 0)
|
if (stage != 0)
|
||||||
@@ -261,22 +270,21 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
Operation *newOp = nullptr;
|
Operation *newOp = nullptr;
|
||||||
if (loads.contains(op->getResult(0))) {
|
if (loads.contains(op->getResult(0))) {
|
||||||
// Allocate empty buffer
|
// Allocate empty buffer
|
||||||
if (stage == 0)
|
if (stage == 0) {
|
||||||
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
|
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
|
||||||
|
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
|
||||||
|
}
|
||||||
// load => copy async
|
// load => copy async
|
||||||
// TODO: check if the hardware supports copyasync
|
// TODO: check if the hardware supports copyasync
|
||||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||||
Value sliceIndex = builder.create<arith::IndexCastOp>(
|
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||||
iv.getLoc(), builder.getI32Type(), iv);
|
|
||||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
|
||||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||||
lookupOrDefault(loadOp.ptr(), stage), loadsBuffer[loadOp],
|
lookupOrDefault(loadOp.ptr(), stage),
|
||||||
sliceIndex, lookupOrDefault(loadOp.mask(), stage),
|
loadStageBuffer[loadOp][stage], pipelineIterIdx,
|
||||||
|
lookupOrDefault(loadOp.mask(), stage),
|
||||||
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
||||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||||
newOp = builder.create<triton::gpu::ExtractSliceOp>(
|
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
|
||||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
|
||||||
sliceIndex, /*axis*/ 0);
|
|
||||||
} else
|
} else
|
||||||
llvm_unreachable("This should be LoadOp");
|
llvm_unreachable("This should be LoadOp");
|
||||||
} else {
|
} else {
|
||||||
@@ -294,9 +302,11 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
|
|
||||||
// If this is a load/async_copy, we need to update the mask
|
// If this is a load/async_copy, we need to update the mask
|
||||||
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
|
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
|
||||||
Value mask = newOp->getOperand(1);
|
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
|
||||||
|
: newOp->getOperand(3);
|
||||||
// assert(I1 or TensorOf<[I1]>);
|
// assert(I1 or TensorOf<[I1]>);
|
||||||
OpBuilder::InsertionGuard g(builder);
|
OpBuilder::InsertionGuard g(builder);
|
||||||
|
// TODO: move this out of the loop
|
||||||
builder.setInsertionPoint(newOp);
|
builder.setInsertionPoint(newOp);
|
||||||
Value splatCond = builder.create<triton::SplatOp>(
|
Value splatCond = builder.create<triton::SplatOp>(
|
||||||
mask.getLoc(), mask.getType(), loopCond);
|
mask.getLoc(), mask.getType(), loopCond);
|
||||||
@@ -313,8 +323,11 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||||
Value originalResult = op->getResult(dstIdx);
|
Value originalResult = op->getResult(dstIdx);
|
||||||
// copy_async will update the value of its only use
|
// copy_async will update the value of its only use
|
||||||
if (loads.contains(originalResult))
|
// TODO: load should no be used in the preheader?
|
||||||
originalResult = loadsMapping[originalResult];
|
if (loads.contains(originalResult)) {
|
||||||
|
break;
|
||||||
|
// originalResult = loadsMapping[originalResult];
|
||||||
|
}
|
||||||
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
|
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
|
||||||
// update mapping for loop-carried values (args)
|
// update mapping for loop-carried values (args)
|
||||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||||
@@ -325,18 +338,25 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||||
|
iv.getLoc(), pipelineIterIdx,
|
||||||
|
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||||
|
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||||
|
|
||||||
// async.wait & extract_slice
|
// async.wait & extract_slice
|
||||||
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
||||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||||
for (int i = numStages - 2; i >= 0; --i) {
|
for (Value loadOp : loads) {
|
||||||
for (auto it = loads.rbegin(); it != loads.rend(); ++it) {
|
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||||
// move extract_slice after asyncWait
|
loadOp.getLoc(), loadsMapping[loadOp].getType(),
|
||||||
Value load = *it;
|
loadStageBuffer[loadOp][numStages - 1],
|
||||||
valueMapping[loadsMapping[load]][i].getDefiningOp()->moveAfter(asyncWait);
|
builder.create<arith::ConstantIntOp>(loadOp.getLoc(), 0, 32),
|
||||||
}
|
/*axis*/ 0);
|
||||||
|
loadsExtract[loadOp] = extractSlice;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||||
@@ -344,10 +364,12 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
|
|
||||||
// order of new args:
|
// order of new args:
|
||||||
// (original args),
|
// (original args),
|
||||||
// for each load result (after layout conversion) x:
|
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||||
// (x at stage[0, numStages-1))
|
// (extracted tensor) for each load
|
||||||
// (depArgs at stage numStages-1)
|
// (depArgs at stage numStages-1)
|
||||||
// (iv at stage numStages-1)
|
// (iv at stage numStages-1)
|
||||||
|
// (pipeline iteration index)
|
||||||
|
// (loop iteration index)
|
||||||
SmallVector<Value> newLoopArgs;
|
SmallVector<Value> newLoopArgs;
|
||||||
// We need this to update operands for yield
|
// We need this to update operands for yield
|
||||||
// original block arg => new arg's idx
|
// original block arg => new arg's idx
|
||||||
@@ -355,10 +377,12 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
for (auto v : forOp.getIterOperands())
|
for (auto v : forOp.getIterOperands())
|
||||||
newLoopArgs.push_back(v);
|
newLoopArgs.push_back(v);
|
||||||
|
|
||||||
|
size_t bufferIdx = newLoopArgs.size();
|
||||||
|
for (Value loadOp : loads)
|
||||||
|
newLoopArgs.push_back(loadStageBuffer[loadOp].back());
|
||||||
size_t loadIdx = newLoopArgs.size();
|
size_t loadIdx = newLoopArgs.size();
|
||||||
for (Value loadOp : loads)
|
for (Value loadOp : loads)
|
||||||
for (int i = 0; i < numStages - 1; ++i)
|
newLoopArgs.push_back(loadsExtract[loadOp]);
|
||||||
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
|
|
||||||
|
|
||||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||||
for (BlockArgument depArg : depArgs) {
|
for (BlockArgument depArg : depArgs) {
|
||||||
@@ -368,6 +392,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
|
|
||||||
size_t nextIVIdx = newLoopArgs.size();
|
size_t nextIVIdx = newLoopArgs.size();
|
||||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
|
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
|
||||||
|
newLoopArgs.push_back(pipelineIterIdx);
|
||||||
|
newLoopArgs.push_back(loopIterIdx);
|
||||||
|
|
||||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||||
assert(newLoopArgs[i]);
|
assert(newLoopArgs[i]);
|
||||||
@@ -399,7 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
"we assume that this load has one use (ConvertLayout)");
|
"we assume that this load has one use (ConvertLayout)");
|
||||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||||
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
|
newForOp.getRegionIterArgs()[loadIdx + idx]);
|
||||||
// delete old load and layout conversion
|
// delete old load and layout conversion
|
||||||
mapping.lookup(loadUse).getDefiningOp()->erase();
|
mapping.lookup(loadUse).getDefiningOp()->erase();
|
||||||
mapping.lookup(load).getDefiningOp()->erase();
|
mapping.lookup(load).getDefiningOp()->erase();
|
||||||
@@ -432,12 +458,18 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
nextIV, newForOp.getUpperBound());
|
nextIV, newForOp.getUpperBound());
|
||||||
|
|
||||||
// slice index
|
// slice index
|
||||||
|
SmallVector<Value> nextBuffers;
|
||||||
SmallVector<Value> extractSlices;
|
SmallVector<Value> extractSlices;
|
||||||
Value sliceIndex = builder.create<arith::IndexCastOp>(
|
|
||||||
nextIV.getLoc(), builder.getI32Type(), nextIV);
|
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
|
||||||
sliceIndex = builder.create<arith::RemSIOp>(
|
Value insertSliceIndex = builder.create<arith::RemSIOp>(
|
||||||
nextIV.getLoc(), sliceIndex,
|
nextIV.getLoc(), pipelineIterIdx,
|
||||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||||
|
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
|
||||||
|
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||||
|
nextIV.getLoc(), loopIterIdx,
|
||||||
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||||
|
|
||||||
for (Operation *op : orderedDeps) {
|
for (Operation *op : orderedDeps) {
|
||||||
Operation *nextOp = nullptr;
|
Operation *nextOp = nullptr;
|
||||||
// TODO(da): does this work if loadOp has no mask?
|
// TODO(da): does this work if loadOp has no mask?
|
||||||
@@ -457,13 +489,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
}
|
}
|
||||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||||
nextMapping.lookupOrDefault(loadOp.ptr()), loadsBuffer[loadOp],
|
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||||
sliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
|
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
|
||||||
|
insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
|
||||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||||
|
nextBuffers.push_back(insertAsyncOp);
|
||||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
||||||
sliceIndex, /*axis*/ 0);
|
extractSliceIndex, /*axis*/ 0);
|
||||||
extractSlices.push_back(nextOp->getResult(0));
|
extractSlices.push_back(nextOp->getResult(0));
|
||||||
} else
|
} else
|
||||||
nextOp = builder.clone(*op, nextMapping);
|
nextOp = builder.clone(*op, nextMapping);
|
||||||
@@ -491,25 +525,29 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
it->getDefiningOp()->moveAfter(asyncWait);
|
it->getDefiningOp()->moveAfter(asyncWait);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bump iteration count
|
||||||
|
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||||
|
nextIV.getLoc(), pipelineIterIdx,
|
||||||
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
||||||
|
loopIterIdx = builder.create<arith::AddIOp>(
|
||||||
|
nextIV.getLoc(), loopIterIdx,
|
||||||
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
||||||
|
|
||||||
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
||||||
SmallVector<Value> yieldValues;
|
SmallVector<Value> yieldValues;
|
||||||
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
||||||
yieldValues.push_back(mapping.lookup(v));
|
yieldValues.push_back(mapping.lookup(v));
|
||||||
// shift pipelined args by 1
|
for (Value nextBuffer : nextBuffers)
|
||||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
yieldValues.push_back(nextBuffer);
|
||||||
Value load = loads[idx];
|
for (Value nextSlice : extractSlices)
|
||||||
for (int stage = 1; stage < numStages - 1; ++stage) {
|
yieldValues.push_back(nextSlice);
|
||||||
yieldValues.push_back(
|
|
||||||
newForOp
|
|
||||||
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
|
|
||||||
}
|
|
||||||
yieldValues.push_back(nextMapping.lookup(load));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
||||||
yieldValues.push_back(
|
yieldValues.push_back(
|
||||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||||
yieldValues.push_back(nextIV);
|
yieldValues.push_back(nextIV);
|
||||||
|
yieldValues.push_back(pipelineIterIdx);
|
||||||
|
yieldValues.push_back(loopIterIdx);
|
||||||
|
|
||||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||||
auto test = builder.create<scf::YieldOp>(
|
auto test = builder.create<scf::YieldOp>(
|
||||||
|
@@ -16,18 +16,16 @@
|
|||||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]]
|
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]]
|
||||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]]
|
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
||||||
// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]]
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}})
|
||||||
// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
|
||||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]]
|
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]]
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]]
|
||||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||||
@@ -66,18 +64,16 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]]
|
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]]
|
||||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]]
|
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
||||||
// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]]
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}})
|
||||||
// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
|
||||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]]
|
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]]
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]]
|
||||||
func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
scf.for %iv0 = %lb to %ub step %step {
|
scf.for %iv0 = %lb to %ub step %step {
|
||||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
@@ -114,14 +110,13 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
|||||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]]
|
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
||||||
// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}})
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
|
||||||
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
||||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[arg_b1]], %[[NEXT_B]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]]
|
||||||
func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||||
|
Reference in New Issue
Block a user