[BACKEND] add dot conversion (mma version=2) (#672)

LLVM Conversion for Dot op.

Due to the lack of `convert_layout`, currently, the dot only supports
the following combination of operands

- `$a` in shared layout
- `$b` in shared layout
- `$c` in MMA layout(but only Splat-like, leaving the generic cases to
`convert_layout`)

This PR focus on `mma.16816` related logic support, leaving the other
cases to the following PR.

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Yan Chunwei
2022-09-23 11:43:54 +08:00
committed by GitHub
parent 23f424c660
commit 922155f1d2
7 changed files with 1033 additions and 116 deletions

View File

@@ -10,14 +10,10 @@ namespace triton {
namespace type {
// Integer types
Type i32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signed);
}
Type i8Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 8, IntegerType::Signed);
}
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signless);
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
@@ -27,6 +23,7 @@ Type u1Ty(MLIRContext *ctx) {
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();

View File

@@ -43,6 +43,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return 0;
}
};
// blocked -> blocked
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<BlockedEncodingAttr>()) {
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
@@ -65,6 +66,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
paddedRepShape[outOrd[0]] += pad;
}
// blocked -> shared
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<SharedEncodingAttr>()) {
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
for (int v : dstTy.getShape())
paddedRepShape.push_back(v);
}
return paddedRepShape;
}
@@ -131,9 +140,8 @@ private:
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (srcEncoding.isa<SharedEncodingAttr>() ||
dstEncoding.isa<SharedEncodingAttr>()) {
// Only blocked -> blocked conversion requires for scratch allocation
if (srcEncoding.isa<SharedEncodingAttr>()) {
// only block->block and block->shared is supported now
return;
}
// ConvertLayoutOp with both input/output non-shared_layout

File diff suppressed because it is too large Load Diff

View File

@@ -177,9 +177,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
// TODO:
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
return 0;
int threads = product(getWarpsPerCTA());
int numElem = product(shape);
return numElem / threads;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

View File

@@ -66,7 +66,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
// maxntid
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
hasMetadata = true;
}

View File

@@ -22,9 +22,11 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: offset = 0, size = 8192
// CHECK: scratch offset = 8192, size = 0
// CHECK-NEXT: offset = 0, size = 8192
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
// CHECK-NEXT: scratch offset = 16384, size = 0
// CHECK-NEXT: offset = 8192, size = 8192
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
@@ -50,16 +52,20 @@ func @reusable(%A : !tt.ptr<f16>) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: offset = 0, size = 8192
// CHECK: scratch offset = 8192, size = 0
// CHECK-NEXT: offset = 0, size = 8192
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: scratch offset = 16384, size = 0
// CHECK-NEXT: offset = 8192, size = 8192
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: scratch offset = 24576, size = 0
// CHECK-NEXT: offset = 16384, size = 8192
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: scratch offset = 8192, size = 0
// CHECK-NEXT: offset = 0, size = 8192
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>

View File

@@ -3,7 +3,7 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
@@ -422,6 +422,33 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot
func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
return
}
}
// TODO: problems in MLIR's parser on slice layout
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
@@ -429,4 +456,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
// return
// }
// }
// }