[BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)

Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
This commit is contained in:
Philippe Tillet
2022-10-28 01:17:55 -07:00
committed by GitHub
parent 3685194456
commit ac0f6793cc
6 changed files with 269 additions and 419 deletions

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
@@ -310,6 +311,10 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
if (!structType.isa<LLVM::LLVMStructType>()) {
return *resultVals.begin();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = insert_val(structType, llvmStruct, v.value(),
@@ -369,6 +374,10 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
llvmStruct.getType().isa<triton::PointerType>() ||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
return {llvmStruct};
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> results(types.size());
@@ -678,7 +687,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
@@ -760,64 +769,49 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
// Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue,
const BlockedEncodingAttr &layout,
ConversionPatternRewriter &rewriter,
Location loc) const {
if (!value)
return {};
auto shape = value.getType().cast<RankedTensorType>().getShape();
if (!llValue.getType().isa<LLVM::LLVMStructType>())
return {llValue};
// Here, we assume that all inputs should have a blockedLayout
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
return valueVals;
}
// Get the blocked layout.
std::tuple<BlockedEncodingAttr, unsigned> getLayout(Value val) const {
auto ty = val.getType().cast<RankedTensorType>();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(layout && "unexpected layout in getLayout");
auto shape = ty.getShape();
unsigned valueElems = layout.getElemsPerThread(shape);
return {layout, valueElems};
}
unsigned getVectorSize(Value ptr) const {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(val);
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned getVectorizeSize(Value ptr,
const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(ptr);
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
auto order = getOrder(layout);
unsigned align = getAlignment(ptr, layout);
auto ty = ptr.getType().dyn_cast<RankedTensorType>();
assert(ty);
auto shape = ty.getShape();
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned contigPerThread = getSizePerThread(layout)[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
return vec;
}
unsigned getMaskAlignment(Value mask) const {
auto maskOrder = mask.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<BlockedEncodingAttr>()
.getOrder();
unsigned getAlignment(Value val, const Attribute &layout) const {
auto axisInfo = getAxisInfo(val);
auto order = getOrder(layout);
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned getMaskAlignment(Value mask) const {
auto tensorTy = mask.getType().cast<RankedTensorType>();
auto maskOrder = getOrder(tensorTy.getEncoding());
auto maskAxis = getAxisInfo(mask);
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
}
@@ -848,46 +842,39 @@ struct LoadOpConversion
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
// original values
Value ptr = op.ptr();
Value mask = op.mask();
Value other = op.other();
// adaptor values
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
auto [layout, numElems] = getLayout(ptr);
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
assert(ptrElems.size() == numElems);
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
Type valueTy = op.getResult().getType();
Type valueElemTy = getElementTypeOrSelf(valueTy);
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));
// Get the LLVM values for pointers
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
assert(ptrElems.size() == numElems);
// Get the LLVM values for mask
SmallVector<Value> maskElems;
if (llMask) {
unsigned maskAlignment = getMaskAlignment(mask);
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
assert(ptrElems.size() == maskElems.size());
size_t maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(maskElems.size() == numElems);
}
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = numElems / vec;
// Get the LLVM values for `other`
// TODO: (goostavz) handle when other is const but not splat, which
// should be rarely seen
bool otherIsSplatConstInt = false;
@@ -898,8 +885,12 @@ struct LoadOpConversion
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
// vectorized iteration through all the pointer/mask/other elements
const int valueElemNbits =
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
const int numVecs = numElems / vec;
SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
@@ -1060,30 +1051,23 @@ struct StoreOpConversion
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) {
store(llValue, llPtr);
rewriter.eraseOp(op);
return success();
}
auto valueTy = value.getType();
Type valueElemTy = getElementTypeOrSelf(valueTy);
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
auto [layout, numElems] = getLayout(ptr);
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(valueElems.size() == maskElems.size());
size_t maskAlign = getMaskAlignment(mask);
unsigned maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
}
@@ -1146,7 +1130,7 @@ struct StoreOpConversion
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
auto &ptxStoreInstr =
ptxBuilder.create<PTXIOInstr>("st")->global().b(width).v(nWords);
ptxBuilder.create<PTXIOInstr>("st")->global().v(nWords).b(width);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
@@ -1223,8 +1207,9 @@ struct BroadcastOpConversion
for (auto it : llvm::enumerate(broadcastDims)) {
// Incase there are multiple indices in the src that is actually
// calculating the same element, srcLogicalShape may not need to be 1.
// Such as the case when src of shape [256, 1], and with a blocked layout:
// sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2]
// Such as the case when src of shape [256, 1], and with a blocked
// layout: sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA:
// [1, 2]
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
broadcastSizes[it.index()] = d;
duplicates *= d;
@@ -1234,10 +1219,10 @@ struct BroadcastOpConversion
duplicates *= d;
}
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
unsigned srcElems = getElemsPerThread(srcTy);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, rewriter);
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
unsigned resultElems = getElemsPerThread(resultTy);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < srcElems; ++i) {
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
@@ -1256,8 +1241,10 @@ struct BroadcastOpConversion
}
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
@@ -1389,7 +1376,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
auto smemShape = getScratchConfigForReduce(op);
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
@@ -1446,7 +1433,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
auto resultLayout = resultTy.getEncoding();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
@@ -1498,7 +1485,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
unsigned sizeIntraWarps = threadsPerWarp[axis];
unsigned sizeInterWarps = warpsPerCTA[axis];
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
@@ -1586,7 +1573,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
@@ -1633,7 +1620,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape);
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
@@ -1712,7 +1699,7 @@ struct AddPtrOpConversion
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTensorTy.getShape();
unsigned elems = resultLayout.getElemsPerThread(resultShape);
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
SmallVector<Type> types(elems, elemTy);
@@ -1769,8 +1756,8 @@ struct ExtractSliceOpConversion
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
// axis > 0 will result in non-contiguous memory access if the result tensor
// is an alias of the source tensor.
// axis > 0 will result in non-contiguous memory access if the result
// tensor is an alias of the source tensor.
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
@@ -1806,22 +1793,14 @@ public:
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
auto resultTy = op.getType();
Location loc = op->getLoc();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout &&
"Unexpected resultLayout in ElementwiseOpConversionBase");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
unsigned elems = getElemsPerThread(resultTy);
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Type structTy = this->getTypeConverter()->convertType(resultTy);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto operands = getOperands(rewriter, adaptor, elems, loc);
@@ -1874,152 +1853,6 @@ struct ElementwiseOpConversion
}
};
//
// Ternary
//
template <typename SourceOp, typename DestOp, typename ConcreteT>
class TernaryOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter);
auto thss =
this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
rhss[i], thss[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
struct TernaryOpConversion
: public TernaryOpConversionBase<SourceOp, DestOp,
TernaryOpConversion<SourceOp, DestOp>> {
explicit TernaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: TernaryOpConversionBase<SourceOp, DestOp,
TernaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value lhs, Value rhs, Value th,
Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs, th);
}
};
//
// Unary
//
template <typename SourceOp, typename DestOp, typename ConcreteT>
class UnaryOpConversionBase : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit UnaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout && "Unexpected resultLayout in UnaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto srcs = this->getElementsFromStruct(loc, concreteThis->getSrc(adaptor),
rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] =
concreteThis->createDestOp(op, rewriter, elemTy, srcs[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
struct UnaryOpConversion
: public UnaryOpConversionBase<SourceOp, DestOp,
UnaryOpConversion<SourceOp, DestOp>> {
explicit UnaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: UnaryOpConversionBase<SourceOp, DestOp,
UnaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value src, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, src);
}
// Get the source operand of the op.
Value getSrc(OpAdaptor adaptor) const {
auto operands = adaptor.getOperands();
if (operands.size() > 1)
llvm::report_fatal_error("unary operator has more than one operand");
return operands.front();
}
};
//
// comparisons
//
@@ -2367,13 +2200,13 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstLayout, shape);
unsigned outElems = getElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
@@ -2431,7 +2264,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
unsigned numElems = getElemsPerThread(srcTy);
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
@@ -2609,7 +2442,8 @@ public:
Value c = urem(lane, i32_val(8));
Value s = udiv(lane, i32_val(8)); // sub-warp-id
// Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a warp
// Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a
// warp
Value s0 = urem(s, i32_val(2));
Value s1 = udiv(s, i32_val(2));
@@ -2756,8 +2590,8 @@ public:
llvm::report_fatal_error("unsupported mma type found");
// The main difference with the original triton code is we removed the
// prefetch-related logic here for the upstream optimizer phase should take
// care with it, and that is transparent in dot conversion.
// prefetch-related logic here for the upstream optimizer phase should
// take care with it, and that is transparent in dot conversion.
auto getPtr = [&](int idx) { return ptrs[idx]; };
Value ptr = getPtr(ptrIdx);
@@ -2768,7 +2602,8 @@ public:
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
PTXBuilder builder;
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread.
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
// thread.
auto resArgs = builder.newListOperand(4, "=r");
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
@@ -3067,7 +2902,8 @@ struct DotOpConversionHelper {
// Get the M and N of mma instruction shape.
static std::tuple<int, int> getInstrShapeMN() {
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are {16,8}
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are
// {16,8}
return {16, 8};
}
@@ -3808,8 +3644,7 @@ public:
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
unsigned numElementsPerThread =
getElemsPerThread(layout, type.getShape());
unsigned numElementsPerThread = getElemsPerThread(type);
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(ctx, types);
@@ -3927,7 +3762,7 @@ struct InsertSliceAsyncOpConversion
Value llIndex = adaptor.index();
// %src
auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc);
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
// %dst
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
@@ -3943,7 +3778,7 @@ struct InsertSliceAsyncOpConversion
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc);
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(srcElems.size() == maskElems.size());
}
@@ -3954,15 +3789,14 @@ struct InsertSliceAsyncOpConversion
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
assert(false && "insert_slice_async: Other value not supported yet");
otherElems =
getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc);
otherElems = getLLVMElems(other, llOther, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}
unsigned inVec = getVectorizeSize(src, srcBlockedLayout);
unsigned inVec = getVectorSize(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
unsigned numElems = getElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
@@ -4212,6 +4046,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<CmpFOpConversion>(typeConverter, benefit);
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
@@ -4221,14 +4056,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<FDivOpConversion>(typeConverter, benefit);