[BACKEND] Support optional mask in TritonGPUToLLVM (#80)
Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
@@ -248,6 +248,13 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
||||
}
|
||||
|
||||
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter *converter, Type ty,
|
||||
int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(loc, converter->convertType(ty),
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType) {
|
||||
@@ -601,7 +608,7 @@ struct StoreOpConversion
|
||||
|
||||
auto getLLVMElems =
|
||||
[&](Value value, Value llValue,
|
||||
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> {
|
||||
const BlockedEncodingAttr &layout) -> SmallVector<Value> {
|
||||
auto ty = value.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
@@ -630,13 +637,16 @@ struct StoreOpConversion
|
||||
};
|
||||
|
||||
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
|
||||
auto [maskLayout, maskNumElems] = getLayout(mask);
|
||||
auto [valueLayout, valueNumElems] = getLayout(value);
|
||||
|
||||
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout);
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
|
||||
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
auto [maskLayout, maskNumElems] = getLayout(mask);
|
||||
maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
auto getAlign = [this](Value val,
|
||||
const BlockedEncodingAttr &layout) -> unsigned {
|
||||
@@ -710,10 +720,12 @@ struct StoreOpConversion
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
||||
ptxStoreInstr.predicate(maskElems[vecIdx], "b")
|
||||
.global()
|
||||
.b(width)
|
||||
.v(nWords);
|
||||
|
||||
Value maskVal =
|
||||
llMask ? maskElems[vecIdx]
|
||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||
rewriter.getIntegerType(1), 1);
|
||||
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
@@ -746,8 +758,8 @@ struct StoreOpConversion
|
||||
}
|
||||
|
||||
ptxStoreInstr(asmAddr, asmArgList);
|
||||
|
||||
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||
for (int i = 0; i < nWords; i++)
|
||||
argTys.push_back(valArgTy);
|
||||
|
||||
@@ -970,9 +982,12 @@ struct LoadOpConversion
|
||||
auto elemTy = resultTy.getElementType();
|
||||
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
||||
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
||||
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
||||
SmallVector<Value> maskVals;
|
||||
if (mask) {
|
||||
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
||||
}
|
||||
SmallVector<Value> otherVals;
|
||||
if (other != nullptr) {
|
||||
if (other) {
|
||||
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
||||
}
|
||||
unsigned nbits = elemTy.isa<FloatType>()
|
||||
@@ -1004,7 +1019,10 @@ struct LoadOpConversion
|
||||
// TODO: Handle the optimization if ptr is from GEP and the idx is
|
||||
// constant. This should be a canonicalization pattern in LLVM Dialect
|
||||
unsigned in_off = 0;
|
||||
Value pred = maskVals[i];
|
||||
Value pred =
|
||||
mask ? maskVals[i]
|
||||
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
||||
rewriter.getIntegerType(1), 1);
|
||||
|
||||
// ---
|
||||
// create inline asm string
|
||||
|
@@ -227,8 +227,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
|
||||
adaptor.evict(), adaptor.isVolatile());
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
|
||||
adaptor.isVolatile());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user