[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)

This commit is contained in:
Keren Zhou
2022-10-04 09:37:00 -07:00
committed by GitHub
parent f9d7f2f126
commit 289ff293cc
9 changed files with 412 additions and 57 deletions

View File

@@ -2,6 +2,7 @@
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ #define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <memory> #include <memory>
@@ -99,8 +100,9 @@ struct PTXBuilder {
std::string dump() const; std::string dump() const;
}; };
template <typename INSTR = PTXInstr> INSTR *create(const std::string &name) { template <typename INSTR = PTXInstr, typename... Args>
instrs.emplace_back(std::make_unique<INSTR>(this, name)); INSTR *create(Args &&...args) {
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
return static_cast<INSTR *>(instrs.back().get()); return static_cast<INSTR *>(instrs.back().get());
} }
@@ -188,6 +190,7 @@ struct PTXInstrCommon {
using Operand = PTXBuilder::Operand; using Operand = PTXBuilder::Operand;
// clang-format off // clang-format off
PTXInstrExecution& operator()() { return call({}); }
PTXInstrExecution& operator()(Operand* a) { return call({a}); } PTXInstrExecution& operator()(Operand* a) { return call({a}); }
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); } PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); } PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
@@ -238,17 +241,17 @@ struct PTXInstr : public PTXInstrBase<PTXInstr> {
// PtxIOInstr store("st"); // PtxIOInstr store("st");
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1 // store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
// store.addAddr(addrValue, "l", off); // store.addAddr(addrValue, "l", off);
struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> { struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
using PTXInstrBase<PtxIOInstr>::PTXInstrBase; using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
// Add ".global" suffix to instruction // Add ".global" suffix to instruction
PtxIOInstr &global(bool predicate = true) { PTXIOInstr &global(bool predicate = true) {
o("global", predicate); o("global", predicate);
return *this; return *this;
} }
// Add ".v" suffix to instruction // Add ".v" suffix to instruction
PtxIOInstr &v(int vecWidth, bool predicate = true) { PTXIOInstr &v(int vecWidth, bool predicate = true) {
if (vecWidth > 1) { if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate); o("v" + std::to_string(vecWidth), predicate);
} }
@@ -256,12 +259,43 @@ struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
} }
// Add ".b" suffix to instruction // Add ".b" suffix to instruction
PtxIOInstr &b(int width) { PTXIOInstr &b(int width) {
o("b" + std::to_string(width)); o("b" + std::to_string(width));
return *this; return *this;
} }
}; };
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
: PTXInstrBase(builder, "cp.async") {}
};
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("commit_group");
}
};
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("wait_group");
}
};
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier,
triton::EvictionPolicy policy)
: PTXCpAsyncInstrBase(builder) {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
o("L2::" + triton::stringifyEvictionPolicy(policy).str());
}
};
// Record the operands and context for "launching" a PtxInstr. // Record the operands and context for "launching" a PtxInstr.
struct PTXInstrExecution { struct PTXInstrExecution {
using Operand = PTXBuilder::Operand; using Operand = PTXBuilder::Operand;

View File

@@ -16,7 +16,7 @@ def TT_CacheModifierAttr : I32EnumAttr<
def TT_EvictionPolicyAttr : I32EnumAttr< def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "", "EvictionPolicy", "",
[ [
I32EnumAttrCase<"NORMAL", 1, "normal">, I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> { ]> {

View File

@@ -24,6 +24,8 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<unsigned> getSizePerThread(Attribute layout); SmallVector<unsigned> getSizePerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout); SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout); SmallVector<unsigned> getOrder(const Attribute &layout);

View File

@@ -54,7 +54,7 @@ in memory. For example, a swizzled row-major layout could store its data
as follows: as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2 A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] / A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements groups of vec=2 elements
are stored contiguously are stored contiguously
_ _ _ _ /\_ _ _ _ _ _ _ _ /\_ _ _ _

View File

@@ -141,11 +141,12 @@ PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
std::string PTXInstrExecution::dump() const { std::string PTXInstrExecution::dump() const {
std::string osStr; std::string osStr;
llvm::raw_string_ostream os(osStr); llvm::raw_string_ostream os(osStr);
if (pred) if (pred) {
if (!pred->repr) if (!pred->repr)
os << "@" << pred->dump() << " "; os << "@" << pred->dump() << " ";
else else
os << pred->repr(pred->idx); os << pred->repr(pred->idx);
}
std::string instrRepr = strJoin(instr->instrParts, "."); std::string instrRepr = strJoin(instr->instrParts, ".");

View File

@@ -32,6 +32,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getThreadsPerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -75,7 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__) #define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__) #define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__) #define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define bit_cast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__) #define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__) #define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) #define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__) #define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
@@ -86,6 +87,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__) rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__) #define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr) #define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__) #define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc) #define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__) #define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
@@ -630,7 +632,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding(); auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType); auto srcType = typeConverter->convertType(elemType);
auto llSrc = bit_cast(srcType, constVal); auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc); llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType); llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
@@ -706,22 +708,14 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
// Get corresponding LLVM element values of \param value. // Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue, SmallVector<Value> getLLVMElems(Value value, Value llValue,
const BlockedEncodingAttr &layout, const BlockedEncodingAttr &layout,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Location loc) const { Location loc) const {
if (!value) if (!value)
return {}; return {};
auto ty = value.getType().cast<RankedTensorType>(); auto shape = value.getType().cast<RankedTensorType>().getShape();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout // Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = layout.getElemsPerThread(shape); unsigned valueElems = layout.getElemsPerThread(shape);
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter); auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals; return valueVals;
} }
@@ -810,16 +804,13 @@ struct StoreOpConversion
auto [layout, numElems] = getLayout(ptr); auto [layout, numElems] = getLayout(ptr);
auto ptrElems = auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc); auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
auto valueElems =
getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == valueElems.size()); assert(ptrElems.size() == valueElems.size());
SmallVector<Value> maskElems; SmallVector<Value> maskElems;
if (llMask) { if (llMask) {
maskElems = maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
assert(valueElems.size() == maskElems.size()); assert(valueElems.size() == maskElems.size());
} }
@@ -848,7 +839,7 @@ struct StoreOpConversion
const bool hasL2EvictPolicy = false; const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st"); auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
llvm::SmallVector<std::string> asmArgs; llvm::SmallVector<std::string> asmArgs;
@@ -866,7 +857,7 @@ struct StoreOpConversion
Value elem = valueElems[elemOffset]; Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1)) if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem); elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bit_cast(valueElemTy, elem); elem = bitcast(valueElemTy, elem);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord = llWord =
@@ -874,7 +865,7 @@ struct StoreOpConversion
rewriter.create<LLVM::ConstantOp>( rewriter.create<LLVM::ConstantOp>(
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx))); loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
} }
llWord = bit_cast(valArgTy, llWord); llWord = bitcast(valArgTy, llWord);
std::string constraint = std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c"); (width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
@@ -1100,14 +1091,12 @@ struct LoadOpConversion
auto [layout, numElems] = getLayout(ptr); auto [layout, numElems] = getLayout(ptr);
auto ptrElems = auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == numElems); assert(ptrElems.size() == numElems);
SmallVector<Value> maskElems; SmallVector<Value> maskElems;
if (llMask) { if (llMask) {
maskElems = maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == maskElems.size()); assert(ptrElems.size() == maskElems.size());
} }
@@ -1132,8 +1121,7 @@ struct LoadOpConversion
splatVal = constAttr.getSplatValue<APInt>().getSExtValue(); splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
} }
auto otherElems = auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
SmallVector<Value> loadedVals; SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
@@ -1153,7 +1141,7 @@ struct LoadOpConversion
const bool hasL2EvictPolicy = false; const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld"); auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
// TODO(Superjomn) Need to check masks before vectorize the load for all // TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are // the values share one predicate? Here assume all the mask values are
@@ -1198,7 +1186,6 @@ struct LoadOpConversion
else else
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
SmallVector<Value> others;
if (other) { if (other) {
for (size_t ii = 0; ii < nWords; ++ii) { for (size_t ii = 0; ii < nWords; ++ii) {
PTXInstr &mov = *ptxBuilder.create<>("mov"); PTXInstr &mov = *ptxBuilder.create<>("mov");
@@ -1214,14 +1201,13 @@ struct LoadOpConversion
rewriter, loc, this->getTypeConverter()->getIndexType(), s); rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal); v = insert_element(vecTy, v, falseVal, sVal);
} }
v = bit_cast(IntegerType::get(getContext(), width), v); v = bitcast(IntegerType::get(getContext(), width), v);
PTXInstr::Operand *opr{}; PTXInstr::Operand *opr{};
if (otherIsSplatConstInt) { if (otherIsSplatConstInt) {
opr = ptxBuilder.newConstantOperand(splatVal); opr = ptxBuilder.newConstantOperand(splatVal);
} else { } else {
opr = ptxBuilder.newOperand(v, readConstraint); opr = ptxBuilder.newOperand(v, readConstraint);
others.push_back(v);
} }
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
@@ -1253,7 +1239,7 @@ struct LoadOpConversion
} else { } else {
curr = ret; curr = ret;
} }
curr = bit_cast( curr = bitcast(
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
curr); curr);
rets.push_back(curr); rets.push_back(curr);
@@ -1360,9 +1346,8 @@ struct ExtractSliceOpConversion
// axis > 0 will result in non-contiguous memory access if the result tensor // axis > 0 will result in non-contiguous memory access if the result tensor
// is an alias of the source tensor. // is an alias of the source tensor.
auto axis = auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt(); assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
assert(axis == 0 && "Only axis=0 is supported for now");
// Example: // Example:
// %dst = extract_slice %src, %index {axis = 0} // %dst = extract_slice %src, %index {axis = 0}
@@ -1372,12 +1357,11 @@ struct ExtractSliceOpConversion
auto base = product<int64_t>(dstTy.getShape()); auto base = product<int64_t>(dstTy.getShape());
auto baseVal = createIndexAttrConstant( auto baseVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), base); rewriter, loc, getTypeConverter()->getIndexType(), base);
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal); Value offset = mul(adaptor.index(), baseVal);
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value resultVal = Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
rewriter.replaceOp(op, resultVal); rewriter.replaceOp(op, resultVal);
return success(); return success();
} }
@@ -1581,7 +1565,7 @@ void ConvertLayoutOpConversion::processReplica(
auto elemPtrTy = ptr_ty(llvmElemTy, 3); auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset); Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec); auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bit_cast(ptr_ty(vecTy, 3), ptr); ptr = bitcast(ptr_ty(vecTy, 3), ptr);
if (stNotRd) { if (stNotRd) {
Value valVec = undef(vecTy); Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) { for (unsigned v = 0; v < vec; ++v) {
@@ -1614,7 +1598,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3); auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bit_cast(elemPtrTy, smemBase); smemBase = bitcast(elemPtrTy, smemBase);
auto shape = dstTy.getShape(); auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank(); unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank); SmallVector<unsigned> numReplicates(rank);
@@ -1732,7 +1716,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value minVecVal = idx_val(minVec); Value minVecVal = idx_val(minVec);
Value smemBase = getSharedMemoryBase(loc, rewriter, dst); Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bit_cast(elemPtrTy, smemBase); smemBase = bitcast(elemPtrTy, smemBase);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep); unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep); SmallVector<Value> wordVecs(numWordsEachRep);
for (unsigned i = 0; i < numElems; ++i) { for (unsigned i = 0; i < numElems; ++i) {
@@ -1783,7 +1767,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
// step 3: store // step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset); Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr); smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr);
store(wordVecs[linearWordIdx], smemAddr); store(wordVecs[linearWordIdx], smemAddr);
} }
} }
@@ -2126,7 +2110,7 @@ public:
for (int e = 0; e < 4; ++e) for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e)); i8Elems[m][e], i32_val(e));
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]); i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
} }
} else { // k first } else { // k first
Value offset = i32_val(sOffsetElem); Value offset = i32_val(sOffsetElem);
@@ -2144,7 +2128,7 @@ public:
for (int e = 0; e < 4; ++e) for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e)); i8Elems[m][e], i32_val(e));
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]); i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
} }
} }
@@ -2628,7 +2612,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
Type smemPtrTy = helper.getShemPtrTy(); Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) { for (int i = 0; i < numPtrs; ++i) {
ptrs[i] = ptrs[i] =
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
} }
bool needTrans = kOrder != order[0]; bool needTrans = kOrder != order[0];
@@ -2777,6 +2761,229 @@ public:
} }
}; };
struct AsyncWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncWaitOp(ptxBuilder.newConstantOperand(num));
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
struct InsertSliceAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// insert_slice_async %src, %dst, %index, %mask, %other
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.dst();
Value res = op.result();
Value mask = op.mask();
Value other = op.other();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice_async for now");
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = resTy.getElementType();
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.dst();
Value llSrc = adaptor.src();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
Value llIndex = adaptor.index();
// %src
auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc);
// %dst
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
auto dstBase = createIndexAttrConstant(rewriter, loc,
getTypeConverter()->getIndexType(),
product<int64_t>(resTy.getShape()));
Value offset = mul(llIndex, dstBase);
auto dstPtrTy = LLVM::LLVMPointerType::get(
getTypeConverter()->convertType(resTy.getElementType()), 3);
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc);
assert(srcElems.size() == maskElems.size());
}
// %other
SmallVector<Value> otherElems;
if (llOther) {
// TODO(Keren): support "other" tensor.
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
assert(false && "insert_slice_async: Other value not supported yet");
otherElems =
getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}
unsigned inVec = getVectorizeSize(src, srcBlockedLayout);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
auto outOrder = resSharedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements
// across phases.
// If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// minVec = 2, inVec = 4, outVec = 2
// baseOffsetCol = 0 baseOffsetCol = 0
// tileVecIdxCol = 0 tileVecIdxCol = 1
// -/\- -/\-
// [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
// baseOffsetRow [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
auto vecIdx = elemIdx / minVec;
auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec);
auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec);
auto baseOffsetCol =
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
threadsPerCTA[inOrder[1]];
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
auto tileVecIdxCol = vecIdxCol % numVecCols;
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) {
// Swizzling
// Since the swizzling index is related to outVec, and we know minVec
// already, inVec doesn't matter
//
// (Numbers represent row indices)
// Example1:
// outVec = 2, inVec = 2, minVec = 2
// outVec = 2, inVec = 4, minVec = 2
// | [1 2] [3 4] ... [15 16] |
// | [3 4] [5 6] ... [1 2] |
// Example2:
// outVec = 4, inVec = 2, minVec = 2
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase));
Value rowOffset =
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
Value colOffset =
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
Value swizzleColOffset =
add(mul(xor_(swizzleIdx, phase), i32_val(outVec)),
urem(colOffset, i32_val(outVec)));
Value tileOffset = add(rowOffset, swizzleColOffset);
tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] =
gep(dstPtrTy, dstPtrBase, tileOffset);
}
// 16 * 8 = 128bits
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto numWords = vecBitWidth / bitWidth;
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// XXX(Keren): Tune CG and CA here.
CacheModifier srcCacheModifier =
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA;
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32);
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto &copyAsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
srcCacheModifier, op.evict());
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
auto *dstOperand =
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset);
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(bitWidth);
auto *srcSize = copySize;
if (op.mask()) {
// We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems],
i32_val(bitWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
}
}
PTXBuilder ptxBuilder;
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
auto ret =
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
rewriter.replaceOp(op, ret);
return success();
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps, RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis, AxisInfoAnalysis &axisInfoAnalysis,
@@ -2786,6 +2993,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem, patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
benefit); benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit); patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter, patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
benefit); benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter, patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
@@ -2800,6 +3008,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem, patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit); benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit); patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit); patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit); patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit); patterns.add<ReturnOpConversion>(typeConverter, benefit);

View File

@@ -72,6 +72,21 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
} }
} }
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(0 && "Unimplemented usage of MmaEncodingAttr");
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
return threads;
}
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) { SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape; SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {

View File

@@ -333,6 +333,99 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// ----- // -----
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_async_wait
func @basic_async_wait() {
// CHECK: cp.async.wait_group 0x4
triton_gpu.async_wait {num = 4: i32}
return
}
}
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v4
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3>
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
%cst_scalar = arith.constant 64 : i32
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
%index = arith.constant 1 : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
return
}
}
// -----
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v1
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3>
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2>
%cst_scalar = arith.constant 32 : i32
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2>
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x32xi32, #block2>
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x32xi32, #block3>) -> tensor<16x32xi32, #block3>
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x32xi32, #block2>) -> tensor<16x32xi32, #AL>
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x32xi32, #block3>) -> tensor<16x32xi32, #AL>
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL>
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x32x!tt.ptr<f32>, #AL>
%a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr<f32>, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A>
%index = arith.constant 1 : i32
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: basic_splat // CHECK: basic_splat
@@ -351,9 +444,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store // CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0> tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return return

View File

@@ -76,7 +76,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) {
auto &ld = auto &ld =
builder builder
.create<PtxIOInstr>("ld") // .create<PTXIOInstr>("ld") //
->o("volatile", isVolatile) ->o("volatile", isVolatile)
.global() .global()
.o("ca", cache == CacheModifier::CA) .o("ca", cache == CacheModifier::CA)