[BACKEND] Add LLVM-translation for store and splat ops (#47)

This commit is contained in:
Yan Chunwei
2022-08-15 15:46:37 +08:00
committed by GitHub
parent 993ba7035a
commit 95bbac41e7
8 changed files with 815 additions and 57 deletions

View File

@@ -0,0 +1,39 @@
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
#define TRITON_CONVERSION_MLIR_TYPES_H_
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
// This file redefines some common MLIR types for easy usage.
namespace mlir {
namespace triton {
namespace type {
// Integer types
Type i32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signed);
}
Type i8Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 8, IntegerType::Signed);
}
Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signless);
}
Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); }
// Float types
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
}
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
} // namespace type
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_MLIR_TYPES_H_

View File

@@ -0,0 +1,191 @@
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <memory>
#include <string>
namespace mlir {
namespace triton {
using llvm::StringRef;
// TODO(Superjomn) Move to a global utility file?
std::string strJoin(llvm::ArrayRef<std::string> strs,
llvm::StringRef delimiter);
// A helper for building a single inline ASM instruction, the objective of
// PtxInstr is to give a thin encapsulation and make the ASM code for MLIR LLVM
// Dialect more clear. Currently, several factors are introduced to reduce the
// need for mixing string and C++ if-else code.
// Usage:
// To build: asm("add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k));
//
// PtxInstr mulr("mul");
// mulr.o("lo").o("u32").addOperand(valueI, "=r") // %0 bind to valueI
// .addOperand(valueJ, "r") // %1 bind to valueJ
// .addOperand(valueK, "k"); // %2 bind to valueK
//
// mulr.getConstrains() // get "=r,r,k"
// mulr.getAllMlirArgs() // get {valueI, valueJ, valueK}
//
// TODO(Superjomn) Add multi-line ASM code support and register support later.
struct PtxInstr {
explicit PtxInstr(const std::string &name) { o(name); }
struct Operand {
std::string constraint;
Value value;
int idx{-1};
llvm::SmallVector<Operand *> list;
std::function<std::string(int idx)> repr;
// for list
Operand() = default;
Operand(Value value, StringRef constraint)
: value(value), constraint(constraint) {}
bool isList() const { return !value; }
Operand *listAppend(Operand *arg) {
list.push_back(arg);
return this;
}
std::string dump() const;
};
// Create a new operand. It will not add to operand list.
// @value: the MLIR value bind to this operand.
// @constraint: ASM operand constraint, .e.g. "=r"
// @formater: extra format to represent this operand in ASM code, default is
// "%{0}".format(operand.idx).
Operand *newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int idx)> formater = nullptr);
// Append the operand to the intruction's operand list.
Operand *addOperand(Operand *opr) {
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
argsInOrder.end());
argsInOrder.push_back(opr);
return opr;
}
// Create and add an operand to the intruction's operand list.
Operand *addOperand(mlir::Value value, StringRef constraint) {
auto *opr = newOperand(value, constraint);
return addOperand(opr);
}
// Prefix a predicate to the instruction.
PtxInstr &predicate(mlir::Value value, StringRef constraint) {
pred = newOperand(value, constraint);
return *this;
}
// Append a suffix to the instruction.
// e.g. PtxInstr("add").o("s32") get a add.s32.
// A predicate is used to tell whether to apply the suffix, so that no if-else
// code needed. e.g. `PtxInstr("add").o("s32", isS32).o("u32", !isS32);` will
// get a `add.s32` if isS32 is true.
PtxInstr &o(const std::string &suffix, bool predicate = true) {
if (predicate)
instrParts.push_back(suffix);
return *this;
}
PtxInstr &addListOperation(llvm::ArrayRef<Operand *> list) {
auto *opr = newList();
for (auto *v : list)
opr->listAppend(v);
addOperand(opr);
return *this;
}
// Create a list of operands.
Operand *newList() {
argArchive.emplace_back(std::make_unique<Operand>());
return argArchive.back().get();
}
std::string dump() const;
llvm::SmallVector<Operand *, 4> getArgList() const;
llvm::SmallVector<Operand *, 4> getAllArgs() const {
llvm::SmallVector<Operand *, 4> res;
for (auto &x : argArchive)
if (!x->isList())
res.push_back(x.get());
return res;
}
std::string getConstrains() const {
auto args = getAllArgs();
llvm::SmallVector<std::string, 4> argReprs;
for (auto arg : args)
argReprs.push_back(arg->constraint);
return strJoin(argReprs, ",");
}
llvm::SmallVector<Value, 4> getAllMlirArgs() const {
llvm::SmallVector<Value, 4> res;
for (auto &arg : argArchive) {
if (!arg->isList())
res.push_back(arg->value);
}
return res;
}
protected:
Operand *pred{};
int oprCounter{};
llvm::SmallVector<std::string, 4> instrParts;
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
llvm::SmallVector<Operand *> argsInOrder;
std::string argStr;
};
// A helper for PTX ld/st instruction.
// Usage:
// PtxIOInstr store("st");
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
// store.addAddr(addrValue, "l", off);
struct PtxIOInstr : public PtxInstr {
PtxIOInstr(const std::string &name) : PtxInstr(name) {}
// Add ".global" suffix to instruction
PtxIOInstr &global(bool predicate = true) {
o("global", predicate);
return *this;
}
// Add ".v" suffix to instruction
PtxIOInstr &v(int vecWidth, bool predicate = true) {
if (vecWidth > 1) {
o(llvm::formatv("v{0}", vecWidth), predicate);
}
return *this;
}
// Add ".b" suffix to instruction
PtxIOInstr &b(int width) {
o(llvm::formatv("b{0}", width));
return *this;
}
PtxIOInstr &addAddr(mlir::Value addr, StringRef constraint, int off = 0) {
auto *operand = newAddrOperand(addr, constraint, off);
addOperand(operand);
return *this;
}
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
};
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_

View File

@@ -1,5 +1,6 @@
add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVM.cpp
PtxAsmFormat.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM

View File

@@ -0,0 +1,81 @@
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace triton {
std::string strJoin(llvm::ArrayRef<std::string> strs,
llvm::StringRef delimiter) {
std::string osStr;
llvm::raw_string_ostream os(osStr);
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++)
os << strs[i] << delimiter;
if (!strs.empty())
os << strs.back();
os.flush();
return osStr;
}
std::string PtxInstr::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
if (pred)
os << "@" << pred->dump() << " ";
std::string instrRepr = strJoin(instrParts, ".");
llvm::SmallVector<std::string, 4> argReprs;
for (auto *arg : argsInOrder) {
argReprs.push_back(arg->dump());
}
std::string argsRepr = strJoin(argReprs, ", ");
os << instrRepr << " " << argsRepr << ";";
os.flush();
return osStr;
}
llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
SmallVector<Operand *> args;
for (auto *arg : argsInOrder) {
if (arg->isList())
args.insert(args.end(), arg->list.begin(), arg->list.end());
else
args.push_back(arg);
}
return args;
}
PtxInstr::Operand *
PtxInstr::newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int)> formater) {
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
auto *opr = argArchive.back().get();
opr->repr = formater;
opr->idx = oprCounter++;
return opr;
}
std::string PtxInstr::Operand::dump() const {
if (repr)
return repr(idx);
if (!isList())
return llvm::formatv("%{0}", idx);
llvm::SmallVector<std::string> oprs;
for (auto *opr : list)
oprs.push_back(opr->dump());
return "{ " + strJoin(oprs, ", ") + " }";
}
PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
StringRef constraint, int off) {
auto *opr = newOperand(addr, constraint);
opr->repr = [off](int idx) -> std::string {
return llvm::formatv("[ %{0} + {1} ]", idx, off);
};
return opr;
}
} // namespace triton
} // namespace mlir

View File

@@ -1,5 +1,4 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
@@ -10,10 +9,17 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <memory>
#include <numeric>
#include <string>
using namespace mlir;
using namespace mlir::triton;
@@ -31,8 +37,15 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace {
namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO(Superjomn) Move to somewhere general utilities locates.
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
// The following code are borrowed from mlir project including the following
// functions or classes:
// - filterFuncAttributes
@@ -81,7 +94,6 @@ protected:
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
assert(llvmType);
if (!llvmType)
return nullptr;
@@ -124,6 +136,8 @@ protected:
}
linkage = attr.getLinkage();
}
auto oldArgs = funcOp.getArguments();
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
@@ -134,6 +148,25 @@ protected:
&result)))
return nullptr;
// Convert argument
llvm::DenseMap<Value, Value> argMap;
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
Value oldArg = oldArgs[i];
Value newArg = newFuncOp.getArgument(i);
argMap.try_emplace(oldArg, newArg);
}
newFuncOp.getBody().walk([&](Operation *op) {
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
// ptr<fp16>
for (int i = 0; i < op->getNumOperands(); i++) {
auto arg = op->getOperand(i);
auto it = argMap.find(arg);
if (it != argMap.end())
op->setOperand(i, it->second);
}
});
return newFuncOp;
}
};
@@ -143,8 +176,9 @@ protected:
/// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
: FuncOpConversionBase(converter), NumWarps(numWarps) {}
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), NumWarps(numWarps) {}
LogicalResult
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
@@ -154,11 +188,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
return failure();
auto ctx = funcOp->getContext();
auto i32 = IntegerType::get(ctx, 32);
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
rewriter.getIntegerAttr(i32, 32 * NumWarps));
newFuncOp->setAttr(
NVVMMetadataField::MaxNTid,
rewriter.getIntegerAttr(type::i32Ty(ctx), 32 * NumWarps));
rewriter.eraseOp(funcOp);
return success();
@@ -190,22 +224,47 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
}
};
// Extract numWarps information from TritonGPU module, return 0 if failed.
// This is a naive implementation, it assumes that all the blocked layout should
// have the same numWarps setting in a module, it just find a blocked layout
// encoding and return the warpsPerCTA field.
int extractNumWarps(mlir::ModuleOp module) {
int numWarps{};
if (module->hasAttr(AttrNumWarpsName))
numWarps = module->getAttr(AttrNumWarpsName)
.dyn_cast<IntegerAttr>()
.getValue()
.getZExtValue();
else
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
ArrayRef<int64_t> shape) {
assert(multidim_index.size() == shape.size());
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
int64_t rank = shape.size();
int64_t acc_mul = 1;
for (int64_t i = 1; i < rank; ++i) {
acc_mul *= shape[i];
}
int64_t linear_index = 0;
for (int64_t i = 0; i < rank; ++i) {
linear_index += multidim_index[i] * acc_mul;
if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1];
}
}
return linear_index;
}
return numWarps;
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
return product(shape) / (product(layout.getThreadsPerWarp()) *
product(layout.getWarpsPerCTA()));
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType, Type elemPtrPtrType) {
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
loc, structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
return llvmStruct;
}
template <typename T>
@@ -247,24 +306,6 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
return linear_index;
}
static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout,
ArrayRef<int64_t> shape) {
unsigned elems = 1;
size_t rank = shape.size();
assert(rank == layout.getThreadsPerWarp().size());
for (size_t d = 0; d < rank; ++d) {
elems *=
shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]);
}
return elems;
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(resultType, value));
}
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp> {
@@ -443,6 +484,358 @@ public:
}
};
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
// LLVM::StructType value.
//
// @elemType: the element type in operand.
// @resType: the return type of the Splat-like op.
// @constVal: a LLVM::ConstantOp or other scalar value.
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
auto numElems = layout.getSizePerThread();
size_t totalElems =
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
std::multiplies<>{});
size_t numThreads =
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
size_t numElemsPerThread = totalElems / numThreads;
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
auto llElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
auto llStruct =
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
return llStruct;
}
struct SplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto src = op->getOperand(0);
LLVM::ConstantOp arithConstantOp;
if (src.getDefiningOp() &&
(arithConstantOp =
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
Value constant;
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
assert(values.size() == 1);
Attribute val;
if (type::isInt(src.getType())) {
val = values.getValues<IntegerAttr>()[0];
} else if (type::isFloat(src.getType())) {
val = values.getValues<FloatAttr>()[0];
} else {
llvm::errs() << "Constant op type not supported";
return failure();
}
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
}
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
return success();
}
};
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the
// logic is the same as triton::SplatOp, so the underlying implementation is
// reused.
struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern<
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto value = op.getValue();
if (!value.dyn_cast<SplatElementsAttr>())
return failure();
auto loc = op->getLoc();
LLVM::ConstantOp arithConstantOp;
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
auto elemType = values.getElementType();
Attribute val;
if (type::isInt(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else if (type::isFloat(elemType)) {
val = values.getValues<FloatAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto llStruct = convertSplatLikeOp(elemType, op.getType(), constOp,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, llStruct);
return success();
}
};
struct StoreOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
AxisAnalysisPass(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr();
Value mask = op.mask();
Value value = op.value();
Value llPtr = adaptor.ptr(); // should be LLVM ops
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
Type valueElemTy = getTypeConverter()->convertType(
value.getType().cast<RankedTensorType>().getElementType());
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
auto getLLVMElems = [&](Value value, Value llValue,
const TritonGPUBlockedEncodingAttr &layout)
-> SmallVector<Value, 4> {
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals =
getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
};
auto getLayout =
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
unsigned valueElems = getElemsPerThread(layout, shape);
return std::make_tuple(layout, valueElems);
};
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
auto [maskLayout, maskNumElems] = getLayout(mask);
auto [valueLayout, valueNumElems] = getLayout(value);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
auto getAlign =
[this](Value val,
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
assert(axisInfo.hasValue());
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
};
// get align
auto getVec = [this, &getAlign](
Value val,
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlign(val, layout);
assert(!order.empty());
// Is this right?
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned vec = std::min(align, contigPerThread);
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
bool isMMAFirstRow = false;
if (isMMAFirstRow)
vec = std::min<size_t>(2, align);
return vec;
};
// Determine the vectorization size
size_t vec = getVec(ptr, ptrLayout);
const size_t dtsize = value.getType()
.cast<RankedTensorType>()
.getElementType()
.getIntOrFloatBitWidth() /
8;
const size_t valueElemNbits = dtsize * 8;
const int numVecs = ptrNumElems / vec;
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
size_t in_off{};
auto ptrProducer = llPtr.getDefiningOp();
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
if (in_gep) {
auto indices = in_gep.getIndices();
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
in_off =
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
ptr = cst ? in_gep.getBase() : in_gep;
}
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == valueElems.size());
// TODO(Superjomn) Add cache policy to store.
// TODO(Superjomn) deal with cache policy.
const bool hasL2EvictPolicy = false;
PtxIOInstr asmStoreInstr("st");
asmStoreInstr.predicate(llMask, "b");
asmStoreInstr.global().v(width).b(nWords);
llvm::SmallVector<std::string> asmArgs;
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy);
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
auto *asmArgList = asmStoreInstr.newList();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
Value elem =
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
llWord = rewriter.create<LLVM::InsertElementOp>(
loc, wordTy, llWord, elem,
rewriter.create<LLVM::ConstantOp>(
loc, type::u32Ty(ctx),
IntegerAttr::get(type::u32Ty(ctx), elemIdx)));
}
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint));
}
asmStoreInstr.addOperand(asmAddr);
asmStoreInstr.addOperand(asmArgList);
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
for (int i = 0; i < nWords; i++)
argTys.push_back(valArgTy);
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands
asmStoreInstr.dump(), // asm_string
asmStoreInstr.getConstrains(), // constraints
// TODO(Superjomn) determine the side effect.
true, // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
rewriter.replaceOp(op, inlineAsm.getRes());
}
return success();
}
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
return it->getValue();
}
return llvm::Optional<AxisInfo>{};
}
private:
AxisInfoAnalysis &AxisAnalysisPass;
};
// Extract numWarps information from TritonGPU module, return 0 if failed.
// This is a naive implementation, it assumes that all the blocked layout should
// have the same numWarps setting in a module, it just find a blocked layout
// encoding and return the warpsPerCTA field.
int extractNumWarps(mlir::ModuleOp module) {
int numWarps{};
if (module->hasAttr(AttrNumWarpsName))
numWarps = module->getAttr(AttrNumWarpsName)
.dyn_cast<IntegerAttr>()
.getValue()
.getZExtValue();
else
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return numWarps;
}
struct BroadcastOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -647,8 +1040,7 @@ struct LoadOpConversion
for (size_t i = 0; i < numElems; i += vecWidth) {
Value ptr = ptrVals[i];
// TODO: Handle the optimization if ptr is from GEP and the idx is
// constant
// This should be a canonicalization pattern in LLVM Dialect
// constant. This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0;
Value pred = maskVals[i];
@@ -826,13 +1218,18 @@ public:
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps) {
patterns.add<BroadcastOpConversion>(typeConverter);
patterns.add<FuncOpConversion>(typeConverter, numWarps);
patterns.add<LoadOpConversion>(typeConverter);
patterns.add<MakeRangeOpConversion>(typeConverter);
patterns.add<ReturnOpConversion>(typeConverter);
patterns.add<ViewOpConversion>(typeConverter);
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &analysis,
PatternBenefit benefit = 1) {
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
}
class ConvertTritonGPUToLLVM
@@ -851,20 +1248,33 @@ public:
TritonLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
// TODO: (goostavz) Temporarily disable this, since the lowering of
// arithmetic ops in tensor format is not complete yet.
// Add arith's patterns to help convert scalar expression to LLVM.
// mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
// patterns);
int numWarps = extractNumWarps(mod);
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
auto axisAnalysis = runAxisAnalysis(mod);
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community patterns.
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
*axisAnalysis, 10 /*benefit*/);
// Add arith's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
protected:
std::unique_ptr<AxisInfoAnalysis> runAxisAnalysis(ModuleOp module) {
auto axisAnalysisPass =
std::make_unique<AxisInfoAnalysis>(module->getContext());
axisAnalysisPass->run(module);
return axisAnalysisPass;
}
};
} // namespace

View File

@@ -102,6 +102,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;

View File

@@ -0,0 +1,36 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s
func @test_splat(%ptr: !tt.ptr<f32>) {
// Here, 128 elements, 64(2*32) threads, so each need to process 2 elements
//
// CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr<f32, 1> to !llvm.ptr<f32, 1>
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
// CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
// CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
%true = arith.constant 1 : i1
%b = tt.splat %a : (f32) -> tensor<128xf32>
// Here, each thread process only 1 element
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)>
%mask = tt.splat %true : (i1) -> tensor<64xi1>
return
}
func @test_store_splat(%ptr: !tt.ptr<f32>) {
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
%true = arith.constant 1 : i1
%vs = tt.splat %a : (f32) -> tensor<128xf32>
%mask = tt.splat %true : (i1) -> tensor<128xi1>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };",
// CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
return
}

View File

@@ -1,11 +1,10 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}}
// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return