Files
triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Philippe Tillet 20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00

866 lines
37 KiB
C++

#include "ElementwiseOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
static SmallVector<Value>
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2x2StructTy =
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
extract_element(i16_ty, bf16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
"mov.u32 fp8_min, 0x38003800; \n"
"mov.u32 fp8_max, 0x3ff03ff0; \n"
"mov.u32 rn_, 0x80008; \n"
"mov.u32 zero, 0; \n"
"and.b32 sign0, $1, 0x80008000; \n"
"and.b32 sign1, $2, 0x80008000; \n"
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n"
"and.b32 nosign1, $2, 0x7fff7fff; \n"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n"
"add.u32 nosign1, nosign1, rn_; \n"
"sub.u32 nosign0, nosign0, 0x38003800; \n"
"sub.u32 nosign1, nosign1, 0x38003800; \n"
"shr.u32 nosign0, nosign0, 4; \n"
"shr.u32 nosign1, nosign1, 4; \n"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
"or.b32 $0, nosign, sign; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value>
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.bf16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
// TODO: This is a hack to get the right type. We should be able to invoke
// the type converter
return builder.launch(rewriter, loc, i16_ty, false);
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
SmallVector<Value> resultVals;
// Select convertor
if (srcEltType.isa<triton::Float8Type>() ||
dstEltType.isa<triton::Float8Type>()) {
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value &, const Value &)>
convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
convertor = convertFp8x4ToBf16x4;
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertBf16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
convertor = convertFp8x4ToFp32x4;
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp32x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
convertor = convertFp8x4ToFp64x4;
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp64x4ToFp8x4;
} else {
assert(false && "unsupported fp8 casting");
}
// Vectorized casting
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
} else if (srcEltType.isBF16() && dstEltType.isF32()) {
resultVals.emplace_back(convertBf16ToFp32(loc, rewriter, adaptor.from()));
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
resultVals.emplace_back(convertFp32ToBf16(loc, rewriter, adaptor.from()));
} else {
assert(false && "unsupported type casting");
}
assert(resultVals.size() == elems);
auto convertedDstTensorType =
this->getTypeConverter()->convertType(dstTensorType);
auto result = getStructFromElements(loc, resultVals, rewriter,
convertedDstTensorType);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename SourceOp, typename ConcreteT>
class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType();
Location loc = op->getLoc();
unsigned elems = getElemsPerThread(resultTy);
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Type> types(elems, elemTy);
Type structTy = this->getTypeConverter()->convertType(resultTy);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto operands = getOperands(rewriter, adaptor, elems, loc);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
operands[i], loc);
if (!bool(resultVals[i]))
return failure();
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
protected:
SmallVector<SmallVector<Value>>
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
const unsigned elems, Location loc) const {
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
return operands;
}
};
template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, operands,
adaptor.getAttributes().getValue());
}
};
struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicateToLLVM(op.predicate()), operands[0],
operands[1]);
}
static LLVM::ICmpPredicate
ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__) \
case arith::CmpIPredicate::item__: \
return LLVM::ICmpPredicate::item__
__PRED_ENUM(eq);
__PRED_ENUM(ne);
__PRED_ENUM(sgt);
__PRED_ENUM(sge);
__PRED_ENUM(slt);
__PRED_ENUM(sle);
__PRED_ENUM(ugt);
__PRED_ENUM(uge);
__PRED_ENUM(ult);
__PRED_ENUM(ule);
#undef __PRED_ENUM
}
return LLVM::ICmpPredicate::eq;
}
};
struct CmpFOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
CmpFOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, ValueRange operands,
Location loc) {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicateToLLVM(op.predicate()), operands[0],
operands[1]);
}
static LLVM::FCmpPredicate
ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__, item1__) \
case arith::CmpFPredicate::item__: \
return LLVM::FCmpPredicate::item1__
__PRED_ENUM(OEQ, oeq);
__PRED_ENUM(ONE, one);
__PRED_ENUM(OGT, ogt);
__PRED_ENUM(OGE, oge);
__PRED_ENUM(OLT, olt);
__PRED_ENUM(OLE, ole);
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(UGE, uge);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
__PRED_ENUM(UNO, uno);
__PRED_ENUM(AlwaysTrue, _true);
__PRED_ENUM(AlwaysFalse, _false);
#undef __PRED_ENUM
}
return LLVM::FCmpPredicate::_true;
}
};
struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion> {
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
StringRef funcName = op.symbol();
if (funcName.empty())
llvm::errs() << "ExtElemwiseOpConversion";
Type funcType = getFunctionType(elemTy, operands);
LLVM::LLVMFuncOp funcOp =
appendOrGetFuncOp(rewriter, op, funcName, funcType);
return rewriter.create<LLVM::CallOp>(loc, funcOp, operands).getResult(0);
}
private:
Type getFunctionType(Type resultType, ValueRange operands) const {
SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
triton::ExtElemwiseOp op,
StringRef funcName, Type funcType) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
ret.getOperation()->setAttr(
"libname", StringAttr::get(op->getContext(), op.libname()));
ret.getOperation()->setAttr(
"libpath", StringAttr::get(op->getContext(), op.libpath()));
return ret;
}
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder ptxBuilder;
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
if (32 == bitwidth) {
fdiv.o("full").o("f32");
} else if (64 == bitwidth) {
fdiv.o("rn").o("f64");
} else {
assert(0 && bitwidth && "not supported");
}
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
fdiv(res, lhs, rhs);
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
return ret;
}
};
struct FMulOpConversion
: ElementwiseOpConversionBase<mlir::arith::MulFOp, FMulOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::MulFOp, FMulOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::MulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
" mov.b16 c, 0x8000U; \n" // 0.0
" fma.rn.bf16 $0, $1, $2, c; } \n";
auto &fMul = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
} else {
return rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct FAddOpConversion
: ElementwiseOpConversionBase<mlir::arith::AddFOp, FAddOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::AddFOp, FAddOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::AddFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
PTXBuilder builder;
auto ptxAsm = "{ .reg .b16 c; \n"
" mov.b16 c, 0x3f80U; \n" // 1.0
" fma.rn.bf16 $0, $1, c, $2; } \n";
auto &fAdd = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
} else {
return rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct FSubOpConversion
: ElementwiseOpConversionBase<mlir::arith::SubFOp, FSubOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SubFOp, FSubOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::SubFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
" mov.b16 c, 0xbf80U; \n" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n";
auto &fSub = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
} else {
return rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct SIToFPOpConversion
: ElementwiseOpConversionBase<mlir::arith::SIToFPOp, SIToFPOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SIToFPOp, SIToFPOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::SIToFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16()) {
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0]);
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value);
} else {
return rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0]);
}
}
};
struct FPToSIOpConversion
: ElementwiseOpConversionBase<mlir::arith::FPToSIOp, FPToSIOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::FPToSIOp, FPToSIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::FPToSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto value =
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
return rewriter.create<LLVM::FPToSIOp>(loc, elemTy, value);
} else {
return rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0]);
}
}
};
struct ExtFOpConversion
: ElementwiseOpConversionBase<mlir::arith::ExtFOp, ExtFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::ExtFOp, ExtFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::ExtFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto outElemTy = getElementType(op.getOut());
assert(outElemTy.isF32() && "unsupported conversion");
return FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
} else {
return rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0]);
}
}
};
struct TruncFOpConversion
: ElementwiseOpConversionBase<mlir::arith::TruncFOp, TruncFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::TruncFOp, TruncFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::TruncFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16()) {
auto inElemTy = getElementType(op.getIn());
assert(inElemTy.isF32() && "unsupported conversion");
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, operands[0]);
} else {
return rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0]);
}
}
};
struct ExpOpConversionApprox
: ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
using Base =
ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
return {};
const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
}
};
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<FSubOpConversion>(typeConverter, benefit);
patterns.add<FAddOpConversion>(typeConverter, benefit);
patterns.add<FMulOpConversion>(typeConverter, benefit);
patterns.add<ExtFOpConversion>(typeConverter, benefit);
patterns.add<TruncFOpConversion>(typeConverter, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
// For FP64 input type, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}