[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment correctly considered. Some cases: ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { // mask = make_range(128) < n_element } ``` This should get the vec=2 `ld`/`st` instructions. While the following example ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { // mask = make_range(128) < n_element } ``` it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
@@ -209,6 +209,33 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// CmpI
|
||||
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
|
||||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
|
||||
op->getResult(0).getType().dyn_cast<TensorType>()) {
|
||||
auto resTy = op->getResult(0).getType().cast<TensorType>();
|
||||
short rank = resTy.getRank();
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto shape = resTy.getShape();
|
||||
|
||||
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||
for (short d = 0; d < rank; ++d) {
|
||||
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
|
||||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
|
||||
constancy.push_back(
|
||||
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
|
||||
else
|
||||
constancy.push_back(1);
|
||||
|
||||
divisibility.push_back(shape[d]);
|
||||
contiguity.push_back(1);
|
||||
}
|
||||
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// UnrealizedConversionCast
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
@@ -219,7 +246,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
|
||||
// join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(curr);
|
||||
|
@@ -759,6 +759,17 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
return vec;
|
||||
}
|
||||
|
||||
unsigned getMaskAlignment(Value mask) const {
|
||||
auto maskOrder = mask.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
}
|
||||
|
||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||
return it->getValue();
|
||||
@@ -771,6 +782,208 @@ protected:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
||||
public LoadStoreConversionBase {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(LLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value ptr = op.ptr();
|
||||
Value mask = op.mask();
|
||||
Value other = op.other();
|
||||
|
||||
Value llPtr = adaptor.ptr();
|
||||
Value llMask = adaptor.mask();
|
||||
Value llOther = adaptor.other();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
auto [layout, numElems] = getLayout(ptr);
|
||||
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == numElems);
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVectorizeSize(ptr, layout);
|
||||
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
unsigned maskAlignment = getMaskAlignment(mask);
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == maskElems.size());
|
||||
|
||||
size_t maskAlign = getMaskAlignment(mask);
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
|
||||
// TODO: (goostavz) handle when other is const but not splat, which
|
||||
// should be rarely seen
|
||||
bool otherIsSplatConstInt = false;
|
||||
DenseElementsAttr constAttr;
|
||||
int64_t splatVal = 0;
|
||||
if (valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
|
||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||
|
||||
SmallVector<Value> loadedVals;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
|
||||
const std::string readConstraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
.o("L1::evict_first",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
|
||||
PTXBuilder::Operand *evictOpr{};
|
||||
|
||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||
// if (has_l2_evict_policy)
|
||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||
|
||||
if (!evictOpr)
|
||||
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||
else
|
||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
mov.o("u", width);
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||
|
||||
PTXInstr::Operand *opr{};
|
||||
if (otherIsSplatConstInt)
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
else
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
|
||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
||||
Type retTy = retTys.size() > 1
|
||||
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
// ---
|
||||
// extract and store return values
|
||||
// ---
|
||||
SmallVector<Value> rets;
|
||||
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
||||
Value curr;
|
||||
if (retTy.isa<LLVM::LLVMStructType>()) {
|
||||
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
||||
rewriter.getI64ArrayAttr(ii));
|
||||
} else {
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(
|
||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||
curr);
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
for (size_t ii = 0; ii < vec; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
||||
loadedVals.push_back(loaded);
|
||||
}
|
||||
} // end vec
|
||||
|
||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StoreOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
|
||||
public LoadStoreConversionBase {
|
||||
@@ -814,14 +1027,8 @@ struct StoreOpConversion
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
auto maskOrder = mask.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
size_t maskAlign = getMaskAlignment(mask);
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
@@ -846,15 +1053,10 @@ struct StoreOpConversion
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
||||
|
||||
auto *asmArgList = ptxBuilder.newListOperand();
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
@@ -876,23 +1078,25 @@ struct StoreOpConversion
|
||||
llWord = bitcast(valArgTy, llWord);
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||
asmArgs.emplace_back(llWord, constraint);
|
||||
}
|
||||
|
||||
// TODO(Superjomn) Need to check masks before vectorize the load for
|
||||
// the values share one predicate? Here assume all the mask values are
|
||||
// the same.
|
||||
// Prepare the PTX inline asm.
|
||||
PTXBuilder ptxBuilder;
|
||||
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
|
||||
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
ptxStoreInstr.global().b(width).v(nWords);
|
||||
|
||||
auto *asmAddr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
auto &ptxStoreInstr =
|
||||
ptxBuilder.create<PTXIOInstr>("st")->global().b(width).v(nWords);
|
||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||
for (int i = 0; i < nWords; ++i)
|
||||
argTys.push_back(valArgTy);
|
||||
argTys.insert(argTys.end(), nWords, valArgTy);
|
||||
|
||||
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
||||
|
||||
@@ -1065,209 +1269,6 @@ struct MakeRangeOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
||||
public LoadStoreConversionBase {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(LLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value ptr = op.ptr();
|
||||
Value mask = op.mask();
|
||||
Value other = op.other();
|
||||
|
||||
Value llPtr = adaptor.ptr();
|
||||
Value llMask = adaptor.mask();
|
||||
Value llOther = adaptor.other();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
auto [layout, numElems] = getLayout(ptr);
|
||||
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == numElems);
|
||||
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVectorizeSize(ptr, layout);
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
|
||||
// TODO: (goostavz) handle when other is const but not splat, which
|
||||
// should be rarely seen
|
||||
bool otherIsSplatConstInt = false;
|
||||
DenseElementsAttr constAttr;
|
||||
int64_t splatVal = 0;
|
||||
if (valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
|
||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||
|
||||
SmallVector<Value> loadedVals;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||
|
||||
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
||||
// the values share one predicate? Here assume all the mask values are
|
||||
// the same.
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
|
||||
const std::string readConstraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
.o("L1::evict_first",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
|
||||
PTXBuilder::Operand *evictOpr{};
|
||||
|
||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||
// if (has_l2_evict_policy)
|
||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||
|
||||
if (!evictOpr)
|
||||
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||
else
|
||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
mov.o("u", width);
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||
|
||||
PTXInstr::Operand *opr{};
|
||||
if (otherIsSplatConstInt) {
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
} else {
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
}
|
||||
|
||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
||||
Type retTy = retTys.size() > 1
|
||||
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
// ---
|
||||
// extract and store return values
|
||||
// ---
|
||||
SmallVector<Value> rets;
|
||||
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
||||
Value curr;
|
||||
if (retTy.isa<LLVM::LLVMStructType>()) {
|
||||
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
||||
rewriter.getI64ArrayAttr(ii));
|
||||
} else {
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(
|
||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||
curr);
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
for (size_t ii = 0; ii < vec; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
||||
loadedVals.push_back(loaded);
|
||||
}
|
||||
} // end vec
|
||||
|
||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GetProgramIdOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <numeric>
|
||||
@@ -23,6 +24,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
||||
return contiguity[x] > contiguity[y];
|
||||
});
|
||||
|
||||
int numElems = product(origType.getShape());
|
||||
int numThreads = numWarps * 32;
|
||||
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
||||
|
||||
// Thread tile size depends on memory alignment
|
||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
||||
@@ -31,7 +37,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
unsigned maxContig = info.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
unsigned perThread = std::min(alignment, 128 / numBits);
|
||||
sizePerThread[order[0]] = perThread;
|
||||
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
|
||||
|
||||
SmallVector<unsigned> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
// create encoding
|
||||
|
@@ -188,7 +188,7 @@ def test_vecadd_no_scf(num_warps, block_size, shape):
|
||||
[2, 256, (3, 256 + 7)],
|
||||
[4, 256, (3, 256 + 7)],
|
||||
])
|
||||
def test_vecadd__no_scf_masked(num_warps, block_size, shape):
|
||||
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
|
||||
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||
|
||||
|
||||
|
@@ -50,3 +50,92 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||
%1 = arith.muli %pid, %c128_i32 : i32
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128]
|
||||
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1]
|
||||
%4 = arith.addi %3, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
|
||||
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
||||
tt.store %5, %cst, %mask : tensor<128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
||||
func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>> )
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||
tt.store %15, %13, %mask : tensor<64xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
||||
func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||
tt.store %15, %13, %10 : tensor<64xf32>
|
||||
return
|
||||
}
|
||||
|
@@ -161,6 +161,37 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
// This test verifies the vectorization of Load and Store Ops.
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32, #blocked>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
|
||||
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
|
||||
// load op has a vector width = 1 due to the %mask's alignment
|
||||
// CHECK: ld.global.b32
|
||||
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32, #blocked>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||
tt.store %15, %13, %10 : tensor<64xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: global_load_store_vec8
|
||||
@@ -682,4 +713,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user