[OPTIMIZER] Fix the load-mask issue with the pipeline pass (#857)
This commit is contained in:
@@ -15,6 +15,14 @@ using namespace mlir;
|
|||||||
#define GEN_PASS_CLASSES
|
#define GEN_PASS_CLASSES
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
static Type getI1SameShape(Value v) {
|
||||||
|
Type vType = v.getType();
|
||||||
|
auto i1Type = IntegerType::get(vType.getContext(), 1);
|
||||||
|
auto tensorType = vType.cast<RankedTensorType>();
|
||||||
|
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||||
|
tensorType.getEncoding());
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LoopPipeliner {
|
class LoopPipeliner {
|
||||||
/// cache forOp we are working on
|
/// cache forOp we are working on
|
||||||
@@ -262,13 +270,23 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
|
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
|
||||||
}
|
}
|
||||||
// load => copy async
|
// load => copy async
|
||||||
// TODO: check if the hardware supports async copy
|
|
||||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||||
|
Value mask = lookupOrDefault(loadOp.mask(), stage);
|
||||||
|
Value newMask;
|
||||||
|
if (mask) {
|
||||||
|
Value splatCond = builder.create<triton::SplatOp>(
|
||||||
|
mask.getLoc(), mask.getType(), loopCond);
|
||||||
|
newMask =
|
||||||
|
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||||
|
} else {
|
||||||
|
newMask = builder.create<triton::SplatOp>(
|
||||||
|
loopCond.getLoc(), getI1SameShape(loadOp), loopCond);
|
||||||
|
}
|
||||||
|
// TODO: check if the hardware supports async copy
|
||||||
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||||
lookupOrDefault(loadOp.ptr(), stage),
|
lookupOrDefault(loadOp.ptr(), stage),
|
||||||
loadStageBuffer[loadOp][stage], pipelineIterIdx,
|
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
|
||||||
lookupOrDefault(loadOp.mask(), stage),
|
|
||||||
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
||||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||||
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
|
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
|
||||||
@@ -287,33 +305,6 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is a load/async_copy, we need to update the mask
|
|
||||||
if (Value mask = [&]() {
|
|
||||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
|
|
||||||
return loadOp.mask();
|
|
||||||
} else if (auto insertSliceAsyncOp =
|
|
||||||
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
|
|
||||||
newOp)) {
|
|
||||||
return insertSliceAsyncOp.mask();
|
|
||||||
} else {
|
|
||||||
return mlir::Value();
|
|
||||||
}
|
|
||||||
}()) {
|
|
||||||
// assert(I1 or TensorOf<[I1]>);
|
|
||||||
OpBuilder::InsertionGuard g(builder);
|
|
||||||
// TODO: move this out of the loop
|
|
||||||
builder.setInsertionPoint(newOp);
|
|
||||||
Value splatCond = builder.create<triton::SplatOp>(
|
|
||||||
mask.getLoc(), mask.getType(), loopCond);
|
|
||||||
Value newMask =
|
|
||||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
|
||||||
// TODO: better way to do this?
|
|
||||||
if (llvm::isa<triton::LoadOp>(newOp))
|
|
||||||
newOp->setOperand(1, newMask);
|
|
||||||
else // InsertSliceAsyncOp
|
|
||||||
newOp->setOperand(3, newMask);
|
|
||||||
}
|
|
||||||
|
|
||||||
// update mapping of results
|
// update mapping of results
|
||||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||||
Value originalResult = op->getResult(dstIdx);
|
Value originalResult = op->getResult(dstIdx);
|
||||||
@@ -332,7 +323,7 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
newOp->getResult(dstIdx), stage + 1);
|
newOp->getResult(dstIdx), stage + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} // for (Operation *op : orderedDeps)
|
||||||
|
|
||||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||||
iv.getLoc(), pipelineIterIdx,
|
iv.getLoc(), pipelineIterIdx,
|
||||||
@@ -490,26 +481,29 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
|
|
||||||
for (Operation *op : orderedDeps) {
|
for (Operation *op : orderedDeps) {
|
||||||
Operation *nextOp = nullptr;
|
Operation *nextOp = nullptr;
|
||||||
// TODO(da): does this work if loadOp has no mask?
|
|
||||||
// update loading mask
|
// update loading mask
|
||||||
if (loads.contains(op->getResult(0))) {
|
if (loads.contains(op->getResult(0))) {
|
||||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||||
Value mask = loadOp.mask();
|
Value mask = loadOp.mask();
|
||||||
|
Value newMask;
|
||||||
if (mask) {
|
if (mask) {
|
||||||
Value splatCond = builder.create<triton::SplatOp>(
|
Value splatCond = builder.create<triton::SplatOp>(
|
||||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||||
Value newMask = builder.create<arith::AndIOp>(
|
newMask = builder.create<arith::AndIOp>(
|
||||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||||
// if mask is defined outside the loop, don't update the map more than
|
// if mask is defined outside the loop, don't update the map more than
|
||||||
// once
|
// once
|
||||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||||
nextMapping.map(mask, newMask);
|
nextMapping.map(mask, newMask);
|
||||||
}
|
newMask = nextMapping.lookupOrDefault(loadOp.mask());
|
||||||
|
} else
|
||||||
|
newMask = builder.create<triton::SplatOp>(
|
||||||
|
loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond);
|
||||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||||
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
|
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
|
||||||
insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
|
insertSliceIndex, newMask,
|
||||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||||
nextBuffers.push_back(insertAsyncOp);
|
nextBuffers.push_back(insertAsyncOp);
|
||||||
|
@@ -13,12 +13,19 @@
|
|||||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||||
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
||||||
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
|
// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32
|
||||||
|
// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
|
||||||
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
|
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
|
||||||
// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]]
|
||||||
|
// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]]
|
||||||
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
|
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
|
||||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
// CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]]
|
||||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]]
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]]
|
||||||
|
// CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]]
|
||||||
|
// CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]]
|
||||||
|
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_A]]
|
||||||
|
// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]]
|
||||||
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
|
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||||
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||||
@@ -49,7 +56,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
Reference in New Issue
Block a user