[Triton-MLIR][BACKEND] insert_slice_async on GPUs < sm80 (#908)

`insert_slice_async` is decomposed into `load + insert_slice` in the
backend.

Not sure if V100 perf can match the master branch though in this way.
Maybe the performance can be improved if instructions are arranged in
the following form:

```
%0 = load
%1 = load 
%2 = load 
...
insert_slice %0
insert_slice %1
insert_slice %2
```

Tested on A100 when manually enabling this decomposition.
Tests on V100 haven't been integrated yet, we can divide the tests into
two phases:
1. Test only load, insert_slice, and insert_slice_async, given TritonGPU
IRs in `test_backend.py`.
2. End to end gemm tests on V100.
This commit is contained in:
Keren Zhou
2022-11-24 14:05:54 -08:00
committed by GitHub
parent f98aed1258
commit 153aecb339
16 changed files with 351 additions and 137 deletions

View File

@@ -12,6 +12,8 @@ bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);
std::string getValueOperandName(Value value, AsmState &state);
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {

View File

@@ -43,6 +43,12 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect",
"mlir::StandardOpsDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
#endif

View File

@@ -33,7 +33,8 @@ struct NVVMMetadataField {
static constexpr char Kernel[] = "nvvm.kernel";
};
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
} // namespace triton

View File

@@ -25,7 +25,8 @@ void addExternalLibs(mlir::ModuleOp &module,
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module);
mlir::ModuleOp module,
int computeCapability);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>