[Triton-MLIR][BACKEND] tiny code cleanup (#899)

- Remove the unnecessary `static` in the anonymous namespace
- Remove several unnecessary functions
- Several simple rewrites to make code more clear
This commit is contained in:
Yan Chunwei
2022-11-21 16:00:46 +08:00
committed by GitHub
parent 04b852e031
commit 136668bac3

View File

@@ -51,8 +51,7 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace { namespace {
// Create a 32-bit integer constant. // Create a 32-bit integer constant.
static Value createConstantI32(Location loc, PatternRewriter &rewriter, Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
int32_t v) {
auto i32ty = rewriter.getIntegerType(32); auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty, return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v)); IntegerAttr::get(i32ty, v));
@@ -71,16 +70,16 @@ Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
} }
// Create an index type constant. // Create an index type constant.
static Value createIndexConstant(OpBuilder &builder, Location loc, Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) { TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType()); Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty, return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value)); builder.getIntegerAttr(ty, value));
} }
// Create an integer constant of \param width bits. // Create an integer constant of \param width bits.
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
short width, int64_t value) { int64_t value) {
Type ty = builder.getIntegerType(width); Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty, return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value)); builder.getIntegerAttr(ty, value));
@@ -187,9 +186,8 @@ template <typename T> void printScalar(const T &e, const std::string &info) {
/// Only retain those attributes that are not constructed by /// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes. /// attributes.
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, bool filterArgAttrs,
bool filterArgAttrs, SmallVectorImpl<NamedAttribute> &result) {
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) { for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() || if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() ||
@@ -202,7 +200,7 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
} }
/// Helper function for wrapping all attributes into a single DictionaryAttr /// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get( return DictionaryAttr::get(
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs)); b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
} }
@@ -359,7 +357,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
// delinearize supposing order is [0, 1, .. , n] // delinearize supposing order is [0, 1, .. , n]
template <typename T> template <typename T>
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) { SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size(); size_t rank = shape.size();
T accMul = product(shape.drop_back()); T accMul = product(shape.drop_back());
@@ -376,8 +374,8 @@ static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
} }
template <typename T> template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape, SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) { ArrayRef<unsigned> order) {
size_t rank = shape.size(); size_t rank = shape.size();
assert(rank == order.size()); assert(rank == order.size());
auto reordered = reorder(shape, order); auto reordered = reorder(shape, order);
@@ -391,7 +389,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
// linearize supposing order is [0, 1, .. , n] // linearize supposing order is [0, 1, .. , n]
template <typename T> template <typename T>
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) { T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size()); assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size(); size_t rank = shape.size();
@@ -407,15 +405,15 @@ static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
} }
template <typename T> template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape, T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) { ArrayRef<unsigned> order) {
assert(shape.size() == order.size()); assert(shape.size() == order.size());
return getLinearIndexImpl<T>(reorder(multiDimIndex, order), return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
reorder(shape, order)); reorder(shape, order));
} }
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value ptr, Value val, Value pred) { Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext(); MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth(); unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
@@ -517,10 +515,9 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
auto rank = (elems.size() - 1) / 2; auto rank = (elems.size() - 1) / 2;
return SharedMemoryObject( return {/*base=*/elems[0],
/*base=*/elems[0], /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, /*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
/*offsets=*/{elems.begin() + 1 + rank, elems.end()});
} }
static Value static Value
@@ -1018,13 +1015,13 @@ struct ArithConstantSplatOpConversion
// Contains some helper functions for both Load and Store conversions. // Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: AxisAnalysisPass(axisAnalysisPass) {} : AxisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value. // Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue, static SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Location loc) const { Location loc) {
if (!value) if (!value)
return {}; return {};
if (!llValue.getType().isa<LLVM::LLVMStructType>()) if (!llValue.getType().isa<LLVM::LLVMStructType>())
@@ -1601,7 +1598,7 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
default: default:
llvm::report_fatal_error("Unsupported reduce op"); llvm::report_fatal_error("Unsupported reduce op");
} }
}; }
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
Location loc, Value val, int i) const { Location loc, Value val, int i) const {
@@ -1946,7 +1943,7 @@ struct PrintfOpConversion
std::string formatStr; std::string formatStr;
llvm::raw_string_ostream os(formatStr); llvm::raw_string_ostream os(formatStr);
os << op.prefix(); os << op.prefix();
if (operands.size() > 0) { if (!operands.empty()) {
os << getFormatSubstr(operands[0]); os << getFormatSubstr(operands[0]);
} }
@@ -2130,7 +2127,7 @@ struct MakeRangeOpConversion
auto idxs = emitIndices(loc, rewriter, layout, shape); auto idxs = emitIndices(loc, rewriter, layout, shape);
unsigned elems = idxs.size(); unsigned elems = idxs.size();
SmallVector<Value> retVals(elems); SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) { for (const auto &multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1); assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start); retVals[multiDim.index()] = add(multiDim.value()[0], start);
} }
@@ -2633,7 +2630,7 @@ struct FpToFpOpConversion
}; };
// A CRTP style of base class. // A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT> template <typename SourceOp, typename ConcreteT>
class ElementwiseOpConversionBase class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> { : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public: public:
@@ -2688,16 +2685,16 @@ protected:
template <typename SourceOp, typename DestOp> template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion struct ElementwiseOpConversion
: public ElementwiseOpConversionBase< : public ElementwiseOpConversionBase<
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> { SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base = using Base =
ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>; ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base; using Base::Base;
using OpAdaptor = typename Base::OpAdaptor; using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter, explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1) PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>( : ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {} typeConverter, benefit) {}
// An interface to support variant DestOp builder. // An interface to support variant DestOp builder.
@@ -2714,10 +2711,10 @@ struct ElementwiseOpConversion
// //
struct CmpIOpConversion struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp, : public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> { CmpIOpConversion> {
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp, using Base =
CmpIOpConversion>; ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
using Base::Base; using Base::Base;
using Adaptor = typename Base::OpAdaptor; using Adaptor = typename Base::OpAdaptor;
@@ -2755,17 +2752,18 @@ struct CmpIOpConversion
}; };
struct CmpFOpConversion struct CmpFOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp, : public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
CmpFOpConversion> { CmpFOpConversion> {
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp, using Base =
CmpFOpConversion>; ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
using Base::Base; using Base::Base;
using Adaptor = typename Base::OpAdaptor; using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder. // An interface to support variant DestOp builder.
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor, static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy, ConversionPatternRewriter &rewriter,
ValueRange operands, Location loc) const { Type elemTy, ValueRange operands,
Location loc) {
return rewriter.create<LLVM::FCmpOp>( return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0], loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
operands[1]); operands[1]);
@@ -2945,13 +2943,6 @@ private:
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const; const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
// shared -> dot_operand if the result layout is blocked
Value lowerSharedToDotOperandBlocked(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
}; };
void ConvertLayoutOpConversion::processReplica( void ConvertLayoutOpConversion::processReplica(
@@ -2960,7 +2951,7 @@ void ConvertLayoutOpConversion::processReplica(
ArrayRef<unsigned> multiDimRepId, unsigned vec, ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd, ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase) const { SmallVector<Value> &vals, Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep); auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding(); auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>(); auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>(); auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
@@ -2989,13 +2980,12 @@ void ConvertLayoutOpConversion::processReplica(
auto multiDimCTAInRepId = auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order); getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank); SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) { for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index(); auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
} }
unsigned linearCTAId = auto linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
// TODO: This is actually redundant index calculation, we should // TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case // consider of caching the index calculation result in case
// of performance issue observed. // of performance issue observed.
@@ -3073,7 +3063,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA); outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
} }
// Potentially we need to store for multiple CTAs in this replication // Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates); auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy); // unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0; unsigned inVec = 0;
@@ -3118,7 +3108,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
return success(); return success();
}; }
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
@@ -3144,7 +3134,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcTy); unsigned numElems = getElemsPerThread(srcTy);
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter); auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned srcAccumSizeInThreads = auto srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread()); product<unsigned>(srcBlockedLayout.getSizePerThread());
auto elemTy = srcTy.getElementType(); auto elemTy = srcTy.getElementType();
auto wordTy = vec_ty(elemTy, minVec); auto wordTy = vec_ty(elemTy, minVec);
@@ -3177,7 +3167,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
smemBase = bitcast(smemBase, elemPtrTy); smemBase = bitcast(smemBase, elemPtrTy);
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep); auto numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep); SmallVector<Value> wordVecs(numWordsEachRep);
// TODO: We should get less barriers if it is handled by membar pass // TODO: We should get less barriers if it is handled by membar pass
// instead of the backend, since the later can only handle it in // instead of the backend, since the later can only handle it in
@@ -3196,7 +3186,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec; multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned wordVecIdx = auto wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd); getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] = wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos)); insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
@@ -3265,7 +3255,6 @@ public:
cMatShape = matShape[order[0]]; cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]]; sMatShape = matShape[order[1]];
cStride = smemStrides[order[0]];
sStride = smemStrides[order[1]]; sStride = smemStrides[order[1]];
// rule: k must be the fast-changing axis. // rule: k must be the fast-changing axis.
@@ -3636,7 +3625,6 @@ private:
int cMatShape; int cMatShape;
int sMatShape; int sMatShape;
Value cStride;
Value sStride; Value sStride;
bool needTrans; bool needTrans;
@@ -3651,13 +3639,6 @@ private:
int warpOffStride; int warpOffStride;
}; };
bool isSplatLike(Value value) {
if (auto constv = dyn_cast<arith::ConstantOp>(value.getDefiningOp()))
if (auto attr = constv.getValue().dyn_cast<SplatElementsAttr>())
return attr.isSplat();
return false;
}
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> { struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
enum class TensorCoreType : uint8_t { enum class TensorCoreType : uint8_t {
// floating-point tensor core instr // floating-point tensor core instr
@@ -3790,7 +3771,6 @@ struct DotOpMmaV1ConversionHelper {
int getRepN(int N) const { int getRepN(int N) const {
return std::max<int>(N / (wpt[1] * instrShape[1]), 1); return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
} }
int getRepK(int K) const { return std::max<int>(K / instrShape[2], 1); }
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; } static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
@@ -3857,9 +3837,6 @@ struct DotOpMmaV1ConversionHelper {
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const; Location loc, ConversionPatternRewriter &rewriter) const;
// Loading $c to registers, returns a LLVM::Struct.
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
static ArrayRef<unsigned> getOrder() { return mmaOrder; } static ArrayRef<unsigned> getOrder() { return mmaOrder; }
// Compute the offset of the matrix to load. // Compute the offset of the matrix to load.
@@ -3900,13 +3877,6 @@ struct DotOpMmaV2ConversionHelper {
mmaType = getTensorCoreTypeFromOperand(operandTy); mmaType = getTensorCoreTypeFromOperand(operandTy);
} }
// Get the M and N of mat instruction shape.
static std::tuple<int, int> getMatShapeMN() {
// According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix
// shape's M,N are {8,8}
return {8, 8};
}
// Get the M and N of mma instruction shape. // Get the M and N of mma instruction shape.
static std::tuple<int, int> getInstrShapeMN() { static std::tuple<int, int> getInstrShapeMN() {
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are // According to DotOpConversionHelper::mmaInstrShape, all the M,N are
@@ -4561,7 +4531,7 @@ struct DotOpFMAConversionHelper {
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Location loc) const; Location loc) const;
Value getStructFromValueTable(ValueTable vals, Value getStructFromValueTable(const ValueTable &vals,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Location loc) const { Location loc) const {
SmallVector<Type> elemTypes(vals.size(), f32_ty); SmallVector<Type> elemTypes(vals.size(), f32_ty);
@@ -4838,7 +4808,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto mma = builder.create("mma.sync.aligned.m8n8k4") auto mma = builder.create("mma.sync.aligned.m8n8k4")
->o(isARow ? "row" : "col") ->o(isARow ? "row" : "col")
.o(isBRow ? "row" : "col") .o(isBRow ? "row" : "col")
.o(".f32.f16.f16.f32"); .o("f32.f16.f16.f32");
mma(resOprs, AOprs, BOprs, COprs); mma(resOprs, AOprs, BOprs, COprs);
@@ -5095,11 +5065,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
return res; return res;
} }
Value DotOpMmaV1ConversionHelper::loadC(
Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const {
return llTensor;
}
std::tuple<Value, Value, Value, Value> std::tuple<Value, Value, Value, Value>
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow, DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
bool isBRow, ArrayRef<int> fpw, bool isBRow, ArrayRef<int> fpw,
@@ -5847,11 +5812,10 @@ struct InsertSliceAsyncOpConversion
}; };
struct ExtElemwiseOpConversion struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase< : public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> { ExtElemwiseOpConversion> {
using Base = using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion>;
ExtElemwiseOpConversion>;
using Base::Base; using Base::Base;
using Adaptor = typename Base::OpAdaptor; using Adaptor = typename Base::OpAdaptor;
@@ -5895,10 +5859,9 @@ private:
}; };
struct FDivOpConversion struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp, : ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
FDivOpConversion> { using Base =
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp, ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
LLVM::InlineAsmOp, FDivOpConversion>;
using Base::Base; using Base::Base;
using Adaptor = typename Base::OpAdaptor; using Adaptor = typename Base::OpAdaptor;
@@ -5911,30 +5874,26 @@ struct FDivOpConversion
unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
if (32 == bitwidth) { if (32 == bitwidth) {
fdiv.o("full").o("f32"); fdiv.o("full").o("f32");
auto res = ptxBuilder.newOperand("=r");
auto lhs = ptxBuilder.newOperand(operands[0], "r");
auto rhs = ptxBuilder.newOperand(operands[1], "r");
fdiv(res, lhs, rhs);
} else if (64 == bitwidth) { } else if (64 == bitwidth) {
fdiv.o("rn").o("f64"); fdiv.o("rn").o("f64");
auto res = ptxBuilder.newOperand("=l");
auto lhs = ptxBuilder.newOperand(operands[0], "l");
auto rhs = ptxBuilder.newOperand(operands[1], "l");
fdiv(res, lhs, rhs);
} else { } else {
assert(0 && bitwidth && "not supported"); assert(0 && bitwidth && "not supported");
} }
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
fdiv(res, lhs, rhs);
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false); Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
return ret; return ret;
} }
}; };
struct ExpOpConversionApprox struct ExpOpConversionApprox
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp, : ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
ExpOpConversionApprox> { using Base =
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp, ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
ExpOpConversionApprox>;
using Base::Base; using Base::Base;
using Adaptor = typename Base::OpAdaptor; using Adaptor = typename Base::OpAdaptor;