diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 8dd03be9e..1115fa200 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -15,6 +15,14 @@ using namespace mlir; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +static Type getI1SameShape(Value v) { + Type vType = v.getType(); + auto i1Type = IntegerType::get(vType.getContext(), 1); + auto tensorType = vType.cast(); + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); +} + namespace { class LoopPipeliner { /// cache forOp we are working on @@ -262,13 +270,23 @@ void LoopPipeliner::emitPrologue() { loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]}; } // load => copy async - // TODO: check if the hardware supports async copy if (auto loadOp = llvm::dyn_cast(op)) { + Value mask = lookupOrDefault(loadOp.mask(), stage); + Value newMask; + if (mask) { + Value splatCond = builder.create( + mask.getLoc(), mask.getType(), loopCond); + newMask = + builder.create(mask.getLoc(), mask, splatCond); + } else { + newMask = builder.create( + loopCond.getLoc(), getI1SameShape(loadOp), loopCond); + } + // TODO: check if the hardware supports async copy newOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), lookupOrDefault(loadOp.ptr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, - lookupOrDefault(loadOp.mask(), stage), + loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, lookupOrDefault(loadOp.other(), stage), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); loadStageBuffer[loadOp].push_back(newOp->getResult(0)); @@ -287,33 +305,6 @@ void LoopPipeliner::emitPrologue() { } } - // If this is a load/async_copy, we need to update the mask - if (Value mask = [&]() { - if (auto loadOp = llvm::dyn_cast(newOp)) { - return loadOp.mask(); - } else if (auto insertSliceAsyncOp = - llvm::dyn_cast( - newOp)) { - return insertSliceAsyncOp.mask(); - } else { - return mlir::Value(); - } - }()) { - // assert(I1 or TensorOf<[I1]>); - OpBuilder::InsertionGuard g(builder); - // TODO: move this out of the loop - builder.setInsertionPoint(newOp); - Value splatCond = builder.create( - mask.getLoc(), mask.getType(), loopCond); - Value newMask = - builder.create(mask.getLoc(), mask, splatCond); - // TODO: better way to do this? - if (llvm::isa(newOp)) - newOp->setOperand(1, newMask); - else // InsertSliceAsyncOp - newOp->setOperand(3, newMask); - } - // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { Value originalResult = op->getResult(dstIdx); @@ -332,7 +323,7 @@ void LoopPipeliner::emitPrologue() { newOp->getResult(dstIdx), stage + 1); } } - } + } // for (Operation *op : orderedDeps) pipelineIterIdx = builder.create( iv.getLoc(), pipelineIterIdx, @@ -490,26 +481,29 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; - // TODO(da): does this work if loadOp has no mask? // update loading mask if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); + Value newMask; if (mask) { Value splatCond = builder.create( mask.getLoc(), mask.getType(), nextLoopCond); - Value newMask = builder.create( + newMask = builder.create( mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); // if mask is defined outside the loop, don't update the map more than // once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); - } + newMask = nextMapping.lookupOrDefault(loadOp.mask()); + } else + newMask = builder.create( + loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond); Value insertAsyncOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), nextMapping.lookupOrDefault(loadOp.ptr()), newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()), + insertSliceIndex, newMask, nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); nextBuffers.push_back(insertAsyncOp); diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 731916e8f..fe9a10e27 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -13,12 +13,19 @@ // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 // CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] -// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] -// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] +// CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] +// CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0] @@ -49,7 +56,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>