[BACKEND] Add LLVM-translation for store and splat ops (#47)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUToLLVM.cpp
|
||||
PtxAsmFormat.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||
|
81
lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp
Normal file
81
lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp
Normal file
@@ -0,0 +1,81 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
llvm::StringRef delimiter) {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++)
|
||||
os << strs[i] << delimiter;
|
||||
if (!strs.empty())
|
||||
os << strs.back();
|
||||
os.flush();
|
||||
return osStr;
|
||||
}
|
||||
|
||||
std::string PtxInstr::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
if (pred)
|
||||
os << "@" << pred->dump() << " ";
|
||||
|
||||
std::string instrRepr = strJoin(instrParts, ".");
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
argReprs.push_back(arg->dump());
|
||||
}
|
||||
|
||||
std::string argsRepr = strJoin(argReprs, ", ");
|
||||
|
||||
os << instrRepr << " " << argsRepr << ";";
|
||||
os.flush();
|
||||
return osStr;
|
||||
}
|
||||
|
||||
llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
|
||||
SmallVector<Operand *> args;
|
||||
for (auto *arg : argsInOrder) {
|
||||
if (arg->isList())
|
||||
args.insert(args.end(), arg->list.begin(), arg->list.end());
|
||||
else
|
||||
args.push_back(arg);
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
PtxInstr::Operand *
|
||||
PtxInstr::newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int)> formater) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||
auto *opr = argArchive.back().get();
|
||||
opr->repr = formater;
|
||||
opr->idx = oprCounter++;
|
||||
return opr;
|
||||
}
|
||||
|
||||
std::string PtxInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return llvm::formatv("%{0}", idx);
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
return "{ " + strJoin(oprs, ", ") + " }";
|
||||
}
|
||||
|
||||
PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
return llvm::formatv("[ %{0} + {1} ]", idx, off);
|
||||
};
|
||||
|
||||
return opr;
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
@@ -1,5 +1,4 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
@@ -10,10 +9,17 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -31,8 +37,15 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
|
||||
namespace {
|
||||
|
||||
namespace type = mlir::triton::type;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter;
|
||||
|
||||
// TODO(Superjomn) Move to somewhere general utilities locates.
|
||||
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
// The following code are borrowed from mlir project including the following
|
||||
// functions or classes:
|
||||
// - filterFuncAttributes
|
||||
@@ -81,7 +94,6 @@ protected:
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
assert(llvmType);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
@@ -124,6 +136,8 @@ protected:
|
||||
}
|
||||
linkage = attr.getLinkage();
|
||||
}
|
||||
|
||||
auto oldArgs = funcOp.getArguments();
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
@@ -134,6 +148,25 @@ protected:
|
||||
&result)))
|
||||
return nullptr;
|
||||
|
||||
// Convert argument
|
||||
llvm::DenseMap<Value, Value> argMap;
|
||||
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
|
||||
Value oldArg = oldArgs[i];
|
||||
Value newArg = newFuncOp.getArgument(i);
|
||||
argMap.try_emplace(oldArg, newArg);
|
||||
}
|
||||
|
||||
newFuncOp.getBody().walk([&](Operation *op) {
|
||||
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
|
||||
// ptr<fp16>
|
||||
for (int i = 0; i < op->getNumOperands(); i++) {
|
||||
auto arg = op->getOperand(i);
|
||||
auto it = argMap.find(arg);
|
||||
if (it != argMap.end())
|
||||
op->setOperand(i, it->second);
|
||||
}
|
||||
});
|
||||
|
||||
return newFuncOp;
|
||||
}
|
||||
};
|
||||
@@ -143,8 +176,9 @@ protected:
|
||||
/// information.
|
||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
|
||||
: FuncOpConversionBase(converter), NumWarps(numWarps) {}
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
PatternBenefit benefit)
|
||||
: FuncOpConversionBase(converter, benefit), NumWarps(numWarps) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
||||
@@ -154,11 +188,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
return failure();
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
auto i32 = IntegerType::get(ctx, 32);
|
||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
|
||||
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
||||
newFuncOp->setAttr(
|
||||
NVVMMetadataField::MaxNTid,
|
||||
rewriter.getIntegerAttr(type::i32Ty(ctx), 32 * NumWarps));
|
||||
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
@@ -190,22 +224,47 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||
// This is a naive implementation, it assumes that all the blocked layout should
|
||||
// have the same numWarps setting in a module, it just find a blocked layout
|
||||
// encoding and return the warpsPerCTA field.
|
||||
int extractNumWarps(mlir::ModuleOp module) {
|
||||
int numWarps{};
|
||||
if (module->hasAttr(AttrNumWarpsName))
|
||||
numWarps = module->getAttr(AttrNumWarpsName)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
ArrayRef<int64_t> shape) {
|
||||
assert(multidim_index.size() == shape.size());
|
||||
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
|
||||
int64_t rank = shape.size();
|
||||
int64_t acc_mul = 1;
|
||||
for (int64_t i = 1; i < rank; ++i) {
|
||||
acc_mul *= shape[i];
|
||||
}
|
||||
int64_t linear_index = 0;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
linear_index += multidim_index[i] * acc_mul;
|
||||
if (i != (rank - 1)) {
|
||||
acc_mul = acc_mul / shape[i + 1];
|
||||
}
|
||||
}
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
return numWarps;
|
||||
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return product(shape) / (product(layout.getThreadsPerWarp()) *
|
||||
product(layout.getWarpsPerCTA()));
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
||||
}
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType, Type elemPtrPtrType) {
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (auto v : llvm::enumerate(resultVals)) {
|
||||
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, structType, llvmStruct, v.value(),
|
||||
rewriter.getI64ArrayAttr(v.index()));
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -247,24 +306,6 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
unsigned elems = 1;
|
||||
size_t rank = shape.size();
|
||||
assert(rank == layout.getThreadsPerWarp().size());
|
||||
for (size_t d = 0; d < rank; ++d) {
|
||||
elems *=
|
||||
shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]);
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, resultType, builder.getIntegerAttr(resultType, value));
|
||||
}
|
||||
|
||||
template <typename SourceOp>
|
||||
class ConvertTritonGPUOpToLLVMPattern
|
||||
: public ConvertOpToLLVMPattern<SourceOp> {
|
||||
@@ -443,6 +484,358 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
// LLVM::StructType value.
|
||||
//
|
||||
// @elemType: the element type in operand.
|
||||
// @resType: the return type of the Splat-like op.
|
||||
// @constVal: a LLVM::ConstantOp or other scalar value.
|
||||
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
|
||||
|
||||
auto numElems = layout.getSizePerThread();
|
||||
size_t totalElems =
|
||||
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
|
||||
std::multiplies<>{});
|
||||
size_t numThreads =
|
||||
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
|
||||
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
|
||||
size_t numElemsPerThread = totalElems / numThreads;
|
||||
|
||||
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
auto llElemPtrPtrTy =
|
||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
|
||||
auto llStruct =
|
||||
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
|
||||
return llStruct;
|
||||
}
|
||||
|
||||
struct SplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto src = op->getOperand(0);
|
||||
|
||||
LLVM::ConstantOp arithConstantOp;
|
||||
if (src.getDefiningOp() &&
|
||||
(arithConstantOp =
|
||||
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
|
||||
Value constant;
|
||||
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
|
||||
|
||||
assert(values.size() == 1);
|
||||
Attribute val;
|
||||
if (type::isInt(src.getType())) {
|
||||
val = values.getValues<IntegerAttr>()[0];
|
||||
} else if (type::isFloat(src.getType())) {
|
||||
val = values.getValues<FloatAttr>()[0];
|
||||
} else {
|
||||
llvm::errs() << "Constant op type not supported";
|
||||
return failure();
|
||||
}
|
||||
|
||||
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
|
||||
}
|
||||
|
||||
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
|
||||
getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, {llStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the
|
||||
// logic is the same as triton::SplatOp, so the underlying implementation is
|
||||
// reused.
|
||||
struct ArithConstantSplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto value = op.getValue();
|
||||
if (!value.dyn_cast<SplatElementsAttr>())
|
||||
return failure();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
LLVM::ConstantOp arithConstantOp;
|
||||
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
|
||||
auto elemType = values.getElementType();
|
||||
|
||||
Attribute val;
|
||||
if (type::isInt(elemType)) {
|
||||
val = values.getValues<IntegerAttr>()[0];
|
||||
} else if (type::isFloat(elemType)) {
|
||||
val = values.getValues<FloatAttr>()[0];
|
||||
} else {
|
||||
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
|
||||
<< value.getType() << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
|
||||
auto llStruct = convertSplatLikeOp(elemType, op.getType(), constOp,
|
||||
getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, llStruct);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StoreOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreOpConversion(LLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
||||
AxisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value ptr = op.ptr();
|
||||
Value mask = op.mask();
|
||||
Value value = op.value();
|
||||
|
||||
Value llPtr = adaptor.ptr(); // should be LLVM ops
|
||||
Value llMask = adaptor.mask();
|
||||
Value llValue = adaptor.value();
|
||||
|
||||
Type valueElemTy = getTypeConverter()->convertType(
|
||||
value.getType().cast<RankedTensorType>().getElementType());
|
||||
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto getLLVMElems = [&](Value value, Value llValue,
|
||||
const TritonGPUBlockedEncodingAttr &layout)
|
||||
-> SmallVector<Value, 4> {
|
||||
auto ty = value.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
|
||||
auto llvmElemPtrPtrTy =
|
||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
|
||||
|
||||
auto valueVals =
|
||||
getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
||||
return valueVals;
|
||||
};
|
||||
|
||||
auto getLayout =
|
||||
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
|
||||
return std::make_tuple(layout, valueElems);
|
||||
};
|
||||
|
||||
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
|
||||
auto [maskLayout, maskNumElems] = getLayout(mask);
|
||||
auto [valueLayout, valueNumElems] = getLayout(value);
|
||||
|
||||
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
|
||||
auto getAlign =
|
||||
[this](Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
assert(axisInfo.hasValue());
|
||||
|
||||
auto order = layout.getOrder();
|
||||
|
||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
};
|
||||
|
||||
// get align
|
||||
auto getVec = [this, &getAlign](
|
||||
Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
auto contig = axisInfo->getContiguity();
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = layout.getOrder();
|
||||
unsigned align = getAlign(val, layout);
|
||||
|
||||
assert(!order.empty());
|
||||
// Is this right?
|
||||
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
|
||||
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
|
||||
bool isMMAFirstRow = false;
|
||||
|
||||
if (isMMAFirstRow)
|
||||
vec = std::min<size_t>(2, align);
|
||||
|
||||
return vec;
|
||||
};
|
||||
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVec(ptr, ptrLayout);
|
||||
|
||||
const size_t dtsize = value.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.getIntOrFloatBitWidth() /
|
||||
8;
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = ptrNumElems / vec;
|
||||
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
|
||||
|
||||
size_t in_off{};
|
||||
auto ptrProducer = llPtr.getDefiningOp();
|
||||
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
|
||||
|
||||
if (in_gep) {
|
||||
auto indices = in_gep.getIndices();
|
||||
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
|
||||
in_off =
|
||||
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
|
||||
ptr = cst ? in_gep.getBase() : in_gep;
|
||||
}
|
||||
|
||||
// pack sub-words (< 32/64bits) into words
|
||||
// each load has width min(nbits*vec, 32/64)
|
||||
// and there are (nbits * vec)/width of them
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
|
||||
assert(wordNElems * nWords * numVecs == valueElems.size());
|
||||
|
||||
// TODO(Superjomn) Add cache policy to store.
|
||||
// TODO(Superjomn) deal with cache policy.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PtxIOInstr asmStoreInstr("st");
|
||||
asmStoreInstr.predicate(llMask, "b");
|
||||
asmStoreInstr.global().v(width).b(nWords);
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
||||
|
||||
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
|
||||
auto *asmArgList = asmStoreInstr.newList();
|
||||
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
|
||||
Value elem =
|
||||
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx];
|
||||
if (elem.getType().isInteger(1))
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
|
||||
|
||||
llWord = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, wordTy, llWord, elem,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, type::u32Ty(ctx),
|
||||
IntegerAttr::get(type::u32Ty(ctx), elemIdx)));
|
||||
}
|
||||
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint));
|
||||
}
|
||||
|
||||
asmStoreInstr.addOperand(asmAddr);
|
||||
asmStoreInstr.addOperand(asmArgList);
|
||||
|
||||
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
|
||||
for (int i = 0; i < nWords; i++)
|
||||
argTys.push_back(valArgTy);
|
||||
|
||||
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
|
||||
|
||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands
|
||||
asmStoreInstr.dump(), // asm_string
|
||||
asmStoreInstr.getConstrains(), // constraints
|
||||
// TODO(Superjomn) determine the side effect.
|
||||
true, // has_side_effects
|
||||
false, // is_align_stack
|
||||
LLVM::AsmDialectAttr::get(ctx,
|
||||
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||
ArrayAttr::get(ctx, {}) // operand_attrs
|
||||
);
|
||||
|
||||
rewriter.replaceOp(op, inlineAsm.getRes());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||
return it->getValue();
|
||||
}
|
||||
|
||||
return llvm::Optional<AxisInfo>{};
|
||||
}
|
||||
|
||||
private:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
};
|
||||
|
||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||
// This is a naive implementation, it assumes that all the blocked layout should
|
||||
// have the same numWarps setting in a module, it just find a blocked layout
|
||||
// encoding and return the warpsPerCTA field.
|
||||
int extractNumWarps(mlir::ModuleOp module) {
|
||||
int numWarps{};
|
||||
if (module->hasAttr(AttrNumWarpsName))
|
||||
numWarps = module->getAttr(AttrNumWarpsName)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
|
||||
return numWarps;
|
||||
}
|
||||
|
||||
struct BroadcastOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -647,8 +1040,7 @@ struct LoadOpConversion
|
||||
for (size_t i = 0; i < numElems; i += vecWidth) {
|
||||
Value ptr = ptrVals[i];
|
||||
// TODO: Handle the optimization if ptr is from GEP and the idx is
|
||||
// constant
|
||||
// This should be a canonicalization pattern in LLVM Dialect
|
||||
// constant. This should be a canonicalization pattern in LLVM Dialect
|
||||
unsigned in_off = 0;
|
||||
Value pred = maskVals[i];
|
||||
|
||||
@@ -826,13 +1218,18 @@ public:
|
||||
};
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps) {
|
||||
patterns.add<BroadcastOpConversion>(typeConverter);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps);
|
||||
patterns.add<LoadOpConversion>(typeConverter);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter);
|
||||
patterns.add<ReturnOpConversion>(typeConverter);
|
||||
patterns.add<ViewOpConversion>(typeConverter);
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &analysis,
|
||||
PatternBenefit benefit = 1) {
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
@@ -851,20 +1248,33 @@ public:
|
||||
TritonLLVMConversionTarget target(*context, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
// TODO: (goostavz) Temporarily disable this, since the lowering of
|
||||
// arithmetic ops in tensor format is not complete yet.
|
||||
// Add arith's patterns to help convert scalar expression to LLVM.
|
||||
// mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
// patterns);
|
||||
|
||||
int numWarps = extractNumWarps(mod);
|
||||
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
||||
auto axisAnalysis = runAxisAnalysis(mod);
|
||||
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community patterns.
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisAnalysis, 10 /*benefit*/);
|
||||
|
||||
// Add arith's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<AxisInfoAnalysis> runAxisAnalysis(ModuleOp module) {
|
||||
auto axisAnalysisPass =
|
||||
std::make_unique<AxisInfoAnalysis>(module->getContext());
|
||||
axisAnalysisPass->run(module);
|
||||
return axisAnalysisPass;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@@ -102,6 +102,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
return ret;
|
||||
|
Reference in New Issue
Block a user