[BUILD] Fix Warnings and Enable Warnings as Errors (#794)
This commit is contained in:
@@ -254,7 +254,6 @@ protected:
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||
/// information.
|
||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
PatternBenefit benefit)
|
||||
@@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
// Currently, Triton kernel function always return nothing.
|
||||
@@ -482,7 +480,6 @@ public:
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
@@ -654,7 +651,6 @@ public:
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
Value offVal = idx_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
@@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(srcType, constVal);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
@@ -842,7 +837,6 @@ struct LoadOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// original values
|
||||
@@ -897,12 +891,11 @@ struct LoadOpConversion
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
@@ -921,7 +914,7 @@ struct LoadOpConversion
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
@@ -988,8 +981,8 @@ struct LoadOpConversion
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
// LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
// ---
|
||||
@@ -1080,27 +1073,25 @@ struct StoreOpConversion
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
||||
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||
assert(elemOffset < valueElems.size());
|
||||
Value elem = valueElems[elemOffset];
|
||||
@@ -1220,7 +1211,6 @@ struct BroadcastOpConversion
|
||||
}
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto elemTy = resultTy.getElementType();
|
||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
@@ -1282,8 +1272,6 @@ private:
|
||||
LogicalResult
|
||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto rank = srcTy.getShape().size();
|
||||
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
@@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||
|
||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value val, int i) const {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
@@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (int i = 0; i < resultElems; i++) {
|
||||
for (size_t i = 0; i < resultElems; i++) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
@@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcOrder = srcLayout.getOrder();
|
||||
|
||||
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||
@@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (int i = 0; i < resultElems; i++) {
|
||||
for (size_t i = 0; i < resultElems; i++) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
@@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
@@ -1698,7 +1683,6 @@ struct AddPtrOpConversion
|
||||
auto resultLayout =
|
||||
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||
auto resultShape = resultTensorTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
@@ -1821,7 +1805,7 @@ protected:
|
||||
SmallVector<SmallVector<Value>> operands(elems);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||
for (int i = 0; i < elems; ++i) {
|
||||
for (size_t i = 0; i < elems; ++i) {
|
||||
operands[i].push_back(sub_operands[i]);
|
||||
}
|
||||
}
|
||||
@@ -1931,6 +1915,7 @@ struct CmpFOpConversion
|
||||
__PRED_ENUM(ORD, ord);
|
||||
__PRED_ENUM(UEQ, ueq);
|
||||
__PRED_ENUM(UGT, ugt);
|
||||
__PRED_ENUM(UGE, uge);
|
||||
__PRED_ENUM(ULT, ult);
|
||||
__PRED_ENUM(ULE, ule);
|
||||
__PRED_ENUM(UNE, une);
|
||||
@@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
auto shapePerCTA = getShapePerCTA(layout);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
@@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedLayout, type.getShape());
|
||||
} else if (sliceLayout) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape =
|
||||
@@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
// unsigned elems = getElemsPerThread(srcTy);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
@@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
// Data loader for mma.16816 instruction.
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, const Location &loc)
|
||||
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
|
||||
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||
tileShape(tileShape.begin(), tileShape.end()),
|
||||
instrShape(instrShape.begin(), instrShape.end()),
|
||||
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
|
||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
|
||||
ctx(rewriter.getContext()) {
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
@@ -2576,7 +2559,6 @@ public:
|
||||
assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
|
||||
"smem matrix load must be aligned");
|
||||
int matIdx[2] = {mat0, mat1};
|
||||
int k = matIdx[kOrder];
|
||||
|
||||
int ptrIdx{-1};
|
||||
|
||||
@@ -2596,7 +2578,6 @@ public:
|
||||
|
||||
Value ptr = getPtr(ptrIdx);
|
||||
|
||||
Value resV4;
|
||||
if (canUseLdmatrix) {
|
||||
int sOffset =
|
||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||
@@ -2727,7 +2708,6 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
int wpt;
|
||||
SmallVector<uint32_t> order;
|
||||
int kOrder;
|
||||
SmallVector<int64_t> tileShape;
|
||||
@@ -2737,7 +2717,6 @@ private:
|
||||
int maxPhase;
|
||||
int elemBytes;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TypeConverter *typeConverter{};
|
||||
const Location &loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
@@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
// D = A * B + C
|
||||
Value A = op.a();
|
||||
Value B = op.b();
|
||||
Value C = op.c();
|
||||
Value D = op.getResult();
|
||||
MLIRContext *ctx = op->getContext();
|
||||
bool allowTF32 = op.allowTF32();
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
@@ -2951,8 +2925,6 @@ struct DotOpConversionHelper {
|
||||
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
||||
Type i8x4Pack4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
||||
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(4, type::i32Ty(ctx)));
|
||||
|
||||
switch (mmaType) {
|
||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||
@@ -3062,7 +3034,6 @@ struct DotOpConversionHelper {
|
||||
auto bTy = B.getType().cast<RankedTensorType>();
|
||||
// d = a*b + c
|
||||
auto dTy = op.d().getType().cast<RankedTensorType>();
|
||||
auto mmaLayout = dTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
|
||||
if (dTy.getElementType().isF32()) {
|
||||
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
|
||||
@@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper {
|
||||
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
|
||||
thread(thread) {
|
||||
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
||||
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
||||
ctx(mmaLayout.getContext()) {
|
||||
wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
Value _32 = i32_val(32);
|
||||
@@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// step1. Perform loading.
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (int m = 0; m < numRepM; ++m)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||
@@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper {
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * n, 2 * k);
|
||||
}
|
||||
|
||||
@@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper {
|
||||
helper.deduceMmaType(op);
|
||||
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto dShape = dTensorTy.getShape();
|
||||
|
||||
int NK = aShape[1];
|
||||
// shape / shape_per_cta
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
||||
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
||||
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
||||
@@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper {
|
||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned n = 0; n < numRepN; ++n)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
for (int m = 0; m < numRepM; ++m)
|
||||
for (int n = 0; n < numRepN; ++n)
|
||||
callMma(2 * m, n, 2 * k);
|
||||
|
||||
// replace with new packed result
|
||||
@@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper {
|
||||
private:
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||
int wpt, int kOrder, ArrayRef<int> instrShape,
|
||||
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, Value warpId,
|
||||
ValueTable &vals) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
@@ -3486,8 +3452,8 @@ private:
|
||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||
int n1) const {
|
||||
std::vector<Value> elems;
|
||||
for (unsigned m = 0; m < n0; ++m)
|
||||
for (unsigned k = 0; k < n1; ++k) {
|
||||
for (int m = 0; m < n0; ++m)
|
||||
for (int k = 0; k < n1; ++k) {
|
||||
elems.push_back(vals.at({2 * m, 2 * k}));
|
||||
elems.push_back(vals.at({2 * m, 2 * k + 1}));
|
||||
elems.push_back(vals.at({2 * m + 1, 2 * k}));
|
||||
@@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
|
||||
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto dotOperandLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
MmaEncodingAttr mmaLayout =
|
||||
@@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = void_ty(ctx);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
@@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
auto outOrder = resSharedLayout.getOrder();
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
||||
// swizzle is not allowd
|
||||
@@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion
|
||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||
|
||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto wordElemIdx = wordIdx * numWordElems;
|
||||
auto ©AsyncOp =
|
||||
@@ -4208,7 +4170,7 @@ namespace mlir {
|
||||
|
||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
// addIllegalDialect<triton::TritonDialect>();
|
||||
@@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
|
||||
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
// addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalOp<mlir::FuncOp>();
|
||||
|
@@ -21,9 +21,7 @@ public:
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res =
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -37,9 +35,8 @@ public:
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
DstOp res =
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -129,10 +126,9 @@ public:
|
||||
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
triton::gpu::SelectOp res =
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -204,9 +200,6 @@ struct TritonExpandDimsPattern
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
// convert operand to slice of return type
|
||||
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||
getContext(), op.axis(), retEncoding);
|
||||
@@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
return success();
|
||||
@@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
@@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
|
Reference in New Issue
Block a user