[Triton-MLIR][BACKEND] tiny code cleanup (#899)
- Remove the unnecessary `static` in the anonymous namespace - Remove several unnecessary functions - Several simple rewrites to make code more clear
This commit is contained in:
@@ -51,8 +51,7 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
namespace {
|
||||
|
||||
// Create a 32-bit integer constant.
|
||||
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
|
||||
int32_t v) {
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
@@ -71,16 +70,16 @@ Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||
}
|
||||
|
||||
// Create an index type constant.
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// Create an integer constant of \param width bits.
|
||||
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
short width, int64_t value) {
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value) {
|
||||
Type ty = builder.getIntegerType(width);
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
@@ -187,9 +186,8 @@ template <typename T> void printScalar(const T &e, const std::string &info) {
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
for (const auto &attr : attrs) {
|
||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||
@@ -202,7 +200,7 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
}
|
||||
|
||||
/// Helper function for wrapping all attributes into a single DictionaryAttr
|
||||
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
||||
auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
||||
return DictionaryAttr::get(
|
||||
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
|
||||
}
|
||||
@@ -359,7 +357,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
|
||||
// delinearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||
SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
@@ -376,8 +374,8 @@ static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
@@ -391,7 +389,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
||||
|
||||
// linearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
@@ -407,15 +405,15 @@ static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
assert(shape.size() == order.size());
|
||||
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
||||
reorder(shape, order));
|
||||
}
|
||||
|
||||
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, Value val, Value pred) {
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
@@ -517,10 +515,9 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||
auto rank = (elems.size() - 1) / 2;
|
||||
return SharedMemoryObject(
|
||||
/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()});
|
||||
return {/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||
}
|
||||
|
||||
static Value
|
||||
@@ -1018,13 +1015,13 @@ struct ArithConstantSplatOpConversion
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
: AxisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
// Get corresponding LLVM element values of \param value.
|
||||
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
if (!value)
|
||||
return {};
|
||||
if (!llValue.getType().isa<LLVM::LLVMStructType>())
|
||||
@@ -1601,7 +1598,7 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported reduce op");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value val, int i) const {
|
||||
@@ -1946,7 +1943,7 @@ struct PrintfOpConversion
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.prefix();
|
||||
if (operands.size() > 0) {
|
||||
if (!operands.empty()) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
}
|
||||
|
||||
@@ -2130,7 +2127,7 @@ struct MakeRangeOpConversion
|
||||
auto idxs = emitIndices(loc, rewriter, layout, shape);
|
||||
unsigned elems = idxs.size();
|
||||
SmallVector<Value> retVals(elems);
|
||||
for (auto multiDim : llvm::enumerate(idxs)) {
|
||||
for (const auto &multiDim : llvm::enumerate(idxs)) {
|
||||
assert(multiDim.value().size() == 1);
|
||||
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
||||
}
|
||||
@@ -2633,7 +2630,7 @@ struct FpToFpOpConversion
|
||||
};
|
||||
|
||||
// A CRTP style of base class.
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
template <typename SourceOp, typename ConcreteT>
|
||||
class ElementwiseOpConversionBase
|
||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
@@ -2688,16 +2685,16 @@ protected:
|
||||
template <typename SourceOp, typename DestOp>
|
||||
struct ElementwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<
|
||||
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
||||
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<SourceOp, DestOp,
|
||||
ElementwiseOpConversionBase<SourceOp,
|
||||
ElementwiseOpConversion<SourceOp, DestOp>>;
|
||||
using Base::Base;
|
||||
using OpAdaptor = typename Base::OpAdaptor;
|
||||
|
||||
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>(
|
||||
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
|
||||
typeConverter, benefit) {}
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
@@ -2714,10 +2711,10 @@ struct ElementwiseOpConversion
|
||||
//
|
||||
|
||||
struct CmpIOpConversion
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
|
||||
CmpIOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
CmpIOpConversion>;
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
@@ -2755,17 +2752,18 @@ struct CmpIOpConversion
|
||||
};
|
||||
|
||||
struct CmpFOpConversion
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
|
||||
CmpFOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||
CmpFOpConversion>;
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, ValueRange operands,
|
||||
Location loc) {
|
||||
return rewriter.create<LLVM::FCmpOp>(
|
||||
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
|
||||
operands[1]);
|
||||
@@ -2945,13 +2943,6 @@ private:
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
||||
|
||||
// shared -> dot_operand if the result layout is blocked
|
||||
Value lowerSharedToDotOperandBlocked(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
||||
};
|
||||
|
||||
void ConvertLayoutOpConversion::processReplica(
|
||||
@@ -2960,7 +2951,7 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
|
||||
SmallVector<Value> &vals, Value smemBase) const {
|
||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto layout = type.getEncoding();
|
||||
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
||||
@@ -2989,13 +2980,12 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto multiDimCTAInRepId =
|
||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
||||
SmallVector<unsigned> multiDimCTAId(rank);
|
||||
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||
auto d = it.index();
|
||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||
}
|
||||
|
||||
unsigned linearCTAId =
|
||||
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
||||
auto linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
||||
// TODO: This is actually redundant index calculation, we should
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
@@ -3073,7 +3063,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
// unsigned elems = getElemsPerThread(srcTy);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned inVec = 0;
|
||||
@@ -3118,7 +3108,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
return success();
|
||||
};
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
@@ -3144,7 +3134,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||
unsigned numElems = getElemsPerThread(srcTy);
|
||||
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned srcAccumSizeInThreads =
|
||||
auto srcAccumSizeInThreads =
|
||||
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
||||
auto elemTy = srcTy.getElementType();
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
@@ -3177,7 +3167,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||
// TODO: We should get less barriers if it is handled by membar pass
|
||||
// instead of the backend, since the later can only handle it in
|
||||
@@ -3196,7 +3186,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
||||
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
||||
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
||||
unsigned wordVecIdx =
|
||||
auto wordVecIdx =
|
||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
||||
wordVecs[wordVecIdx] =
|
||||
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
|
||||
@@ -3265,7 +3255,6 @@ public:
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
cStride = smemStrides[order[0]];
|
||||
sStride = smemStrides[order[1]];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
@@ -3636,7 +3625,6 @@ private:
|
||||
int cMatShape;
|
||||
int sMatShape;
|
||||
|
||||
Value cStride;
|
||||
Value sStride;
|
||||
|
||||
bool needTrans;
|
||||
@@ -3651,13 +3639,6 @@ private:
|
||||
int warpOffStride;
|
||||
};
|
||||
|
||||
bool isSplatLike(Value value) {
|
||||
if (auto constv = dyn_cast<arith::ConstantOp>(value.getDefiningOp()))
|
||||
if (auto attr = constv.getValue().dyn_cast<SplatElementsAttr>())
|
||||
return attr.isSplat();
|
||||
return false;
|
||||
}
|
||||
|
||||
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
enum class TensorCoreType : uint8_t {
|
||||
// floating-point tensor core instr
|
||||
@@ -3790,7 +3771,6 @@ struct DotOpMmaV1ConversionHelper {
|
||||
int getRepN(int N) const {
|
||||
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
|
||||
}
|
||||
int getRepK(int K) const { return std::max<int>(K / instrShape[2], 1); }
|
||||
|
||||
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
|
||||
|
||||
@@ -3857,9 +3837,6 @@ struct DotOpMmaV1ConversionHelper {
|
||||
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $c to registers, returns a LLVM::Struct.
|
||||
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
||||
|
||||
// Compute the offset of the matrix to load.
|
||||
@@ -3900,13 +3877,6 @@ struct DotOpMmaV2ConversionHelper {
|
||||
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
||||
}
|
||||
|
||||
// Get the M and N of mat instruction shape.
|
||||
static std::tuple<int, int> getMatShapeMN() {
|
||||
// According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix
|
||||
// shape's M,N are {8,8}
|
||||
return {8, 8};
|
||||
}
|
||||
|
||||
// Get the M and N of mma instruction shape.
|
||||
static std::tuple<int, int> getInstrShapeMN() {
|
||||
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are
|
||||
@@ -4561,7 +4531,7 @@ struct DotOpFMAConversionHelper {
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const;
|
||||
|
||||
Value getStructFromValueTable(ValueTable vals,
|
||||
Value getStructFromValueTable(const ValueTable &vals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
SmallVector<Type> elemTypes(vals.size(), f32_ty);
|
||||
@@ -4838,7 +4808,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto mma = builder.create("mma.sync.aligned.m8n8k4")
|
||||
->o(isARow ? "row" : "col")
|
||||
.o(isBRow ? "row" : "col")
|
||||
.o(".f32.f16.f16.f32");
|
||||
.o("f32.f16.f16.f32");
|
||||
|
||||
mma(resOprs, AOprs, BOprs, COprs);
|
||||
|
||||
@@ -5095,11 +5065,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
return res;
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadC(
|
||||
Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const {
|
||||
return llTensor;
|
||||
}
|
||||
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
|
||||
bool isBRow, ArrayRef<int> fpw,
|
||||
@@ -5847,11 +5812,10 @@ struct InsertSliceAsyncOpConversion
|
||||
};
|
||||
|
||||
struct ExtElemwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<
|
||||
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp,
|
||||
ExtElemwiseOpConversion>;
|
||||
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
|
||||
ExtElemwiseOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
|
||||
ExtElemwiseOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
@@ -5895,10 +5859,9 @@ private:
|
||||
};
|
||||
|
||||
struct FDivOpConversion
|
||||
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp,
|
||||
FDivOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp,
|
||||
LLVM::InlineAsmOp, FDivOpConversion>;
|
||||
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
@@ -5911,30 +5874,26 @@ struct FDivOpConversion
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
if (32 == bitwidth) {
|
||||
fdiv.o("full").o("f32");
|
||||
auto res = ptxBuilder.newOperand("=r");
|
||||
auto lhs = ptxBuilder.newOperand(operands[0], "r");
|
||||
auto rhs = ptxBuilder.newOperand(operands[1], "r");
|
||||
fdiv(res, lhs, rhs);
|
||||
} else if (64 == bitwidth) {
|
||||
fdiv.o("rn").o("f64");
|
||||
auto res = ptxBuilder.newOperand("=l");
|
||||
auto lhs = ptxBuilder.newOperand(operands[0], "l");
|
||||
auto rhs = ptxBuilder.newOperand(operands[1], "l");
|
||||
fdiv(res, lhs, rhs);
|
||||
} else {
|
||||
assert(0 && bitwidth && "not supported");
|
||||
}
|
||||
|
||||
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
|
||||
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
|
||||
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
|
||||
fdiv(res, lhs, rhs);
|
||||
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpOpConversionApprox
|
||||
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||
ExpOpConversionApprox> {
|
||||
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||
ExpOpConversionApprox>;
|
||||
: ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
|
Reference in New Issue
Block a user