[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

View File

@@ -45,7 +45,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
return argArchive.back().get();
}
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
std::stringstream ss;
ss << "0x" << std::hex << v;
return newConstantOperand(ss.str());
@@ -130,8 +130,18 @@ std::string PTXBuilder::dump() const {
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
if (onlyAttachMLIRArgs) {
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
// the same MLIR values in onlyAttachMLIRArgs mode.
assert(builder->executions.empty() &&
"builder can only hold a single execution when onlyAttachMIIRArgs "
"is true.");
builder->reorderArgArchive(oprs);
}
builder->executions.emplace_back(
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
return *builder->executions.back();
}

View File

@@ -64,9 +64,8 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
rewriter.getF32FloatAttr(v));
}
// Create a index type constant.
// Create an index type constant.
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
@@ -127,6 +126,7 @@ void llPrintf(StringRef msg, ValueRange args,
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
@@ -339,7 +339,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
for (const auto& v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
@@ -699,7 +700,7 @@ public:
// [elemsPerThread X rank] index matrix.
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization. We might
// implement a indiceCache if necessary.
// implement an indexCache if necessary.
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout,
@@ -953,7 +954,8 @@ struct LoadOpConversion
// Determine the vectorization size
Type valueTy = op.getResult().getType();
Type valueElemTy = getElementTypeOrSelf(valueTy);
Type valueElemTy = typeConverter->convertType(
getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
if (llMask)
@@ -1147,7 +1149,8 @@ struct StoreOpConversion
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType();
Type valueElemTy = getElementTypeOrSelf(valueTy);
Type valueElemTy = typeConverter->convertType(
getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
@@ -1734,8 +1737,8 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly
// rewriter.replaceOp(op, adaptor.src());
// We cannot directly run
// `rewriter.replaceOp(op, adaptor.src())`
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
@@ -2096,6 +2099,330 @@ struct ExtractSliceOpConversion
}
};
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
static SmallVector<Value> convertFp8x4ToFp16x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
rewriter.getI32ArrayAttr({1}));
return {
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
};
}
static SmallVector<Value> convertFp16x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
}
static SmallVector<Value> convertFp8x4ToBf16x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
auto bf16x2x2StructTy =
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
rewriter.getI32ArrayAttr({1}));
return {
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
};
}
static SmallVector<Value> convertBf16x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
"mov.u32 fp8_min, 0x38003800; \n"
"mov.u32 fp8_max, 0x3ff03ff0; \n"
"mov.u32 rn_, 0x80008; \n"
"mov.u32 zero, 0; \n"
"and.b32 sign0, $1, 0x80008000; \n"
"and.b32 sign1, $2, 0x80008000; \n"
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n"
"and.b32 nosign1, $2, 0x7fff7fff; \n"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n"
"add.u32 nosign1, nosign1, rn_; \n"
"sub.u32 nosign0, nosign0, 0x38003800; \n"
"sub.u32 nosign1, nosign1, 0x38003800; \n"
"shr.u32 nosign0, nosign0, 4; \n"
"shr.u32 nosign1, nosign1, 4; \n"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
"or.b32 $0, nosign, sign; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
}
static SmallVector<Value> convertFp8x4ToFp32x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
};
}
static SmallVector<Value> convertFp32x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value> convertFp8x4ToFp64x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
};
}
static SmallVector<Value> convertFp64x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
assert(srcEltType.isa<triton::Float8Type>() ||
dstEltType.isa<triton::Float8Type>());
auto convertedDstTensorType =
this->getTypeConverter()->convertType(dstTensorType);
auto convertedDstEleType =
this->getTypeConverter()->convertType(dstEltType);
// Select convertor
std::function<SmallVector<Value>(Location, ConversionPatternRewriter&,
const Value&, const Value&,
const Value&, const Value&)> convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
convertor = convertFp8x4ToBf16x4;
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertBf16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
convertor = convertFp8x4ToFp32x4;
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp32x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
convertor = convertFp8x4ToFp64x4;
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp64x4ToFp8x4;
} else {
assert(false && "unsupported type casting");
}
// Vectorized casting
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
SmallVector<Value> resultVals;
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter,
elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
assert(resultVals.size() == elems);
auto result = getStructFromElements(loc, resultVals, rewriter,
convertedDstTensorType);
rewriter.replaceOp(op, result);
return success();
}
};
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
class ElementwiseOpConversionBase
@@ -2309,7 +2636,7 @@ private:
Value smemBase) const;
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank confict.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
@@ -3265,7 +3592,7 @@ struct DotOpMmaV1ConversionHelper {
int NK = shape[1];
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
// NOTE We cound't get the vec from the shared layout.
// NOTE: We couldn't get the vec from the shared layout.
// int vecA = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
@@ -3283,7 +3610,7 @@ struct DotOpMmaV1ConversionHelper {
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
// NOTE We cound't get the vec from the shared layout.
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
@@ -3387,7 +3714,7 @@ struct DotOpMmaV2ConversionHelper {
return Type{};
}
// The type of a matrix that loaded by either a ldmatrix or composed lds.
// The type of matrix that loaded by either a ldmatrix or composed lds.
Type getMatType() const {
Type fp32Ty = type::f32Ty(ctx);
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
@@ -3583,7 +3910,7 @@ private:
"mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
// vector length per ldmatrix (16*8/element_size_in_bits)
inline static const std::map<TensorCoreType, uint8_t> mmaInstrVec = {
{TensorCoreType::FP32_FP16_FP16_FP32, 8},
{TensorCoreType::FP32_BF16_BF16_FP32, 8},
@@ -3723,7 +4050,7 @@ struct MMA16816ConversionHelper {
// load from smem
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
@@ -3754,7 +4081,7 @@ struct MMA16816ConversionHelper {
auto loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
@@ -4713,14 +5040,15 @@ public:
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
// Internally store float8 as int8
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
}
Type convertTritonPointerType(triton::PointerType type) {
return LLVM::LLVMPointerType::get(type.getPointeeType(),
// Recursively translate pointee type
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
@@ -4753,7 +5081,7 @@ public:
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
size_t fcSize = 4 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fcSize, type.getElementType()));
ctx, SmallVector<Type>(fcSize, convertType(type.getElementType())));
}
if (mmaLayout.getVersion() == 1) {
@@ -4762,7 +5090,7 @@ public:
int repN = helper.getRepN(shape[1]);
int elems = 8 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, type.getElementType()));
ctx, SmallVector<Type>(elems, convertType(type.getElementType())));
}
llvm::errs()
@@ -4773,7 +5101,7 @@ public:
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = type.getElementType();
Type elemTy = convertType(type.getElementType());
auto vecSize = 1;
if (elemTy.getIntOrFloatBitWidth() == 16) {
vecSize = 2;
@@ -5324,6 +5652,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
@@ -5369,12 +5699,12 @@ public:
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Allocate shared memories and insert barriers
// setp 2: Convert SCF to CFG
// step 2: Convert SCF to CFG
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
// step 4: Convert the rest of ops via partial conversion
// The reason for putting step 1 before step 2 is that the membar analysis
// currently only supports SCF but not CFG.
// The reason for a seperation between 1/4 is that, step 3 is out of
// The reason for a separation between 1/4 is that, step 3 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem
// is not revised during the conversion of step 4.
Allocation allocation(mod);