[Triton-MLIR] Support FP8 (#864)
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
14
.gitignore
vendored
14
.gitignore
vendored
@@ -1,12 +1,20 @@
|
|||||||
|
# Triton builds
|
||||||
build/
|
build/
|
||||||
|
|
||||||
__pycache__
|
# Triton Python module builds
|
||||||
.pytest_cache
|
|
||||||
|
|
||||||
python/build/
|
python/build/
|
||||||
python/triton.egg-info/
|
python/triton.egg-info/
|
||||||
python/triton/_C/libtriton.pyd
|
python/triton/_C/libtriton.pyd
|
||||||
python/triton/_C/libtriton.so
|
python/triton/_C/libtriton.so
|
||||||
|
|
||||||
|
# Python caches
|
||||||
|
__pycache__
|
||||||
|
.pytest_cache
|
||||||
|
|
||||||
|
# VS Code project files
|
||||||
.vscode
|
.vscode
|
||||||
.vs
|
.vs
|
||||||
|
|
||||||
|
# JetBrains project files
|
||||||
|
.idea
|
||||||
|
cmake-build-*
|
||||||
|
@@ -22,8 +22,8 @@ struct PTXInstrExecution;
|
|||||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||||
// instructions.
|
// instructions.
|
||||||
//
|
//
|
||||||
// A helper for building a ASM program, the objective of PTXBuilder is to give a
|
// A helper for building an ASM program, the objective of PTXBuilder is to give
|
||||||
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||||
// Currently, several factors are introduced to reduce the need for mixing
|
// Currently, several factors are introduced to reduce the need for mixing
|
||||||
// string and C++ if-else code.
|
// string and C++ if-else code.
|
||||||
//
|
//
|
||||||
@@ -147,7 +147,7 @@ struct PTXBuilder {
|
|||||||
Operand *newOperand(StringRef constraint);
|
Operand *newOperand(StringRef constraint);
|
||||||
|
|
||||||
// Create a constant integer operand.
|
// Create a constant integer operand.
|
||||||
Operand *newConstantOperand(int v);
|
Operand *newConstantOperand(int64_t v);
|
||||||
// Create a constant operand with explicit code specified.
|
// Create a constant operand with explicit code specified.
|
||||||
Operand *newConstantOperand(const std::string &v);
|
Operand *newConstantOperand(const std::string &v);
|
||||||
|
|
||||||
@@ -172,6 +172,22 @@ private:
|
|||||||
return argArchive.back().get();
|
return argArchive.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make the oprands in argArchive follow the provided \param order.
|
||||||
|
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||||
|
assert(order.size() == argArchive.size());
|
||||||
|
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
|
||||||
|
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
|
||||||
|
// determined by PTX code snippet passed from external.
|
||||||
|
sort(argArchive.begin(), argArchive.end(),
|
||||||
|
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
|
||||||
|
auto ida = std::find(order.begin(), order.end(), a.get());
|
||||||
|
auto idb = std::find(order.begin(), order.end(), b.get());
|
||||||
|
assert(ida != order.end());
|
||||||
|
assert(idb != order.end());
|
||||||
|
return ida < idb;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
friend struct PTXInstr;
|
friend struct PTXInstr;
|
||||||
friend struct PTXInstrCommon;
|
friend struct PTXInstrCommon;
|
||||||
|
|
||||||
|
@@ -10,6 +10,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
|||||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||||
|
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||||
|
|
||||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||||
@@ -72,17 +73,16 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
|||||||
// TODO: Add verifier
|
// TODO: Add verifier
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding,
|
SameOperandsAndResultEncoding,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
||||||
let summary = "Floating point casting for custom types";
|
let summary = "Floating point casting for custom types";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
Floating point casting for custom types (F8, BF8).
|
Floating point casting for custom types (F8).
|
||||||
|
|
||||||
F8 <-> BF8, FP16, FP32
|
F8 <-> FP16, BF16, FP32, FP64
|
||||||
BF8 <-> F8, FP16, FP32
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TT_FloatLike:$from);
|
let arguments = (ins TT_FloatLike:$from);
|
||||||
|
@@ -14,9 +14,8 @@ class TritonTypeDef<string name, string _mnemonic>
|
|||||||
|
|
||||||
// Floating-point Type
|
// Floating-point Type
|
||||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||||
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
|
|
||||||
|
|
||||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
|
||||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||||
|
|
||||||
|
@@ -45,7 +45,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
|
|||||||
return argArchive.back().get();
|
return argArchive.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
|
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "0x" << std::hex << v;
|
ss << "0x" << std::hex << v;
|
||||||
return newConstantOperand(ss.str());
|
return newConstantOperand(ss.str());
|
||||||
@@ -130,8 +130,18 @@ std::string PTXBuilder::dump() const {
|
|||||||
|
|
||||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||||
bool onlyAttachMLIRArgs) {
|
bool onlyAttachMLIRArgs) {
|
||||||
|
if (onlyAttachMLIRArgs) {
|
||||||
|
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
|
||||||
|
// the same MLIR values in onlyAttachMLIRArgs mode.
|
||||||
|
assert(builder->executions.empty() &&
|
||||||
|
"builder can only hold a single execution when onlyAttachMIIRArgs "
|
||||||
|
"is true.");
|
||||||
|
builder->reorderArgArchive(oprs);
|
||||||
|
}
|
||||||
|
|
||||||
builder->executions.emplace_back(
|
builder->executions.emplace_back(
|
||||||
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
||||||
|
|
||||||
return *builder->executions.back();
|
return *builder->executions.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -64,9 +64,8 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
|||||||
rewriter.getF32FloatAttr(v));
|
rewriter.getF32FloatAttr(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a index type constant.
|
// Create an index type constant.
|
||||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||||
|
|
||||||
TypeConverter *converter, int64_t value) {
|
TypeConverter *converter, int64_t value) {
|
||||||
Type ty = converter->convertType(builder.getIndexType());
|
Type ty = converter->convertType(builder.getIndexType());
|
||||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||||
@@ -127,6 +126,7 @@ void llPrintf(StringRef msg, ValueRange args,
|
|||||||
#define i32_ty rewriter.getIntegerType(32)
|
#define i32_ty rewriter.getIntegerType(32)
|
||||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||||
#define f16_ty rewriter.getF16Type()
|
#define f16_ty rewriter.getF16Type()
|
||||||
|
#define bf16_ty rewriter.getBF16Type()
|
||||||
#define i8_ty rewriter.getIntegerType(8)
|
#define i8_ty rewriter.getIntegerType(8)
|
||||||
#define f32_ty rewriter.getF32Type()
|
#define f32_ty rewriter.getF32Type()
|
||||||
#define f64_ty rewriter.getF64Type()
|
#define f64_ty rewriter.getF64Type()
|
||||||
@@ -339,7 +339,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||||
for (auto v : llvm::enumerate(resultVals)) {
|
for (const auto& v : llvm::enumerate(resultVals)) {
|
||||||
|
assert(v.value() && "can not insert null values");
|
||||||
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
||||||
rewriter.getI64ArrayAttr(v.index()));
|
rewriter.getI64ArrayAttr(v.index()));
|
||||||
}
|
}
|
||||||
@@ -699,7 +700,7 @@ public:
|
|||||||
// [elemsPerThread X rank] index matrix.
|
// [elemsPerThread X rank] index matrix.
|
||||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||||
// implement a indiceCache if necessary.
|
// implement an indexCache if necessary.
|
||||||
SmallVector<SmallVector<Value>>
|
SmallVector<SmallVector<Value>>
|
||||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
const BlockedEncodingAttr &blockedLayout,
|
const BlockedEncodingAttr &blockedLayout,
|
||||||
@@ -953,7 +954,8 @@ struct LoadOpConversion
|
|||||||
|
|
||||||
// Determine the vectorization size
|
// Determine the vectorization size
|
||||||
Type valueTy = op.getResult().getType();
|
Type valueTy = op.getResult().getType();
|
||||||
Type valueElemTy = getElementTypeOrSelf(valueTy);
|
Type valueElemTy = typeConverter->convertType(
|
||||||
|
getElementTypeOrSelf(valueTy));
|
||||||
unsigned vec = getVectorSize(ptr);
|
unsigned vec = getVectorSize(ptr);
|
||||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||||
if (llMask)
|
if (llMask)
|
||||||
@@ -1147,7 +1149,8 @@ struct StoreOpConversion
|
|||||||
MLIRContext *ctx = rewriter.getContext();
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
|
||||||
auto valueTy = value.getType();
|
auto valueTy = value.getType();
|
||||||
Type valueElemTy = getElementTypeOrSelf(valueTy);
|
Type valueElemTy = typeConverter->convertType(
|
||||||
|
getElementTypeOrSelf(valueTy));
|
||||||
|
|
||||||
unsigned vec = getVectorSize(ptr);
|
unsigned vec = getVectorSize(ptr);
|
||||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||||
@@ -1734,8 +1737,8 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// We cannot directly
|
// We cannot directly run
|
||||||
// rewriter.replaceOp(op, adaptor.src());
|
// `rewriter.replaceOp(op, adaptor.src())`
|
||||||
// due to MLIR's restrictions
|
// due to MLIR's restrictions
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||||
@@ -2096,6 +2099,330 @@ struct ExtractSliceOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FpToFpOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp8x4ToFp16x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto ctx = rewriter.getContext();
|
||||||
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||||
|
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
|
||||||
|
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto *ptxAsm =
|
||||||
|
"{ \n"
|
||||||
|
".reg .b32 a<2>, b<2>; \n"
|
||||||
|
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||||
|
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||||
|
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||||
|
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||||
|
"shr.b32 b0, b0, 1; \n"
|
||||||
|
"shr.b32 b1, b1, 1; \n"
|
||||||
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||||
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||||
|
"}";
|
||||||
|
auto &call = *builder.create(ptxAsm);
|
||||||
|
|
||||||
|
auto *o0 = builder.newOperand("=r");
|
||||||
|
auto *o1 = builder.newOperand("=r");
|
||||||
|
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||||
|
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||||
|
|
||||||
|
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||||
|
auto fp16x2x2StructTy =
|
||||||
|
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
|
||||||
|
auto fp16x2x2Struct =
|
||||||
|
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
|
||||||
|
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
||||||
|
rewriter.getI32ArrayAttr({0}));
|
||||||
|
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
||||||
|
rewriter.getI32ArrayAttr({1}));
|
||||||
|
return {
|
||||||
|
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||||
|
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||||
|
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||||
|
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp16x4ToFp8x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto ctx = rewriter.getContext();
|
||||||
|
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||||
|
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||||
|
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||||
|
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||||
|
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||||
|
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||||
|
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||||
|
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||||
|
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto *ptxAsm =
|
||||||
|
"{ \n"
|
||||||
|
".reg .b32 a<2>, b<2>; \n"
|
||||||
|
"shl.b32 a0, $1, 1; \n"
|
||||||
|
"shl.b32 a1, $2, 1; \n"
|
||||||
|
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||||
|
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||||
|
"add.u32 a0, a0, 0x00800080; \n"
|
||||||
|
"add.u32 a1, a1, 0x00800080; \n"
|
||||||
|
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||||
|
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||||
|
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||||
|
"}";
|
||||||
|
auto &call = *builder.create(ptxAsm);
|
||||||
|
|
||||||
|
auto *o = builder.newOperand("=r");
|
||||||
|
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
|
||||||
|
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
|
||||||
|
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
|
||||||
|
|
||||||
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||||
|
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||||
|
return {
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp8x4ToBf16x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto ctx = rewriter.getContext();
|
||||||
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||||
|
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
|
||||||
|
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
|
||||||
|
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto *ptxAsm =
|
||||||
|
"{ \n"
|
||||||
|
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
||||||
|
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||||
|
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||||
|
"and.b32 sign0, a0, 0x80008000; \n"
|
||||||
|
"and.b32 sign1, a1, 0x80008000; \n"
|
||||||
|
"and.b32 nosign0, a0, 0x7fff7fff; \n"
|
||||||
|
"and.b32 nosign1, a1, 0x7fff7fff; \n"
|
||||||
|
"shr.b32 nosign0, nosign0, 4; \n"
|
||||||
|
"shr.b32 nosign1, nosign1, 4; \n"
|
||||||
|
"add.u32 nosign0, nosign0, 0x38003800; \n"
|
||||||
|
"add.u32 nosign1, nosign1, 0x38003800; \n"
|
||||||
|
"or.b32 $0, sign0, nosign0; \n"
|
||||||
|
"or.b32 $1, sign1, nosign1; \n"
|
||||||
|
"}";
|
||||||
|
auto &call = *builder.create(ptxAsm);
|
||||||
|
|
||||||
|
auto *o0 = builder.newOperand("=r");
|
||||||
|
auto *o1 = builder.newOperand("=r");
|
||||||
|
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||||
|
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||||
|
|
||||||
|
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
|
||||||
|
auto bf16x2x2StructTy =
|
||||||
|
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
|
||||||
|
auto bf16x2x2Struct =
|
||||||
|
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
|
||||||
|
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
||||||
|
rewriter.getI32ArrayAttr({0}));
|
||||||
|
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
||||||
|
rewriter.getI32ArrayAttr({1}));
|
||||||
|
return {
|
||||||
|
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
|
||||||
|
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
|
||||||
|
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
|
||||||
|
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertBf16x4ToFp8x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto ctx = rewriter.getContext();
|
||||||
|
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
|
||||||
|
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||||
|
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||||
|
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
|
||||||
|
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
|
||||||
|
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
|
||||||
|
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
|
||||||
|
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
|
||||||
|
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto *ptxAsm =
|
||||||
|
"{ \n"
|
||||||
|
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
|
||||||
|
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
|
||||||
|
"mov.u32 fp8_min, 0x38003800; \n"
|
||||||
|
"mov.u32 fp8_max, 0x3ff03ff0; \n"
|
||||||
|
"mov.u32 rn_, 0x80008; \n"
|
||||||
|
"mov.u32 zero, 0; \n"
|
||||||
|
"and.b32 sign0, $1, 0x80008000; \n"
|
||||||
|
"and.b32 sign1, $2, 0x80008000; \n"
|
||||||
|
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||||
|
"and.b32 nosign0, $1, 0x7fff7fff; \n"
|
||||||
|
"and.b32 nosign1, $2, 0x7fff7fff; \n"
|
||||||
|
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||||
|
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||||
|
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||||
|
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
|
||||||
|
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||||
|
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||||
|
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
|
||||||
|
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||||
|
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||||
|
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||||
|
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
|
||||||
|
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||||
|
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||||
|
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
|
||||||
|
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||||
|
"add.u32 nosign0, nosign0, rn_; \n"
|
||||||
|
"add.u32 nosign1, nosign1, rn_; \n"
|
||||||
|
"sub.u32 nosign0, nosign0, 0x38003800; \n"
|
||||||
|
"sub.u32 nosign1, nosign1, 0x38003800; \n"
|
||||||
|
"shr.u32 nosign0, nosign0, 4; \n"
|
||||||
|
"shr.u32 nosign1, nosign1, 4; \n"
|
||||||
|
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
|
||||||
|
"or.b32 $0, nosign, sign; \n"
|
||||||
|
"}";
|
||||||
|
auto &call = *builder.create(ptxAsm);
|
||||||
|
|
||||||
|
auto *o = builder.newOperand("=r");
|
||||||
|
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
|
||||||
|
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
|
||||||
|
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
|
||||||
|
|
||||||
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||||
|
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||||
|
return {
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||||
|
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp8x4ToFp32x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||||
|
return {
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp32x4ToFp8x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||||
|
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||||
|
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||||
|
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||||
|
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp8x4ToFp64x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||||
|
return {
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
|
||||||
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> convertFp64x4ToFp8x4(
|
||||||
|
Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||||
|
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||||
|
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||||
|
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||||
|
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||||
|
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
|
||||||
|
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
|
||||||
|
auto srcEltType = srcTensorType.getElementType();
|
||||||
|
auto dstEltType = dstTensorType.getElementType();
|
||||||
|
assert(srcEltType.isa<triton::Float8Type>() ||
|
||||||
|
dstEltType.isa<triton::Float8Type>());
|
||||||
|
auto convertedDstTensorType =
|
||||||
|
this->getTypeConverter()->convertType(dstTensorType);
|
||||||
|
auto convertedDstEleType =
|
||||||
|
this->getTypeConverter()->convertType(dstEltType);
|
||||||
|
|
||||||
|
// Select convertor
|
||||||
|
std::function<SmallVector<Value>(Location, ConversionPatternRewriter&,
|
||||||
|
const Value&, const Value&,
|
||||||
|
const Value&, const Value&)> convertor;
|
||||||
|
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
|
||||||
|
convertor = convertFp8x4ToFp16x4;
|
||||||
|
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||||
|
convertor = convertFp16x4ToFp8x4;
|
||||||
|
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
|
||||||
|
convertor = convertFp8x4ToBf16x4;
|
||||||
|
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||||
|
convertor = convertBf16x4ToFp8x4;
|
||||||
|
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
|
||||||
|
convertor = convertFp8x4ToFp32x4;
|
||||||
|
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
|
||||||
|
convertor = convertFp32x4ToFp8x4;
|
||||||
|
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
|
||||||
|
convertor = convertFp8x4ToFp64x4;
|
||||||
|
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
|
||||||
|
convertor = convertFp64x4ToFp8x4;
|
||||||
|
} else {
|
||||||
|
assert(false && "unsupported type casting");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vectorized casting
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto elems = getElemsPerThread(dstTensorType);
|
||||||
|
assert(elems % 4 == 0 &&
|
||||||
|
"FP8 casting only support tensors with 4-aligned sizes");
|
||||||
|
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
|
||||||
|
SmallVector<Value> resultVals;
|
||||||
|
for (size_t i = 0; i < elems; i += 4) {
|
||||||
|
auto converted = convertor(loc, rewriter,
|
||||||
|
elements[i], elements[i + 1],
|
||||||
|
elements[i + 2], elements[i + 3]);
|
||||||
|
resultVals.append(converted);
|
||||||
|
}
|
||||||
|
assert(resultVals.size() == elems);
|
||||||
|
auto result = getStructFromElements(loc, resultVals, rewriter,
|
||||||
|
convertedDstTensorType);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// A CRTP style of base class.
|
// A CRTP style of base class.
|
||||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||||
class ElementwiseOpConversionBase
|
class ElementwiseOpConversionBase
|
||||||
@@ -2309,7 +2636,7 @@ private:
|
|||||||
Value smemBase) const;
|
Value smemBase) const;
|
||||||
|
|
||||||
// blocked/mma -> blocked/mma.
|
// blocked/mma -> blocked/mma.
|
||||||
// Data padding in shared memory to avoid bank confict.
|
// Data padding in shared memory to avoid bank conflict.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
|
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
|
||||||
OpAdaptor adaptor,
|
OpAdaptor adaptor,
|
||||||
@@ -3265,7 +3592,7 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
int NK = shape[1];
|
int NK = shape[1];
|
||||||
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
||||||
|
|
||||||
// NOTE We cound't get the vec from the shared layout.
|
// NOTE: We couldn't get the vec from the shared layout.
|
||||||
// int vecA = sharedLayout.getVec();
|
// int vecA = sharedLayout.getVec();
|
||||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||||
bool vecGt4 = false;
|
bool vecGt4 = false;
|
||||||
@@ -3283,7 +3610,7 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
SmallVector<int> fpw({2, 2, 1});
|
SmallVector<int> fpw({2, 2, 1});
|
||||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||||
// NOTE We cound't get the vec from the shared layout.
|
// NOTE: We couldn't get the vec from the shared layout.
|
||||||
// int vecB = sharedLayout.getVec();
|
// int vecB = sharedLayout.getVec();
|
||||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||||
bool vecGt4 = false;
|
bool vecGt4 = false;
|
||||||
@@ -3387,7 +3714,7 @@ struct DotOpMmaV2ConversionHelper {
|
|||||||
return Type{};
|
return Type{};
|
||||||
}
|
}
|
||||||
|
|
||||||
// The type of a matrix that loaded by either a ldmatrix or composed lds.
|
// The type of matrix that loaded by either a ldmatrix or composed lds.
|
||||||
Type getMatType() const {
|
Type getMatType() const {
|
||||||
Type fp32Ty = type::f32Ty(ctx);
|
Type fp32Ty = type::f32Ty(ctx);
|
||||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
||||||
@@ -3583,7 +3910,7 @@ private:
|
|||||||
"mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
"mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||||
};
|
};
|
||||||
|
|
||||||
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
// vector length per ldmatrix (16*8/element_size_in_bits)
|
||||||
inline static const std::map<TensorCoreType, uint8_t> mmaInstrVec = {
|
inline static const std::map<TensorCoreType, uint8_t> mmaInstrVec = {
|
||||||
{TensorCoreType::FP32_FP16_FP16_FP32, 8},
|
{TensorCoreType::FP32_FP16_FP16_FP32, 8},
|
||||||
{TensorCoreType::FP32_BF16_BF16_FP32, 8},
|
{TensorCoreType::FP32_BF16_BF16_FP32, 8},
|
||||||
@@ -3723,7 +4050,7 @@ struct MMA16816ConversionHelper {
|
|||||||
// load from smem
|
// load from smem
|
||||||
loadFn = getLoadMatrixFn(
|
loadFn = getLoadMatrixFn(
|
||||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||||
// load from registers, used in gemm fuse
|
// load from registers, used in gemm fuse
|
||||||
@@ -3754,7 +4081,7 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
auto loadFn = getLoadMatrixFn(
|
auto loadFn = getLoadMatrixFn(
|
||||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||||
|
|
||||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||||
@@ -4713,14 +5040,15 @@ public:
|
|||||||
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
||||||
return convertTritonTensorType(type);
|
return convertTritonTensorType(type);
|
||||||
});
|
});
|
||||||
// internally store bfloat16 as int16
|
// Internally store float8 as int8
|
||||||
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
|
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
|
||||||
return IntegerType::get(type.getContext(), 16);
|
return IntegerType::get(type.getContext(), 8);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Type convertTritonPointerType(triton::PointerType type) {
|
Type convertTritonPointerType(triton::PointerType type) {
|
||||||
return LLVM::LLVMPointerType::get(type.getPointeeType(),
|
// Recursively translate pointee type
|
||||||
|
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||||
type.getAddressSpace());
|
type.getAddressSpace());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4753,7 +5081,7 @@ public:
|
|||||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
||||||
size_t fcSize = 4 * repM * repN;
|
size_t fcSize = 4 * repM * repN;
|
||||||
return LLVM::LLVMStructType::getLiteral(
|
return LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(fcSize, type.getElementType()));
|
ctx, SmallVector<Type>(fcSize, convertType(type.getElementType())));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mmaLayout.getVersion() == 1) {
|
if (mmaLayout.getVersion() == 1) {
|
||||||
@@ -4762,7 +5090,7 @@ public:
|
|||||||
int repN = helper.getRepN(shape[1]);
|
int repN = helper.getRepN(shape[1]);
|
||||||
int elems = 8 * repM * repN;
|
int elems = 8 * repM * repN;
|
||||||
return LLVM::LLVMStructType::getLiteral(
|
return LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(elems, type.getElementType()));
|
ctx, SmallVector<Type>(elems, convertType(type.getElementType())));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::errs()
|
llvm::errs()
|
||||||
@@ -4773,7 +5101,7 @@ public:
|
|||||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||||
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
||||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||||
Type elemTy = type.getElementType();
|
Type elemTy = convertType(type.getElementType());
|
||||||
auto vecSize = 1;
|
auto vecSize = 1;
|
||||||
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
||||||
vecSize = 2;
|
vecSize = 2;
|
||||||
@@ -5324,6 +5652,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||||
#undef POPULATE_UNARY_OP
|
#undef POPULATE_UNARY_OP
|
||||||
|
|
||||||
|
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||||
@@ -5369,12 +5699,12 @@ public:
|
|||||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
|
||||||
// step 1: Allocate shared memories and insert barriers
|
// step 1: Allocate shared memories and insert barriers
|
||||||
// setp 2: Convert SCF to CFG
|
// step 2: Convert SCF to CFG
|
||||||
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||||
// step 4: Convert the rest of ops via partial conversion
|
// step 4: Convert the rest of ops via partial conversion
|
||||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
// The reason for putting step 1 before step 2 is that the membar analysis
|
||||||
// currently only supports SCF but not CFG.
|
// currently only supports SCF but not CFG.
|
||||||
// The reason for a seperation between 1/4 is that, step 3 is out of
|
// The reason for a separation between 1/4 is that, step 3 is out of
|
||||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
// the scope of Dialect Conversion, thus we need to make sure the smem
|
||||||
// is not revised during the conversion of step 4.
|
// is not revised during the conversion of step 4.
|
||||||
Allocation allocation(mod);
|
Allocation allocation(mod);
|
||||||
|
@@ -371,6 +371,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||||
TritonGenericPattern<triton::ViewOp>,
|
TritonGenericPattern<triton::ViewOp>,
|
||||||
TritonGenericPattern<triton::BitcastOp>,
|
TritonGenericPattern<triton::BitcastOp>,
|
||||||
|
TritonGenericPattern<triton::FpToFpOp>,
|
||||||
TritonGenericPattern<triton::IntToPtrOp>,
|
TritonGenericPattern<triton::IntToPtrOp>,
|
||||||
TritonGenericPattern<triton::PtrToIntOp>,
|
TritonGenericPattern<triton::PtrToIntOp>,
|
||||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||||
|
@@ -124,6 +124,29 @@ void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
//-- FpToFpOp --
|
||||||
|
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
|
||||||
|
::mlir::TypeRange outputs) {
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
|
return false;
|
||||||
|
auto srcEltType = inputs.front();
|
||||||
|
auto dstEltType = outputs.front();
|
||||||
|
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
|
||||||
|
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
|
||||||
|
if (srcTensorType && dstTensorType) {
|
||||||
|
srcEltType = srcTensorType.getElementType();
|
||||||
|
dstEltType = dstTensorType.getElementType();
|
||||||
|
}
|
||||||
|
// Check whether fp8 <=> fp16, bf16, f32, f64
|
||||||
|
// Make `srcEltType` always the fp8 side
|
||||||
|
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||||
|
std::swap(srcEltType, dstEltType);
|
||||||
|
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||||
|
return false;
|
||||||
|
return dstEltType.isF16() || dstEltType.isBF16() ||
|
||||||
|
dstEltType.isF32() || dstEltType.isF64();
|
||||||
|
}
|
||||||
|
|
||||||
//-- StoreOp --
|
//-- StoreOp --
|
||||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
::mlir::Value ptr, ::mlir::Value value) {
|
::mlir::Value ptr, ::mlir::Value value) {
|
||||||
|
@@ -44,7 +44,9 @@ namespace gpu {
|
|||||||
|
|
||||||
// TODO: Inheritation of layout attributes
|
// TODO: Inheritation of layout attributes
|
||||||
unsigned getElemsPerThread(Type type) {
|
unsigned getElemsPerThread(Type type) {
|
||||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
if (type.isIntOrIndexOrFloat() ||
|
||||||
|
type.isa<triton::Float8Type>() ||
|
||||||
|
type.isa<triton::PointerType>())
|
||||||
return 1;
|
return 1;
|
||||||
auto tensorType = type.cast<RankedTensorType>();
|
auto tensorType = type.cast<RankedTensorType>();
|
||||||
auto layout = tensorType.getEncoding();
|
auto layout = tensorType.getEncoding();
|
||||||
|
@@ -32,7 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|||||||
// Thread tile size depends on memory alignment
|
// Thread tile size depends on memory alignment
|
||||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||||
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
||||||
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
|
auto pointeeType = ptrType.getPointeeType();
|
||||||
|
unsigned numBits =
|
||||||
|
pointeeType.isa<triton::Float8Type>() ?
|
||||||
|
8 : pointeeType.getIntOrFloatBitWidth();
|
||||||
unsigned maxMultiple = info.getDivisibility(order[0]);
|
unsigned maxMultiple = info.getDivisibility(order[0]);
|
||||||
unsigned maxContig = info.getContiguity(order[0]);
|
unsigned maxContig = info.getContiguity(order[0]);
|
||||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||||
|
@@ -140,7 +140,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||||
|
|
||||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||||
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||||
pm.addPass(mlir::createSymbolDCEPass());
|
pm.addPass(mlir::createSymbolDCEPass());
|
||||||
|
@@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||||
return self.getType<mlir::triton::Float8Type>();
|
return self.getType<mlir::triton::Float8Type>();
|
||||||
})
|
})
|
||||||
.def("get_bf8_ty",
|
|
||||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
|
||||||
return self.getType<mlir::triton::BFloat8Type>();
|
|
||||||
})
|
|
||||||
.def(
|
.def(
|
||||||
"get_half_ty",
|
"get_half_ty",
|
||||||
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
|
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
|
||||||
@@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Cast instructions
|
// Cast instructions
|
||||||
|
// Conversions for custom FP types (FP8)
|
||||||
|
.def("create_fp_to_fp",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||||
|
mlir::Type &dstType) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
|
||||||
|
})
|
||||||
|
// Conversions for standard LLVM builtin types
|
||||||
.def("create_bitcast",
|
.def("create_bitcast",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||||
mlir::Type &dstType) -> mlir::Value {
|
mlir::Type &dstType) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
||||||
})
|
})
|
||||||
// .def("create_cast", &ir::builder::create_cast)
|
|
||||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
|
||||||
.def("create_si_to_fp",
|
.def("create_si_to_fp",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||||
mlir::Type &dstType) -> mlir::Value {
|
mlir::Type &dstType) -> mlir::Value {
|
||||||
@@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||||
self.getI32Type());
|
self.getI32Type());
|
||||||
})
|
})
|
||||||
|
|
||||||
.def("create_fmul",
|
.def("create_fmul",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||||
mlir::Value &rhs) -> mlir::Value {
|
mlir::Value &rhs) -> mlir::Value {
|
||||||
|
@@ -780,88 +780,88 @@ def test_store_bool():
|
|||||||
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
# def test_f8_xf16_roundtrip(dtype):
|
def test_f8_xf16_roundtrip(dtype):
|
||||||
# """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||||
# check_type_supported(dtype)
|
check_type_supported(dtype)
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
# mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
input = tl.load(input_ptr + offsets, mask=mask)
|
||||||
# output = input
|
output = input
|
||||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
# f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||||
# f8 = triton.reinterpret(f8_tensor, tl.float8)
|
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||||
# n_elements = f8_tensor.numel()
|
n_elements = f8_tensor.numel()
|
||||||
# xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
# copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
# f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||||
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||||
# copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
# assert torch.all(f8_tensor == f8_output_tensor)
|
assert torch.all(f8_tensor == f8_output_tensor)
|
||||||
|
|
||||||
|
|
||||||
# def test_f16_to_f8_rounding():
|
def test_f16_to_f8_rounding():
|
||||||
# """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||||
# error is the minimum over all float8.
|
error is the minimum over all float8.
|
||||||
# Or the same explanation a bit mathier:
|
Or the same explanation a bit mathier:
|
||||||
# for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
# mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
input = tl.load(input_ptr + offsets, mask=mask)
|
||||||
# output = input
|
output = input
|
||||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
# # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||||
# f16_input_np = (
|
f16_input_np = (
|
||||||
# np.array(
|
np.array(
|
||||||
# range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||||
# )
|
)
|
||||||
# .view(np.float16)
|
.view(np.float16)
|
||||||
# )
|
)
|
||||||
# f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||||
# n_elements = f16_input.numel()
|
n_elements = f16_input.numel()
|
||||||
# f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||||
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
# copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
# f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||||
# copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
# abs_error = torch.abs(f16_input - f16_output)
|
abs_error = torch.abs(f16_input - f16_output)
|
||||||
|
|
||||||
# all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||||
# all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||||
# all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||||
# copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
# all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||||
# torch.isfinite(all_f8_vals_in_f16)
|
torch.isfinite(all_f8_vals_in_f16)
|
||||||
# ]
|
]
|
||||||
|
|
||||||
# min_error = torch.min(
|
min_error = torch.min(
|
||||||
# torch.abs(
|
torch.abs(
|
||||||
# f16_input.reshape((-1, 1))
|
f16_input.reshape((-1, 1))
|
||||||
# - all_finite_f8_vals_in_f16.reshape((1, -1))
|
- all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||||
# ),
|
),
|
||||||
# dim=1,
|
dim=1,
|
||||||
# )[0]
|
)[0]
|
||||||
# # 1.9375 is float8 max
|
# 1.9375 is float8 max
|
||||||
# mismatch = torch.logical_and(
|
mismatch = torch.logical_and(
|
||||||
# abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||||
# )
|
)
|
||||||
# assert torch.all(
|
assert torch.all(
|
||||||
# torch.logical_not(mismatch)
|
torch.logical_not(mismatch)
|
||||||
# ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
|
@@ -48,6 +48,8 @@ class dtype:
|
|||||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||||
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
|
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||||
|
CUSTOMIZED_FP_TYPES = ['fp8']
|
||||||
|
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||||
OTHER_TYPES = ['void']
|
OTHER_TYPES = ['void']
|
||||||
|
|
||||||
class SIGNEDNESS(Enum):
|
class SIGNEDNESS(Enum):
|
||||||
@@ -129,6 +131,12 @@ class dtype:
|
|||||||
def is_floating(self):
|
def is_floating(self):
|
||||||
return self.name in dtype.FP_TYPES
|
return self.name in dtype.FP_TYPES
|
||||||
|
|
||||||
|
def is_customized_floating(self):
|
||||||
|
return self.name in dtype.CUSTOMIZED_FP_TYPES
|
||||||
|
|
||||||
|
def is_standard_floating(self):
|
||||||
|
return self.name in dtype.STANDARD_FP_TYPES
|
||||||
|
|
||||||
def is_int_signed(self):
|
def is_int_signed(self):
|
||||||
return self.name in dtype.SINT_TYPES
|
return self.name in dtype.SINT_TYPES
|
||||||
|
|
||||||
|
@@ -613,39 +613,45 @@ def cast(input: tl.tensor,
|
|||||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||||
if src_ty == dst_ty:
|
if src_ty == dst_ty:
|
||||||
return input
|
return input
|
||||||
|
|
||||||
src_sca_ty = src_ty.scalar
|
src_sca_ty = src_ty.scalar
|
||||||
dst_sca_ty = dst_ty.scalar
|
dst_sca_ty = dst_ty.scalar
|
||||||
# fp8 <=> bf16/fp16
|
|
||||||
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
|
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||||
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
|
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
|
||||||
|
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
|
||||||
|
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
|
|
||||||
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
|
# Casting types of the same bit width: fp16 <=> bf16
|
||||||
dst_ty)
|
if (src_sca_ty.is_fp16() and dst_sca_ty.is_bf16()) or \
|
||||||
# bf16 <=> (not fp32)
|
(src_sca_ty.is_bf16() and dst_sca_ty.is_fp16()):
|
||||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
|
||||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
|
||||||
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
||||||
|
|
||||||
# FP Truncation
|
# Standard floating types' casting: truncation
|
||||||
|
# fp64 => fp32, fp16, bf16
|
||||||
|
# fp32 => fp16, bf16
|
||||||
truncate_fp = src_sca_ty.is_floating() and \
|
truncate_fp = src_sca_ty.is_floating() and \
|
||||||
dst_sca_ty.is_floating() and \
|
dst_sca_ty.is_floating() and \
|
||||||
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
|
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
||||||
if truncate_fp:
|
if truncate_fp:
|
||||||
return tl.tensor(builder.create_fp_trunc(input.handle,
|
return tl.tensor(builder.create_fp_trunc(input.handle,
|
||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
|
|
||||||
# FP Extension
|
# Standard floating types' casting: extension
|
||||||
|
# fp32 => fp64
|
||||||
|
# fp16 => fp32, fp64
|
||||||
|
# bf16 => fp32, fp64
|
||||||
ext_fp = src_sca_ty.is_floating() and \
|
ext_fp = src_sca_ty.is_floating() and \
|
||||||
dst_sca_ty.is_floating() and \
|
dst_sca_ty.is_floating() and \
|
||||||
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
|
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
||||||
if ext_fp:
|
if ext_fp:
|
||||||
return tl.tensor(builder.create_fp_ext(input.handle,
|
return tl.tensor(builder.create_fp_ext(input.handle,
|
||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
|
|
||||||
# Int cast
|
# Casting between integer types
|
||||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||||
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
||||||
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
||||||
@@ -658,8 +664,8 @@ def cast(input: tl.tensor,
|
|||||||
dst_ty.to_ir(builder), sign_extend),
|
dst_ty.to_ir(builder), sign_extend),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
|
|
||||||
# Float to Int
|
# Casting standard floating types to integer types
|
||||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
||||||
if dst_sca_ty.is_bool():
|
if dst_sca_ty.is_bool():
|
||||||
ty = input.dtype.to_ir(builder)
|
ty = input.dtype.to_ir(builder)
|
||||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||||
@@ -673,8 +679,8 @@ def cast(input: tl.tensor,
|
|||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
|
|
||||||
# int => float
|
# Casting integer types to standard floating types
|
||||||
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
|
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
||||||
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
||||||
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
@@ -684,7 +690,7 @@ def cast(input: tl.tensor,
|
|||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
|
|
||||||
# ptr => int
|
# Casting pointer types to integer types
|
||||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||||
bitwidth = dst_sca_ty.int_bitwidth
|
bitwidth = dst_sca_ty.int_bitwidth
|
||||||
if bitwidth == 64:
|
if bitwidth == 64:
|
||||||
@@ -695,19 +701,14 @@ def cast(input: tl.tensor,
|
|||||||
tl.tensor(builder.get_int64(0), tl.int64),
|
tl.tensor(builder.get_int64(0), tl.int64),
|
||||||
builder)
|
builder)
|
||||||
|
|
||||||
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
# Casting integer types to pointer types
|
||||||
|
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
||||||
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||||
# Ptr . Ptr
|
|
||||||
|
# Casting pointer types to pointer types
|
||||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||||
# * . Bool
|
|
||||||
if dst_sca_ty.is_bool():
|
|
||||||
if src_sca_ty.is_ptr():
|
|
||||||
input = cast(input, tl.int64, builder)
|
|
||||||
other = builder.get_int64(0)
|
|
||||||
if src_ty.is_bool():
|
|
||||||
other = builder.create_splat(other, src_ty.get_block_shapes())
|
|
||||||
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
|
|
||||||
assert False, f'cannot cast {input} to {dst_ty}'
|
assert False, f'cannot cast {input} to {dst_ty}'
|
||||||
|
|
||||||
# ===----------------------------------------------------------------------===//
|
# ===----------------------------------------------------------------------===//
|
||||||
|
@@ -176,6 +176,9 @@ class JITFunction(KernelInterface):
|
|||||||
triton.language.uint32: 'u32',
|
triton.language.uint32: 'u32',
|
||||||
triton.language.uint64: 'u64',
|
triton.language.uint64: 'u64',
|
||||||
triton.language.float8: 'fp8',
|
triton.language.float8: 'fp8',
|
||||||
|
triton.language.float16: 'fp16',
|
||||||
|
triton.language.bfloat16: 'bf16',
|
||||||
|
triton.language.float32: 'fp32',
|
||||||
}[key]
|
}[key]
|
||||||
return f'*{ty}'
|
return f'*{ty}'
|
||||||
if key is None:
|
if key is None:
|
||||||
|
@@ -6,8 +6,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
|||||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||||
// CHECK: !tt.ptr<f32> -> i64
|
// CHECK: !tt.ptr<f32> -> i64
|
||||||
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
|
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
|
||||||
// CHECK: f32 -> f16
|
// CHECK: f32 to f16
|
||||||
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
|
%2 = arith.truncf %scalar_f32 : f32 to f16
|
||||||
|
|
||||||
// 0D tensor -> 0D tensor
|
// 0D tensor -> 0D tensor
|
||||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||||
@@ -18,8 +18,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
|||||||
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
|
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||||
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
|
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||||
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
|
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||||
// CHECK: tensor<f32> -> tensor<f16>
|
// CHECK: tensor<f32> to tensor<f16>
|
||||||
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
|
%5 = arith.truncf %tensor_f32_0d : tensor<f32> to tensor<f16>
|
||||||
|
|
||||||
// 1D tensor -> 1D tensor
|
// 1D tensor -> 1D tensor
|
||||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||||
@@ -30,8 +30,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
|||||||
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||||
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||||
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||||
// CHECK: tensor<16xf32> -> tensor<16xf16>
|
// CHECK: tensor<16xf32> to tensor<16xf16>
|
||||||
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
|
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -125,15 +125,21 @@ TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) {
|
|||||||
PTXBuilder builder;
|
PTXBuilder builder;
|
||||||
const char *ptxCode =
|
const char *ptxCode =
|
||||||
".param .b64 param0;\n" // prepare param0 (format string)
|
".param .b64 param0;\n" // prepare param0 (format string)
|
||||||
"st.param.b64 [param0], %0;\n";
|
"st.param.b64 [param0], %0;\n"
|
||||||
|
"st.param.b64 [param0], %1;\n"
|
||||||
|
"st.param.b64 [param0], %2;\n";
|
||||||
|
|
||||||
auto &ptxSnippet = *builder.create(ptxCode);
|
auto &ptxSnippet = *builder.create(ptxCode);
|
||||||
auto *opr = builder.newOperand(v[0], "r");
|
auto *opr0 = builder.newOperand(v[0], "r");
|
||||||
ptxSnippet({opr}, true);
|
auto *opr1 = builder.newOperand(v[1], "r");
|
||||||
|
auto *opr2 = builder.newOperand(v[2], "r");
|
||||||
|
ptxSnippet({opr1, opr2, opr0}, true);
|
||||||
|
|
||||||
EXPECT_EQ(builder.dump(), ptxCode);
|
EXPECT_EQ(builder.dump(), ptxCode);
|
||||||
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[0]);
|
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[1]);
|
||||||
ASSERT_EQ(builder.getAllMLIRArgs().size(), 1);
|
ASSERT_EQ(builder.getAllMLIRArgs()[1], v[2]);
|
||||||
|
ASSERT_EQ(builder.getAllMLIRArgs()[2], v[0]);
|
||||||
|
ASSERT_EQ(builder.getAllMLIRArgs().size(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
Reference in New Issue
Block a user