[Triton-MLIR][BACKEND] tiny code cleanup (#899)

- Remove the unnecessary `static` in the anonymous namespace
- Remove several unnecessary functions
- Several simple rewrites to make code more clear
This commit is contained in:
Yan Chunwei
2022-11-21 16:00:46 +08:00
committed by GitHub
parent 04b852e031
commit 136668bac3

View File

@@ -51,8 +51,7 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace {
// Create a 32-bit integer constant.
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
int32_t v) {
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
@@ -71,16 +70,16 @@ Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
}
// Create an index type constant.
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
short width, int64_t value) {
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
@@ -187,9 +186,8 @@ template <typename T> void printScalar(const T &e, const std::string &info) {
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
@@ -202,7 +200,7 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
}
/// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get(
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
}
@@ -359,7 +357,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
// delinearize supposing order is [0, 1, .. , n]
template <typename T>
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
@@ -376,8 +374,8 @@ static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
}
template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
size_t rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
@@ -391,7 +389,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
// linearize supposing order is [0, 1, .. , n]
template <typename T>
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
@@ -407,15 +405,15 @@ static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
}
template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
reorder(shape, order));
}
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) {
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
@@ -517,10 +515,9 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
auto rank = (elems.size() - 1) / 2;
return SharedMemoryObject(
/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()});
return {/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
}
static Value
@@ -1018,13 +1015,13 @@ struct ArithConstantSplatOpConversion
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: AxisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter,
Location loc) const {
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter,
Location loc) {
if (!value)
return {};
if (!llValue.getType().isa<LLVM::LLVMStructType>())
@@ -1601,7 +1598,7 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
default:
llvm::report_fatal_error("Unsupported reduce op");
}
};
}
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
Location loc, Value val, int i) const {
@@ -1946,7 +1943,7 @@ struct PrintfOpConversion
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.prefix();
if (operands.size() > 0) {
if (!operands.empty()) {
os << getFormatSubstr(operands[0]);
}
@@ -2130,7 +2127,7 @@ struct MakeRangeOpConversion
auto idxs = emitIndices(loc, rewriter, layout, shape);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) {
for (const auto &multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start);
}
@@ -2633,7 +2630,7 @@ struct FpToFpOpConversion
};
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
template <typename SourceOp, typename ConcreteT>
class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
@@ -2688,16 +2685,16 @@ protected:
template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> {
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp, DestOp,
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>(
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
// An interface to support variant DestOp builder.
@@ -2714,10 +2711,10 @@ struct ElementwiseOpConversion
//
struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> {
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion>;
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
@@ -2755,17 +2752,18 @@ struct CmpIOpConversion
};
struct CmpFOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
CmpFOpConversion> {
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
CmpFOpConversion>;
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, ValueRange operands,
Location loc) {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
operands[1]);
@@ -2945,13 +2943,6 @@ private:
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
// shared -> dot_operand if the result layout is blocked
Value lowerSharedToDotOperandBlocked(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
};
void ConvertLayoutOpConversion::processReplica(
@@ -2960,7 +2951,7 @@ void ConvertLayoutOpConversion::processReplica(
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
@@ -2989,13 +2980,12 @@ void ConvertLayoutOpConversion::processReplica(
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
unsigned linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
auto linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
@@ -3073,7 +3063,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
@@ -3118,7 +3108,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
rewriter.replaceOp(op, result);
return success();
};
}
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
@@ -3144,7 +3134,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcTy);
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned srcAccumSizeInThreads =
auto srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
auto elemTy = srcTy.getElementType();
auto wordTy = vec_ty(elemTy, minVec);
@@ -3177,7 +3167,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
smemBase = bitcast(smemBase, elemPtrTy);
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep);
// TODO: We should get less barriers if it is handled by membar pass
// instead of the backend, since the later can only handle it in
@@ -3196,7 +3186,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned wordVecIdx =
auto wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
@@ -3265,7 +3255,6 @@ public:
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
cStride = smemStrides[order[0]];
sStride = smemStrides[order[1]];
// rule: k must be the fast-changing axis.
@@ -3636,7 +3625,6 @@ private:
int cMatShape;
int sMatShape;
Value cStride;
Value sStride;
bool needTrans;
@@ -3651,13 +3639,6 @@ private:
int warpOffStride;
};
bool isSplatLike(Value value) {
if (auto constv = dyn_cast<arith::ConstantOp>(value.getDefiningOp()))
if (auto attr = constv.getValue().dyn_cast<SplatElementsAttr>())
return attr.isSplat();
return false;
}
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
enum class TensorCoreType : uint8_t {
// floating-point tensor core instr
@@ -3790,7 +3771,6 @@ struct DotOpMmaV1ConversionHelper {
int getRepN(int N) const {
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
}
int getRepK(int K) const { return std::max<int>(K / instrShape[2], 1); }
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
@@ -3857,9 +3837,6 @@ struct DotOpMmaV1ConversionHelper {
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
// Loading $c to registers, returns a LLVM::Struct.
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
// Compute the offset of the matrix to load.
@@ -3900,13 +3877,6 @@ struct DotOpMmaV2ConversionHelper {
mmaType = getTensorCoreTypeFromOperand(operandTy);
}
// Get the M and N of mat instruction shape.
static std::tuple<int, int> getMatShapeMN() {
// According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix
// shape's M,N are {8,8}
return {8, 8};
}
// Get the M and N of mma instruction shape.
static std::tuple<int, int> getInstrShapeMN() {
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are
@@ -4561,7 +4531,7 @@ struct DotOpFMAConversionHelper {
ConversionPatternRewriter &rewriter,
Location loc) const;
Value getStructFromValueTable(ValueTable vals,
Value getStructFromValueTable(const ValueTable &vals,
ConversionPatternRewriter &rewriter,
Location loc) const {
SmallVector<Type> elemTypes(vals.size(), f32_ty);
@@ -4838,7 +4808,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto mma = builder.create("mma.sync.aligned.m8n8k4")
->o(isARow ? "row" : "col")
.o(isBRow ? "row" : "col")
.o(".f32.f16.f16.f32");
.o("f32.f16.f16.f32");
mma(resOprs, AOprs, BOprs, COprs);
@@ -5095,11 +5065,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
return res;
}
Value DotOpMmaV1ConversionHelper::loadC(
Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const {
return llTensor;
}
std::tuple<Value, Value, Value, Value>
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
bool isBRow, ArrayRef<int> fpw,
@@ -5847,11 +5812,10 @@ struct InsertSliceAsyncOpConversion
};
struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase<
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp,
ExtElemwiseOpConversion>;
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion> {
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
@@ -5895,10 +5859,9 @@ private:
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp,
FDivOpConversion> {
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp,
LLVM::InlineAsmOp, FDivOpConversion>;
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
@@ -5911,30 +5874,26 @@ struct FDivOpConversion
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
if (32 == bitwidth) {
fdiv.o("full").o("f32");
auto res = ptxBuilder.newOperand("=r");
auto lhs = ptxBuilder.newOperand(operands[0], "r");
auto rhs = ptxBuilder.newOperand(operands[1], "r");
fdiv(res, lhs, rhs);
} else if (64 == bitwidth) {
fdiv.o("rn").o("f64");
auto res = ptxBuilder.newOperand("=l");
auto lhs = ptxBuilder.newOperand(operands[0], "l");
auto rhs = ptxBuilder.newOperand(operands[1], "l");
fdiv(res, lhs, rhs);
} else {
assert(0 && bitwidth && "not supported");
}
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
fdiv(res, lhs, rhs);
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
return ret;
}
};
struct ExpOpConversionApprox
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
ExpOpConversionApprox> {
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
ExpOpConversionApprox>;
: ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
using Base =
ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;