[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment correctly considered. Some cases: ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { // mask = make_range(128) < n_element } ``` This should get the vec=2 `ld`/`st` instructions. While the following example ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { // mask = make_range(128) < n_element } ``` it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
@@ -209,6 +209,33 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
}
|
}
|
||||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CmpI
|
||||||
|
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
|
||||||
|
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
|
||||||
|
op->getResult(0).getType().dyn_cast<TensorType>()) {
|
||||||
|
auto resTy = op->getResult(0).getType().cast<TensorType>();
|
||||||
|
short rank = resTy.getRank();
|
||||||
|
auto lhsInfo = operands[0]->getValue();
|
||||||
|
auto rhsInfo = operands[1]->getValue();
|
||||||
|
auto shape = resTy.getShape();
|
||||||
|
|
||||||
|
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||||
|
for (short d = 0; d < rank; ++d) {
|
||||||
|
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
|
||||||
|
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
|
||||||
|
constancy.push_back(
|
||||||
|
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
|
||||||
|
else
|
||||||
|
constancy.push_back(1);
|
||||||
|
|
||||||
|
divisibility.push_back(shape[d]);
|
||||||
|
contiguity.push_back(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
|
}
|
||||||
|
|
||||||
// UnrealizedConversionCast
|
// UnrealizedConversionCast
|
||||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||||
@@ -219,7 +246,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
if (curr.getRank() == 0) {
|
if (curr.getRank() == 0) {
|
||||||
return markAllPessimisticFixpoint(op->getResults());
|
return markAllPessimisticFixpoint(op->getResults());
|
||||||
}
|
}
|
||||||
// join all latice elements
|
|
||||||
|
// join all lattice elements
|
||||||
ChangeResult result = ChangeResult::NoChange;
|
ChangeResult result = ChangeResult::NoChange;
|
||||||
for (Value value : op->getResults()) {
|
for (Value value : op->getResults()) {
|
||||||
result |= getLatticeElement(value).join(curr);
|
result |= getLatticeElement(value).join(curr);
|
||||||
|
@@ -759,6 +759,17 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned getMaskAlignment(Value mask) const {
|
||||||
|
auto maskOrder = mask.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.cast<BlockedEncodingAttr>()
|
||||||
|
.getOrder();
|
||||||
|
|
||||||
|
auto maskAxis = getAxisInfo(mask);
|
||||||
|
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||||
return it->getValue();
|
return it->getValue();
|
||||||
@@ -771,6 +782,208 @@ protected:
|
|||||||
AxisInfoAnalysis &AxisAnalysisPass;
|
AxisInfoAnalysis &AxisAnalysisPass;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LoadOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
||||||
|
public LoadStoreConversionBase {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LoadOpConversion(LLVMTypeConverter &converter,
|
||||||
|
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||||
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value ptr = op.ptr();
|
||||||
|
Value mask = op.mask();
|
||||||
|
Value other = op.other();
|
||||||
|
|
||||||
|
Value llPtr = adaptor.ptr();
|
||||||
|
Value llMask = adaptor.mask();
|
||||||
|
Value llOther = adaptor.other();
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
|
||||||
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!valueTy)
|
||||||
|
return failure();
|
||||||
|
Type valueElemTy =
|
||||||
|
getTypeConverter()->convertType(valueTy.getElementType());
|
||||||
|
|
||||||
|
auto [layout, numElems] = getLayout(ptr);
|
||||||
|
|
||||||
|
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||||
|
assert(ptrElems.size() == numElems);
|
||||||
|
// Determine the vectorization size
|
||||||
|
size_t vec = getVectorizeSize(ptr, layout);
|
||||||
|
|
||||||
|
SmallVector<Value> maskElems;
|
||||||
|
if (llMask) {
|
||||||
|
unsigned maskAlignment = getMaskAlignment(mask);
|
||||||
|
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||||
|
assert(ptrElems.size() == maskElems.size());
|
||||||
|
|
||||||
|
size_t maskAlign = getMaskAlignment(mask);
|
||||||
|
vec = std::min(vec, maskAlign);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t dtsize =
|
||||||
|
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||||
|
const size_t valueElemNbits = dtsize * 8;
|
||||||
|
|
||||||
|
const int numVecs = numElems / vec;
|
||||||
|
|
||||||
|
// TODO: (goostavz) handle when other is const but not splat, which
|
||||||
|
// should be rarely seen
|
||||||
|
bool otherIsSplatConstInt = false;
|
||||||
|
DenseElementsAttr constAttr;
|
||||||
|
int64_t splatVal = 0;
|
||||||
|
if (valueElemTy.isa<IntegerType>() &&
|
||||||
|
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
||||||
|
constAttr.isSplat()) {
|
||||||
|
otherIsSplatConstInt = true;
|
||||||
|
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||||
|
|
||||||
|
SmallVector<Value> loadedVals;
|
||||||
|
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||||
|
// TODO: optimization when ptr is GEP with constant offset
|
||||||
|
size_t in_off = 0;
|
||||||
|
|
||||||
|
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||||
|
const int totalWidth = valueElemNbits * vec;
|
||||||
|
const int width = std::min(totalWidth, maxWordWidth);
|
||||||
|
const int nWords = std::max(1, totalWidth / width);
|
||||||
|
const int wordNElems = width / valueElemNbits;
|
||||||
|
const int vecNElems = totalWidth / valueElemNbits;
|
||||||
|
assert(wordNElems * nWords * numVecs == numElems);
|
||||||
|
|
||||||
|
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||||
|
// TODO(Superjomn) Deal with cache policy here.
|
||||||
|
const bool hasL2EvictPolicy = false;
|
||||||
|
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||||
|
|
||||||
|
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||||
|
|
||||||
|
const std::string readConstraint =
|
||||||
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
|
const std::string writeConstraint =
|
||||||
|
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||||
|
|
||||||
|
// prepare asm operands
|
||||||
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||||
|
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||||
|
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||||
|
dstsOpr->listAppend(opr);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *addrOpr =
|
||||||
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
|
// Define the instruction opcode
|
||||||
|
ld.o("volatile", op.isVolatile())
|
||||||
|
.global()
|
||||||
|
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||||
|
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||||
|
.o("L1::evict_first",
|
||||||
|
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||||
|
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||||
|
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||||
|
.v(nWords)
|
||||||
|
.b(width);
|
||||||
|
|
||||||
|
PTXBuilder::Operand *evictOpr{};
|
||||||
|
|
||||||
|
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||||
|
// if (has_l2_evict_policy)
|
||||||
|
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||||
|
|
||||||
|
if (!evictOpr)
|
||||||
|
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
||||||
|
else
|
||||||
|
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||||
|
|
||||||
|
if (other) {
|
||||||
|
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||||
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||||
|
mov.o("u", width);
|
||||||
|
|
||||||
|
size_t size = width / valueElemNbits;
|
||||||
|
|
||||||
|
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||||
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||||
|
for (size_t s = 0; s < size; ++s) {
|
||||||
|
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||||
|
Value sVal = createIndexAttrConstant(
|
||||||
|
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||||
|
v = insert_element(vecTy, v, falseVal, sVal);
|
||||||
|
}
|
||||||
|
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||||
|
|
||||||
|
PTXInstr::Operand *opr{};
|
||||||
|
if (otherIsSplatConstInt)
|
||||||
|
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||||
|
else
|
||||||
|
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||||
|
|
||||||
|
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---
|
||||||
|
// create inline ASM signature
|
||||||
|
// ---
|
||||||
|
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
||||||
|
Type retTy = retTys.size() > 1
|
||||||
|
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
||||||
|
: retTys[0];
|
||||||
|
|
||||||
|
// TODO: if (has_l2_evict_policy)
|
||||||
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||||
|
LLVM::AsmDialect::AD_ATT);
|
||||||
|
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||||
|
|
||||||
|
// ---
|
||||||
|
// extract and store return values
|
||||||
|
// ---
|
||||||
|
SmallVector<Value> rets;
|
||||||
|
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
||||||
|
Value curr;
|
||||||
|
if (retTy.isa<LLVM::LLVMStructType>()) {
|
||||||
|
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
||||||
|
rewriter.getI64ArrayAttr(ii));
|
||||||
|
} else {
|
||||||
|
curr = ret;
|
||||||
|
}
|
||||||
|
curr = bitcast(
|
||||||
|
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||||
|
curr);
|
||||||
|
rets.push_back(curr);
|
||||||
|
}
|
||||||
|
int tmp = width / valueElemNbits;
|
||||||
|
for (size_t ii = 0; ii < vec; ++ii) {
|
||||||
|
Value vecIdx = createIndexAttrConstant(
|
||||||
|
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||||
|
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
||||||
|
loadedVals.push_back(loaded);
|
||||||
|
}
|
||||||
|
} // end vec
|
||||||
|
|
||||||
|
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||||
|
Value resultStruct =
|
||||||
|
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||||
|
rewriter.replaceOp(op, {resultStruct});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct StoreOpConversion
|
struct StoreOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
|
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
|
||||||
public LoadStoreConversionBase {
|
public LoadStoreConversionBase {
|
||||||
@@ -814,14 +1027,8 @@ struct StoreOpConversion
|
|||||||
if (llMask) {
|
if (llMask) {
|
||||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||||
assert(valueElems.size() == maskElems.size());
|
assert(valueElems.size() == maskElems.size());
|
||||||
auto maskOrder = mask.getType()
|
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getEncoding()
|
|
||||||
.cast<BlockedEncodingAttr>()
|
|
||||||
.getOrder();
|
|
||||||
|
|
||||||
auto maskAxis = getAxisInfo(mask);
|
size_t maskAlign = getMaskAlignment(mask);
|
||||||
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
|
||||||
vec = std::min(vec, maskAlign);
|
vec = std::min(vec, maskAlign);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -846,15 +1053,10 @@ struct StoreOpConversion
|
|||||||
// TODO(Superjomn) Deal with cache policy here.
|
// TODO(Superjomn) Deal with cache policy here.
|
||||||
const bool hasL2EvictPolicy = false;
|
const bool hasL2EvictPolicy = false;
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
|
||||||
auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
|
|
||||||
|
|
||||||
llvm::SmallVector<std::string> asmArgs;
|
|
||||||
|
|
||||||
Type valArgTy = IntegerType::get(ctx, width);
|
Type valArgTy = IntegerType::get(ctx, width);
|
||||||
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
||||||
|
|
||||||
auto *asmArgList = ptxBuilder.newListOperand();
|
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||||
// llWord is a width-len composition
|
// llWord is a width-len composition
|
||||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||||
@@ -876,23 +1078,25 @@ struct StoreOpConversion
|
|||||||
llWord = bitcast(valArgTy, llWord);
|
llWord = bitcast(valArgTy, llWord);
|
||||||
std::string constraint =
|
std::string constraint =
|
||||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
asmArgs.emplace_back(llWord, constraint);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(Superjomn) Need to check masks before vectorize the load for
|
// Prepare the PTX inline asm.
|
||||||
// the values share one predicate? Here assume all the mask values are
|
PTXBuilder ptxBuilder;
|
||||||
// the same.
|
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
|
||||||
|
|
||||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||||
ptxStoreInstr.global().b(width).v(nWords);
|
|
||||||
|
|
||||||
auto *asmAddr =
|
auto *asmAddr =
|
||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
|
auto &ptxStoreInstr =
|
||||||
|
ptxBuilder.create<PTXIOInstr>("st")->global().b(width).v(nWords);
|
||||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||||
|
|
||||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||||
for (int i = 0; i < nWords; ++i)
|
argTys.insert(argTys.end(), nWords, valArgTy);
|
||||||
argTys.push_back(valArgTy);
|
|
||||||
|
|
||||||
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
||||||
|
|
||||||
@@ -1065,209 +1269,6 @@ struct MakeRangeOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LoadOpConversion
|
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
|
||||||
public LoadStoreConversionBase {
|
|
||||||
using ConvertTritonGPUOpToLLVMPattern<
|
|
||||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
||||||
|
|
||||||
LoadOpConversion(LLVMTypeConverter &converter,
|
|
||||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
|
||||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
Value ptr = op.ptr();
|
|
||||||
Value mask = op.mask();
|
|
||||||
Value other = op.other();
|
|
||||||
|
|
||||||
Value llPtr = adaptor.ptr();
|
|
||||||
Value llMask = adaptor.mask();
|
|
||||||
Value llOther = adaptor.other();
|
|
||||||
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
|
||||||
|
|
||||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!valueTy)
|
|
||||||
return failure();
|
|
||||||
Type valueElemTy =
|
|
||||||
getTypeConverter()->convertType(valueTy.getElementType());
|
|
||||||
|
|
||||||
auto [layout, numElems] = getLayout(ptr);
|
|
||||||
|
|
||||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
|
||||||
assert(ptrElems.size() == numElems);
|
|
||||||
|
|
||||||
SmallVector<Value> maskElems;
|
|
||||||
if (llMask) {
|
|
||||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
|
||||||
assert(ptrElems.size() == maskElems.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the vectorization size
|
|
||||||
size_t vec = getVectorizeSize(ptr, layout);
|
|
||||||
|
|
||||||
const size_t dtsize =
|
|
||||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
|
||||||
const size_t valueElemNbits = dtsize * 8;
|
|
||||||
|
|
||||||
const int numVecs = numElems / vec;
|
|
||||||
|
|
||||||
// TODO: (goostavz) handle when other is const but not splat, which
|
|
||||||
// should be rarely seen
|
|
||||||
bool otherIsSplatConstInt = false;
|
|
||||||
DenseElementsAttr constAttr;
|
|
||||||
int64_t splatVal = 0;
|
|
||||||
if (valueElemTy.isa<IntegerType>() &&
|
|
||||||
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
|
||||||
constAttr.isSplat()) {
|
|
||||||
otherIsSplatConstInt = true;
|
|
||||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<Value> loadedVals;
|
|
||||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
|
||||||
// TODO: optimization when ptr is GEP with constant offset
|
|
||||||
size_t in_off = 0;
|
|
||||||
|
|
||||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
|
||||||
const int totalWidth = valueElemNbits * vec;
|
|
||||||
const int width = std::min(totalWidth, maxWordWidth);
|
|
||||||
const int nWords = std::max(1, totalWidth / width);
|
|
||||||
const int wordNElems = width / valueElemNbits;
|
|
||||||
const int vecNElems = totalWidth / valueElemNbits;
|
|
||||||
assert(wordNElems * nWords * numVecs == numElems);
|
|
||||||
|
|
||||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
|
||||||
// TODO(Superjomn) Deal with cache policy here.
|
|
||||||
const bool hasL2EvictPolicy = false;
|
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
|
||||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
|
||||||
|
|
||||||
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
||||||
// the values share one predicate? Here assume all the mask values are
|
|
||||||
// the same.
|
|
||||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
|
||||||
|
|
||||||
const std::string readConstraint =
|
|
||||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
||||||
const std::string writeConstraint =
|
|
||||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
|
||||||
|
|
||||||
// prepare asm operands
|
|
||||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
|
||||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
|
||||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
|
||||||
dstsOpr->listAppend(opr);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto *addrOpr =
|
|
||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
|
||||||
|
|
||||||
// Define the instruction opcode
|
|
||||||
ld.o("volatile", op.isVolatile())
|
|
||||||
.global()
|
|
||||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
|
||||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
|
||||||
.o("L1::evict_first",
|
|
||||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
|
||||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
|
||||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
|
||||||
.v(nWords)
|
|
||||||
.b(width);
|
|
||||||
|
|
||||||
PTXBuilder::Operand *evictOpr{};
|
|
||||||
|
|
||||||
// Here lack a mlir::Value to bind to this operation, so disabled.
|
|
||||||
// if (has_l2_evict_policy)
|
|
||||||
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
|
||||||
|
|
||||||
if (!evictOpr)
|
|
||||||
ld(dstsOpr, addrOpr).predicate(pred, "b");
|
|
||||||
else
|
|
||||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
|
||||||
|
|
||||||
if (other) {
|
|
||||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
|
||||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
|
||||||
mov.o("u", width);
|
|
||||||
|
|
||||||
size_t size = width / valueElemNbits;
|
|
||||||
|
|
||||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
|
||||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
||||||
for (size_t s = 0; s < size; ++s) {
|
|
||||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
|
||||||
Value sVal = createIndexAttrConstant(
|
|
||||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
|
||||||
v = insert_element(vecTy, v, falseVal, sVal);
|
|
||||||
}
|
|
||||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
|
||||||
|
|
||||||
PTXInstr::Operand *opr{};
|
|
||||||
if (otherIsSplatConstInt) {
|
|
||||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
|
||||||
} else {
|
|
||||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
|
||||||
}
|
|
||||||
|
|
||||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---
|
|
||||||
// create inline ASM signature
|
|
||||||
// ---
|
|
||||||
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
|
||||||
Type retTy = retTys.size() > 1
|
|
||||||
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
|
||||||
: retTys[0];
|
|
||||||
|
|
||||||
// TODO: if (has_l2_evict_policy)
|
|
||||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
||||||
LLVM::AsmDialect::AD_ATT);
|
|
||||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
|
||||||
|
|
||||||
// ---
|
|
||||||
// extract and store return values
|
|
||||||
// ---
|
|
||||||
SmallVector<Value> rets;
|
|
||||||
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
|
||||||
Value curr;
|
|
||||||
if (retTy.isa<LLVM::LLVMStructType>()) {
|
|
||||||
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
|
||||||
rewriter.getI64ArrayAttr(ii));
|
|
||||||
} else {
|
|
||||||
curr = ret;
|
|
||||||
}
|
|
||||||
curr = bitcast(
|
|
||||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
|
||||||
curr);
|
|
||||||
rets.push_back(curr);
|
|
||||||
}
|
|
||||||
int tmp = width / valueElemNbits;
|
|
||||||
for (size_t ii = 0; ii < vec; ++ii) {
|
|
||||||
Value vecIdx = createIndexAttrConstant(
|
|
||||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
|
||||||
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
|
||||||
loadedVals.push_back(loaded);
|
|
||||||
}
|
|
||||||
} // end vec
|
|
||||||
|
|
||||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
|
||||||
Value resultStruct =
|
|
||||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
|
||||||
rewriter.replaceOp(op, {resultStruct});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GetProgramIdOpConversion
|
struct GetProgramIdOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
|
||||||
using ConvertTritonGPUOpToLLVMPattern<
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "triton/Analysis/AxisInfo.h"
|
#include "triton/Analysis/AxisInfo.h"
|
||||||
|
#include "triton/Analysis/Utility.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@@ -23,6 +24,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|||||||
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
||||||
return contiguity[x] > contiguity[y];
|
return contiguity[x] > contiguity[y];
|
||||||
});
|
});
|
||||||
|
|
||||||
|
int numElems = product(origType.getShape());
|
||||||
|
int numThreads = numWarps * 32;
|
||||||
|
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
||||||
|
|
||||||
// Thread tile size depends on memory alignment
|
// Thread tile size depends on memory alignment
|
||||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||||
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
||||||
@@ -31,7 +37,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|||||||
unsigned maxContig = info.getContiguity(order[0]);
|
unsigned maxContig = info.getContiguity(order[0]);
|
||||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||||
unsigned perThread = std::min(alignment, 128 / numBits);
|
unsigned perThread = std::min(alignment, 128 / numBits);
|
||||||
sizePerThread[order[0]] = perThread;
|
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
|
||||||
|
|
||||||
SmallVector<unsigned> dims(rank);
|
SmallVector<unsigned> dims(rank);
|
||||||
std::iota(dims.begin(), dims.end(), 0);
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
// create encoding
|
// create encoding
|
||||||
|
@@ -188,7 +188,7 @@ def test_vecadd_no_scf(num_warps, block_size, shape):
|
|||||||
[2, 256, (3, 256 + 7)],
|
[2, 256, (3, 256 + 7)],
|
||||||
[4, 256, (3, 256 + 7)],
|
[4, 256, (3, 256 + 7)],
|
||||||
])
|
])
|
||||||
def test_vecadd__no_scf_masked(num_warps, block_size, shape):
|
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
|
||||||
vecadd_no_scf_tester(num_warps, block_size, shape)
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -50,3 +50,92 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
|||||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
|
||||||
|
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||||
|
func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||||
|
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||||
|
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||||
|
%c128_i32 = arith.constant 128 : i32
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||||
|
%1 = arith.muli %pid, %c128_i32 : i32
|
||||||
|
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||||
|
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128]
|
||||||
|
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1]
|
||||||
|
%4 = arith.addi %3, %2 : tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||||
|
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
|
||||||
|
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||||
|
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
|
||||||
|
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||||
|
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
||||||
|
tt.store %5, %cst, %mask : tensor<128xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// This IR is dumped from vecadd test.
|
||||||
|
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
||||||
|
func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||||
|
%c64_i32 = arith.constant 64 : i32
|
||||||
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
%1 = arith.muli %0, %c64_i32 : i32
|
||||||
|
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
|
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||||
|
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||||
|
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||||
|
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||||
|
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||||
|
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||||
|
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||||
|
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||||
|
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||||
|
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||||
|
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||||
|
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||||
|
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||||
|
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>> )
|
||||||
|
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||||
|
tt.store %15, %13, %mask : tensor<64xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// This IR is dumped from vecadd test.
|
||||||
|
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
||||||
|
func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||||
|
%c64_i32 = arith.constant 64 : i32
|
||||||
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
%1 = arith.muli %0, %c64_i32 : i32
|
||||||
|
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
|
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||||
|
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||||
|
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||||
|
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||||
|
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||||
|
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||||
|
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||||
|
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||||
|
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||||
|
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||||
|
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||||
|
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||||
|
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||||
|
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||||
|
tt.store %15, %13, %10 : tensor<64xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -161,6 +161,37 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// This test verifies the vectorization of Load and Store Ops.
|
||||||
|
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||||
|
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
|
||||||
|
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||||
|
func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||||
|
%c64_i32 = arith.constant 64 : i32
|
||||||
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
%1 = arith.muli %0, %c64_i32 : i32
|
||||||
|
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
|
||||||
|
%3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked>
|
||||||
|
%4 = arith.addi %3, %2 : tensor<64xi32, #blocked>
|
||||||
|
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||||
|
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||||
|
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||||
|
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||||
|
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
|
||||||
|
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
|
||||||
|
// load op has a vector width = 1 due to the %mask's alignment
|
||||||
|
// CHECK: ld.global.b32
|
||||||
|
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
|
||||||
|
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
|
||||||
|
%13 = arith.addf %11, %12 : tensor<64xf32, #blocked>
|
||||||
|
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
|
||||||
|
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>, #blocked>
|
||||||
|
tt.store %15, %13, %10 : tensor<64xf32, #blocked>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
// CHECK-LABEL: global_load_store_vec8
|
// CHECK-LABEL: global_load_store_vec8
|
||||||
|
Reference in New Issue
Block a user