[BACKEND] Add C++ tests for PTXFormat and some tiny refinement (#109)

This PR does

1. Add some C++ tests for `PTXFormat`
2. Enhance the functionality of `PTXFormat`, make a `PTXInstr` instance
can be called multiple times similar as a C function.
This commit is contained in:
Yan Chunwei
2022-09-10 00:15:07 +08:00
committed by GitHub
parent a9464f4993
commit 2a852044d9
10 changed files with 231 additions and 92 deletions

View File

@@ -756,12 +756,12 @@ struct StoreOpConversion
llMask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
ptxStoreInstr.global().b(width).v(nWords);
auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
ptxStoreInstr(asmAddr, asmArgList);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
for (int i = 0; i < nWords; i++)
@@ -772,7 +772,7 @@ struct StoreOpConversion
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
ptxBuilder.dump(), // asm_string
ptxBuilder.getConstrains(), // constraints
ptxBuilder.getConstraints(), // constraints
// TODO(Superjomn) determine the side effect.
true, // has_side_effects
false, // is_align_stack
@@ -1045,8 +1045,7 @@ struct LoadOpConversion
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
// Define the instruction opcode
ld.predicate(pred, "b")
.o("violatile", op.isVolatile())
ld.o("volatile", op.isVolatile())
.global()
.o("ca", op.cache() == triton::CacheModifier::CA)
.o("cg", op.cache() == triton::CacheModifier::CG)
@@ -1064,15 +1063,15 @@ struct LoadOpConversion
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
if (!evictOpr)
ld(dstsOpr, addrOpr);
ld(dstsOpr, addrOpr).predicate(pred, "b");
else
ld(dstsOpr, addrOpr, evictOpr);
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
SmallVector<Value> others;
if (other) {
for (size_t ii = 0; ii < nWords; ii++) {
PTXInstr &mov = *ptxBuilder.create<>("mov");
mov.predicateNot(pred, "b").o("u", width);
mov.o("u", width);
size_t size = width / valueElemNbits;
@@ -1096,7 +1095,7 @@ struct LoadOpConversion
others.push_back(v);
}
mov(dstsOpr->listGet(ii), opr);
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
}
}
@@ -1114,7 +1113,7 @@ struct LoadOpConversion
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
/*asm_string=*/ptxBuilder.dump(),
/*constraints=*/ptxBuilder.getConstrains(),
/*constraints=*/ptxBuilder.getConstraints(),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());