[Triton-MLIR][BACKEND] tiny code cleanup (#899)
- Remove the unnecessary `static` in the anonymous namespace - Remove several unnecessary functions - Several simple rewrites to make code more clear
This commit is contained in:
@@ -51,8 +51,7 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Create a 32-bit integer constant.
|
// Create a 32-bit integer constant.
|
||||||
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
|
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||||
int32_t v) {
|
|
||||||
auto i32ty = rewriter.getIntegerType(32);
|
auto i32ty = rewriter.getIntegerType(32);
|
||||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||||
IntegerAttr::get(i32ty, v));
|
IntegerAttr::get(i32ty, v));
|
||||||
@@ -71,16 +70,16 @@ Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create an index type constant.
|
// Create an index type constant.
|
||||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||||
TypeConverter *converter, int64_t value) {
|
TypeConverter *converter, int64_t value) {
|
||||||
Type ty = converter->convertType(builder.getIndexType());
|
Type ty = converter->convertType(builder.getIndexType());
|
||||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||||
builder.getIntegerAttr(ty, value));
|
builder.getIntegerAttr(ty, value));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an integer constant of \param width bits.
|
// Create an integer constant of \param width bits.
|
||||||
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||||
short width, int64_t value) {
|
int64_t value) {
|
||||||
Type ty = builder.getIntegerType(width);
|
Type ty = builder.getIntegerType(width);
|
||||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||||
builder.getIntegerAttr(ty, value));
|
builder.getIntegerAttr(ty, value));
|
||||||
@@ -187,9 +186,8 @@ template <typename T> void printScalar(const T &e, const std::string &info) {
|
|||||||
/// Only retain those attributes that are not constructed by
|
/// Only retain those attributes that are not constructed by
|
||||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||||
/// attributes.
|
/// attributes.
|
||||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, bool filterArgAttrs,
|
||||||
bool filterArgAttrs,
|
SmallVectorImpl<NamedAttribute> &result) {
|
||||||
SmallVectorImpl<NamedAttribute> &result) {
|
|
||||||
for (const auto &attr : attrs) {
|
for (const auto &attr : attrs) {
|
||||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||||
@@ -202,7 +200,7 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Helper function for wrapping all attributes into a single DictionaryAttr
|
/// Helper function for wrapping all attributes into a single DictionaryAttr
|
||||||
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
||||||
return DictionaryAttr::get(
|
return DictionaryAttr::get(
|
||||||
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
|
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
|
||||||
}
|
}
|
||||||
@@ -359,7 +357,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
|||||||
|
|
||||||
// delinearize supposing order is [0, 1, .. , n]
|
// delinearize supposing order is [0, 1, .. , n]
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
T accMul = product(shape.drop_back());
|
T accMul = product(shape.drop_back());
|
||||||
@@ -376,8 +374,8 @@ static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
||||||
ArrayRef<unsigned> order) {
|
ArrayRef<unsigned> order) {
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
assert(rank == order.size());
|
assert(rank == order.size());
|
||||||
auto reordered = reorder(shape, order);
|
auto reordered = reorder(shape, order);
|
||||||
@@ -391,7 +389,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
|||||||
|
|
||||||
// linearize supposing order is [0, 1, .. , n]
|
// linearize supposing order is [0, 1, .. , n]
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||||
assert(multiDimIndex.size() == shape.size());
|
assert(multiDimIndex.size() == shape.size());
|
||||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
@@ -407,15 +405,15 @@ static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
||||||
ArrayRef<unsigned> order) {
|
ArrayRef<unsigned> order) {
|
||||||
assert(shape.size() == order.size());
|
assert(shape.size() == order.size());
|
||||||
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
||||||
reorder(shape, order));
|
reorder(shape, order));
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||||
Value ptr, Value val, Value pred) {
|
Value val, Value pred) {
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||||
@@ -517,10 +515,9 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||||
auto rank = (elems.size() - 1) / 2;
|
auto rank = (elems.size() - 1) / 2;
|
||||||
return SharedMemoryObject(
|
return {/*base=*/elems[0],
|
||||||
/*base=*/elems[0],
|
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
||||||
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
@@ -1018,13 +1015,13 @@ struct ArithConstantSplatOpConversion
|
|||||||
|
|
||||||
// Contains some helper functions for both Load and Store conversions.
|
// Contains some helper functions for both Load and Store conversions.
|
||||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||||
: AxisAnalysisPass(axisAnalysisPass) {}
|
: AxisAnalysisPass(axisAnalysisPass) {}
|
||||||
|
|
||||||
// Get corresponding LLVM element values of \param value.
|
// Get corresponding LLVM element values of \param value.
|
||||||
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Location loc) const {
|
Location loc) {
|
||||||
if (!value)
|
if (!value)
|
||||||
return {};
|
return {};
|
||||||
if (!llValue.getType().isa<LLVM::LLVMStructType>())
|
if (!llValue.getType().isa<LLVM::LLVMStructType>())
|
||||||
@@ -1601,7 +1598,7 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
|||||||
default:
|
default:
|
||||||
llvm::report_fatal_error("Unsupported reduce op");
|
llvm::report_fatal_error("Unsupported reduce op");
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value val, int i) const {
|
Location loc, Value val, int i) const {
|
||||||
@@ -1946,7 +1943,7 @@ struct PrintfOpConversion
|
|||||||
std::string formatStr;
|
std::string formatStr;
|
||||||
llvm::raw_string_ostream os(formatStr);
|
llvm::raw_string_ostream os(formatStr);
|
||||||
os << op.prefix();
|
os << op.prefix();
|
||||||
if (operands.size() > 0) {
|
if (!operands.empty()) {
|
||||||
os << getFormatSubstr(operands[0]);
|
os << getFormatSubstr(operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2130,7 +2127,7 @@ struct MakeRangeOpConversion
|
|||||||
auto idxs = emitIndices(loc, rewriter, layout, shape);
|
auto idxs = emitIndices(loc, rewriter, layout, shape);
|
||||||
unsigned elems = idxs.size();
|
unsigned elems = idxs.size();
|
||||||
SmallVector<Value> retVals(elems);
|
SmallVector<Value> retVals(elems);
|
||||||
for (auto multiDim : llvm::enumerate(idxs)) {
|
for (const auto &multiDim : llvm::enumerate(idxs)) {
|
||||||
assert(multiDim.value().size() == 1);
|
assert(multiDim.value().size() == 1);
|
||||||
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
||||||
}
|
}
|
||||||
@@ -2633,7 +2630,7 @@ struct FpToFpOpConversion
|
|||||||
};
|
};
|
||||||
|
|
||||||
// A CRTP style of base class.
|
// A CRTP style of base class.
|
||||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
template <typename SourceOp, typename ConcreteT>
|
||||||
class ElementwiseOpConversionBase
|
class ElementwiseOpConversionBase
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||||
public:
|
public:
|
||||||
@@ -2688,16 +2685,16 @@ protected:
|
|||||||
template <typename SourceOp, typename DestOp>
|
template <typename SourceOp, typename DestOp>
|
||||||
struct ElementwiseOpConversion
|
struct ElementwiseOpConversion
|
||||||
: public ElementwiseOpConversionBase<
|
: public ElementwiseOpConversionBase<
|
||||||
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
||||||
using Base =
|
using Base =
|
||||||
ElementwiseOpConversionBase<SourceOp, DestOp,
|
ElementwiseOpConversionBase<SourceOp,
|
||||||
ElementwiseOpConversion<SourceOp, DestOp>>;
|
ElementwiseOpConversion<SourceOp, DestOp>>;
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
using OpAdaptor = typename Base::OpAdaptor;
|
using OpAdaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
|
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
|
||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>(
|
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
|
||||||
typeConverter, benefit) {}
|
typeConverter, benefit) {}
|
||||||
|
|
||||||
// An interface to support variant DestOp builder.
|
// An interface to support variant DestOp builder.
|
||||||
@@ -2714,10 +2711,10 @@ struct ElementwiseOpConversion
|
|||||||
//
|
//
|
||||||
|
|
||||||
struct CmpIOpConversion
|
struct CmpIOpConversion
|
||||||
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
|
||||||
CmpIOpConversion> {
|
CmpIOpConversion> {
|
||||||
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
using Base =
|
||||||
CmpIOpConversion>;
|
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
using Adaptor = typename Base::OpAdaptor;
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
@@ -2755,17 +2752,18 @@ struct CmpIOpConversion
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CmpFOpConversion
|
struct CmpFOpConversion
|
||||||
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
|
||||||
CmpFOpConversion> {
|
CmpFOpConversion> {
|
||||||
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
using Base =
|
||||||
CmpFOpConversion>;
|
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
using Adaptor = typename Base::OpAdaptor;
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
// An interface to support variant DestOp builder.
|
// An interface to support variant DestOp builder.
|
||||||
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
ConversionPatternRewriter &rewriter,
|
||||||
ValueRange operands, Location loc) const {
|
Type elemTy, ValueRange operands,
|
||||||
|
Location loc) {
|
||||||
return rewriter.create<LLVM::FCmpOp>(
|
return rewriter.create<LLVM::FCmpOp>(
|
||||||
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
|
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
|
||||||
operands[1]);
|
operands[1]);
|
||||||
@@ -2945,13 +2943,6 @@ private:
|
|||||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
||||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
||||||
|
|
||||||
// shared -> dot_operand if the result layout is blocked
|
|
||||||
Value lowerSharedToDotOperandBlocked(
|
|
||||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
|
||||||
const BlockedEncodingAttr &blockedLayout,
|
|
||||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void ConvertLayoutOpConversion::processReplica(
|
void ConvertLayoutOpConversion::processReplica(
|
||||||
@@ -2960,7 +2951,7 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||||
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
|
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
|
||||||
SmallVector<Value> &vals, Value smemBase) const {
|
SmallVector<Value> &vals, Value smemBase) const {
|
||||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||||
auto layout = type.getEncoding();
|
auto layout = type.getEncoding();
|
||||||
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||||
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
||||||
@@ -2989,13 +2980,12 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
auto multiDimCTAInRepId =
|
auto multiDimCTAInRepId =
|
||||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
||||||
SmallVector<unsigned> multiDimCTAId(rank);
|
SmallVector<unsigned> multiDimCTAId(rank);
|
||||||
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||||
auto d = it.index();
|
auto d = it.index();
|
||||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned linearCTAId =
|
auto linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
||||||
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
|
||||||
// TODO: This is actually redundant index calculation, we should
|
// TODO: This is actually redundant index calculation, we should
|
||||||
// consider of caching the index calculation result in case
|
// consider of caching the index calculation result in case
|
||||||
// of performance issue observed.
|
// of performance issue observed.
|
||||||
@@ -3073,7 +3063,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
||||||
}
|
}
|
||||||
// Potentially we need to store for multiple CTAs in this replication
|
// Potentially we need to store for multiple CTAs in this replication
|
||||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||||
// unsigned elems = getElemsPerThread(srcTy);
|
// unsigned elems = getElemsPerThread(srcTy);
|
||||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||||
unsigned inVec = 0;
|
unsigned inVec = 0;
|
||||||
@@ -3118,7 +3108,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
};
|
}
|
||||||
|
|
||||||
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
@@ -3144,7 +3134,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||||
unsigned numElems = getElemsPerThread(srcTy);
|
unsigned numElems = getElemsPerThread(srcTy);
|
||||||
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||||
unsigned srcAccumSizeInThreads =
|
auto srcAccumSizeInThreads =
|
||||||
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
||||||
auto elemTy = srcTy.getElementType();
|
auto elemTy = srcTy.getElementType();
|
||||||
auto wordTy = vec_ty(elemTy, minVec);
|
auto wordTy = vec_ty(elemTy, minVec);
|
||||||
@@ -3177,7 +3167,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
||||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||||
// TODO: We should get less barriers if it is handled by membar pass
|
// TODO: We should get less barriers if it is handled by membar pass
|
||||||
// instead of the backend, since the later can only handle it in
|
// instead of the backend, since the later can only handle it in
|
||||||
@@ -3196,7 +3186,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
||||||
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
||||||
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
||||||
unsigned wordVecIdx =
|
auto wordVecIdx =
|
||||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
||||||
wordVecs[wordVecIdx] =
|
wordVecs[wordVecIdx] =
|
||||||
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
|
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
|
||||||
@@ -3265,7 +3255,6 @@ public:
|
|||||||
cMatShape = matShape[order[0]];
|
cMatShape = matShape[order[0]];
|
||||||
sMatShape = matShape[order[1]];
|
sMatShape = matShape[order[1]];
|
||||||
|
|
||||||
cStride = smemStrides[order[0]];
|
|
||||||
sStride = smemStrides[order[1]];
|
sStride = smemStrides[order[1]];
|
||||||
|
|
||||||
// rule: k must be the fast-changing axis.
|
// rule: k must be the fast-changing axis.
|
||||||
@@ -3636,7 +3625,6 @@ private:
|
|||||||
int cMatShape;
|
int cMatShape;
|
||||||
int sMatShape;
|
int sMatShape;
|
||||||
|
|
||||||
Value cStride;
|
|
||||||
Value sStride;
|
Value sStride;
|
||||||
|
|
||||||
bool needTrans;
|
bool needTrans;
|
||||||
@@ -3651,13 +3639,6 @@ private:
|
|||||||
int warpOffStride;
|
int warpOffStride;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool isSplatLike(Value value) {
|
|
||||||
if (auto constv = dyn_cast<arith::ConstantOp>(value.getDefiningOp()))
|
|
||||||
if (auto attr = constv.getValue().dyn_cast<SplatElementsAttr>())
|
|
||||||
return attr.isSplat();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||||
enum class TensorCoreType : uint8_t {
|
enum class TensorCoreType : uint8_t {
|
||||||
// floating-point tensor core instr
|
// floating-point tensor core instr
|
||||||
@@ -3790,7 +3771,6 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
int getRepN(int N) const {
|
int getRepN(int N) const {
|
||||||
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
|
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
|
||||||
}
|
}
|
||||||
int getRepK(int K) const { return std::max<int>(K / instrShape[2], 1); }
|
|
||||||
|
|
||||||
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
|
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
|
||||||
|
|
||||||
@@ -3857,9 +3837,6 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
||||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
// Loading $c to registers, returns a LLVM::Struct.
|
|
||||||
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
|
||||||
|
|
||||||
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
||||||
|
|
||||||
// Compute the offset of the matrix to load.
|
// Compute the offset of the matrix to load.
|
||||||
@@ -3900,13 +3877,6 @@ struct DotOpMmaV2ConversionHelper {
|
|||||||
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the M and N of mat instruction shape.
|
|
||||||
static std::tuple<int, int> getMatShapeMN() {
|
|
||||||
// According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix
|
|
||||||
// shape's M,N are {8,8}
|
|
||||||
return {8, 8};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the M and N of mma instruction shape.
|
// Get the M and N of mma instruction shape.
|
||||||
static std::tuple<int, int> getInstrShapeMN() {
|
static std::tuple<int, int> getInstrShapeMN() {
|
||||||
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are
|
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are
|
||||||
@@ -4561,7 +4531,7 @@ struct DotOpFMAConversionHelper {
|
|||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Location loc) const;
|
Location loc) const;
|
||||||
|
|
||||||
Value getStructFromValueTable(ValueTable vals,
|
Value getStructFromValueTable(const ValueTable &vals,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Location loc) const {
|
Location loc) const {
|
||||||
SmallVector<Type> elemTypes(vals.size(), f32_ty);
|
SmallVector<Type> elemTypes(vals.size(), f32_ty);
|
||||||
@@ -4838,7 +4808,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
auto mma = builder.create("mma.sync.aligned.m8n8k4")
|
auto mma = builder.create("mma.sync.aligned.m8n8k4")
|
||||||
->o(isARow ? "row" : "col")
|
->o(isARow ? "row" : "col")
|
||||||
.o(isBRow ? "row" : "col")
|
.o(isBRow ? "row" : "col")
|
||||||
.o(".f32.f16.f16.f32");
|
.o("f32.f16.f16.f32");
|
||||||
|
|
||||||
mma(resOprs, AOprs, BOprs, COprs);
|
mma(resOprs, AOprs, BOprs, COprs);
|
||||||
|
|
||||||
@@ -5095,11 +5065,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value DotOpMmaV1ConversionHelper::loadC(
|
|
||||||
Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const {
|
|
||||||
return llTensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Value, Value, Value, Value>
|
std::tuple<Value, Value, Value, Value>
|
||||||
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
|
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
|
||||||
bool isBRow, ArrayRef<int> fpw,
|
bool isBRow, ArrayRef<int> fpw,
|
||||||
@@ -5847,11 +5812,10 @@ struct InsertSliceAsyncOpConversion
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ExtElemwiseOpConversion
|
struct ExtElemwiseOpConversion
|
||||||
: public ElementwiseOpConversionBase<
|
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
|
||||||
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> {
|
ExtElemwiseOpConversion> {
|
||||||
using Base =
|
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
|
||||||
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp,
|
ExtElemwiseOpConversion>;
|
||||||
ExtElemwiseOpConversion>;
|
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
using Adaptor = typename Base::OpAdaptor;
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
@@ -5895,10 +5859,9 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct FDivOpConversion
|
struct FDivOpConversion
|
||||||
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp,
|
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
|
||||||
FDivOpConversion> {
|
using Base =
|
||||||
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp,
|
ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
|
||||||
LLVM::InlineAsmOp, FDivOpConversion>;
|
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
using Adaptor = typename Base::OpAdaptor;
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
@@ -5911,30 +5874,26 @@ struct FDivOpConversion
|
|||||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||||
if (32 == bitwidth) {
|
if (32 == bitwidth) {
|
||||||
fdiv.o("full").o("f32");
|
fdiv.o("full").o("f32");
|
||||||
auto res = ptxBuilder.newOperand("=r");
|
|
||||||
auto lhs = ptxBuilder.newOperand(operands[0], "r");
|
|
||||||
auto rhs = ptxBuilder.newOperand(operands[1], "r");
|
|
||||||
fdiv(res, lhs, rhs);
|
|
||||||
} else if (64 == bitwidth) {
|
} else if (64 == bitwidth) {
|
||||||
fdiv.o("rn").o("f64");
|
fdiv.o("rn").o("f64");
|
||||||
auto res = ptxBuilder.newOperand("=l");
|
|
||||||
auto lhs = ptxBuilder.newOperand(operands[0], "l");
|
|
||||||
auto rhs = ptxBuilder.newOperand(operands[1], "l");
|
|
||||||
fdiv(res, lhs, rhs);
|
|
||||||
} else {
|
} else {
|
||||||
assert(0 && bitwidth && "not supported");
|
assert(0 && bitwidth && "not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
|
||||||
|
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
|
||||||
|
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
|
||||||
|
fdiv(res, lhs, rhs);
|
||||||
|
|
||||||
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
|
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ExpOpConversionApprox
|
struct ExpOpConversionApprox
|
||||||
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
: ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
|
||||||
ExpOpConversionApprox> {
|
using Base =
|
||||||
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
|
||||||
ExpOpConversionApprox>;
|
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
using Adaptor = typename Base::OpAdaptor;
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user