[BACKEND] Support optional mask in TritonGPUToLLVM (#80)

Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2022-08-25 08:51:37 +08:00
committed by GitHub
parent 84aa7d025a
commit bedbf221c0
5 changed files with 69 additions and 17 deletions

View File

@@ -248,6 +248,13 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
LLVMTypeConverter *converter, Type ty,
int64_t value) {
return builder.create<LLVM::ConstantOp>(loc, converter->convertType(ty),
builder.getIntegerAttr(ty, value));
}
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
@@ -601,7 +608,7 @@ struct StoreOpConversion
auto getLLVMElems =
[&](Value value, Value llValue,
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> {
const BlockedEncodingAttr &layout) -> SmallVector<Value> {
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
@@ -630,13 +637,16 @@ struct StoreOpConversion
};
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
auto [maskLayout, maskNumElems] = getLayout(mask);
auto [valueLayout, valueNumElems] = getLayout(value);
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout);
auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
SmallVector<Value> maskElems;
if (llMask) {
auto [maskLayout, maskNumElems] = getLayout(mask);
maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
}
auto getAlign = [this](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
@@ -710,10 +720,12 @@ struct StoreOpConversion
PTXBuilder ptxBuilder;
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
ptxStoreInstr.predicate(maskElems[vecIdx], "b")
.global()
.b(width)
.v(nWords);
Value maskVal =
llMask ? maskElems[vecIdx]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
llvm::SmallVector<std::string> asmArgs;
@@ -746,8 +758,8 @@ struct StoreOpConversion
}
ptxStoreInstr(asmAddr, asmArgList);
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
for (int i = 0; i < nWords; i++)
argTys.push_back(valArgTy);
@@ -970,9 +982,12 @@ struct LoadOpConversion
auto elemTy = resultTy.getElementType();
unsigned numElems = getElemsPerThread(blockedLayout, shape);
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
SmallVector<Value> maskVals;
if (mask) {
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
}
SmallVector<Value> otherVals;
if (other != nullptr) {
if (other) {
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
}
unsigned nbits = elemTy.isa<FloatType>()
@@ -1004,7 +1019,10 @@ struct LoadOpConversion
// TODO: Handle the optimization if ptr is from GEP and the idx is
// constant. This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0;
Value pred = maskVals[i];
Value pred =
mask ? maskVals[i]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
// ---
// create inline asm string

View File

@@ -227,8 +227,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
adaptor.evict(), adaptor.isVolatile());
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
adaptor.isVolatile());
return success();
}
};