[BACKEND] add dot conversion (mma version=2) (#672)
LLVM Conversion for Dot op. Due to the lack of `convert_layout`, currently, the dot only supports the following combination of operands - `$a` in shared layout - `$b` in shared layout - `$c` in MMA layout(but only Splat-like, leaving the generic cases to `convert_layout`) This PR focus on `mma.16816` related logic support, leaving the other cases to the following PR. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -10,14 +10,10 @@ namespace triton {
|
||||
namespace type {
|
||||
|
||||
// Integer types
|
||||
Type i32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Signed);
|
||||
}
|
||||
Type i8Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 8, IntegerType::Signed);
|
||||
}
|
||||
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
|
||||
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
|
||||
Type u32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Signless);
|
||||
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||
}
|
||||
Type u1Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
||||
@@ -27,6 +23,7 @@ Type u1Ty(MLIRContext *ctx) {
|
||||
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
|
||||
|
||||
static bool isFloat(Type type) {
|
||||
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
||||
|
@@ -43,6 +43,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
// blocked -> blocked
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
||||
@@ -65,6 +66,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
}
|
||||
paddedRepShape[outOrd[0]] += pad;
|
||||
}
|
||||
// blocked -> shared
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
dstLayout.isa<SharedEncodingAttr>()) {
|
||||
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
|
||||
for (int v : dstTy.getShape())
|
||||
paddedRepShape.push_back(v);
|
||||
}
|
||||
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
@@ -131,9 +140,8 @@ private:
|
||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
auto dstEncoding = dstTy.getEncoding();
|
||||
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<SharedEncodingAttr>()) {
|
||||
// Only blocked -> blocked conversion requires for scratch allocation
|
||||
if (srcEncoding.isa<SharedEncodingAttr>()) {
|
||||
// only block->block and block->shared is supported now
|
||||
return;
|
||||
}
|
||||
// ConvertLayoutOp with both input/output non-shared_layout
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -177,9 +177,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
int threads = product(getWarpsPerCTA());
|
||||
int numElem = product(shape);
|
||||
return numElem / threads;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
@@ -66,7 +66,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
// maxntid
|
||||
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
|
@@ -22,9 +22,11 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: offset = 0, size = 8192
|
||||
// CHECK: scratch offset = 8192, size = 0
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
// CHECK-NEXT: scratch offset = 16384, size = 0
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
@@ -50,16 +52,20 @@ func @reusable(%A : !tt.ptr<f16>) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: offset = 0, size = 8192
|
||||
// CHECK: scratch offset = 8192, size = 0
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: scratch offset = 16384, size = 0
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK-NEXT: scratch offset = 24576, size = 0
|
||||
// CHECK-NEXT: offset = 16384, size = 8192
|
||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: scratch offset = 8192, size = 0
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
@@ -3,7 +3,7 @@
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
@@ -422,6 +422,33 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||
#mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot
|
||||
func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: problems in MLIR's parser on slice layout
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -429,4 +456,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
Reference in New Issue
Block a user