[Triton-MLIR][BACKEND] Pass compute capability from the frontend and code cleanup (#961)
This commit is contained in:
@@ -32,6 +32,12 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
|||||||
let arguments = (ins I32Attr:$num);
|
let arguments = (ins I32Attr:$num);
|
||||||
|
|
||||||
let assemblyFormat = "attr-dict";
|
let assemblyFormat = "attr-dict";
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static bool isSupported(int computeCapability) {
|
||||||
|
return computeCapability >= 80;
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||||
@@ -152,7 +158,13 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
|||||||
//}];
|
//}];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
|
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||||
|
DenseSet<unsigned> validLoadBytes;
|
||||||
|
if (computeCapability >= 80) {
|
||||||
|
validLoadBytes = {4, 8, 16};
|
||||||
|
}
|
||||||
|
return validLoadBytes;
|
||||||
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// The custom parser could be replaced with oilist in LLVM-16
|
// The custom parser could be replaced with oilist in LLVM-16
|
||||||
|
@@ -4681,8 +4681,7 @@ private:
|
|||||||
// capability does not support async copy, then we do decompose
|
// capability does not support async copy, then we do decompose
|
||||||
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||||
computeCapability)
|
computeCapability)
|
||||||
.contains(byteWidth) &&
|
.contains(byteWidth))
|
||||||
computeCapability >= 80)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// load
|
// load
|
||||||
@@ -4716,13 +4715,8 @@ private:
|
|||||||
|
|
||||||
// async wait is supported in Ampere and later
|
// async wait is supported in Ampere and later
|
||||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||||
if (computeCapability < 80) {
|
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
|
||||||
asyncWaitOp.erase();
|
decomposed) {
|
||||||
} else if (decomposed) {
|
|
||||||
OpBuilder builder(asyncWaitOp);
|
|
||||||
// Wait for all previous async ops
|
|
||||||
auto newAsyncWaitOp = builder.create<triton::gpu::AsyncWaitOp>(
|
|
||||||
asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0));
|
|
||||||
asyncWaitOp.erase();
|
asyncWaitOp.erase();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@@ -659,15 +659,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
|||||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseSet<unsigned>
|
|
||||||
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
|
|
||||||
DenseSet<unsigned> validLoadBytes;
|
|
||||||
if (computeCapability >= 80) {
|
|
||||||
validLoadBytes = {4, 8, 16};
|
|
||||||
}
|
|
||||||
return validLoadBytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ASM Interface (i.e.: alias)
|
// ASM Interface (i.e.: alias)
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -134,7 +134,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
/*printAfterOnlyOnChange=*/true,
|
/*printAfterOnlyOnChange=*/true,
|
||||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||||
|
|
||||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
|
||||||
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||||
|
Reference in New Issue
Block a user