[Triton-MLIR][BACKEND] Pass compute capability from the frontend and code cleanup (#961)

This commit is contained in:
Keren Zhou
2022-12-07 15:03:46 -08:00
committed by GitHub
parent 4eab9dcedf
commit 18e683d9bb
4 changed files with 17 additions and 20 deletions

View File

@@ -32,6 +32,12 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let arguments = (ins I32Attr:$num); let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict"; let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 80;
}
}];
} }
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU. // Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
@@ -152,7 +158,13 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
//}]; //}];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability); static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}]; }];
// The custom parser could be replaced with oilist in LLVM-16 // The custom parser could be replaced with oilist in LLVM-16

View File

@@ -4681,8 +4681,7 @@ private:
// capability does not support async copy, then we do decompose // capability does not support async copy, then we do decompose
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth( if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
computeCapability) computeCapability)
.contains(byteWidth) && .contains(byteWidth))
computeCapability >= 80)
return; return;
// load // load
@@ -4716,13 +4715,8 @@ private:
// async wait is supported in Ampere and later // async wait is supported in Ampere and later
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (computeCapability < 80) { if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
asyncWaitOp.erase(); decomposed) {
} else if (decomposed) {
OpBuilder builder(asyncWaitOp);
// Wait for all previous async ops
auto newAsyncWaitOp = builder.create<triton::gpu::AsyncWaitOp>(
asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0));
asyncWaitOp.erase(); asyncWaitOp.erase();
} }
}); });

View File

@@ -659,15 +659,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
} }
DenseSet<unsigned>
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias) // ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -134,7 +134,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(createConvertTritonGPUToLLVMPass()); pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp // Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability. pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.