[BUILD] Fix Warnings and Enable Warnings as Errors (#794)

This commit is contained in:
Ian Bearman
2022-10-28 12:36:09 -07:00
committed by GitHub
parent ac0f6793cc
commit f2106d0aa2
20 changed files with 205 additions and 213 deletions

View File

@@ -254,7 +254,6 @@ protected:
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
@@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
@@ -482,7 +480,6 @@ public:
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
@@ -654,7 +651,6 @@ public:
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
@@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(tensorTy);
@@ -842,7 +837,6 @@ struct LoadOpConversion
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
// original values
@@ -897,12 +891,11 @@ struct LoadOpConversion
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
@@ -921,7 +914,7 @@ struct LoadOpConversion
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr);
}
@@ -988,8 +981,8 @@ struct LoadOpConversion
: retTys[0];
// TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
// LLVM::AsmDialect::AD_ATT);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
// ---
@@ -1080,27 +1073,25 @@ struct StoreOpConversion
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems);
SmallVector<std::pair<Value, std::string>> asmArgs;
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
@@ -1220,7 +1211,6 @@ struct BroadcastOpConversion
}
unsigned srcElems = getElemsPerThread(srcTy);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, rewriter);
unsigned resultElems = getElemsPerThread(resultTy);
SmallVector<Value> resultVals(resultElems);
@@ -1282,8 +1272,6 @@ private:
LogicalResult
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto rank = srcTy.getShape().size();
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
@@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
Location loc, Value val, int i) const {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {
@@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
barrier();
SmallVector<Value> resultVals(resultElems);
for (int i = 0; i < resultElems; i++) {
for (size_t i = 0; i < resultElems; i++) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
auto srcOrder = srcLayout.getOrder();
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
@@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
barrier();
SmallVector<Value> resultVals(resultElems);
for (int i = 0; i < resultElems; i++) {
for (size_t i = 0; i < resultElems; i++) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
@@ -1698,7 +1683,6 @@ struct AddPtrOpConversion
auto resultLayout =
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTensorTy.getShape();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
@@ -1821,7 +1805,7 @@ protected:
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (int i = 0; i < elems; ++i) {
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
@@ -1931,6 +1915,7 @@ struct CmpFOpConversion
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(UGE, uge);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
@@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica(
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
for (unsigned d = 0; d < rank; ++d) {
@@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica(
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape());
} else if (sliceLayout) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape =
@@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcTy);
// unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
@@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase,
int elemBytes, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, const Location &loc)
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
: order(order.begin(), order.end()), kOrder(kOrder),
tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
ctx(rewriter.getContext()) {
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
@@ -2576,7 +2559,6 @@ public:
assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
"smem matrix load must be aligned");
int matIdx[2] = {mat0, mat1};
int k = matIdx[kOrder];
int ptrIdx{-1};
@@ -2596,7 +2578,6 @@ public:
Value ptr = getPtr(ptrIdx);
Value resV4;
if (canUseLdmatrix) {
int sOffset =
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
@@ -2727,7 +2708,6 @@ public:
}
private:
int wpt;
SmallVector<uint32_t> order;
int kOrder;
SmallVector<int64_t> tileShape;
@@ -2737,7 +2717,6 @@ private:
int maxPhase;
int elemBytes;
ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter{};
const Location &loc;
MLIRContext *ctx{};
@@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// D = A * B + C
Value A = op.a();
Value B = op.b();
Value C = op.c();
Value D = op.getResult();
MLIRContext *ctx = op->getContext();
bool allowTF32 = op.allowTF32();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
@@ -2951,8 +2925,6 @@ struct DotOpConversionHelper {
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
Type i8x4Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(4, type::i32Ty(ctx)));
switch (mmaType) {
case TensorCoreType::FP32_FP16_FP16_FP32:
@@ -3062,7 +3034,6 @@ struct DotOpConversionHelper {
auto bTy = B.getType().cast<RankedTensorType>();
// d = a*b + c
auto dTy = op.d().getType().cast<RankedTensorType>();
auto mmaLayout = dTy.getEncoding().cast<MmaEncodingAttr>();
if (dTy.getElementType().isF32()) {
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
@@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper {
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc)
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
thread(thread) {
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
ctx(mmaLayout.getContext()) {
wpt = mmaLayout.getWarpsPerCTA();
Value _32 = i32_val(32);
@@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper {
}
// step1. Perform loading.
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned k = 0; k < numRepK; ++k)
for (int m = 0; m < numRepM; ++m)
for (int k = 0; k < numRepK; ++k)
loadFn(2 * m, 2 * k);
// step2. Format the values to LLVM::Struct to passing to mma codegen.
@@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper {
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (unsigned k = 0; k < numRepK; ++k)
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (int k = 0; k < numRepK; ++k)
loadFn(2 * n, 2 * k);
}
@@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper {
helper.deduceMmaType(op);
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto dShape = dTensorTy.getShape();
int NK = aShape[1];
// shape / shape_per_cta
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
int numRepM = getNumRepM(aTensorTy, dShape[0]);
int numRepN = getNumRepN(aTensorTy, dShape[1]);
int numRepK = getNumRepK(aTensorTy, aShape[1]);
@@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper {
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
};
for (unsigned k = 0; k < numRepK; ++k)
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned n = 0; n < numRepN; ++n)
for (int k = 0; k < numRepK; ++k)
for (int m = 0; m < numRepM; ++m)
for (int n = 0; n < numRepN; ++n)
callMma(2 * m, n, 2 * k);
// replace with new packed result
@@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper {
private:
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
int wpt, int kOrder, ArrayRef<int> instrShape,
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
ArrayRef<int> matShape, Value warpId,
ValueTable &vals) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
@@ -3486,8 +3452,8 @@ private:
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
int n1) const {
std::vector<Value> elems;
for (unsigned m = 0; m < n0; ++m)
for (unsigned k = 0; k < n1; ++k) {
for (int m = 0; m < n0; ++m)
for (int k = 0; k < n1; ++k) {
elems.push_back(vals.at({2 * m, 2 * k}));
elems.push_back(vals.at({2 * m, 2 * k + 1}));
elems.push_back(vals.at({2 * m + 1, 2 * k}));
@@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr mmaLayout =
@@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = void_ty(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
@@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
auto outOrder = resSharedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
// swizzle is not allowd
@@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp =
@@ -4208,7 +4170,7 @@ namespace mlir {
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
// addIllegalDialect<triton::TritonDialect>();
@@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
// addLegalDialect<NVVM::NVVMDialect>();
addIllegalOp<mlir::FuncOp>();

View File

@@ -21,9 +21,7 @@ public:
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res =
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success();
}
};
@@ -37,9 +35,8 @@ public:
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
DstOp res =
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
@@ -129,10 +126,9 @@ public:
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
triton::gpu::SelectOp res =
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
@@ -204,9 +200,6 @@ struct TritonExpandDimsPattern
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
// convert operand to slice of return type
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
getContext(), op.axis(), retEncoding);
@@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
return success();
@@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
return success();
}
@@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
return success();
}