[BACKEND] add dot conversion (mma version=2) (#672)
LLVM Conversion for Dot op. Due to the lack of `convert_layout`, currently, the dot only supports the following combination of operands - `$a` in shared layout - `$b` in shared layout - `$c` in MMA layout(but only Splat-like, leaving the generic cases to `convert_layout`) This PR focus on `mma.16816` related logic support, leaving the other cases to the following PR. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -10,14 +10,10 @@ namespace triton {
|
|||||||
namespace type {
|
namespace type {
|
||||||
|
|
||||||
// Integer types
|
// Integer types
|
||||||
Type i32Ty(MLIRContext *ctx) {
|
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
|
||||||
return IntegerType::get(ctx, 32, IntegerType::Signed);
|
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
|
||||||
}
|
|
||||||
Type i8Ty(MLIRContext *ctx) {
|
|
||||||
return IntegerType::get(ctx, 8, IntegerType::Signed);
|
|
||||||
}
|
|
||||||
Type u32Ty(MLIRContext *ctx) {
|
Type u32Ty(MLIRContext *ctx) {
|
||||||
return IntegerType::get(ctx, 32, IntegerType::Signless);
|
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||||
}
|
}
|
||||||
Type u1Ty(MLIRContext *ctx) {
|
Type u1Ty(MLIRContext *ctx) {
|
||||||
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
||||||
@@ -27,6 +23,7 @@ Type u1Ty(MLIRContext *ctx) {
|
|||||||
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||||
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||||
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||||
|
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
|
||||||
|
|
||||||
static bool isFloat(Type type) {
|
static bool isFloat(Type type) {
|
||||||
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
||||||
|
@@ -43,6 +43,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// blocked -> blocked
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
dstLayout.isa<BlockedEncodingAttr>()) {
|
||||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
||||||
@@ -65,6 +66,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
}
|
}
|
||||||
paddedRepShape[outOrd[0]] += pad;
|
paddedRepShape[outOrd[0]] += pad;
|
||||||
}
|
}
|
||||||
|
// blocked -> shared
|
||||||
|
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||||
|
dstLayout.isa<SharedEncodingAttr>()) {
|
||||||
|
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
|
||||||
|
for (int v : dstTy.getShape())
|
||||||
|
paddedRepShape.push_back(v);
|
||||||
|
}
|
||||||
|
|
||||||
return paddedRepShape;
|
return paddedRepShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,9 +140,8 @@ private:
|
|||||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||||
auto srcEncoding = srcTy.getEncoding();
|
auto srcEncoding = srcTy.getEncoding();
|
||||||
auto dstEncoding = dstTy.getEncoding();
|
auto dstEncoding = dstTy.getEncoding();
|
||||||
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
if (srcEncoding.isa<SharedEncodingAttr>()) {
|
||||||
dstEncoding.isa<SharedEncodingAttr>()) {
|
// only block->block and block->shared is supported now
|
||||||
// Only blocked -> blocked conversion requires for scratch allocation
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// ConvertLayoutOp with both input/output non-shared_layout
|
// ConvertLayoutOp with both input/output non-shared_layout
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -177,9 +177,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
// TODO:
|
int threads = product(getWarpsPerCTA());
|
||||||
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
|
int numElem = product(shape);
|
||||||
return 0;
|
return numElem / threads;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
|
@@ -66,7 +66,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|||||||
// maxntid
|
// maxntid
|
||||||
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||||
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
|
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||||
hasMetadata = true;
|
hasMetadata = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -22,9 +22,11 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: offset = 0, size = 8192
|
// CHECK: scratch offset = 8192, size = 0
|
||||||
|
// CHECK-NEXT: offset = 0, size = 8192
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
|
// CHECK-NEXT: scratch offset = 16384, size = 0
|
||||||
// CHECK-NEXT: offset = 8192, size = 8192
|
// CHECK-NEXT: offset = 8192, size = 8192
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
|
||||||
@@ -50,16 +52,20 @@ func @reusable(%A : !tt.ptr<f16>) {
|
|||||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: offset = 0, size = 8192
|
// CHECK: scratch offset = 8192, size = 0
|
||||||
|
// CHECK-NEXT: offset = 0, size = 8192
|
||||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
|
// CHECK-NEXT: scratch offset = 16384, size = 0
|
||||||
// CHECK-NEXT: offset = 8192, size = 8192
|
// CHECK-NEXT: offset = 8192, size = 8192
|
||||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
|
// CHECK-NEXT: scratch offset = 24576, size = 0
|
||||||
// CHECK-NEXT: offset = 16384, size = 8192
|
// CHECK-NEXT: offset = 16384, size = 8192
|
||||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
|
// CHECK-NEXT: scratch offset = 8192, size = 0
|
||||||
// CHECK-NEXT: offset = 0, size = 8192
|
// CHECK-NEXT: offset = 0, size = 8192
|
||||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
|
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
|
||||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
return
|
return
|
||||||
@@ -422,6 +422,33 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||||
|
#mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: convert_dot
|
||||||
|
func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||||
|
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||||
|
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||||
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||||
|
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||||
|
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// TODO: problems in MLIR's parser on slice layout
|
// TODO: problems in MLIR's parser on slice layout
|
||||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
Reference in New Issue
Block a user