[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

14
.gitignore vendored
View File

@@ -1,12 +1,20 @@
# Triton builds
build/
__pycache__
.pytest_cache
# Triton Python module builds
python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
# Python caches
__pycache__
.pytest_cache
# VS Code project files
.vscode
.vs
# JetBrains project files
.idea
cmake-build-*

View File

@@ -22,8 +22,8 @@ struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
//
// A helper for building a ASM program, the objective of PTXBuilder is to give a
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// A helper for building an ASM program, the objective of PTXBuilder is to give
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// Currently, several factors are introduced to reduce the need for mixing
// string and C++ if-else code.
//
@@ -147,7 +147,7 @@ struct PTXBuilder {
Operand *newOperand(StringRef constraint);
// Create a constant integer operand.
Operand *newConstantOperand(int v);
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.
Operand *newConstantOperand(const std::string &v);
@@ -172,6 +172,22 @@ private:
return argArchive.back().get();
}
// Make the oprands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
// determined by PTX code snippet passed from external.
sort(argArchive.begin(), argArchive.end(),
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
auto ida = std::find(order.begin(), order.end(), a.get());
auto idb = std::find(order.begin(), order.end(), b.get());
assert(ida != order.end());
assert(idb != order.end());
return ida < idb;
});
}
friend struct PTXInstr;
friend struct PTXInstrCommon;

View File

@@ -10,6 +10,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
@@ -72,17 +73,16 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
// TODO: Add verifier
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Floating point casting for custom types";
let description = [{
Floating point casting for custom types (F8, BF8).
Floating point casting for custom types (F8).
F8 <-> BF8, FP16, FP32
BF8 <-> F8, FP16, FP32
F8 <-> FP16, BF16, FP32, FP64
}];
let arguments = (ins TT_FloatLike:$from);

View File

@@ -14,9 +14,8 @@ class TritonTypeDef<string name, string _mnemonic>
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

View File

@@ -45,7 +45,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
return argArchive.back().get();
}
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
std::stringstream ss;
ss << "0x" << std::hex << v;
return newConstantOperand(ss.str());
@@ -130,8 +130,18 @@ std::string PTXBuilder::dump() const {
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
if (onlyAttachMLIRArgs) {
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
// the same MLIR values in onlyAttachMLIRArgs mode.
assert(builder->executions.empty() &&
"builder can only hold a single execution when onlyAttachMIIRArgs "
"is true.");
builder->reorderArgArchive(oprs);
}
builder->executions.emplace_back(
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
return *builder->executions.back();
}

View File

@@ -64,9 +64,8 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
rewriter.getF32FloatAttr(v));
}
// Create a index type constant.
// Create an index type constant.
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
@@ -127,6 +126,7 @@ void llPrintf(StringRef msg, ValueRange args,
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
@@ -339,7 +339,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
for (const auto& v : llvm::enumerate(resultVals)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
@@ -699,7 +700,7 @@ public:
// [elemsPerThread X rank] index matrix.
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization. We might
// implement a indiceCache if necessary.
// implement an indexCache if necessary.
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout,
@@ -953,7 +954,8 @@ struct LoadOpConversion
// Determine the vectorization size
Type valueTy = op.getResult().getType();
Type valueElemTy = getElementTypeOrSelf(valueTy);
Type valueElemTy = typeConverter->convertType(
getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
if (llMask)
@@ -1147,7 +1149,8 @@ struct StoreOpConversion
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType();
Type valueElemTy = getElementTypeOrSelf(valueTy);
Type valueElemTy = typeConverter->convertType(
getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
@@ -1734,8 +1737,8 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly
// rewriter.replaceOp(op, adaptor.src());
// We cannot directly run
// `rewriter.replaceOp(op, adaptor.src())`
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
@@ -2096,6 +2099,330 @@ struct ExtractSliceOpConversion
}
};
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
static SmallVector<Value> convertFp8x4ToFp16x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
rewriter.getI32ArrayAttr({1}));
return {
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
};
}
static SmallVector<Value> convertFp16x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
}
static SmallVector<Value> convertFp8x4ToBf16x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
auto bf16x2x2StructTy =
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
rewriter.getI32ArrayAttr({1}));
return {
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
};
}
static SmallVector<Value> convertBf16x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
"mov.u32 fp8_min, 0x38003800; \n"
"mov.u32 fp8_max, 0x3ff03ff0; \n"
"mov.u32 rn_, 0x80008; \n"
"mov.u32 zero, 0; \n"
"and.b32 sign0, $1, 0x80008000; \n"
"and.b32 sign1, $2, 0x80008000; \n"
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n"
"and.b32 nosign1, $2, 0x7fff7fff; \n"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n"
"add.u32 nosign1, nosign1, rn_; \n"
"sub.u32 nosign0, nosign0, 0x38003800; \n"
"sub.u32 nosign1, nosign1, 0x38003800; \n"
"shr.u32 nosign0, nosign0, 4; \n"
"shr.u32 nosign1, nosign1, 4; \n"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
"or.b32 $0, nosign, sign; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
}
static SmallVector<Value> convertFp8x4ToFp32x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
};
}
static SmallVector<Value> convertFp32x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value> convertFp8x4ToFp64x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
};
}
static SmallVector<Value> convertFp64x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
assert(srcEltType.isa<triton::Float8Type>() ||
dstEltType.isa<triton::Float8Type>());
auto convertedDstTensorType =
this->getTypeConverter()->convertType(dstTensorType);
auto convertedDstEleType =
this->getTypeConverter()->convertType(dstEltType);
// Select convertor
std::function<SmallVector<Value>(Location, ConversionPatternRewriter&,
const Value&, const Value&,
const Value&, const Value&)> convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
convertor = convertFp8x4ToBf16x4;
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertBf16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
convertor = convertFp8x4ToFp32x4;
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp32x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
convertor = convertFp8x4ToFp64x4;
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp64x4ToFp8x4;
} else {
assert(false && "unsupported type casting");
}
// Vectorized casting
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
SmallVector<Value> resultVals;
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter,
elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
assert(resultVals.size() == elems);
auto result = getStructFromElements(loc, resultVals, rewriter,
convertedDstTensorType);
rewriter.replaceOp(op, result);
return success();
}
};
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
class ElementwiseOpConversionBase
@@ -2309,7 +2636,7 @@ private:
Value smemBase) const;
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank confict.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
@@ -3265,7 +3592,7 @@ struct DotOpMmaV1ConversionHelper {
int NK = shape[1];
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
// NOTE We cound't get the vec from the shared layout.
// NOTE: We couldn't get the vec from the shared layout.
// int vecA = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
@@ -3283,7 +3610,7 @@ struct DotOpMmaV1ConversionHelper {
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
// NOTE We cound't get the vec from the shared layout.
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
@@ -3387,7 +3714,7 @@ struct DotOpMmaV2ConversionHelper {
return Type{};
}
// The type of a matrix that loaded by either a ldmatrix or composed lds.
// The type of matrix that loaded by either a ldmatrix or composed lds.
Type getMatType() const {
Type fp32Ty = type::f32Ty(ctx);
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
@@ -3583,7 +3910,7 @@ private:
"mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
// vector length per ldmatrix (16*8/element_size_in_bits)
inline static const std::map<TensorCoreType, uint8_t> mmaInstrVec = {
{TensorCoreType::FP32_FP16_FP16_FP32, 8},
{TensorCoreType::FP32_BF16_BF16_FP32, 8},
@@ -3723,7 +4050,7 @@ struct MMA16816ConversionHelper {
// load from smem
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
@@ -3754,7 +4081,7 @@ struct MMA16816ConversionHelper {
auto loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
@@ -4713,14 +5040,15 @@ public:
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
// Internally store float8 as int8
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
}
Type convertTritonPointerType(triton::PointerType type) {
return LLVM::LLVMPointerType::get(type.getPointeeType(),
// Recursively translate pointee type
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
@@ -4753,7 +5081,7 @@ public:
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
size_t fcSize = 4 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fcSize, type.getElementType()));
ctx, SmallVector<Type>(fcSize, convertType(type.getElementType())));
}
if (mmaLayout.getVersion() == 1) {
@@ -4762,7 +5090,7 @@ public:
int repN = helper.getRepN(shape[1]);
int elems = 8 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, type.getElementType()));
ctx, SmallVector<Type>(elems, convertType(type.getElementType())));
}
llvm::errs()
@@ -4773,7 +5101,7 @@ public:
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = type.getElementType();
Type elemTy = convertType(type.getElementType());
auto vecSize = 1;
if (elemTy.getIntOrFloatBitWidth() == 16) {
vecSize = 2;
@@ -5324,6 +5652,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
@@ -5369,12 +5699,12 @@ public:
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Allocate shared memories and insert barriers
// setp 2: Convert SCF to CFG
// step 2: Convert SCF to CFG
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
// step 4: Convert the rest of ops via partial conversion
// The reason for putting step 1 before step 2 is that the membar analysis
// currently only supports SCF but not CFG.
// The reason for a seperation between 1/4 is that, step 3 is out of
// The reason for a separation between 1/4 is that, step 3 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem
// is not revised during the conversion of step 4.
Allocation allocation(mod);

View File

@@ -371,6 +371,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::FpToFpOp>,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,

View File

@@ -124,6 +124,29 @@ void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
namespace mlir {
namespace triton {
//-- FpToFpOp --
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
::mlir::TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
auto srcEltType = inputs.front();
auto dstEltType = outputs.front();
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
if (srcTensorType && dstTensorType) {
srcEltType = srcTensorType.getElementType();
dstEltType = dstTensorType.getElementType();
}
// Check whether fp8 <=> fp16, bf16, f32, f64
// Make `srcEltType` always the fp8 side
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false;
return dstEltType.isF16() || dstEltType.isBF16() ||
dstEltType.isF32() || dstEltType.isF64();
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {

View File

@@ -44,7 +44,9 @@ namespace gpu {
// TODO: Inheritation of layout attributes
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
if (type.isIntOrIndexOrFloat() ||
type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding();

View File

@@ -32,7 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
auto pointeeType = ptrType.getPointeeType();
unsigned numBits =
pointeeType.isa<triton::Float8Type>() ?
8 : pointeeType.getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);

View File

@@ -140,7 +140,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(createConvertTritonGPUToLLVMPass());
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
pm.addPass(mlir::createSymbolDCEPass());

View File

@@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::Float8Type>();
})
.def("get_bf8_ty",
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::BFloat8Type>();
})
.def(
"get_half_ty",
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
@@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) {
})
// Cast instructions
// Conversions for custom FP types (FP8)
.def("create_fp_to_fp",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
})
// Conversions for standard LLVM builtin types
.def("create_bitcast",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
@@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getI32Type());
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {

View File

@@ -780,88 +780,88 @@ def test_store_bool():
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# def test_f8_xf16_roundtrip(dtype):
# """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
# check_type_supported(dtype)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
# @triton.jit
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements
# input = tl.load(input_ptr + offsets, mask=mask)
# output = input
# tl.store(output_ptr + offsets, output, mask=mask)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
# f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
# f8 = triton.reinterpret(f8_tensor, tl.float8)
# n_elements = f8_tensor.numel()
# xf16 = torch.empty_like(f8_tensor, dtype=dtype)
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
n_elements = f8_tensor.numel()
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
# f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
# copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
# assert torch.all(f8_tensor == f8_output_tensor)
assert torch.all(f8_tensor == f8_output_tensor)
# def test_f16_to_f8_rounding():
# """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
# error is the minimum over all float8.
# Or the same explanation a bit mathier:
# for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
# @triton.jit
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements
# input = tl.load(input_ptr + offsets, mask=mask)
# output = input
# tl.store(output_ptr + offsets, output, mask=mask)
def test_f16_to_f8_rounding():
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
# # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
# f16_input_np = (
# np.array(
# range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
# )
# .view(np.float16)
# )
# f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
# n_elements = f16_input.numel()
# f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
f16_input_np = (
np.array(
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
)
.view(np.float16)
)
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
# f16_output = torch.empty_like(f16_input, dtype=torch.float16)
# copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
# abs_error = torch.abs(f16_input - f16_output)
abs_error = torch.abs(f16_input - f16_output)
# all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
# all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
# all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
# copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
# all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
# torch.isfinite(all_f8_vals_in_f16)
# ]
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
torch.isfinite(all_f8_vals_in_f16)
]
# min_error = torch.min(
# torch.abs(
# f16_input.reshape((-1, 1))
# - all_finite_f8_vals_in_f16.reshape((1, -1))
# ),
# dim=1,
# )[0]
# # 1.9375 is float8 max
# mismatch = torch.logical_and(
# abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
# )
# assert torch.all(
# torch.logical_not(mismatch)
# ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
min_error = torch.min(
torch.abs(
f16_input.reshape((-1, 1))
- all_finite_f8_vals_in_f16.reshape((1, -1))
),
dim=1,
)[0]
# 1.9375 is float8 max
mismatch = torch.logical_and(
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
)
assert torch.all(
torch.logical_not(mismatch)
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
# # ---------------

View File

@@ -48,6 +48,8 @@ class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
CUSTOMIZED_FP_TYPES = ['fp8']
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
OTHER_TYPES = ['void']
class SIGNEDNESS(Enum):
@@ -129,6 +131,12 @@ class dtype:
def is_floating(self):
return self.name in dtype.FP_TYPES
def is_customized_floating(self):
return self.name in dtype.CUSTOMIZED_FP_TYPES
def is_standard_floating(self):
return self.name in dtype.STANDARD_FP_TYPES
def is_int_signed(self):
return self.name in dtype.SINT_TYPES

View File

@@ -613,39 +613,45 @@ def cast(input: tl.tensor,
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# fp8 <=> bf16/fp16
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
# Casting types of the same bit width: fp16 <=> bf16
if (src_sca_ty.is_fp16() and dst_sca_ty.is_bf16()) or \
(src_sca_ty.is_bf16() and dst_sca_ty.is_fp16()):
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
# FP Truncation
# Standard floating types' casting: truncation
# fp64 => fp32, fp16, bf16
# fp32 => fp16, bf16
truncate_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
if truncate_fp:
return tl.tensor(builder.create_fp_trunc(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# FP Extension
# Standard floating types' casting: extension
# fp32 => fp64
# fp16 => fp32, fp64
# bf16 => fp32, fp64
ext_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
if ext_fp:
return tl.tensor(builder.create_fp_ext(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Int cast
# Casting between integer types
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
@@ -658,8 +664,8 @@ def cast(input: tl.tensor,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Float to Int
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
# Casting standard floating types to integer types
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
if dst_sca_ty.is_bool():
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
@@ -673,8 +679,8 @@ def cast(input: tl.tensor,
dst_ty.to_ir(builder)),
dst_ty)
# int => float
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
# Casting integer types to standard floating types
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
return tl.tensor(builder.create_ui_to_fp(input.handle,
dst_ty.to_ir(builder)),
@@ -684,7 +690,7 @@ def cast(input: tl.tensor,
dst_ty.to_ir(builder)),
dst_ty)
# ptr => int
# Casting pointer types to integer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64:
@@ -695,19 +701,14 @@ def cast(input: tl.tensor,
tl.tensor(builder.get_int64(0), tl.int64),
builder)
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
# Casting integer types to pointer types
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Ptr . Ptr
# Casting pointer types to pointer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
# * . Bool
if dst_sca_ty.is_bool():
if src_sca_ty.is_ptr():
input = cast(input, tl.int64, builder)
other = builder.get_int64(0)
if src_ty.is_bool():
other = builder.create_splat(other, src_ty.get_block_shapes())
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
assert False, f'cannot cast {input} to {dst_ty}'
# ===----------------------------------------------------------------------===//

View File

@@ -176,6 +176,9 @@ class JITFunction(KernelInterface):
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8: 'fp8',
triton.language.float16: 'fp16',
triton.language.bfloat16: 'bf16',
triton.language.float32: 'fp32',
}[key]
return f'*{ty}'
if key is None:

View File

@@ -6,8 +6,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
// CHECK: !tt.ptr<f32> -> i64
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
// CHECK: f32 -> f16
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
// CHECK: f32 to f16
%2 = arith.truncf %scalar_f32 : f32 to f16
// 0D tensor -> 0D tensor
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
@@ -18,8 +18,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
// CHECK: tensor<f32> -> tensor<f16>
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
// CHECK: tensor<f32> to tensor<f16>
%5 = arith.truncf %tensor_f32_0d : tensor<f32> to tensor<f16>
// 1D tensor -> 1D tensor
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
@@ -30,8 +30,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
// CHECK: tensor<16xf32> -> tensor<16xf16>
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
// CHECK: tensor<16xf32> to tensor<16xf16>
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
return
}

View File

@@ -125,15 +125,21 @@ TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) {
PTXBuilder builder;
const char *ptxCode =
".param .b64 param0;\n" // prepare param0 (format string)
"st.param.b64 [param0], %0;\n";
"st.param.b64 [param0], %0;\n"
"st.param.b64 [param0], %1;\n"
"st.param.b64 [param0], %2;\n";
auto &ptxSnippet = *builder.create(ptxCode);
auto *opr = builder.newOperand(v[0], "r");
ptxSnippet({opr}, true);
auto *opr0 = builder.newOperand(v[0], "r");
auto *opr1 = builder.newOperand(v[1], "r");
auto *opr2 = builder.newOperand(v[2], "r");
ptxSnippet({opr1, opr2, opr0}, true);
EXPECT_EQ(builder.dump(), ptxCode);
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[0]);
ASSERT_EQ(builder.getAllMLIRArgs().size(), 1);
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[1]);
ASSERT_EQ(builder.getAllMLIRArgs()[1], v[2]);
ASSERT_EQ(builder.getAllMLIRArgs()[2], v[0]);
ASSERT_EQ(builder.getAllMLIRArgs().size(), 3);
}
} // namespace triton