[Triton-MLIR][BACKEND] insert_slice_async on GPUs < sm80 (#908)

`insert_slice_async` is decomposed into `load + insert_slice` in the
backend.

Not sure if V100 perf can match the master branch though in this way.
Maybe the performance can be improved if instructions are arranged in
the following form:

```
%0 = load
%1 = load 
%2 = load 
...
insert_slice %0
insert_slice %1
insert_slice %2
```

Tested on A100 when manually enabling this decomposition.
Tests on V100 haven't been integrated yet, we can divide the tests into
two phases:
1. Test only load, insert_slice, and insert_slice_async, given TritonGPU
IRs in `test_backend.py`.
2. End to end gemm tests on V100.
This commit is contained in:
Keren Zhou
2022-11-24 14:05:54 -08:00
committed by GitHub
parent f98aed1258
commit 153aecb339
16 changed files with 351 additions and 137 deletions

View File

@@ -28,7 +28,7 @@ namespace mlir {
namespace triton {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
@@ -155,8 +155,7 @@ private:
// For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op) || isa<tensor::ExtractSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
return;
}
@@ -210,9 +209,9 @@ private:
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = srcTy.getElementType().isa<triton::PointerType>()?
elems * kPtrBitWidth / 8 :
elems * srcTy.getElementTypeBitWidth() / 8;
auto bytes = srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * srcTy.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);