[Triton-MLIR][BACKEND] insert_slice_async on GPUs < sm80 (#908)
`insert_slice_async` is decomposed into `load + insert_slice` in the backend. Not sure if V100 perf can match the master branch though in this way. Maybe the performance can be improved if instructions are arranged in the following form: ``` %0 = load %1 = load %2 = load ... insert_slice %0 insert_slice %1 insert_slice %2 ``` Tested on A100 when manually enabling this decomposition. Tests on V100 haven't been integrated yet, we can divide the tests into two phases: 1. Test only load, insert_slice, and insert_slice_async, given TritonGPU IRs in `test_backend.py`. 2. End to end gemm tests on V100.
This commit is contained in:
@@ -107,7 +107,8 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
|
auto llvmir =
|
||||||
|
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
|
||||||
if (!llvmir) {
|
if (!llvmir) {
|
||||||
llvm::errs() << "Translate to LLVM IR failed";
|
llvm::errs() << "Translate to LLVM IR failed";
|
||||||
}
|
}
|
||||||
|
@@ -12,6 +12,8 @@ bool isSharedEncoding(Value value);
|
|||||||
|
|
||||||
bool maybeSharedAllocationOp(Operation *op);
|
bool maybeSharedAllocationOp(Operation *op);
|
||||||
|
|
||||||
|
bool maybeAliasOp(Operation *op);
|
||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state);
|
std::string getValueOperandName(Value value, AsmState &state);
|
||||||
|
|
||||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||||
|
@@ -43,6 +43,12 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
|||||||
"mlir::triton::gpu::TritonGPUDialect",
|
"mlir::triton::gpu::TritonGPUDialect",
|
||||||
"mlir::NVVM::NVVMDialect",
|
"mlir::NVVM::NVVMDialect",
|
||||||
"mlir::StandardOpsDialect"];
|
"mlir::StandardOpsDialect"];
|
||||||
|
|
||||||
|
let options = [
|
||||||
|
Option<"computeCapability", "compute-capability",
|
||||||
|
"int32_t", /*default*/"80",
|
||||||
|
"device compute capability">
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -33,7 +33,8 @@ struct NVVMMetadataField {
|
|||||||
static constexpr char Kernel[] = "nvvm.kernel";
|
static constexpr char Kernel[] = "nvvm.kernel";
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
|
||||||
|
@@ -25,7 +25,8 @@ void addExternalLibs(mlir::ModuleOp &module,
|
|||||||
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||||
mlir::ModuleOp module);
|
mlir::ModuleOp module,
|
||||||
|
int computeCapability);
|
||||||
|
|
||||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
|
@@ -26,13 +26,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
|||||||
// These ops may allocate a new shared memory buffer.
|
// These ops may allocate a new shared memory buffer.
|
||||||
auto result = op->getResult(0);
|
auto result = op->getResult(0);
|
||||||
// FIXME(Keren): extract and insert are always alias for now
|
// FIXME(Keren): extract and insert are always alias for now
|
||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
if (isa<tensor::ExtractSliceOp>(op)) {
|
||||||
// extract_slice %src
|
// extract_slice %src
|
||||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||||
pessimistic = false;
|
pessimistic = false;
|
||||||
} else if (auto insertSliceOp =
|
} else if (isa<tensor::InsertSliceOp>(op) ||
|
||||||
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
|
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||||
// insert_slice_async %src, %dst, %index
|
// insert_slice_async %src, %dst, %index
|
||||||
|
// insert_slice %src into %dst[%offsets]
|
||||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||||
pessimistic = false;
|
pessimistic = false;
|
||||||
} else if (isSharedEncoding(result)) {
|
} else if (isSharedEncoding(result)) {
|
||||||
|
@@ -28,7 +28,7 @@ namespace mlir {
|
|||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
// Bitwidth of pointers
|
// Bitwidth of pointers
|
||||||
constexpr int kPtrBitWidth = 64;
|
constexpr int kPtrBitWidth = 64;
|
||||||
|
|
||||||
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
||||||
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
||||||
@@ -155,8 +155,7 @@ private:
|
|||||||
// For example: %a = scf.if -> yield
|
// For example: %a = scf.if -> yield
|
||||||
// %a must be allocated elsewhere by other operations.
|
// %a must be allocated elsewhere by other operations.
|
||||||
// FIXME(Keren): extract and insert are always alias for now
|
// FIXME(Keren): extract and insert are always alias for now
|
||||||
if (!maybeSharedAllocationOp(op) || isa<tensor::ExtractSliceOp>(op) ||
|
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
|
||||||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,9 +209,9 @@ private:
|
|||||||
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
|
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
|
||||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||||
std::multiplies{});
|
std::multiplies{});
|
||||||
auto bytes = srcTy.getElementType().isa<triton::PointerType>()?
|
auto bytes = srcTy.getElementType().isa<triton::PointerType>()
|
||||||
elems * kPtrBitWidth / 8 :
|
? elems * kPtrBitWidth / 8
|
||||||
elems * srcTy.getElementTypeBitWidth() / 8;
|
: elems * srcTy.getElementTypeBitWidth() / 8;
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||||
auto value = op->getOperand(0);
|
auto value = op->getOperand(0);
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "triton/Analysis/Membar.h"
|
#include "triton/Analysis/Membar.h"
|
||||||
|
#include "triton/Analysis/Alias.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
@@ -71,11 +72,17 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
|||||||
|
|
||||||
RegionInfo curRegionInfo;
|
RegionInfo curRegionInfo;
|
||||||
for (Value value : op->getOperands()) {
|
for (Value value : op->getOperands()) {
|
||||||
// ConvertLayoutOp: shared memory -> registers
|
|
||||||
// Need to consider all alias buffers
|
|
||||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||||
if (bufferId != Allocation::InvalidBufferId) {
|
if (bufferId != Allocation::InvalidBufferId) {
|
||||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||||
|
isa<tensor::InsertSliceOp>(op)) {
|
||||||
|
// FIXME(Keren): insert_slice and insert_slice_async are always alias
|
||||||
|
// for now
|
||||||
|
curRegionInfo.syncWriteBuffers.insert(bufferId);
|
||||||
|
} else {
|
||||||
|
// ConvertLayoutOp: shared memory -> registers
|
||||||
|
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -28,6 +28,12 @@ bool maybeSharedAllocationOp(Operation *op) {
|
|||||||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool maybeAliasOp(Operation *op) {
|
||||||
|
return isa<tensor::ExtractSliceOp>(op) ||
|
||||||
|
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||||
|
isa<tensor::InsertSliceOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state) {
|
std::string getValueOperandName(Value value, AsmState &state) {
|
||||||
std::string opName;
|
std::string opName;
|
||||||
llvm::raw_string_ostream ss(opName);
|
llvm::raw_string_ostream ss(opName);
|
||||||
|
@@ -205,6 +205,20 @@ auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
|||||||
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
|
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper function to get strides from a given shape and its order
|
||||||
|
auto getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
|
||||||
|
ArrayRef<unsigned> order, Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto rank = shape.size();
|
||||||
|
SmallVector<Value> strides(rank);
|
||||||
|
auto stride = 1;
|
||||||
|
for (auto idx : order) {
|
||||||
|
strides[idx] = i32_val(stride);
|
||||||
|
stride *= shape[idx];
|
||||||
|
}
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||||
protected:
|
protected:
|
||||||
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||||
@@ -452,13 +466,10 @@ struct SharedMemoryObject {
|
|||||||
ArrayRef<unsigned> order, Location loc,
|
ArrayRef<unsigned> order, Location loc,
|
||||||
ConversionPatternRewriter &rewriter)
|
ConversionPatternRewriter &rewriter)
|
||||||
: base(base) {
|
: base(base) {
|
||||||
auto rank = shape.size();
|
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
|
||||||
auto stride = 1;
|
|
||||||
strides.resize(rank);
|
|
||||||
for (auto idx : order) {
|
for (auto idx : order) {
|
||||||
strides[idx] = i32_val(stride);
|
|
||||||
offsets.emplace_back(i32_val(0));
|
offsets.emplace_back(i32_val(0));
|
||||||
stride *= shape[idx];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2835,6 +2846,112 @@ public:
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void storeBlockedToShared(Value src, Value llSrc,
|
||||||
|
ArrayRef<Value> srcStrides,
|
||||||
|
ArrayRef<Value> srcIndices, Value dst,
|
||||||
|
Value smemBase, Type elemPtrTy, Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
|
auto srcShape = srcTy.getShape();
|
||||||
|
assert(srcShape.size() == 2 && "Unexpected rank of insertSlice");
|
||||||
|
|
||||||
|
auto elemTy = srcTy.getElementType();
|
||||||
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||||
|
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
|
auto inOrd = srcBlockedLayout.getOrder();
|
||||||
|
auto outOrd = dstSharedLayout.getOrder();
|
||||||
|
unsigned inVec =
|
||||||
|
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
||||||
|
unsigned outVec = dstSharedLayout.getVec();
|
||||||
|
unsigned minVec = std::min(outVec, inVec);
|
||||||
|
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||||
|
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||||
|
unsigned numElems = getElemsPerThread(srcTy);
|
||||||
|
auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
|
||||||
|
auto srcAccumSizeInThreads =
|
||||||
|
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
||||||
|
auto wordTy = vec_ty(elemTy, minVec);
|
||||||
|
|
||||||
|
// TODO: [goostavz] We should make a cache for the calculation of
|
||||||
|
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
|
||||||
|
// optimize that
|
||||||
|
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
||||||
|
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
||||||
|
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
||||||
|
|
||||||
|
// Visit each input value in the order they are placed in inVals
|
||||||
|
//
|
||||||
|
// Please note that the order was not awaring of blockLayout.getOrder(),
|
||||||
|
// thus the adjacent elems may not belong to a same word. This could be
|
||||||
|
// improved if we update the elements order by emitIndicesForBlockedLayout()
|
||||||
|
SmallVector<unsigned> wordsInEachRep(2);
|
||||||
|
wordsInEachRep[0] = inOrd[0] == 0
|
||||||
|
? srcBlockedLayout.getSizePerThread()[0] / minVec
|
||||||
|
: srcBlockedLayout.getSizePerThread()[0];
|
||||||
|
wordsInEachRep[1] = inOrd[0] == 0
|
||||||
|
? srcBlockedLayout.getSizePerThread()[1]
|
||||||
|
: srcBlockedLayout.getSizePerThread()[1] / minVec;
|
||||||
|
Value outVecVal = i32_val(outVec);
|
||||||
|
Value minVecVal = i32_val(minVec);
|
||||||
|
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||||
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||||
|
for (unsigned i = 0; i < numElems; ++i) {
|
||||||
|
if (i % srcAccumSizeInThreads == 0) {
|
||||||
|
// start of a replication
|
||||||
|
for (unsigned w = 0; w < numWordsEachRep; ++w) {
|
||||||
|
wordVecs[w] = undef(wordTy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
||||||
|
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
||||||
|
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
||||||
|
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
||||||
|
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
||||||
|
auto wordVecIdx = getLinearIndex<unsigned>(multiDimIdxInNanoTile,
|
||||||
|
wordsInEachRep, inOrd);
|
||||||
|
wordVecs[wordVecIdx] =
|
||||||
|
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos));
|
||||||
|
|
||||||
|
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
|
||||||
|
// end of replication, store the vectors into shared memory
|
||||||
|
unsigned linearRepIdx = i / srcAccumSizeInThreads;
|
||||||
|
auto multiDimRepIdx =
|
||||||
|
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
||||||
|
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
||||||
|
++linearWordIdx) {
|
||||||
|
// step 1: recover the multidim_index from the index of input_elements
|
||||||
|
auto multiDimWordIdx =
|
||||||
|
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
||||||
|
SmallVector<Value> multiDimIdx(2);
|
||||||
|
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
|
||||||
|
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
|
||||||
|
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
|
||||||
|
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
|
||||||
|
multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0));
|
||||||
|
multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1));
|
||||||
|
|
||||||
|
// step 2: do swizzling
|
||||||
|
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
||||||
|
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
||||||
|
Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]);
|
||||||
|
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
|
||||||
|
phaseId = urem(phaseId, i32_val(maxPhase));
|
||||||
|
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
||||||
|
off_0 = mul(off_0, outVecVal);
|
||||||
|
remained = udiv(remained, minVecVal);
|
||||||
|
off_0 = add(off_0, mul(remained, minVecVal));
|
||||||
|
Value offset = add(off_1, off_0);
|
||||||
|
|
||||||
|
// step 3: store
|
||||||
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||||
|
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
||||||
|
store(wordVecs[linearWordIdx], smemAddr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
@@ -3129,110 +3246,91 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
auto inOrd = srcBlockedLayout.getOrder();
|
auto inOrd = srcBlockedLayout.getOrder();
|
||||||
auto outOrd = dstSharedLayout.getOrder();
|
auto outOrd = dstSharedLayout.getOrder();
|
||||||
unsigned inVec =
|
|
||||||
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
|
||||||
unsigned outVec = dstSharedLayout.getVec();
|
|
||||||
unsigned minVec = std::min(outVec, inVec);
|
|
||||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
|
||||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
|
||||||
unsigned numElems = getElemsPerThread(srcTy);
|
|
||||||
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
|
||||||
auto srcAccumSizeInThreads =
|
|
||||||
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
|
||||||
auto elemTy = srcTy.getElementType();
|
|
||||||
auto wordTy = vec_ty(elemTy, minVec);
|
|
||||||
|
|
||||||
// TODO: [goostavz] We should make a cache for the calculation of
|
|
||||||
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
|
|
||||||
// optimize that
|
|
||||||
SmallVector<Value> multiDimOffsetFirstElem =
|
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape);
|
|
||||||
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
|
||||||
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
|
||||||
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
|
||||||
|
|
||||||
// Visit each input value in the order they are placed in inVals
|
|
||||||
//
|
|
||||||
// Please note that the order was not awaring of blockLayout.getOrder(),
|
|
||||||
// thus the adjacent elems may not belong to a same word. This could be
|
|
||||||
// improved if we update the elements order by emitIndicesForBlockedLayout()
|
|
||||||
SmallVector<unsigned> wordsInEachRep(2);
|
|
||||||
wordsInEachRep[0] = inOrd[0] == 0
|
|
||||||
? srcBlockedLayout.getSizePerThread()[0] / minVec
|
|
||||||
: srcBlockedLayout.getSizePerThread()[0];
|
|
||||||
wordsInEachRep[1] = inOrd[0] == 0
|
|
||||||
? srcBlockedLayout.getSizePerThread()[1]
|
|
||||||
: srcBlockedLayout.getSizePerThread()[1] / minVec;
|
|
||||||
Value outVecVal = idx_val(outVec);
|
|
||||||
Value minVecVal = idx_val(minVec);
|
|
||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||||
|
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
|
|
||||||
|
auto srcStrides = getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
||||||
|
auto srcIndices =
|
||||||
|
emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape);
|
||||||
|
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
|
||||||
|
smemBase, elemPtrTy, loc, rewriter);
|
||||||
|
|
||||||
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
||||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
|
|
||||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
|
||||||
for (unsigned i = 0; i < numElems; ++i) {
|
|
||||||
if (i % srcAccumSizeInThreads == 0) {
|
|
||||||
// start of a replication
|
|
||||||
for (unsigned w = 0; w < numWordsEachRep; ++w) {
|
|
||||||
wordVecs[w] = undef(wordTy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
|
||||||
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
|
||||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
|
||||||
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
|
||||||
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
|
||||||
auto wordVecIdx =
|
|
||||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
|
||||||
wordVecs[wordVecIdx] =
|
|
||||||
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
|
|
||||||
|
|
||||||
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
|
|
||||||
// end of replication, store the vectors into shared memory
|
|
||||||
unsigned linearRepIdx = i / srcAccumSizeInThreads;
|
|
||||||
auto multiDimRepIdx =
|
|
||||||
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
|
||||||
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
|
||||||
++linearWordIdx) {
|
|
||||||
// step 1: recover the multidim_index from the index of input_elements
|
|
||||||
auto multiDimWordIdx =
|
|
||||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
|
||||||
SmallVector<Value> multiDimIdx(2);
|
|
||||||
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
|
|
||||||
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
|
|
||||||
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
|
|
||||||
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
|
|
||||||
multiDimIdx[0] = add(multiDimOffsetFirstElem[0], idx_val(wordOffset0));
|
|
||||||
multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1));
|
|
||||||
|
|
||||||
// step 2: do swizzling
|
|
||||||
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
|
||||||
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
|
||||||
Value off_1 = mul(multiDimIdx[outOrd[1]], idx_val(srcShape[outOrd[0]]));
|
|
||||||
Value phaseId = udiv(multiDimIdx[outOrd[1]], idx_val(perPhase));
|
|
||||||
phaseId = urem(phaseId, idx_val(maxPhase));
|
|
||||||
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
|
||||||
off_0 = mul(off_0, outVecVal);
|
|
||||||
remained = udiv(remained, minVecVal);
|
|
||||||
off_0 = add(off_0, mul(remained, minVecVal));
|
|
||||||
Value offset = add(off_1, off_0);
|
|
||||||
|
|
||||||
// step 3: store
|
|
||||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
|
||||||
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
|
||||||
store(wordVecs[linearWordIdx], smemAddr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Barrier is not necessary.
|
|
||||||
// The membar pass knows that it writes to shared memory and will handle it
|
|
||||||
// properly.
|
|
||||||
rewriter.replaceOp(op, retVal);
|
rewriter.replaceOp(op, retVal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct InsertSliceOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
tensor::InsertSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// %dst = insert_slice %src into %dst[%offsets]
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value dst = op.dest();
|
||||||
|
Value src = op.source();
|
||||||
|
Value res = op.result();
|
||||||
|
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||||
|
"Only support in-place insert_slice for now");
|
||||||
|
|
||||||
|
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||||
|
auto srcShape = srcTy.getShape();
|
||||||
|
assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion");
|
||||||
|
|
||||||
|
auto dstTy = dst.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto dstLayout = dstTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||||
|
auto llDst = adaptor.dest();
|
||||||
|
assert(dstLayout && "Unexpected dstLayout in InsertSliceOpConversion");
|
||||||
|
assert(op.hasUnitStride() &&
|
||||||
|
"Only unit stride supported by InsertSliceOpConversion");
|
||||||
|
|
||||||
|
// newBase = base + offset
|
||||||
|
// Triton support either static and dynamic offsets
|
||||||
|
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
|
||||||
|
SmallVector<Value, 4> offsets;
|
||||||
|
SmallVector<Value, 4> srcStrides;
|
||||||
|
auto mixedOffsets = op.getMixedOffsets();
|
||||||
|
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||||
|
if (op.isDynamicOffset(i)) {
|
||||||
|
offsets.emplace_back(adaptor.offsets()[i]);
|
||||||
|
} else {
|
||||||
|
offsets.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||||
|
}
|
||||||
|
// Like insert_slice_async, we only support slice from one dimension,
|
||||||
|
// which has a slice size of 1
|
||||||
|
if (op.getStaticSize(i) != 1) {
|
||||||
|
srcStrides.emplace_back(smemObj.strides[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the offset based on the original strides of the shared memory
|
||||||
|
// object
|
||||||
|
auto offset = dot(rewriter, loc, offsets, smemObj.strides);
|
||||||
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||||
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
|
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
|
||||||
|
|
||||||
|
auto llSrc = adaptor.source();
|
||||||
|
auto srcIndices =
|
||||||
|
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape);
|
||||||
|
ConvertLayoutOpConversion::storeBlockedToShared(src, llSrc, srcStrides,
|
||||||
|
srcIndices, dst, smemBase,
|
||||||
|
elemPtrTy, loc, rewriter);
|
||||||
|
// Barrier is not necessary.
|
||||||
|
// The membar pass knows that it writes to shared memory and will handle it
|
||||||
|
// properly.
|
||||||
|
rewriter.replaceOp(op, llDst);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// ====================== dot codegen begin ==========================
|
/// ====================== dot codegen begin ==========================
|
||||||
|
|
||||||
// Data loader for mma.16816 instruction.
|
// Data loader for mma.16816 instruction.
|
||||||
@@ -5972,7 +6070,7 @@ struct AtomicRMWOpConversion
|
|||||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||||
|
|
||||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
Type valueElemTy =
|
Type valueElemTy =
|
||||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||||
@@ -6166,11 +6264,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, axisInfoAnalysis, benefit);
|
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||||
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
|
benefit);
|
||||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||||
axisInfoAnalysis, benefit);
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
@@ -6216,8 +6317,57 @@ private:
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void decomposeInsertSliceAsyncOp(ModuleOp mod,
|
||||||
|
TritonGPUToLLVMTypeConverter &converter) {
|
||||||
|
// cp.async is supported in Ampere and later
|
||||||
|
if (computeCapability >= 80)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// insert_slice_async %src, %dst, %idx, %mask, %other
|
||||||
|
// =>
|
||||||
|
// %tmp = load %src, %mask, %other
|
||||||
|
// %res = insert_slice %tmp into %dst[%idx]
|
||||||
|
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
|
||||||
|
OpBuilder builder(insertSliceAsyncOp);
|
||||||
|
// load
|
||||||
|
auto srcTy = insertSliceAsyncOp.src().getType().cast<RankedTensorType>();
|
||||||
|
auto dstTy = insertSliceAsyncOp.getType().cast<RankedTensorType>();
|
||||||
|
auto srcBlocked =
|
||||||
|
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||||
|
auto elemTy = converter.convertType(dstTy.getElementType());
|
||||||
|
auto tmpTy = RankedTensorType::get(srcTy.getShape(), elemTy, srcBlocked);
|
||||||
|
auto loadOp = builder.create<triton::LoadOp>(
|
||||||
|
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(),
|
||||||
|
insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(),
|
||||||
|
insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(),
|
||||||
|
insertSliceAsyncOp.isVolatile());
|
||||||
|
// insert_slice
|
||||||
|
auto axis = insertSliceAsyncOp.axis();
|
||||||
|
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||||
|
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
|
||||||
|
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
||||||
|
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
||||||
|
offsets[axis] = insertSliceAsyncOp.index();
|
||||||
|
for (size_t i = 0; i < dstTy.getRank(); i++) {
|
||||||
|
if (i != axis)
|
||||||
|
sizes[i] = intAttr(dstTy.getShape()[i]);
|
||||||
|
}
|
||||||
|
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
|
||||||
|
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.dst(),
|
||||||
|
offsets, sizes, strides);
|
||||||
|
// Replace
|
||||||
|
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
||||||
|
insertSliceAsyncOp.erase();
|
||||||
|
});
|
||||||
|
|
||||||
|
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||||
|
asyncWaitOp.erase();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ConvertTritonGPUToLLVM() = default;
|
explicit ConvertTritonGPUToLLVM(int computeCapability)
|
||||||
|
: computeCapability(computeCapability) {}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
@@ -6233,18 +6383,22 @@ public:
|
|||||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
|
||||||
// step 1: Decompose unoptimized layout conversions to use shared memory
|
// step 1: Decompose unoptimized layout conversions to use shared memory
|
||||||
// step 2: Allocate shared memories and insert barriers
|
// step 2: Decompose insert_slice_async to use load + insert_slice for
|
||||||
// step 3: Convert SCF to CFG
|
// pre-Ampere architectures
|
||||||
// step 4: Convert FuncOp to LLVMFuncOp via partial conversion
|
// step 3: Allocate shared memories and insert barriers
|
||||||
// step 5: Convert the rest of ops via partial conversion
|
// step 4: Convert SCF to CFG
|
||||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
// step 5: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||||
// currently only supports SCF but not CFG.
|
// step 6: Convert the rest of ops via partial
|
||||||
// The reason for a separation between 1/4 is that, step 3 is out of
|
// conversion The reason for putting step 1 before step 2 is that the membar
|
||||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
// analysis currently only supports SCF but not CFG. The reason for a
|
||||||
// is not revised during the conversion of step 4.
|
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||||
|
// Conversion, thus we need to make sure the smem is not revised during the
|
||||||
|
// conversion of step 4.
|
||||||
|
|
||||||
decomposeBlockedToDotOperand(mod);
|
decomposeBlockedToDotOperand(mod);
|
||||||
|
|
||||||
|
decomposeInsertSliceAsyncOp(mod, typeConverter);
|
||||||
|
|
||||||
Allocation allocation(mod);
|
Allocation allocation(mod);
|
||||||
MembarAnalysis membar(&allocation);
|
MembarAnalysis membar(&allocation);
|
||||||
|
|
||||||
@@ -6303,6 +6457,8 @@ protected:
|
|||||||
TritonGPUToLLVMTypeConverter &typeConverter);
|
TritonGPUToLLVMTypeConverter &typeConverter);
|
||||||
|
|
||||||
Value smem;
|
Value smem;
|
||||||
|
|
||||||
|
int computeCapability{};
|
||||||
};
|
};
|
||||||
|
|
||||||
void ConvertTritonGPUToLLVM::initSharedMemory(
|
void ConvertTritonGPUToLLVM::initSharedMemory(
|
||||||
@@ -6365,8 +6521,9 @@ TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
|||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
return std::make_unique<::ConvertTritonGPUToLLVM>();
|
createConvertTritonGPUToLLVMPass(int computeCapability) {
|
||||||
|
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
@@ -202,8 +202,7 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
bufferShape.insert(bufferShape.begin(), numStages);
|
bufferShape.insert(bufferShape.begin(), numStages);
|
||||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||||
triton::gpu::getOrder(ty.getEncoding()),
|
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
|
||||||
ty.getElementType());
|
|
||||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||||
bufferShape, ty.getElementType(), sharedEnc);
|
bufferShape, ty.getElementType(), sharedEnc);
|
||||||
}
|
}
|
||||||
|
@@ -119,7 +119,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
|||||||
|
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||||
mlir::ModuleOp module) {
|
mlir::ModuleOp module, int computeCapability) {
|
||||||
mlir::PassManager pm(module->getContext());
|
mlir::PassManager pm(module->getContext());
|
||||||
applyPassManagerCLOptions(pm);
|
applyPassManagerCLOptions(pm);
|
||||||
auto printingFlags = mlir::OpPrintingFlags();
|
auto printingFlags = mlir::OpPrintingFlags();
|
||||||
|
@@ -1107,7 +1107,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
mlir::Value &mask) -> mlir::Value {
|
mlir::Value &mask) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
mlir::Type dstType;
|
mlir::Type dstType;
|
||||||
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
if (auto srcTensorType =
|
||||||
|
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||||
mlir::Type dstElemType = srcTensorType.getElementType()
|
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||||
.cast<mlir::triton::PointerType>()
|
.cast<mlir::triton::PointerType>()
|
||||||
.getPointeeType();
|
.getPointeeType();
|
||||||
@@ -1315,8 +1316,8 @@ void init_triton_translation(py::module &m) {
|
|||||||
"translate_triton_gpu_to_llvmir",
|
"translate_triton_gpu_to_llvmir",
|
||||||
[](mlir::ModuleOp op, int computeCapability) {
|
[](mlir::ModuleOp op, int computeCapability) {
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
auto llvmModule =
|
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
|
||||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
&llvmContext, op, computeCapability);
|
||||||
if (!llvmModule)
|
if (!llvmModule)
|
||||||
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
|
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
|
||||||
|
|
||||||
|
@@ -65,6 +65,20 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: insert_slice
|
||||||
|
func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||||
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||||
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||||
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||||
|
// CHECK: %cst_0 -> %cst_0
|
||||||
|
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||||
|
%index = arith.constant 0 : index
|
||||||
|
%a = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
|
||||||
|
// CHECK: %3 -> %cst_0
|
||||||
|
%b = tensor.insert_slice %a into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: extract_slice
|
// CHECK-LABEL: extract_slice
|
||||||
func @extract_slice(%A : !tt.ptr<f16>) {
|
func @extract_slice(%A : !tt.ptr<f16>) {
|
||||||
// CHECK: %cst -> %cst
|
// CHECK: %cst -> %cst
|
||||||
|
@@ -119,8 +119,26 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED>
|
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : i32
|
||||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
|
||||||
|
// CHECK: Membar 6
|
||||||
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
||||||
// CHECK: Membar 7
|
// CHECK: Membar 8
|
||||||
|
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: insert_slice
|
||||||
|
func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||||
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||||
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||||
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||||
|
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||||
|
%index = arith.constant 0 : index
|
||||||
|
%al = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
|
||||||
|
// CHECK: Membar 6
|
||||||
|
%a = tensor.insert_slice %al into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
|
||||||
|
// CHECK: Membar 8
|
||||||
|
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
||||||
|
// CHECK: Membar 10
|
||||||
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -34,7 +34,8 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
|||||||
|
|
||||||
// create element type
|
// create element type
|
||||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||||
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType);
|
auto layout =
|
||||||
|
SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType);
|
||||||
|
|
||||||
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
|
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
|
||||||
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
|
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
|
||||||
|
Reference in New Issue
Block a user