[BACKEND] add dot conversion (mma version=2) (#672)

LLVM Conversion for Dot op.

Due to the lack of `convert_layout`, currently, the dot only supports
the following combination of operands

- `$a` in shared layout
- `$b` in shared layout
- `$c` in MMA layout(but only Splat-like, leaving the generic cases to
`convert_layout`)

This PR focus on `mma.16816` related logic support, leaving the other
cases to the following PR.

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Yan Chunwei
2022-09-23 11:43:54 +08:00
committed by GitHub
parent 23f424c660
commit 922155f1d2
7 changed files with 1033 additions and 116 deletions

View File

@@ -10,14 +10,10 @@ namespace triton {
namespace type {
// Integer types
Type i32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signed);
}
Type i8Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 8, IntegerType::Signed);
}
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signless);
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
@@ -27,6 +23,7 @@ Type u1Ty(MLIRContext *ctx) {
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();