[BACKEND] Add C++ tests for PTXFormat and some tiny refinement (#109)
This PR does 1. Add some C++ tests for `PTXFormat` 2. Enhance the functionality of `PTXFormat`, make a `PTXInstr` instance can be called multiple times similar as a C function.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUToLLVM.cpp
|
||||
PtxAsmFormat.cpp
|
||||
PtxAsmFormat.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||
|
@@ -49,7 +49,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
|
||||
return newConstantOperand(ss.str());
|
||||
}
|
||||
|
||||
std::string PTXBuilder::getConstrains() const {
|
||||
std::string PTXBuilder::getConstraints() const {
|
||||
auto args = getAllArgs();
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto arg : args)
|
||||
@@ -78,7 +78,7 @@ std::string PTXInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return llvm::formatv("${0}", idx);
|
||||
return "$" + std::to_string(idx);
|
||||
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
@@ -90,7 +90,9 @@ PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
||||
std::stringstream ss;
|
||||
ss << "[ $" << idx << " + " << off << " ]";
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
return opr;
|
||||
@@ -98,14 +100,24 @@ PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
||||
|
||||
std::string PTXBuilder::dump() const {
|
||||
llvm::SmallVector<std::string> lines;
|
||||
for (auto &instr : instrs) {
|
||||
lines.push_back(instr->dump());
|
||||
for (auto &exec : executions) {
|
||||
lines.push_back(exec->dump());
|
||||
}
|
||||
|
||||
return strJoin(lines, "\n\t");
|
||||
return strJoin(lines, "\r\n");
|
||||
}
|
||||
|
||||
std::string PTXInstrCommon::dump() const {
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
||||
builder->executions.emplace_back(
|
||||
std::make_unique<PTXInstrExecution>(this, oprs));
|
||||
return *builder->executions.back();
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||
return call(oprs);
|
||||
}
|
||||
|
||||
std::string PTXInstrExecution::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
if (pred)
|
||||
@@ -114,7 +126,7 @@ std::string PTXInstrCommon::dump() const {
|
||||
else
|
||||
os << pred->repr(pred->idx);
|
||||
|
||||
std::string instrRepr = strJoin(instrParts, ".");
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
@@ -128,7 +140,8 @@ std::string PTXInstrCommon::dump() const {
|
||||
return osStr;
|
||||
}
|
||||
|
||||
SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
|
||||
SmallVector<PTXInstrExecution::Operand *>
|
||||
PTXInstrExecution::getArgList() const {
|
||||
SmallVector<Operand *> args;
|
||||
for (auto *arg : argsInOrder) {
|
||||
if (arg->isList())
|
||||
@@ -138,11 +151,5 @@ SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
void PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||
for (auto *opr : oprs) {
|
||||
addOperand(opr);
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
@@ -756,12 +756,12 @@ struct StoreOpConversion
|
||||
llMask ? maskElems[vecStart]
|
||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||
rewriter.getIntegerType(1), 1);
|
||||
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
||||
ptxStoreInstr.global().b(width).v(nWords);
|
||||
|
||||
auto *asmAddr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
ptxStoreInstr(asmAddr, asmArgList);
|
||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||
for (int i = 0; i < nWords; i++)
|
||||
@@ -772,7 +772,7 @@ struct StoreOpConversion
|
||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
||||
ptxBuilder.dump(), // asm_string
|
||||
ptxBuilder.getConstrains(), // constraints
|
||||
ptxBuilder.getConstraints(), // constraints
|
||||
// TODO(Superjomn) determine the side effect.
|
||||
true, // has_side_effects
|
||||
false, // is_align_stack
|
||||
@@ -1045,8 +1045,7 @@ struct LoadOpConversion
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.predicate(pred, "b")
|
||||
.o("violatile", op.isVolatile())
|
||||
ld.o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
@@ -1064,15 +1063,15 @@ struct LoadOpConversion
|
||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||
|
||||
if (!evictOpr)
|
||||
ld(dstsOpr, addrOpr);
|
||||
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||
else
|
||||
ld(dstsOpr, addrOpr, evictOpr);
|
||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||
|
||||
SmallVector<Value> others;
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ii++) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
mov.predicateNot(pred, "b").o("u", width);
|
||||
mov.o("u", width);
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
@@ -1096,7 +1095,7 @@ struct LoadOpConversion
|
||||
others.push_back(v);
|
||||
}
|
||||
|
||||
mov(dstsOpr->listGet(ii), opr);
|
||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1114,7 +1113,7 @@ struct LoadOpConversion
|
||||
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
|
||||
/*asm_string=*/ptxBuilder.dump(),
|
||||
/*constraints=*/ptxBuilder.getConstrains(),
|
||||
/*constraints=*/ptxBuilder.getConstraints(),
|
||||
/*has_side_effects=*/true,
|
||||
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
||||
/*operand_attrs=*/ArrayAttr());
|
||||
|
Reference in New Issue
Block a user