diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 1fe76624d..452a49e68 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -32,6 +32,12 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { let arguments = (ins I32Attr:$num); let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; } // Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU. @@ -152,7 +158,13 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", //}]; let extraClassDeclaration = [{ - static DenseSet getEligibleLoadByteWidth(int computeCapability); + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } }]; // The custom parser could be replaced with oilist in LLVM-16 diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index dc32f7825..534ea9b01 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4681,8 +4681,7 @@ private: // capability does not support async copy, then we do decompose if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth( computeCapability) - .contains(byteWidth) && - computeCapability >= 80) + .contains(byteWidth)) return; // load @@ -4716,13 +4715,8 @@ private: // async wait is supported in Ampere and later mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { - if (computeCapability < 80) { - asyncWaitOp.erase(); - } else if (decomposed) { - OpBuilder builder(asyncWaitOp); - // Wait for all previous async ops - auto newAsyncWaitOp = builder.create( - asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0)); + if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) || + decomposed) { asyncWaitOp.erase(); } }); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3592c52d4..2649be1f0 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -659,15 +659,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer, printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); } -DenseSet -InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) { - DenseSet validLoadBytes; - if (computeCapability >= 80) { - validLoadBytes = {4, 8, 16}; - } - return validLoadBytes; -} - //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 4ab7876d6..307099584 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -134,7 +134,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); - pm.addPass(createConvertTritonGPUToLLVMPass()); + pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability)); // Canonicalize to eliminate the remaining UnrealizedConversionCastOp pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.