[BACKEND] Refactoring codegen for LoadOp with PTXFormat (#77)

This PR does following things:

Enhance the PTXFormat by
Introducing PTXBuilder to enable multiple instructions in a single asm program
override PTXInstr's operator() method to enable instr(opr0, opr1) style of setting operands for an instruction
Refactor the PTX code used in LoadOpConversion with PTXFormat

Authored-by: goostavz <gzhu@nvidia.com>
This commit is contained in:
Yan Chunwei
2022-08-24 06:51:13 +08:00
committed by GitHub
parent 0ebef11c77
commit 1b513c9866
3 changed files with 316 additions and 203 deletions

View File

@@ -708,18 +708,20 @@ struct StoreOpConversion
// TODO(Superjomn) deal with cache policy.
const bool hasL2EvictPolicy = false;
PtxIOInstr asmStoreInstr("st");
asmStoreInstr.predicate(maskElems[vecIdx], "b");
asmStoreInstr.global().b(width).v(nWords);
PTXBuilder ptxBuilder;
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
ptxStoreInstr.predicate(maskElems[vecIdx], "b")
.global()
.b(width)
.v(nWords);
llvm::SmallVector<std::string> asmArgs;
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy);
auto *asmAddr =
asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off);
auto *asmArgList = asmStoreInstr.newList();
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
auto *asmArgList = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
@@ -740,11 +742,10 @@ struct StoreOpConversion
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint));
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
}
asmStoreInstr.addOperand(asmAddr);
asmStoreInstr.addOperand(asmArgList);
ptxStoreInstr(asmAddr, asmArgList);
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
for (int i = 0; i < nWords; i++)
@@ -753,9 +754,9 @@ struct StoreOpConversion
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands
asmStoreInstr.dump(), // asm_string
asmStoreInstr.getConstrains(), // constraints
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
ptxBuilder.dump(), // asm_string
ptxBuilder.getConstrains(), // constraints
// TODO(Superjomn) determine the side effect.
true, // has_side_effects
false, // is_align_stack
@@ -1008,43 +1009,54 @@ struct LoadOpConversion
// ---
// create inline asm string
// ---
// TODO: (Superjomn) refactor with AsmInstr abstraction
std::ostringstream asmOss;
asmOss << "@$" << n_words; // predicate
asmOss << " ld";
if (op.isVolatile()) {
asmOss << ".volatile";
const std::string writeConstrait =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
const std::string readConstrait =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
PTXBuilder ptxBuilder;
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
// Define the instruction opcode
ld.predicate(pred, "b")
.o("violatile", op.isVolatile())
.global()
.o("ca", op.cache() == triton::CacheModifier::CA)
.o("cg", op.cache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", has_l2_evict_policy)
.v(n_words)
.b(width);
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (int i = 0; i < n_words; i++) {
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
dstsOpr->listAppend(opr);
}
asmOss << ".global";
if (op.cache() == triton::CacheModifier::CA)
asmOss << ".ca";
if (op.cache() == triton::CacheModifier::CG)
asmOss << ".cg";
if (op.evict() == triton::EvictionPolicy::EVICT_FIRST)
asmOss << ".L1::evict_first";
if (op.evict() == triton::EvictionPolicy::EVICT_LAST)
asmOss << ".L1::evict_last";
if (has_l2_evict_policy)
asmOss << ".L2::cache_hint";
if (n_words > 1)
asmOss << ".v" << n_words; // vector width
asmOss << ".b" << width; // word size
asmOss << " {";
for (int i = 0; i < n_words; i++) { // return values
if (i > 0)
asmOss << ",";
asmOss << "$" << i;
}
asmOss << "}";
asmOss << ", [ $" << n_words + 1; // load
asmOss << " + " << in_off << "]"; // constant offset
if (has_l2_evict_policy)
asmOss << ", $" << n_words + 2;
asmOss << ";";
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
PTXBuilder::Operand *evictOpr{};
// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
if (!evictOpr)
ld(dstsOpr, addrOpr);
else
ld(dstsOpr, addrOpr, evictOpr);
SmallVector<Value> others;
if (other != nullptr) {
if (other) {
for (size_t ii = 0; ii < n_words; ii++) {
PTXInstr &mov = *ptxBuilder.create<>("mov");
mov.predicateNot(pred, "b").o("u", width);
size_t size = width / nbits;
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t s = 0; s < size; s++) {
@@ -1056,20 +1068,19 @@ struct LoadOpConversion
}
v = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(getContext(), width), v);
asmOss << "\n ";
asmOss << "@!$" << n_words << " mov.u" << width;
asmOss << " $" << ii << ", ";
std::ios_base::fmtflags flags(asmOss.flags());
if (otherIsSplatConstInt)
asmOss << "0x" << std::hex << splatVal;
else {
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
PTXInstr::Operand *opr{};
if (otherIsSplatConstInt) {
opr = ptxBuilder.newConstantOperand(splatVal);
} else {
opr = ptxBuilder.newOperand(v, readConstrait);
others.push_back(v);
}
asmOss.flags(flags);
asmOss << ";";
mov(dstsOpr->listGet(ii), opr);
}
}
// ---
// create inline ASM signature
// ---
@@ -1077,39 +1088,18 @@ struct LoadOpConversion
Type retTy = retTys.size() > 1
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
: retTys[0];
// ---
// create inline ASM constraints
// ---
std::string asmCstrt;
for (int ii = 0; ii < n_words; ii++) {
if (ii > 0)
asmCstrt += ",";
asmCstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
}
asmCstrt += ",b,l";
for (size_t ii = 0; ii < others.size(); ii++) {
asmCstrt += ",";
asmCstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_l2_evict_policy) {
asmCstrt += ",l";
}
// ---
// finally call inline ASM
// ---
SmallVector<Value> args = {pred, ptr};
for (Value v : others) {
args.push_back(v);
}
// TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
loc, retTy, /*operands=*/args, /*asm_string=*/asmOss.str(),
/*constraints=*/asmCstrt, /*has_side_effects=*/true,
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
/*asm_string=*/ptxBuilder.dump(),
/*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
Value ret = inlineAsmOp.getResult(0);
// ---
// extract and store return values
// ---
@@ -1135,7 +1125,7 @@ struct LoadOpConversion
loc, elemTy, rets[ii / tmp], vecIdx);
loadedVals.push_back(loaded);
}
}
} // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
@@ -1314,7 +1304,8 @@ public:
auto axisAnalysis = runAxisAnalysis(mod);
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community patterns.
// arith patterns for some encoding not supported by the community
// patterns.
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
*axisAnalysis, 10 /*benefit*/);