[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)
This commit is contained in:
@@ -64,16 +64,6 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
|
||||
if (failed(pm.run(module->getOperation()))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
|
@@ -14,7 +14,12 @@ namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class AllocationAnalysis;
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
/// A class that represents an interval, specified using a start and an end
|
||||
|
@@ -2,7 +2,10 @@
|
||||
#define TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
@@ -11,6 +14,12 @@ bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||
|
@@ -18,6 +18,14 @@ public:
|
||||
mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
mlir::LLVMTypeConverter &typeConverter;
|
||||
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Names for identifying different NVVM annotations. It is used as attribute
|
||||
|
@@ -16,4 +16,16 @@
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
||||
|
||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
@@ -31,6 +31,10 @@ Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
|
||||
|
||||
Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -64,6 +68,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -93,6 +99,8 @@ Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
||||
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -171,11 +179,10 @@ for
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
@@ -282,6 +289,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
"unsigned":$version,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
@@ -311,6 +320,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
// TODO: constraint here to only take distributed encodings
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
||||
|
@@ -22,6 +22,7 @@
|
||||
#ifndef TDL_TOOLS_SYS_GETENV_HPP
|
||||
#define TDL_TOOLS_SYS_GETENV_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
@@ -37,6 +38,14 @@ inline std::string getenv(const char *name) {
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(const std::string &env) {
|
||||
const char *s = std::getenv(env.c_str());
|
||||
std::string str(s ? s : "");
|
||||
std::transform(str.begin(), str.end(), str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return (str == "on" || str == "true" || str == "1");
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
|
||||
} // namespace triton
|
||||
|
@@ -8,6 +8,11 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -15,6 +20,54 @@ namespace mlir {
|
||||
// Shared Memory Allocation Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace triton {
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.result().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
// TODO: move to TritonGPUAttrDefs.h.inc
|
||||
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>();
|
||||
auto inOrd = srcBlockedLayout.getOrder();
|
||||
auto outOrd = dstBlockedLayout.getOrder();
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1
|
||||
: inOrd[0] == 0 ? 1
|
||||
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
|
||||
outVec =
|
||||
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] = std::max(
|
||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
||||
std::min<unsigned>(dstTy.getShape()[d],
|
||||
getShapePerCTA(dstLayout, d)));
|
||||
}
|
||||
paddedRepShape[outOrd[0]] += pad;
|
||||
}
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
class AllocationAnalysis {
|
||||
public:
|
||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||
@@ -73,6 +126,27 @@ private:
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
auto dstEncoding = dstTy.getEncoding();
|
||||
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<SharedEncodingAttr>()) {
|
||||
// Only blocked -> blocked conversion requires for scratch allocation
|
||||
return;
|
||||
}
|
||||
// ConvertLayoutOp with both input/output non-shared_layout
|
||||
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
|
||||
// also possible to realize it with other approaches in restricted
|
||||
// conditions, such as warp-shuffle
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
@@ -46,6 +47,11 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
} else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op)) {
|
||||
Attribute attr =
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
DimVectorT contiguity(rank, 1);
|
||||
@@ -203,6 +209,13 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// UnrealizedConversionCast
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
// may exist
|
||||
if (llvm::isa<mlir::UnrealizedConversionCastOp>(op)) {
|
||||
curr = operands[0]->getValue();
|
||||
}
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
|
@@ -11,7 +11,9 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
@@ -26,8 +28,11 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
@@ -43,11 +48,6 @@ namespace type = mlir::triton::type;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter;
|
||||
|
||||
// TODO(Superjomn) Move to somewhere general utilities locates.
|
||||
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
@@ -214,36 +214,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
}
|
||||
};
|
||||
|
||||
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
ArrayRef<int64_t> shape) {
|
||||
assert(multidim_index.size() == shape.size());
|
||||
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
|
||||
int64_t rank = shape.size();
|
||||
int64_t acc_mul = 1;
|
||||
for (int64_t i = 1; i < rank; ++i) {
|
||||
acc_mul *= shape[i];
|
||||
}
|
||||
int64_t linear_index = 0;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
linear_index += multidim_index[i] * acc_mul;
|
||||
if (i != (rank - 1)) {
|
||||
acc_mul = acc_mul / shape[i + 1];
|
||||
}
|
||||
}
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
size_t rank = shape.size();
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
|
||||
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
|
||||
}
|
||||
return product<unsigned>(elemsPerThreadPerDim);
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
@@ -309,9 +279,9 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
|
||||
}
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
SmallVector<Value>
|
||||
static SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
SmallVector<Value> results(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
Type type =
|
||||
@@ -344,7 +314,12 @@ public:
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
reordered[i] = shape[order[i]];
|
||||
}
|
||||
return delinearize(rewriter, loc, linear, reordered);
|
||||
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
@@ -370,13 +345,29 @@ public:
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
int rank = multiDim.size();
|
||||
Value linear = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), 0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.front();
|
||||
for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) {
|
||||
Value dimSize = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(),
|
||||
std::get<1>(z));
|
||||
linear = rewriter.create<LLVM::AddOp>(
|
||||
loc, rewriter.create<LLVM::MulOp>(loc, linear, dimSize),
|
||||
std::get<0>(z));
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = b.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
@@ -391,7 +382,6 @@ public:
|
||||
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
|
||||
auto order = blocked_layout.getOrder();
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<Value, 4> threadIds(rank);
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
@@ -400,8 +390,19 @@ public:
|
||||
delinearize(b, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimBase(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] *
|
||||
// threadsPerWarp[k]) *
|
||||
// Wrap around multiDimWarpId/multiDimThreadId incase
|
||||
// shape[k] > shapePerCTA[k]
|
||||
unsigned maxWarps =
|
||||
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
multiDimWarpId[k] = b.create<LLVM::URemOp>(
|
||||
loc, multiDimWarpId[k],
|
||||
createIndexAttrConstant(b, loc, llvmIndexTy, maxWarps));
|
||||
multiDimThreadId[k] = b.create<LLVM::URemOp>(
|
||||
loc, multiDimThreadId[k],
|
||||
createIndexAttrConstant(b, loc, llvmIndexTy, maxThreads));
|
||||
// multiDimBase[k] = (multiDimThreadId[k] +
|
||||
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
||||
// sizePerThread[k];
|
||||
Value threadsPerWarpK =
|
||||
createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]);
|
||||
@@ -413,17 +414,100 @@ public:
|
||||
loc, multiDimThreadId[k],
|
||||
b.create<LLVM::MulOp>(loc, multiDimWarpId[k], threadsPerWarpK)));
|
||||
}
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim) {
|
||||
paddedShape[d] = shape[d];
|
||||
} else if (d == dim) {
|
||||
paddedShape[d] = 1;
|
||||
} else {
|
||||
paddedShape[d] = shape[d - 1];
|
||||
}
|
||||
}
|
||||
auto paddedIndices =
|
||||
emitIndicesForBlockedLayout(loc, b, blockedParent, paddedShape);
|
||||
unsigned numIndices = paddedIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
||||
for (unsigned i = 0; i < numIndices; ++i) {
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d != dim) {
|
||||
resultIndices[i].push_back(paddedIndices[i][d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return resultIndices;
|
||||
|
||||
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
|
||||
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
|
||||
"is not implemented yet");
|
||||
return {};
|
||||
|
||||
} else {
|
||||
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
|
||||
"slice not implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||
// implement a indiceCache if necessary.
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k];
|
||||
}
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase =
|
||||
emitBaseIndexForBlockedLayout(loc, b, blockedLayout, shape);
|
||||
|
||||
// step 2, get offset of each element
|
||||
unsigned elemsPerThread = 1;
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
SmallVector<unsigned> multiDimElemsPerThread(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k];
|
||||
multiDimElemsPerThread[k] =
|
||||
ceil<unsigned>(shape[k], shapePerCTA[k]) * sizePerThread[k];
|
||||
elemsPerThread *= multiDimElemsPerThread[k];
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
for (unsigned blockOffset = 0;
|
||||
blockOffset <
|
||||
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
|
||||
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
++blockOffset)
|
||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
||||
@@ -445,7 +529,7 @@ public:
|
||||
std::multiplies<unsigned>());
|
||||
SmallVector<unsigned> threadsPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
threadsPerDim[k] = shape[k] / sizePerThread[k];
|
||||
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
}
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
unsigned linearNanoTileId = n / accumSizePerThread;
|
||||
@@ -469,6 +553,20 @@ public:
|
||||
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value smem, const Allocation *allocation,
|
||||
Operation *op) const {
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset);
|
||||
Value base = rewriter.create<LLVM::GEPOp>(loc, ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
};
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
@@ -482,19 +580,10 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
|
||||
|
||||
auto numElems = layout.getSizePerThread();
|
||||
size_t totalElems =
|
||||
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
|
||||
std::multiplies<>{});
|
||||
size_t numThreads =
|
||||
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
|
||||
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
|
||||
size_t numElemsPerThread = totalElems / numThreads;
|
||||
|
||||
size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
||||
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
@@ -580,7 +669,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||
|
||||
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
|
||||
auto llvmElemPtrPtrTy =
|
||||
@@ -595,16 +684,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(layout && "unexpected layout in getLayout");
|
||||
auto shape = ty.getShape();
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||
return std::make_tuple(layout, valueElems);
|
||||
}
|
||||
|
||||
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
|
||||
auto order = layout.getOrder();
|
||||
|
||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
@@ -614,22 +702,18 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
unsigned getVectorizeSize(Value ptr,
|
||||
const BlockedEncodingAttr &layout) const {
|
||||
auto axisInfo = getAxisInfo(ptr);
|
||||
auto contig = axisInfo->getContiguity();
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = layout.getOrder();
|
||||
unsigned align = getAlignment(ptr, layout);
|
||||
|
||||
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
return shape;
|
||||
};
|
||||
|
||||
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
||||
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
|
||||
auto ty = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
assert(ty);
|
||||
auto shape = ty.getShape();
|
||||
|
||||
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
@@ -819,25 +903,22 @@ struct BroadcastOpConversion
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
// TODO: [goostavz] double confirm the op semantics with Phil
|
||||
assert(rank == resultTy.getRank());
|
||||
|
||||
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
|
||||
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
|
||||
SmallVector<unsigned, 2> broadcastDims;
|
||||
SmallVector<int64_t, 2> broadcastSizes;
|
||||
int64_t duplicates = 1;
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
resultLayout.getWarpsPerCTA()[d]);
|
||||
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
resultLayout.getWarpsPerCTA()[d];
|
||||
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
|
||||
if (srcShape[d] != resultShape[d]) {
|
||||
assert(srcShape[d] == 1);
|
||||
broadcastDims.push_back(d);
|
||||
broadcastSizes.push_back(resultShape[d]);
|
||||
srcLogicalShape[d] = 1;
|
||||
srcLogicalShape[d + rank] = 1;
|
||||
duplicates *= resultShape[d];
|
||||
srcLogicalShape[d + rank] =
|
||||
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
|
||||
} else {
|
||||
srcLogicalShape[d] = numCtas;
|
||||
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
@@ -845,18 +926,37 @@ struct BroadcastOpConversion
|
||||
resultLogicalShape[d] = numCtas;
|
||||
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
}
|
||||
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
||||
int64_t duplicates = 1;
|
||||
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
|
||||
for (auto it : llvm::enumerate(broadcastDims)) {
|
||||
// Incase there are multiple indices in the src that is actually
|
||||
// calculating the same element, srcLogicalShape may not need to be 1.
|
||||
// Such as the case when src of shape [256, 1], and with a blocked layout:
|
||||
// sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2]
|
||||
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
||||
broadcastSizes[it.index()] = d;
|
||||
duplicates *= d;
|
||||
d = resultLogicalShape[it.value() + rank] /
|
||||
srcLogicalShape[it.value() + rank];
|
||||
broadcastSizes[it.index() + broadcastDims.size()] = d;
|
||||
duplicates *= d;
|
||||
}
|
||||
|
||||
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
|
||||
auto elemTy = resultTy.getElementType();
|
||||
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
|
||||
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
||||
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
|
||||
auto resultMultiDim = srcMultiDim;
|
||||
for (int64_t j = 0; j < duplicates; ++j) {
|
||||
auto resultMultiDim = srcMultiDim;
|
||||
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
|
||||
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
|
||||
resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()];
|
||||
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
|
||||
resultMultiDim[bcastDim.value() + rank] +=
|
||||
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
|
||||
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
|
||||
}
|
||||
auto resultLinearIndex =
|
||||
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
|
||||
@@ -871,27 +971,29 @@ struct BroadcastOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct ViewOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ViewOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// We cannot directly
|
||||
// rewriter.replaceOp(op, adaptor.src());
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().cast<RankedTensorType>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto vals =
|
||||
this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
Value view = getStructFromElements(loc, vals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
@@ -911,12 +1013,12 @@ struct MakeRangeOpConversion
|
||||
Location loc = op->getLoc();
|
||||
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
|
||||
auto shape = rankedTy.getShape();
|
||||
auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto layout = rankedTy.getEncoding();
|
||||
|
||||
auto elemTy = rankedTy.getElementType();
|
||||
assert(elemTy.isInteger(32));
|
||||
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
|
||||
auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
|
||||
auto idxs = emitIndices(loc, rewriter, layout, shape);
|
||||
unsigned elems = idxs.size();
|
||||
SmallVector<Value> retVals(elems);
|
||||
for (auto multiDim : llvm::enumerate(idxs)) {
|
||||
@@ -1184,8 +1286,9 @@ struct AddPtrOpConversion
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
@@ -1225,7 +1328,8 @@ public:
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
@@ -1244,6 +1348,228 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertLayoutOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
||||
public:
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
ConvertLayoutOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
|
||||
benefit),
|
||||
allocation_(allocation), smem_(smem) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
||||
// TODO: not implemented
|
||||
assert(0 &&
|
||||
"convert_layout except for blocked -> blocked is not implemented");
|
||||
return failure();
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, smem_, allocation_,
|
||||
op.getOperation());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
smemBase = rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
||||
|
||||
auto shape = dstTy.getShape();
|
||||
unsigned rank = dstTy.getRank();
|
||||
auto getContigPerThread = [&](const Attribute &layout,
|
||||
unsigned d) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getContigPerThread");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return product<unsigned>(blockedLayout.getSizePerThread());
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getOrder();
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
SmallVector<unsigned> numReplicates(rank);
|
||||
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> inNumCTAs(rank);
|
||||
SmallVector<unsigned> outNumCTAs(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned inPerCTA =
|
||||
std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d));
|
||||
unsigned outPerCTA =
|
||||
std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d));
|
||||
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
||||
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
||||
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
||||
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
|
||||
// TODO: confirm this
|
||||
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
|
||||
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
|
||||
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
|
||||
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
|
||||
unsigned outElems = getElemsPerThread(dstLayout, shape);
|
||||
auto outOrd = getOrder(dstLayout);
|
||||
SmallVector<Value> outVals(outElems);
|
||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy,
|
||||
inNumCTAsEachRep, multiDimRepId, inVec,
|
||||
paddedRepShape, outOrd, vals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with input layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec,
|
||||
paddedRepShape, outOrd, outVals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with output layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[rank - 1 - it.value()] = input[it.index()];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
ArrayRef<unsigned> numCTAsEachRep,
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape,
|
||||
ArrayRef<unsigned> outOrd,
|
||||
SmallVector<Value> &vals, Value smemBase) const {
|
||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto layout = type.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = layout.getSizePerThread();
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
SmallVector<unsigned> shapePerCTA(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
shapePerCTA[d] = layout.getSizePerThread()[d] *
|
||||
layout.getThreadsPerWarp()[d] *
|
||||
layout.getWarpsPerCTA()[d];
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape());
|
||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||
auto multiDimCTAInRepId =
|
||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
||||
SmallVector<unsigned> multiDimCTAId(rank);
|
||||
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||
auto d = it.index();
|
||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||
}
|
||||
|
||||
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
|
||||
// TODO: This is actually redundant index calculation, we should
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
// for (unsigned elemId = linearCTAId * accumSizePerThread;
|
||||
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
auto multiDimElemId =
|
||||
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
|
||||
loc, multiDimOffsetFirstElem[d],
|
||||
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
||||
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
Value offset =
|
||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
||||
reorder<unsigned>(paddedRepShape, outOrd));
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value ptr =
|
||||
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, smemBase, offset);
|
||||
auto vecTy = VectorType::get(vec, llvmElemTy);
|
||||
ptr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMPointerType::get(vecTy, 3), ptr);
|
||||
if (stNotRd) {
|
||||
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value vVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
||||
valVec = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, vecTy, valVec,
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
|
||||
}
|
||||
rewriter.create<LLVM::StoreOp>(loc, valVec, ptr);
|
||||
} else {
|
||||
Value valVec = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value vVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v] =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, llvmElemTy, valVec,
|
||||
vVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Allocation *allocation_;
|
||||
Value smem_;
|
||||
};
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
@@ -1266,9 +1592,10 @@ public:
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
Attribute layout = type.getEncoding();
|
||||
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
if (layout && (layout.isa<BlockedEncodingAttr>() ||
|
||||
layout.isa<SliceEncodingAttr>())) {
|
||||
unsigned numElementsPerThread =
|
||||
getElemsPerThread(blocked_layout, type.getShape());
|
||||
getElemsPerThread(layout, type.getShape());
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
@@ -1285,7 +1612,8 @@ public:
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &analysis,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit = 1) {
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||
@@ -1296,17 +1624,19 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||
benefit);
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
}
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
@@ -1322,19 +1652,34 @@ public:
|
||||
// TODO: need confirm
|
||||
option.overrideIndexBitwidth(32);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, typeConverter);
|
||||
TritonLLVMConversionTarget target(*context, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 2: Allocate for shared memories
|
||||
// step 3: Convert the rest of ops via partial conversion
|
||||
// The reason for a seperation between 1/3 is that, step 2 is out of
|
||||
// the scope of Dialect Conversion, thus we need to make sure the smem_
|
||||
// is not revised during the conversion of step 3.
|
||||
RewritePatternSet func_patterns(context);
|
||||
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
Allocation allocation(mod);
|
||||
auto axisAnalysis = runAxisAnalysis(mod);
|
||||
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
||||
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
// patterns.
|
||||
RewritePatternSet patterns(context);
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisAnalysis, 10 /*benefit*/);
|
||||
*axisAnalysis, &allocation, smem_,
|
||||
10 /*benefit*/);
|
||||
|
||||
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
@@ -1352,10 +1697,35 @@ protected:
|
||||
auto axisAnalysisPass =
|
||||
std::make_unique<AxisInfoAnalysis>(module->getContext());
|
||||
axisAnalysisPass->run(module);
|
||||
|
||||
return axisAnalysisPass;
|
||||
}
|
||||
|
||||
void initSharedMemory(size_t size,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter);
|
||||
|
||||
Value smem_;
|
||||
};
|
||||
|
||||
void ConvertTritonGPUToLLVM::initSharedMemory(
|
||||
size_t size, TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
ModuleOp mod = getOperation();
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size);
|
||||
auto global = b.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal,
|
||||
"global_smem", /*value=*/Attribute(),
|
||||
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
|
||||
assert(funcs.size() == 1 &&
|
||||
"Inliner pass is expected before TritonGPUToLLVM");
|
||||
b.setInsertionPointToStart(&funcs[0].getBody().front());
|
||||
smem_ = b.create<LLVM::AddressOfOp>(loc, global);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@@ -1366,10 +1736,20 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
// addIllegalDialect<triton::TritonDialect>();
|
||||
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
|
||||
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
// addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalOp<mlir::FuncOp>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
||||
|
@@ -39,6 +39,37 @@ static Type getPointeeType(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// TODO: Inheritation of layout attributes
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return sliceLayout.getElemsPerThread(shape);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return mmaLayout.getElemsPerThread(shape);
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return sharedLayout.getElemsPerThread(shape);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -108,6 +139,55 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
||||
return SliceEncodingAttr::get(getContext(), axis, *this);
|
||||
}
|
||||
|
||||
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == getSizePerThread().size() &&
|
||||
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
unsigned t =
|
||||
getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i];
|
||||
elemsPerThreadPerDim[i] =
|
||||
ceil<unsigned>(shape[i], t) * getSizePerThread()[i];
|
||||
}
|
||||
return product<unsigned>(elemsPerThreadPerDim);
|
||||
}
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
unsigned dim = getDim();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||
SmallVector<int64_t> paddedShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
paddedShape[d] = shape[d];
|
||||
else if (d == dim)
|
||||
paddedShape[d] = 1;
|
||||
else
|
||||
paddedShape[d] = shape[d - 1];
|
||||
}
|
||||
return blockedParent.getElemsPerThread(paddedShape);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -14,6 +14,7 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "llvm/IR/Constants.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -124,6 +125,17 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module->getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
auto printingFlags = mlir::OpPrintingFlags();
|
||||
printingFlags.elideLargeElementsAttrs(16);
|
||||
pm.enableIRPrinting(
|
||||
/*shouldPrintBeforePass=*/nullptr,
|
||||
/*shouldPrintAfterPass=*/
|
||||
[](mlir::Pass *pass, mlir::Operation *) {
|
||||
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
|
||||
},
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
|
@@ -19,6 +19,7 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
@@ -100,14 +101,6 @@ long pow2_divisor(long N) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool getBoolEnv(const std::string &env) {
|
||||
const char *s = std::getenv(env.c_str());
|
||||
std::string str(s ? s : "");
|
||||
std::transform(str.begin(), str.end(), str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return (str == "on" || str == "true" || str == "1");
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
// triton.language.dtype.
|
||||
std::string dtype_cache_key_part(const py::object &dtype) {
|
||||
@@ -1635,7 +1628,7 @@ void init_triton_ir(py::module &&m) {
|
||||
/*shouldPrintBeforePass=*/nullptr,
|
||||
/*shouldPrintAfterPass=*/
|
||||
[](mlir::Pass *pass, mlir::Operation *) {
|
||||
return getBoolEnv("MLIR_ENABLE_DUMP");
|
||||
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
|
||||
},
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
|
68
python/tests/test_transpose.py
Normal file
68
python/tests/test_transpose.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xm,
|
||||
z_ptr, stride_zn,
|
||||
SIZE_M: tl.constexpr, SIZE_N: tl.constexpr):
|
||||
off_m = tl.arange(0, SIZE_M)
|
||||
off_n = tl.arange(0, SIZE_N)
|
||||
Xs = x_ptr + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
||||
Zs = z_ptr + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
# These sizes cover the case of:
|
||||
# - blocked layout and sliced layout with block parent
|
||||
# -- blocked layout in which sizePerThread/threadsPerWarp/warpsPerCTA
|
||||
# need/need not to be wrapped
|
||||
# -- sliced layout incase sizePerThread need to be wrapped
|
||||
# -- different orders
|
||||
# - LayoutConversion from blocked -> blocked
|
||||
# - tt.Broadcast which requires for broadcast in either/both of
|
||||
# CTA/perThread level
|
||||
|
||||
# What is not covered and requires for TODO:
|
||||
# - vectorization load/store of shared memory
|
||||
# - multiple replication of layout conversion
|
||||
|
||||
|
||||
@pytest.mark.parametrize('NUM_WARPS,SIZE_M,SIZE_N', [
|
||||
[1, 16, 16],
|
||||
[1, 32, 32],
|
||||
[1, 32, 64],
|
||||
[2, 64, 128],
|
||||
[2, 128, 64]
|
||||
])
|
||||
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
||||
# TODO: this is to initialize the cuda context since it is not properly
|
||||
# dealed with in the existing runtime, remove this when the runtime
|
||||
# is updated
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel,
|
||||
"*fp32,i32,*fp32,i32",
|
||||
constants={"SIZE_M": SIZE_M,
|
||||
"SIZE_N": SIZE_N},
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=3)
|
||||
grid = lambda META: (1, )
|
||||
|
||||
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
device=device,
|
||||
grid=grid,
|
||||
x_ptr=x,
|
||||
stride_xm=x.stride(0),
|
||||
z_ptr=z,
|
||||
stride_zn=z.stride(0),
|
||||
SIZE_M=tl.constexpr(SIZE_M),
|
||||
SIZE_N=tl.constexpr(SIZE_N))
|
||||
golden_z = torch.t(x)
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
@@ -48,6 +48,7 @@ def vecadd_no_scf_tester(num_warps, block_size):
|
||||
|
||||
|
||||
def test_vecadd_no_scf():
|
||||
vecadd_no_scf_tester(num_warps=4, block_size=256)
|
||||
vecadd_no_scf_tester(num_warps=2, block_size=256)
|
||||
vecadd_no_scf_tester(num_warps=1, block_size=256)
|
||||
|
||||
|
@@ -1,37 +0,0 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
func @test_splat(%ptr: !tt.ptr<f32>) {
|
||||
// Here, 128 elements, 64(2*32) threads, so each need to process 2 elements
|
||||
//
|
||||
// CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr<f32, 1> to !llvm.ptr<f32, 1>
|
||||
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
// CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
// CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
%true = arith.constant 1 : i1
|
||||
%b = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
|
||||
// Here, each thread process only 1 element
|
||||
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)>
|
||||
%mask = tt.splat %true : (i1) -> tensor<64xi1>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
%true = arith.constant 1 : i1
|
||||
|
||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };",
|
||||
// CHECK-SAME: "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
|
||||
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
||||
|
||||
return
|
||||
}
|
@@ -1,16 +1,13 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
}
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
@@ -58,7 +55,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Pending on the support of isSplat constant
|
||||
// TODO: masked load with vectorization is pending on TODO
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other
|
||||
@@ -71,10 +68,23 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: masked load with vectorization is pending on TODO
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other_vec
|
||||
func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256
|
||||
func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
// CHECK-LABEL: global_load_store_no_vec
|
||||
func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
@@ -86,22 +96,107 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// CHECK: ld.global.v4.b32
|
||||
// Load 4 elements from vector0
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
|
||||
// Load 4 elements from vector1
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
// CHECK: ld.global.v4.b32
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Store 4 elements to global
|
||||
// CHECK: st.global.b32.v4
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec4
|
||||
func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Load 4 elements from A with single one vectorized load instruction
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
|
||||
// Load 4 elements from B with single one vectorized load instruction
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
|
||||
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Store 4 elements to global with single one vectorized store instruction
|
||||
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec8
|
||||
func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Load 8 elements from A with two vectorized load instruction
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
|
||||
// Load 8 elements from B with two vectorized load instruction
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
|
||||
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Store 8 elements to global with two vectorized store instruction
|
||||
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
|
||||
// is from an addptr with const idx
|
||||
@@ -217,10 +312,121 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_store
|
||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked
|
||||
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_vec
|
||||
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
|
||||
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: nvvm.barrier0
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: problems in MLIR's parser on slice layout
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// func @make_range_sliced_layout() {
|
||||
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
// return
|
||||
// }
|
||||
// }
|
Reference in New Issue
Block a user