[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment correctly considered. Some cases: ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { // mask = make_range(128) < n_element } ``` This should get the vec=2 `ld`/`st` instructions. While the following example ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { // mask = make_range(128) < n_element } ``` it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
@@ -209,6 +209,33 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// CmpI
|
||||
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
|
||||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
|
||||
op->getResult(0).getType().dyn_cast<TensorType>()) {
|
||||
auto resTy = op->getResult(0).getType().cast<TensorType>();
|
||||
short rank = resTy.getRank();
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto shape = resTy.getShape();
|
||||
|
||||
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||
for (short d = 0; d < rank; ++d) {
|
||||
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
|
||||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
|
||||
constancy.push_back(
|
||||
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
|
||||
else
|
||||
constancy.push_back(1);
|
||||
|
||||
divisibility.push_back(shape[d]);
|
||||
contiguity.push_back(1);
|
||||
}
|
||||
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// UnrealizedConversionCast
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
@@ -219,7 +246,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
|
||||
// join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(curr);
|
||||
|
@@ -759,6 +759,17 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
return vec;
|
||||
}
|
||||
|
||||
unsigned getMaskAlignment(Value mask) const {
|
||||
auto maskOrder = mask.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
}
|
||||
|
||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||
return it->getValue();
|
||||
@@ -771,6 +782,208 @@ protected:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
||||
public LoadStoreConversionBase {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(LLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value ptr = op.ptr();
|
||||
Value mask = op.mask();
|
||||
Value other = op.other();
|
||||
|
||||
Value llPtr = adaptor.ptr();
|
||||
Value llMask = adaptor.mask();
|
||||
Value llOther = adaptor.other();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
auto [layout, numElems] = getLayout(ptr);
|
||||
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == numElems);
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVectorizeSize(ptr, layout);
|
||||
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
unsigned maskAlignment = getMaskAlignment(mask);
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == maskElems.size());
|
||||
|
||||
size_t maskAlign = getMaskAlignment(mask);
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
|
||||
// TODO: (goostavz) handle when other is const but not splat, which
|
||||
// should be rarely seen
|
||||
bool otherIsSplatConstInt = false;
|
||||
DenseElementsAttr constAttr;
|
||||
int64_t splatVal = 0;
|
||||
if (valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
|
||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||
|
||||
SmallVector<Value> loadedVals;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
|
||||
const std::string readConstraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
.o("L1::evict_first",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
|
||||
PTXBuilder::Operand *evictOpr{};
|
||||
|
||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||
// if (has_l2_evict_policy)
|
||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||
|
||||
if (!evictOpr)
|
||||
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||
else
|
||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
mov.o("u", width);
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||
|
||||
PTXInstr::Operand *opr{};
|
||||
if (otherIsSplatConstInt)
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
else
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
|
||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
||||
Type retTy = retTys.size() > 1
|
||||
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
// ---
|
||||
// extract and store return values
|
||||
// ---
|
||||
SmallVector<Value> rets;
|
||||
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
||||
Value curr;
|
||||
if (retTy.isa<LLVM::LLVMStructType>()) {
|
||||
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
||||
rewriter.getI64ArrayAttr(ii));
|
||||
} else {
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(
|
||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||
curr);
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
for (size_t ii = 0; ii < vec; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
||||
loadedVals.push_back(loaded);
|
||||
}
|
||||
} // end vec
|
||||
|
||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StoreOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
|
||||
public LoadStoreConversionBase {
|
||||
@@ -814,14 +1027,8 @@ struct StoreOpConversion
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
auto maskOrder = mask.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
size_t maskAlign = getMaskAlignment(mask);
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
@@ -846,15 +1053,10 @@ struct StoreOpConversion
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
||||
|
||||
auto *asmArgList = ptxBuilder.newListOperand();
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
@@ -876,23 +1078,25 @@ struct StoreOpConversion
|
||||
llWord = bitcast(valArgTy, llWord);
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||
asmArgs.emplace_back(llWord, constraint);
|
||||
}
|
||||
|
||||
// TODO(Superjomn) Need to check masks before vectorize the load for
|
||||
// the values share one predicate? Here assume all the mask values are
|
||||
// the same.
|
||||
// Prepare the PTX inline asm.
|
||||
PTXBuilder ptxBuilder;
|
||||
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
|
||||
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
ptxStoreInstr.global().b(width).v(nWords);
|
||||
|
||||
auto *asmAddr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
auto &ptxStoreInstr =
|
||||
ptxBuilder.create<PTXIOInstr>("st")->global().b(width).v(nWords);
|
||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||
for (int i = 0; i < nWords; ++i)
|
||||
argTys.push_back(valArgTy);
|
||||
argTys.insert(argTys.end(), nWords, valArgTy);
|
||||
|
||||
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
||||
|
||||
@@ -1065,209 +1269,6 @@ struct MakeRangeOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
||||
public LoadStoreConversionBase {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(LLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value ptr = op.ptr();
|
||||
Value mask = op.mask();
|
||||
Value other = op.other();
|
||||
|
||||
Value llPtr = adaptor.ptr();
|
||||
Value llMask = adaptor.mask();
|
||||
Value llOther = adaptor.other();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
auto [layout, numElems] = getLayout(ptr);
|
||||
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == numElems);
|
||||
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVectorizeSize(ptr, layout);
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
|
||||
// TODO: (goostavz) handle when other is const but not splat, which
|
||||
// should be rarely seen
|
||||
bool otherIsSplatConstInt = false;
|
||||
DenseElementsAttr constAttr;
|
||||
int64_t splatVal = 0;
|
||||
if (valueElemTy.isa<IntegerType>() &&
|
||||
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
||||
constAttr.isSplat()) {
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
|
||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||
|
||||
SmallVector<Value> loadedVals;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||
|
||||
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
||||
// the values share one predicate? Here assume all the mask values are
|
||||
// the same.
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
|
||||
const std::string readConstraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
.o("L1::evict_first",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
|
||||
PTXBuilder::Operand *evictOpr{};
|
||||
|
||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||
// if (has_l2_evict_policy)
|
||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||
|
||||
if (!evictOpr)
|
||||
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||
else
|
||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
mov.o("u", width);
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||
|
||||
PTXInstr::Operand *opr{};
|
||||
if (otherIsSplatConstInt) {
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
} else {
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
}
|
||||
|
||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
||||
Type retTy = retTys.size() > 1
|
||||
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
// ---
|
||||
// extract and store return values
|
||||
// ---
|
||||
SmallVector<Value> rets;
|
||||
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
||||
Value curr;
|
||||
if (retTy.isa<LLVM::LLVMStructType>()) {
|
||||
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
||||
rewriter.getI64ArrayAttr(ii));
|
||||
} else {
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(
|
||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||
curr);
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
for (size_t ii = 0; ii < vec; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
||||
loadedVals.push_back(loaded);
|
||||
}
|
||||
} // end vec
|
||||
|
||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GetProgramIdOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <numeric>
|
||||
@@ -23,6 +24,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
||||
return contiguity[x] > contiguity[y];
|
||||
});
|
||||
|
||||
int numElems = product(origType.getShape());
|
||||
int numThreads = numWarps * 32;
|
||||
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
||||
|
||||
// Thread tile size depends on memory alignment
|
||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
||||
@@ -31,7 +37,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
unsigned maxContig = info.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
unsigned perThread = std::min(alignment, 128 / numBits);
|
||||
sizePerThread[order[0]] = perThread;
|
||||
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
|
||||
|
||||
SmallVector<unsigned> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
// create encoding
|
||||
|
Reference in New Issue
Block a user