#include "ElementwiseOpToLLVM.h" using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getStructFromElements; using ::mlir::triton::gpu::getElemsPerThread; struct FpToFpOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern; static SmallVector convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto ctx = rewriter.getContext(); auto fp8x4VecTy = vec_ty(i8_ty, 4); Value fp8x4Vec = undef(fp8x4VecTy); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0)); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1)); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2)); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3)); fp8x4Vec = bitcast(fp8x4Vec, i32_ty); PTXBuilder builder; auto *ptxAsm = "{ \n" ".reg .b32 a<2>, b<2>; \n" "prmt.b32 a0, 0, $2, 0x5040; \n" "prmt.b32 a1, 0, $2, 0x7060; \n" "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" "shr.b32 b0, b0, 1; \n" "shr.b32 b1, b1, 1; \n" "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" "}"; auto &call = *builder.create(ptxAsm); auto *o0 = builder.newOperand("=r"); auto *o1 = builder.newOperand("=r"); auto *i = builder.newOperand(fp8x4Vec, "r"); call({o0, o1, i}, /*onlyAttachMLIRArgs=*/true); auto fp16x2VecTy = vec_ty(f16_ty, 2); auto fp16x2x2StructTy = struct_ty(SmallVector{fp16x2VecTy, fp16x2VecTy}); auto fp16x2x2Struct = builder.launch(rewriter, loc, fp16x2x2StructTy, false); auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0})); auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1})); return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)), extract_element(f16_ty, fp16x2Vec0, i32_val(1)), extract_element(f16_ty, fp16x2Vec1, i32_val(0)), extract_element(f16_ty, fp16x2Vec1, i32_val(1))}; } static SmallVector convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto ctx = rewriter.getContext(); auto fp16x2VecTy = vec_ty(f16_ty, 2); Value fp16x2Vec0 = undef(fp16x2VecTy); Value fp16x2Vec1 = undef(fp16x2VecTy); fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0)); fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1)); fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0)); fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1)); fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty); fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty); PTXBuilder builder; auto *ptxAsm = "{ \n" ".reg .b32 a<2>, b<2>; \n" "shl.b32 a0, $1, 1; \n" "shl.b32 a1, $2, 1; \n" "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" "add.u32 a0, a0, 0x00800080; \n" "add.u32 a1, a1, 0x00800080; \n" "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" "prmt.b32 $0, b0, b1, 0x7531; \n" "}"; auto &call = *builder.create(ptxAsm); auto *o = builder.newOperand("=r"); auto *i0 = builder.newOperand(fp16x2Vec0, "r"); auto *i1 = builder.newOperand(fp16x2Vec1, "r"); call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true); auto fp8x4VecTy = vec_ty(i8_ty, 4); auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false); return {extract_element(i8_ty, fp8x4Vec, i32_val(0)), extract_element(i8_ty, fp8x4Vec, i32_val(1)), extract_element(i8_ty, fp8x4Vec, i32_val(2)), extract_element(i8_ty, fp8x4Vec, i32_val(3))}; } static SmallVector convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto ctx = rewriter.getContext(); auto fp8x4VecTy = vec_ty(i8_ty, 4); Value fp8x4Vec = undef(fp8x4VecTy); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0)); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1)); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2)); fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3)); fp8x4Vec = bitcast(fp8x4Vec, i32_ty); PTXBuilder builder; auto *ptxAsm = "{ \n" ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n" "prmt.b32 a0, 0, $2, 0x5040; \n" "prmt.b32 a1, 0, $2, 0x7060; \n" "and.b32 sign0, a0, 0x80008000; \n" "and.b32 sign1, a1, 0x80008000; \n" "and.b32 nosign0, a0, 0x7fff7fff; \n" "and.b32 nosign1, a1, 0x7fff7fff; \n" "shr.b32 nosign0, nosign0, 4; \n" "shr.b32 nosign1, nosign1, 4; \n" "add.u32 nosign0, nosign0, 0x38003800; \n" "add.u32 nosign1, nosign1, 0x38003800; \n" "or.b32 $0, sign0, nosign0; \n" "or.b32 $1, sign1, nosign1; \n" "}"; auto &call = *builder.create(ptxAsm); auto *o0 = builder.newOperand("=r"); auto *o1 = builder.newOperand("=r"); auto *i = builder.newOperand(fp8x4Vec, "r"); call({o0, o1, i}, /* onlyAttachMLIRArgs */ true); auto bf16x2VecTy = vec_ty(i16_ty, 2); auto bf16x2x2StructTy = struct_ty(SmallVector{bf16x2VecTy, bf16x2VecTy}); auto bf16x2x2Struct = builder.launch(rewriter, loc, bf16x2x2StructTy, false); auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0})); auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1})); return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)), extract_element(i16_ty, bf16x2Vec0, i32_val(1)), extract_element(i16_ty, bf16x2Vec1, i32_val(0)), extract_element(i16_ty, bf16x2Vec1, i32_val(1))}; } static SmallVector convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto ctx = rewriter.getContext(); auto bf16x2VecTy = vec_ty(i16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); Value bf16x2Vec1 = undef(bf16x2VecTy); bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0)); bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1)); bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0)); bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1)); bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty); bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty); PTXBuilder builder; auto *ptxAsm = "{ \n" ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" ".reg .u32 fp8_min, fp8_max, rn_, zero; \n" "mov.u32 fp8_min, 0x38003800; \n" "mov.u32 fp8_max, 0x3ff03ff0; \n" "mov.u32 rn_, 0x80008; \n" "mov.u32 zero, 0; \n" "and.b32 sign0, $1, 0x80008000; \n" "and.b32 sign1, $2, 0x80008000; \n" "prmt.b32 sign, sign0, sign1, 0x7531; \n" "and.b32 nosign0, $1, 0x7fff7fff; \n" "and.b32 nosign1, $2, 0x7fff7fff; \n" ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n" "and.b32 nosign_0_0, nosign0, 0xffff0000; \n" "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n" "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n" "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n" "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n" "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n" "or.b32 nosign0, nosign_0_0, nosign_0_1; \n" "and.b32 nosign_1_0, nosign1, 0xffff0000; \n" "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n" "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n" "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n" "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n" "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n" "or.b32 nosign1, nosign_1_0, nosign_1_1; \n" "add.u32 nosign0, nosign0, rn_; \n" "add.u32 nosign1, nosign1, rn_; \n" "sub.u32 nosign0, nosign0, 0x38003800; \n" "sub.u32 nosign1, nosign1, 0x38003800; \n" "shr.u32 nosign0, nosign0, 4; \n" "shr.u32 nosign1, nosign1, 4; \n" "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n" "or.b32 $0, nosign, sign; \n" "}"; auto &call = *builder.create(ptxAsm); auto *o = builder.newOperand("=r"); auto *i0 = builder.newOperand(bf16x2Vec0, "r"); auto *i1 = builder.newOperand(bf16x2Vec1, "r"); call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true); auto fp8x4VecTy = vec_ty(i8_ty, 4); auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false); return {extract_element(i8_ty, fp8x4Vec, i32_val(0)), extract_element(i8_ty, fp8x4Vec, i32_val(1)), extract_element(i8_ty, fp8x4Vec, i32_val(2)), extract_element(i8_ty, fp8x4Vec, i32_val(3))}; } static SmallVector convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3); return {rewriter.create(loc, f32_ty, fp16Values[0]), rewriter.create(loc, f32_ty, fp16Values[1]), rewriter.create(loc, f32_ty, fp16Values[2]), rewriter.create(loc, f32_ty, fp16Values[3])}; } static SmallVector convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto c0 = rewriter.create(loc, f16_ty, v0); auto c1 = rewriter.create(loc, f16_ty, v1); auto c2 = rewriter.create(loc, f16_ty, v2); auto c3 = rewriter.create(loc, f16_ty, v3); return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3); } static SmallVector convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3); return {rewriter.create(loc, f64_ty, fp16Values[0]), rewriter.create(loc, f64_ty, fp16Values[1]), rewriter.create(loc, f64_ty, fp16Values[2]), rewriter.create(loc, f64_ty, fp16Values[3])}; } static SmallVector convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { auto c0 = rewriter.create(loc, f16_ty, v0); auto c1 = rewriter.create(loc, f16_ty, v1); auto c2 = rewriter.create(loc, f16_ty, v2); auto c3 = rewriter.create(loc, f16_ty, v3); return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3); } static Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, const Value &v) { PTXBuilder builder; auto &cvt = *builder.create("cvt.rn.f32.bf16"); auto res = builder.newOperand("=r"); auto operand = builder.newOperand(v, "h"); cvt(res, operand); return builder.launch(rewriter, loc, f32_ty, false); } static Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter, const Value &v) { PTXBuilder builder; auto &cvt = *builder.create("cvt.rn.bf16.f32"); auto res = builder.newOperand("=h"); auto operand = builder.newOperand(v, "r"); cvt(res, operand); // TODO: This is a hack to get the right type. We should be able to invoke // the type converter return builder.launch(rewriter, loc, i16_ty, false); } LogicalResult matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTensorType = op.from().getType().cast(); auto dstTensorType = op.result().getType().cast(); auto srcEltType = srcTensorType.getElementType(); auto dstEltType = dstTensorType.getElementType(); auto loc = op->getLoc(); auto elems = getElemsPerThread(dstTensorType); SmallVector resultVals; // Select convertor if (srcEltType.isa() || dstEltType.isa()) { std::function(Location, ConversionPatternRewriter &, const Value &, const Value &, const Value &, const Value &)> convertor; if (srcEltType.isa() && dstEltType.isF16()) { convertor = convertFp8x4ToFp16x4; } else if (srcEltType.isF16() && dstEltType.isa()) { convertor = convertFp16x4ToFp8x4; } else if (srcEltType.isa() && dstEltType.isBF16()) { convertor = convertFp8x4ToBf16x4; } else if (srcEltType.isBF16() && dstEltType.isa()) { convertor = convertBf16x4ToFp8x4; } else if (srcEltType.isa() && dstEltType.isF32()) { convertor = convertFp8x4ToFp32x4; } else if (srcEltType.isF32() && dstEltType.isa()) { convertor = convertFp32x4ToFp8x4; } else if (srcEltType.isa() && dstEltType.isF64()) { convertor = convertFp8x4ToFp64x4; } else if (srcEltType.isF64() && dstEltType.isa()) { convertor = convertFp64x4ToFp8x4; } else { assert(false && "unsupported fp8 casting"); } // Vectorized casting assert(elems % 4 == 0 && "FP8 casting only support tensors with 4-aligned sizes"); auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter); for (size_t i = 0; i < elems; i += 4) { auto converted = convertor(loc, rewriter, elements[i], elements[i + 1], elements[i + 2], elements[i + 3]); resultVals.append(converted); } } else if (srcEltType.isBF16() && dstEltType.isF32()) { resultVals.emplace_back(convertBf16ToFp32(loc, rewriter, adaptor.from())); } else if (srcEltType.isF32() && dstEltType.isBF16()) { resultVals.emplace_back(convertFp32ToBf16(loc, rewriter, adaptor.from())); } else { assert(false && "unsupported type casting"); } assert(resultVals.size() == elems); auto convertedDstTensorType = this->getTypeConverter()->convertType(dstTensorType); auto result = getStructFromElements(loc, resultVals, rewriter, convertedDstTensorType); rewriter.replaceOp(op, result); return success(); } }; template class ElementwiseOpConversionBase : public ConvertTritonGPUOpToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultTy = op.getType(); Location loc = op->getLoc(); unsigned elems = getElemsPerThread(resultTy); auto resultElementTy = getElementTypeOrSelf(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultElementTy); SmallVector types(elems, elemTy); Type structTy = this->getTypeConverter()->convertType(resultTy); auto *concreteThis = static_cast(this); auto operands = getOperands(rewriter, adaptor, elems, loc); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy, operands[i], loc); if (!bool(resultVals[i])) return failure(); } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); } protected: SmallVector> getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor, const unsigned elems, Location loc) const { SmallVector> operands(elems); for (auto operand : adaptor.getOperands()) { auto sub_operands = getElementsFromStruct(loc, operand, rewriter); for (size_t i = 0; i < elems; ++i) { operands[i].push_back(sub_operands[i]); } } return operands; } }; template struct ElementwiseOpConversion : public ElementwiseOpConversionBase< SourceOp, ElementwiseOpConversion> { using Base = ElementwiseOpConversionBase>; using Base::Base; using OpAdaptor = typename Base::OpAdaptor; explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ElementwiseOpConversionBase( typeConverter, benefit) {} // An interface to support variant DestOp builder. DestOp createDestOp(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { return rewriter.create(loc, elemTy, operands, adaptor.getAttributes().getValue()); } }; struct CmpIOpConversion : public ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { return rewriter.create( loc, elemTy, ArithCmpIPredicateToLLVM(op.predicate()), operands[0], operands[1]); } static LLVM::ICmpPredicate ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { switch (predicate) { #define __PRED_ENUM(item__) \ case arith::CmpIPredicate::item__: \ return LLVM::ICmpPredicate::item__ __PRED_ENUM(eq); __PRED_ENUM(ne); __PRED_ENUM(sgt); __PRED_ENUM(sge); __PRED_ENUM(slt); __PRED_ENUM(sle); __PRED_ENUM(ugt); __PRED_ENUM(uge); __PRED_ENUM(ult); __PRED_ENUM(ule); #undef __PRED_ENUM } return LLVM::ICmpPredicate::eq; } }; struct CmpFOpConversion : public ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) { return rewriter.create( loc, elemTy, ArithCmpFPredicateToLLVM(op.predicate()), operands[0], operands[1]); } static LLVM::FCmpPredicate ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { switch (predicate) { #define __PRED_ENUM(item__, item1__) \ case arith::CmpFPredicate::item__: \ return LLVM::FCmpPredicate::item1__ __PRED_ENUM(OEQ, oeq); __PRED_ENUM(ONE, one); __PRED_ENUM(OGT, ogt); __PRED_ENUM(OGE, oge); __PRED_ENUM(OLT, olt); __PRED_ENUM(OLE, ole); __PRED_ENUM(ORD, ord); __PRED_ENUM(UEQ, ueq); __PRED_ENUM(UGT, ugt); __PRED_ENUM(UGE, uge); __PRED_ENUM(ULT, ult); __PRED_ENUM(ULE, ule); __PRED_ENUM(UNE, une); __PRED_ENUM(UNO, uno); __PRED_ENUM(AlwaysTrue, _true); __PRED_ENUM(AlwaysFalse, _false); #undef __PRED_ENUM } return LLVM::FCmpPredicate::_true; } }; struct ExtElemwiseOpConversion : public ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { StringRef funcName = op.symbol(); if (funcName.empty()) llvm::errs() << "ExtElemwiseOpConversion"; Type funcType = getFunctionType(elemTy, operands); LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(rewriter, op, funcName, funcType); return rewriter.create(loc, funcOp, operands).getResult(0); } private: Type getFunctionType(Type resultType, ValueRange operands) const { SmallVector operandTypes(operands.getTypes()); return LLVM::LLVMFunctionType::get(resultType, operandTypes); } LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, triton::ExtElemwiseOp op, StringRef funcName, Type funcType) const { using LLVM::LLVMFuncOp; auto funcAttr = StringAttr::get(op->getContext(), funcName); Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (funcOp) return cast(*funcOp); mlir::OpBuilder b(op->getParentOfType()); auto ret = b.create(op->getLoc(), funcName, funcType); ret.getOperation()->setAttr( "libname", StringAttr::get(op->getContext(), op.libname())); ret.getOperation()->setAttr( "libpath", StringAttr::get(op->getContext(), op.libpath())); return ret; } }; struct FDivOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { PTXBuilder ptxBuilder; auto &fdiv = *ptxBuilder.create("div"); unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); if (32 == bitwidth) { fdiv.o("full").o("f32"); } else if (64 == bitwidth) { fdiv.o("rn").o("f64"); } else { assert(0 && bitwidth && "not supported"); } auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l"); auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l"); auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l"); fdiv(res, lhs, rhs); Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false); return ret; } }; struct FMulOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::MulFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto lhsElemTy = getElementType(op.getLhs()); auto rhsElemTy = getElementType(op.getRhs()); if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { PTXBuilder builder; auto ptxAsm = " { .reg .b16 c; \n" " mov.b16 c, 0x8000U; \n" // 0.0 " fma.rn.bf16 $0, $1, $2, c; } \n"; auto &fMul = *builder.create(ptxAsm); auto res = builder.newOperand("=h"); auto lhs = builder.newOperand(operands[0], "h"); auto rhs = builder.newOperand(operands[1], "h"); fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); return builder.launch(rewriter, loc, i16_ty, false); } else { return rewriter.create(loc, elemTy, operands[0], operands[1]); } } }; struct FAddOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::AddFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto lhsElemTy = getElementType(op.getLhs()); auto rhsElemTy = getElementType(op.getRhs()); if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { PTXBuilder builder; auto ptxAsm = "{ .reg .b16 c; \n" " mov.b16 c, 0x3f80U; \n" // 1.0 " fma.rn.bf16 $0, $1, c, $2; } \n"; auto &fAdd = *builder.create(ptxAsm); auto res = builder.newOperand("=h"); auto lhs = builder.newOperand(operands[0], "h"); auto rhs = builder.newOperand(operands[1], "h"); fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); return builder.launch(rewriter, loc, i16_ty, false); } else { return rewriter.create(loc, elemTy, operands[0], operands[1]); } } }; struct FSubOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::SubFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto lhsElemTy = getElementType(op.getLhs()); auto rhsElemTy = getElementType(op.getRhs()); if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { PTXBuilder builder; auto ptxAsm = " { .reg .b16 c; \n" " mov.b16 c, 0xbf80U; \n" // -1.0 " fma.rn.bf16 $0, $2, c, $1;} \n"; auto &fSub = *builder.create(ptxAsm); auto res = builder.newOperand("=h"); auto lhs = builder.newOperand(operands[0], "h"); auto rhs = builder.newOperand(operands[1], "h"); fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); return builder.launch(rewriter, loc, i16_ty, false); } else { return rewriter.create(loc, elemTy, operands[0], operands[1]); } } }; struct SIToFPOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto outElemTy = getElementType(op.getOut()); if (outElemTy.isBF16()) { auto value = rewriter.create(loc, f32_ty, operands[0]); return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value); } else { return rewriter.create(loc, elemTy, operands[0]); } } }; struct FPToSIOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::FPToSIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto inElemTy = getElementType(op.getIn()); if (inElemTy.isBF16()) { auto value = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]); return rewriter.create(loc, elemTy, value); } else { return rewriter.create(loc, elemTy, operands[0]); } } }; struct ExtFOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::ExtFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto inElemTy = getElementType(op.getIn()); if (inElemTy.isBF16()) { auto outElemTy = getElementType(op.getOut()); assert(outElemTy.isF32() && "unsupported conversion"); return FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]); } else { return rewriter.create(loc, elemTy, operands[0]); } } }; struct TruncFOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::arith::TruncFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { auto outElemTy = getElementType(op.getOut()); if (outElemTy.isBF16()) { auto inElemTy = getElementType(op.getIn()); assert(inElemTy.isF32() && "unsupported conversion"); return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, operands[0]); } else { return rewriter.create(loc, elemTy, operands[0]); } } }; struct ExpOpConversionApprox : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, ValueRange operands, Location loc) const { // For FP64 input, call __nv_expf for higher-precision calculation if (elemTy.getIntOrFloatBitWidth() == 64) return {}; const double log2e = 1.4426950408889634; Value prod = fmul(f32_ty, operands[0], f32_val(log2e)); PTXBuilder ptxBuilder; auto &exp2 = ptxBuilder.create("ex2")->o("approx").o("f32"); auto output = ptxBuilder.newOperand("=f"); auto input = ptxBuilder.newOperand(prod, "f"); exp2(output, input); return ptxBuilder.launch(rewriter, loc, f32_ty, false); } }; void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, const Allocation *allocation, Value smem, PatternBenefit benefit) { #define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp) #undef POPULATE_TERNARY_OP #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) POPULATE_UNARY_OP(math::LogOp, math::LogOp) POPULATE_UNARY_OP(math::CosOp, math::CosOp) POPULATE_UNARY_OP(math::SinOp, math::SinOp) POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) #undef POPULATE_UNARY_OP patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); // ExpOpConversionApprox will try using ex2.approx if the input type is FP32. // For FP64 input type, ExpOpConversionApprox will return failure and // ElementwiseOpConversion defined below will call // __nv_expf for higher-precision calculation patterns.add(typeConverter, benefit); }