Files
triton/lib/Conversion/TritonGPUToLLVM/PTXAsmFormat.cpp
Philippe Tillet 20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00

218 lines
6.0 KiB
C++

#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/raw_ostream.h"
// TODO(Superjomn): unify to llvm::raw_string_ostream
#include <sstream>
namespace mlir {
namespace triton {
// TODO(Superjomn) Move to a global utility file?
std::string strJoin(llvm::ArrayRef<std::string> strs,
llvm::StringRef delimiter) {
std::string osStr;
llvm::raw_string_ostream os(osStr);
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; ++i)
os << strs[i] << delimiter;
if (!strs.empty())
os << strs.back();
os.flush();
return osStr;
}
PTXInstr::Operand *
PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int)> formatter) {
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
auto *opr = argArchive.back().get();
opr->repr = formatter;
opr->idx = oprCounter++;
return opr;
}
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint) {
// Constraint should be something like "=r"
assert(!constraint.empty() && constraint[0] == '=');
auto *opr = newOperand();
opr->idx = oprCounter++;
opr->constraint = constraint;
return opr;
}
PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
argArchive.emplace_back(std::make_unique<Operand>());
argArchive.back()->repr = [v](int idx) { return v; };
return argArchive.back().get();
}
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
std::stringstream ss;
ss << "0x" << std::hex << v;
return newConstantOperand(ss.str());
}
std::string PTXBuilder::getConstraints() const {
auto args = getAllArgs();
llvm::SmallVector<std::string, 4> argReprs;
for (auto arg : args)
argReprs.push_back(arg->constraint);
return strJoin(argReprs, ",");
}
llvm::SmallVector<Value, 4> PTXBuilder::getAllMLIRArgs() const {
llvm::SmallVector<Value, 4> res;
for (auto &arg : argArchive) {
if (!arg->isList() && arg->value)
res.push_back(arg->value);
}
return res;
}
SmallVector<PTXBuilder::Operand *, 4> PTXBuilder::getAllArgs() const {
llvm::SmallVector<Operand *, 4> res;
for (auto &x : argArchive)
if (!x->isList())
res.push_back(x.get());
return res;
}
mlir::Value PTXBuilder::launch(ConversionPatternRewriter &rewriter,
Location loc, Type resTy, bool hasSideEffect,
bool isAlignStack,
ArrayRef<Attribute> attrs) const {
auto *ctx = rewriter.getContext();
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, resTy, getAllMLIRArgs(), // operands
dump(), // asm_string
getConstraints(), // constraints
hasSideEffect, // has_side_effects
isAlignStack, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, attrs) // operand_attrs
);
return inlineAsm.getRes();
}
std::string PTXInstr::Operand::dump() const {
if (repr)
return repr(idx);
if (!isList())
return "$" + std::to_string(idx);
llvm::SmallVector<std::string> oprs;
for (auto *opr : list)
oprs.push_back(opr->dump());
return "{ " + strJoin(oprs, ", ") + " }";
}
PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
StringRef constraint, int off) {
auto *opr = newOperand(addr, constraint);
opr->repr = [off](int idx) -> std::string {
std::stringstream ss;
ss << "[ $" << idx << " + " << off << " ]";
return ss.str();
};
return opr;
}
std::string PTXBuilder::dump() const {
llvm::SmallVector<std::string> lines;
for (auto &exec : executions) {
lines.push_back(exec->dump());
}
return strJoin(lines, "\n\t");
}
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
if (onlyAttachMLIRArgs) {
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
// the same MLIR values in onlyAttachMLIRArgs mode.
assert(builder->executions.empty() &&
"builder can only hold a single execution when onlyAttachMIIRArgs "
"is true.");
builder->reorderArgArchive(oprs);
}
builder->executions.emplace_back(
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
return *builder->executions.back();
}
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
return call(oprs, onlyAttachMLIRArgs);
}
std::string PTXInstrExecution::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
std::string instrRepr = strJoin(instr->instrParts, ".");
if (onlyAttachMLIRArgs)
return instrRepr;
if (pred) {
if (!pred->repr)
os << "@" << pred->dump() << " ";
else
os << pred->repr(pred->idx) << " ";
}
llvm::SmallVector<std::string, 4> argReprs;
for (auto *arg : argsInOrder) {
argReprs.push_back(arg->dump());
}
std::string argsRepr = strJoin(argReprs, ", ");
os << instrRepr << " " << argsRepr << ";";
os.flush();
return osStr;
}
SmallVector<PTXInstrExecution::Operand *>
PTXInstrExecution::getArgList() const {
SmallVector<Operand *> args;
for (auto *arg : argsInOrder) {
if (arg->isList())
args.insert(args.end(), arg->list.begin(), arg->list.end());
else
args.push_back(arg);
}
return args;
}
PTXInstr &PTXInstr::global() {
o("global");
return *this;
}
PTXInstr &PTXInstr::shared() {
o("shared");
return *this;
}
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
return *this;
}
PTXInstr &PTXInstr::b(int width) {
o("b" + std::to_string(width));
return *this;
}
} // namespace triton
} // namespace mlir