[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)

This commit is contained in:
goostavz
2022-09-18 05:58:42 +08:00
committed by GitHub
parent 13669b46a6
commit 15bfd0cb79
17 changed files with 1025 additions and 191 deletions

View File

@@ -64,16 +64,6 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
return nullptr;
}
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
pm.addPass(createConvertTritonGPUToLLVMPass());
if (failed(pm.run(module->getOperation()))) {
llvm::errs() << "Pass execution failed";
return nullptr;
}
return module;
}

View File

@@ -14,7 +14,12 @@ namespace mlir {
namespace triton {
class AllocationAnalysis;
}
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
/// A class that represents an interval, specified using a start and an end

View File

@@ -2,7 +2,10 @@
#define TRITON_ANALYSIS_UTILITY_H
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
#include <string>
namespace mlir {
bool isSharedEncoding(Value value);
@@ -11,6 +14,12 @@ bool maybeSharedAllocationOp(Operation *op);
std::string getValueOperandName(Value value, AsmState &state);
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -18,6 +18,14 @@ public:
mlir::LLVMTypeConverter &typeConverter);
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public:
explicit TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
};
namespace triton {
// Names for identifying different NVVM annotations. It is used as attribute

View File

@@ -16,4 +16,16 @@
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace gpu {
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -31,6 +31,10 @@ Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
Right now, Triton implements two classes of layouts: shared, and distributed.
}];
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
}];
}
//===----------------------------------------------------------------------===//
@@ -64,6 +68,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
@@ -93,6 +99,8 @@ Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
}];
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
@@ -171,11 +179,10 @@ for
}]>
];
let extraClassDeclaration = [{
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
@@ -282,6 +289,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
"unsigned":$version,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
@@ -311,6 +320,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
// TODO: constraint here to only take distributed encodings
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration;
}

View File

@@ -22,6 +22,7 @@
#ifndef TDL_TOOLS_SYS_GETENV_HPP
#define TDL_TOOLS_SYS_GETENV_HPP
#include <algorithm>
#include <cstdlib>
#include <string>
@@ -37,6 +38,14 @@ inline std::string getenv(const char *name) {
return result;
}
inline bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}
} // namespace tools
} // namespace triton

View File

@@ -8,6 +8,11 @@
#include <algorithm>
#include <limits>
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
namespace mlir {
@@ -15,6 +20,54 @@ namespace mlir {
// Shared Memory Allocation Analysis
//===----------------------------------------------------------------------===//
namespace triton {
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto dstTy = op.result().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
assert(srcLayout && dstLayout &&
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
// TODO: move to TritonGPUAttrDefs.h.inc
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d];
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
return 0;
}
};
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<BlockedEncodingAttr>()) {
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder();
auto outOrd = dstBlockedLayout.getOrder();
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1
: inOrd[0] == 0 ? 1
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
outVec =
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] = std::max(
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
std::min<unsigned>(dstTy.getShape()[d],
getShapePerCTA(dstLayout, d)));
}
paddedRepShape[outOrd[0]] += pad;
}
return paddedRepShape;
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -73,6 +126,27 @@ private:
tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (srcEncoding.isa<SharedEncodingAttr>() ||
dstEncoding.isa<SharedEncodingAttr>()) {
// Only blocked -> blocked conversion requires for scratch allocation
return;
}
// ConvertLayoutOp with both input/output non-shared_layout
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
// also possible to realize it with other approaches in restricted
// conditions, such as warp-shuffle
unsigned inVec = 0;
unsigned outVec = 0;
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}

View File

@@ -1,4 +1,5 @@
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
@@ -46,6 +47,11 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
} else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
}
}
DimVectorT contiguity(rank, 1);
@@ -203,6 +209,13 @@ ChangeResult AxisInfoAnalysis::visitOperation(
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// UnrealizedConversionCast
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
// may exist
if (llvm::isa<mlir::UnrealizedConversionCastOp>(op)) {
curr = operands[0]->getValue();
}
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}

View File

@@ -11,7 +11,9 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -26,8 +28,11 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
namespace mlir {
namespace LLVM {
@@ -43,11 +48,6 @@ namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO(Superjomn) Move to somewhere general utilities locates.
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
@@ -214,36 +214,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
}
};
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
ArrayRef<int64_t> shape) {
assert(multidim_index.size() == shape.size());
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
int64_t rank = shape.size();
int64_t acc_mul = 1;
for (int64_t i = 1; i < rank; ++i) {
acc_mul *= shape[i];
}
int64_t linear_index = 0;
for (int64_t i = 0; i < rank; ++i) {
linear_index += multidim_index[i] * acc_mul;
if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1];
}
}
return linear_index;
}
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
size_t rank = shape.size();
SmallVector<unsigned> elemsPerThreadPerDim(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
}
return product<unsigned>(elemsPerThreadPerDim);
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
@@ -309,9 +279,9 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
}
struct ConvertTritonGPUOpToLLVMPatternBase {
SmallVector<Value>
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter &rewriter) {
SmallVector<Value> results(elems);
for (unsigned i = 0; i < elems; ++i) {
Type type =
@@ -344,7 +314,12 @@ public:
for (unsigned i = 0; i < rank; ++i) {
reordered[i] = shape[order[i]];
}
return delinearize(rewriter, loc, linear, reordered);
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
@@ -370,13 +345,29 @@ public:
return multiDim;
}
// Emit indices calculation within each ConversionPattern
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
int rank = multiDim.size();
Value linear = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), 0);
if (rank > 0) {
linear = multiDim.front();
for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) {
Value dimSize = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(),
std::get<1>(z));
linear = rewriter.create<LLVM::AddOp>(
loc, rewriter.create<LLVM::MulOp>(loc, linear, dimSize),
std::get<0>(z));
}
}
return linear;
}
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = b.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
@@ -391,7 +382,6 @@ public:
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
auto order = blocked_layout.getOrder();
unsigned rank = shape.size();
SmallVector<Value, 4> threadIds(rank);
// step 1, delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
@@ -400,8 +390,19 @@ public:
delinearize(b, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] *
// threadsPerWarp[k]) *
// Wrap around multiDimWarpId/multiDimThreadId incase
// shape[k] > shapePerCTA[k]
unsigned maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
multiDimWarpId[k] = b.create<LLVM::URemOp>(
loc, multiDimWarpId[k],
createIndexAttrConstant(b, loc, llvmIndexTy, maxWarps));
multiDimThreadId[k] = b.create<LLVM::URemOp>(
loc, multiDimThreadId[k],
createIndexAttrConstant(b, loc, llvmIndexTy, maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK =
createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]);
@@ -413,17 +414,100 @@ public:
loc, multiDimThreadId[k],
b.create<LLVM::MulOp>(loc, multiDimWarpId[k], threadsPerWarpK)));
}
return multiDimBase;
}
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &b,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
size_t rank = shape.size();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim) {
paddedShape[d] = shape[d];
} else if (d == dim) {
paddedShape[d] = 1;
} else {
paddedShape[d] = shape[d - 1];
}
}
auto paddedIndices =
emitIndicesForBlockedLayout(loc, b, blockedParent, paddedShape);
unsigned numIndices = paddedIndices.size();
SmallVector<SmallVector<Value>> resultIndices(numIndices);
for (unsigned i = 0; i < numIndices; ++i) {
for (unsigned d = 0; d < rank + 1; ++d) {
if (d != dim) {
resultIndices[i].push_back(paddedIndices[i][d]);
}
}
}
return resultIndices;
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
"is not implemented yet");
return {};
} else {
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
"slice not implemented yet");
return {};
}
}
// Emit indices calculation within each ConversionPattern
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization. We might
// implement a indiceCache if necessary.
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA(rank);
for (unsigned k = 0; k < rank; ++k) {
shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k];
}
// step 1, delinearize threadId to get the base index
auto multiDimBase =
emitBaseIndexForBlockedLayout(loc, b, blockedLayout, shape);
// step 2, get offset of each element
unsigned elemsPerThread = 1;
SmallVector<SmallVector<unsigned>> offset(rank);
SmallVector<unsigned> multiDimElemsPerThread(rank);
for (unsigned k = 0; k < rank; ++k) {
multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k];
multiDimElemsPerThread[k] =
ceil<unsigned>(shape[k], shapePerCTA[k]) * sizePerThread[k];
elemsPerThread *= multiDimElemsPerThread[k];
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
for (unsigned blockOffset = 0;
blockOffset <
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
@@ -445,7 +529,7 @@ public:
std::multiplies<unsigned>());
SmallVector<unsigned> threadsPerDim(rank);
for (unsigned k = 0; k < rank; ++k) {
threadsPerDim[k] = shape[k] / sizePerThread[k];
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
}
for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / accumSizePerThread;
@@ -469,6 +553,20 @@ public:
return multiDimIdx;
}
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
Value smem, const Allocation *allocation,
Operation *op) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
auto bufferId = allocation->getBufferId(op);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset);
Value base = rewriter.create<LLVM::GEPOp>(loc, ptrTy, smem, offVal);
return base;
}
};
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
@@ -482,19 +580,10 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
auto numElems = layout.getSizePerThread();
size_t totalElems =
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
std::multiplies<>{});
size_t numThreads =
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
size_t numElemsPerThread = totalElems / numThreads;
size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
auto structTy =
@@ -580,7 +669,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
unsigned valueElems = layout.getElemsPerThread(shape);
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
@@ -595,16 +684,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
auto ty = val.getType().cast<RankedTensorType>();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(layout && "unexpected layout in getLayout");
auto shape = ty.getShape();
unsigned valueElems = getElemsPerThread(layout, shape);
unsigned valueElems = layout.getElemsPerThread(shape);
return std::make_tuple(layout, valueElems);
}
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(val);
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
@@ -614,22 +702,18 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
unsigned getVectorizeSize(Value ptr,
const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(ptr);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlignment(ptr, layout);
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
return shape;
};
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
auto ty = ptr.getType().dyn_cast<RankedTensorType>();
assert(ty);
auto shape = ty.getShape();
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
return vec;
}
@@ -819,25 +903,22 @@ struct BroadcastOpConversion
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
// TODO: [goostavz] double confirm the op semantics with Phil
assert(rank == resultTy.getRank());
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
SmallVector<unsigned, 2> broadcastDims;
SmallVector<int64_t, 2> broadcastSizes;
int64_t duplicates = 1;
for (unsigned d = 0; d < rank; ++d) {
int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d]);
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d];
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
if (srcShape[d] != resultShape[d]) {
assert(srcShape[d] == 1);
broadcastDims.push_back(d);
broadcastSizes.push_back(resultShape[d]);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] = 1;
duplicates *= resultShape[d];
srcLogicalShape[d + rank] =
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
@@ -845,18 +926,37 @@ struct BroadcastOpConversion
resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
}
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
int64_t duplicates = 1;
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
for (auto it : llvm::enumerate(broadcastDims)) {
// Incase there are multiple indices in the src that is actually
// calculating the same element, srcLogicalShape may not need to be 1.
// Such as the case when src of shape [256, 1], and with a blocked layout:
// sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2]
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
broadcastSizes[it.index()] = d;
duplicates *= d;
d = resultLogicalShape[it.value() + rank] /
srcLogicalShape[it.value() + rank];
broadcastSizes[it.index() + broadcastDims.size()] = d;
duplicates *= d;
}
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < srcElems; ++i) {
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
auto resultMultiDim = srcMultiDim;
for (int64_t j = 0; j < duplicates; ++j) {
auto resultMultiDim = srcMultiDim;
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()];
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
resultMultiDim[bcastDim.value() + rank] +=
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
}
auto resultLinearIndex =
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
@@ -871,27 +971,29 @@ struct BroadcastOpConversion
}
};
struct ViewOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ViewOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern;
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor,
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly
// rewriter.replaceOp(op, adaptor.src());
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().cast<RankedTensorType>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals =
this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
@@ -911,12 +1013,12 @@ struct MakeRangeOpConversion
Location loc = op->getLoc();
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
auto layout = rankedTy.getEncoding();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
auto idxs = emitIndices(loc, rewriter, layout, shape);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) {
@@ -1184,8 +1286,9 @@ struct AddPtrOpConversion
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
@@ -1225,7 +1328,8 @@ public:
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
@@ -1244,6 +1348,228 @@ public:
}
};
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
ConvertLayoutOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
benefit),
allocation_(allocation), smem_(smem) {}
LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
(!dstLayout.isa<BlockedEncodingAttr>())) {
// TODO: not implemented
assert(0 &&
"convert_layout except for blocked -> blocked is not implemented");
return failure();
}
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, smem_, allocation_,
op.getOperation());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
smemBase = rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
auto getContigPerThread = [&](const Attribute &layout,
unsigned d) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d];
} else {
assert(0 && "Unimplemented usage of getContigPerThread");
return 0;
}
};
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return product<unsigned>(blockedLayout.getSizePerThread());
} else {
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
return 0;
}
};
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getOrder();
} else {
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
return {};
}
};
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA =
std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d));
unsigned outPerCTA =
std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d));
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
// TODO: confirm this
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstLayout, shape);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
rewriter.create<mlir::gpu::BarrierOp>(loc);
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) {
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy,
inNumCTAsEachRep, multiDimRepId, inVec,
paddedRepShape, outOrd, vals, smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
rewriter.create<mlir::gpu::BarrierOp>(loc);
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec,
paddedRepShape, outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
}
private:
template <typename T>
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[rank - 1 - it.value()] = input[it.index()];
}
return result;
};
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding().cast<BlockedEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = layout.getSizePerThread();
auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank);
SmallVector<unsigned> shapePerCTA(rank);
for (unsigned d = 0; d < rank; ++d) {
shapePerCTA[d] = layout.getSizePerThread()[d] *
layout.getThreadsPerWarp()[d] *
layout.getWarpsPerCTA()[d];
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape());
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
// for (unsigned elemId = linearCTAId * accumSizePerThread;
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
auto multiDimElemId =
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
loc, multiDimOffsetFirstElem[d],
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd));
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value ptr =
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, smemBase, offset);
auto vecTy = VectorType::get(vec, llvmElemTy);
ptr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(vecTy, 3), ptr);
if (stNotRd) {
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v);
valVec = rewriter.create<LLVM::InsertElementOp>(
loc, vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
}
rewriter.create<LLVM::StoreOp>(loc, valVec, ptr);
} else {
Value valVec = rewriter.create<LLVM::LoadOp>(loc, ptr);
for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v);
vals[elemId + linearCTAId * accumSizePerThread + v] =
rewriter.create<LLVM::ExtractElementOp>(loc, llvmElemTy, valVec,
vVal);
}
}
}
}
}
const Allocation *allocation_;
Value smem_;
};
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
@@ -1266,9 +1592,10 @@ public:
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
Attribute layout = type.getEncoding();
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
if (layout && (layout.isa<BlockedEncodingAttr>() ||
layout.isa<SliceEncodingAttr>())) {
unsigned numElementsPerThread =
getElemsPerThread(blocked_layout, type.getShape());
getElemsPerThread(layout, type.getShape());
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
@@ -1285,7 +1612,8 @@ public:
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &analysis,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit = 1) {
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
@@ -1296,17 +1624,19 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
}
class ConvertTritonGPUToLLVM
@@ -1322,19 +1652,34 @@ public:
// TODO: need confirm
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, typeConverter);
TritonLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion
// step 2: Allocate for shared memories
// step 3: Convert the rest of ops via partial conversion
// The reason for a seperation between 1/3 is that, step 2 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem_
// is not revised during the conversion of step 3.
RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure();
Allocation allocation(mod);
auto axisAnalysis = runAxisAnalysis(mod);
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
RewritePatternSet patterns(context);
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
*axisAnalysis, 10 /*benefit*/);
*axisAnalysis, &allocation, smem_,
10 /*benefit*/);
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
@@ -1352,10 +1697,35 @@ protected:
auto axisAnalysisPass =
std::make_unique<AxisInfoAnalysis>(module->getContext());
axisAnalysisPass->run(module);
return axisAnalysisPass;
}
void initSharedMemory(size_t size,
TritonGPUToLLVMTypeConverter &typeConverter);
Value smem_;
};
void ConvertTritonGPUToLLVM::initSharedMemory(
size_t size, TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal,
"global_smem", /*value=*/Attribute(),
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
SmallVector<LLVM::LLVMFuncOp> funcs;
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
assert(funcs.size() == 1 &&
"Inliner pass is expected before TritonGPUToLLVM");
b.setInsertionPointToStart(&funcs[0].getBody().front());
smem_ = b.create<LLVM::AddressOfOp>(loc, global);
}
} // namespace
namespace mlir {
@@ -1366,10 +1736,20 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
// addIllegalDialect<triton::TritonDialect>();
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
addLegalDialect<LLVM::LLVMDialect>();
// addLegalDialect<NVVM::NVVMDialect>();
addIllegalOp<mlir::FuncOp>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {

View File

@@ -39,6 +39,37 @@ static Type getPointeeType(Type type) {
return Type();
}
namespace gpu {
// TODO: Inheritation of layout attributes
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
size_t rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return sliceLayout.getElemsPerThread(shape);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return mmaLayout.getElemsPerThread(shape);
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return sharedLayout.getElemsPerThread(shape);
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
}
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d];
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
return 0;
}
};
} // namespace gpu
} // namespace triton
} // namespace mlir
@@ -108,6 +139,55 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
return SliceEncodingAttr::get(getContext(), axis, *this);
}
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == getSizePerThread().size() &&
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
SmallVector<unsigned> elemsPerThreadPerDim(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t =
getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i];
elemsPerThreadPerDim[i] =
ceil<unsigned>(shape[i], t) * getSizePerThread()[i];
}
return product<unsigned>(elemsPerThreadPerDim);
}
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
unsigned dim = getDim();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
SmallVector<int64_t> paddedShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
paddedShape[d] = shape[d];
else if (d == dim)
paddedShape[d] = 1;
else
paddedShape[d] = shape[d - 1];
}
return blockedParent.getElemsPerThread(paddedShape);
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
// TODO:
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
return 0;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
// TODO:
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
return 0;
}
//===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//

View File

@@ -14,6 +14,7 @@
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/driver/llvm.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/Constants.h"
namespace mlir {
@@ -124,6 +125,17 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module) {
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
pm.enableIRPrinting(
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
[](mlir::Pass *pass, mlir::Operation *) {
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
},
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(createConvertTritonGPUToLLVMPass());
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp

View File

@@ -19,6 +19,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
@@ -100,14 +101,6 @@ long pow2_divisor(long N) {
return 1;
}
bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}
// Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype.
std::string dtype_cache_key_part(const py::object &dtype) {
@@ -1635,7 +1628,7 @@ void init_triton_ir(py::module &&m) {
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
[](mlir::Pass *pass, mlir::Operation *) {
return getBoolEnv("MLIR_ENABLE_DUMP");
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
},
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,

View File

@@ -0,0 +1,68 @@
import pytest
import torch
from torch.testing import assert_allclose
import triton
import triton.language as tl
import triton.runtime as runtime
@triton.jit
def kernel(x_ptr, stride_xm,
z_ptr, stride_zn,
SIZE_M: tl.constexpr, SIZE_N: tl.constexpr):
off_m = tl.arange(0, SIZE_M)
off_n = tl.arange(0, SIZE_N)
Xs = x_ptr + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = z_ptr + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# These sizes cover the case of:
# - blocked layout and sliced layout with block parent
# -- blocked layout in which sizePerThread/threadsPerWarp/warpsPerCTA
# need/need not to be wrapped
# -- sliced layout incase sizePerThread need to be wrapped
# -- different orders
# - LayoutConversion from blocked -> blocked
# - tt.Broadcast which requires for broadcast in either/both of
# CTA/perThread level
# What is not covered and requires for TODO:
# - vectorization load/store of shared memory
# - multiple replication of layout conversion
@pytest.mark.parametrize('NUM_WARPS,SIZE_M,SIZE_N', [
[1, 16, 16],
[1, 32, 32],
[1, 32, 64],
[2, 64, 128],
[2, 128, 64]
])
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
# TODO: this is to initialize the cuda context since it is not properly
# dealed with in the existing runtime, remove this when the runtime
# is updated
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel,
"*fp32,i32,*fp32,i32",
constants={"SIZE_M": SIZE_M,
"SIZE_N": SIZE_N},
num_warps=NUM_WARPS,
num_stages=3)
grid = lambda META: (1, )
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
runtime.launch_kernel(kernel=binary,
device=device,
grid=grid,
x_ptr=x,
stride_xm=x.stride(0),
z_ptr=z,
stride_zn=z.stride(0),
SIZE_M=tl.constexpr(SIZE_M),
SIZE_N=tl.constexpr(SIZE_N))
golden_z = torch.t(x)
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)

View File

@@ -48,6 +48,7 @@ def vecadd_no_scf_tester(num_warps, block_size):
def test_vecadd_no_scf():
vecadd_no_scf_tester(num_warps=4, block_size=256)
vecadd_no_scf_tester(num_warps=2, block_size=256)
vecadd_no_scf_tester(num_warps=1, block_size=256)

View File

@@ -1,37 +0,0 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s
func @test_splat(%ptr: !tt.ptr<f32>) {
// Here, 128 elements, 64(2*32) threads, so each need to process 2 elements
//
// CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr<f32, 1> to !llvm.ptr<f32, 1>
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
// CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
// CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
%true = arith.constant 1 : i1
%b = tt.splat %a : (f32) -> tensor<128xf32>
// Here, each thread process only 1 element
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)>
%mask = tt.splat %true : (i1) -> tensor<64xi1>
return
}
// -----
func @test_store_splat(%ptr: !tt.ptr<f32>) {
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
%true = arith.constant 1 : i1
%vs = tt.splat %a : (f32) -> tensor<128xf32>
%mask = tt.splat %true : (i1) -> tensor<128xi1>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };",
// CHECK-SAME: "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
tt.store %ptrs, %vs, %mask : tensor<128xf32>
return
}

View File

@@ -1,16 +1,13 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}
} // end module
// -----
@@ -58,7 +55,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
// TODO: Pending on the support of isSplat constant
// TODO: masked load with vectorization is pending on TODO
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
@@ -71,10 +68,23 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
// TODO: masked load with vectorization is pending on TODO
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other_vec
func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256
func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
// CHECK-LABEL: global_load_store_no_vec
func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -86,22 +96,107 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// CHECK: ld.global.v4.b32
// Load 4 elements from vector0
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// Load 4 elements from vector1
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
// CHECK: ld.global.v4.b32
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global
// CHECK: st.global.b32.v4
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_vec4
func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Load 4 elements from A with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// Load 4 elements from B with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global with single one vectorized store instruction
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Load 8 elements from A with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// Load 8 elements from B with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 8 elements to global with two vectorized store instruction
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
// is from an addptr with const idx
@@ -217,10 +312,121 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_vec
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
}
}
// TODO: problems in MLIR's parser on slice layout
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
// func @make_range_sliced_layout() {
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
// return
// }
// }