[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||||
|
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@@ -99,8 +100,9 @@ struct PTXBuilder {
|
|||||||
std::string dump() const;
|
std::string dump() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename INSTR = PTXInstr> INSTR *create(const std::string &name) {
|
template <typename INSTR = PTXInstr, typename... Args>
|
||||||
instrs.emplace_back(std::make_unique<INSTR>(this, name));
|
INSTR *create(Args &&...args) {
|
||||||
|
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
|
||||||
return static_cast<INSTR *>(instrs.back().get());
|
return static_cast<INSTR *>(instrs.back().get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,6 +190,7 @@ struct PTXInstrCommon {
|
|||||||
using Operand = PTXBuilder::Operand;
|
using Operand = PTXBuilder::Operand;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
PTXInstrExecution& operator()() { return call({}); }
|
||||||
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
||||||
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
||||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
||||||
@@ -238,17 +241,17 @@ struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
|||||||
// PtxIOInstr store("st");
|
// PtxIOInstr store("st");
|
||||||
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
||||||
// store.addAddr(addrValue, "l", off);
|
// store.addAddr(addrValue, "l", off);
|
||||||
struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
||||||
using PTXInstrBase<PtxIOInstr>::PTXInstrBase;
|
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
|
||||||
|
|
||||||
// Add ".global" suffix to instruction
|
// Add ".global" suffix to instruction
|
||||||
PtxIOInstr &global(bool predicate = true) {
|
PTXIOInstr &global(bool predicate = true) {
|
||||||
o("global", predicate);
|
o("global", predicate);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add ".v" suffix to instruction
|
// Add ".v" suffix to instruction
|
||||||
PtxIOInstr &v(int vecWidth, bool predicate = true) {
|
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
||||||
if (vecWidth > 1) {
|
if (vecWidth > 1) {
|
||||||
o("v" + std::to_string(vecWidth), predicate);
|
o("v" + std::to_string(vecWidth), predicate);
|
||||||
}
|
}
|
||||||
@@ -256,12 +259,43 @@ struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add ".b" suffix to instruction
|
// Add ".b" suffix to instruction
|
||||||
PtxIOInstr &b(int width) {
|
PTXIOInstr &b(int width) {
|
||||||
o("b" + std::to_string(width));
|
o("b" + std::to_string(width));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
|
||||||
|
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
|
||||||
|
: PTXInstrBase(builder, "cp.async") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
|
||||||
|
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
|
||||||
|
: PTXCpAsyncInstrBase(builder) {
|
||||||
|
o("commit_group");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
|
||||||
|
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
|
||||||
|
: PTXCpAsyncInstrBase(builder) {
|
||||||
|
o("wait_group");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
|
||||||
|
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||||
|
triton::CacheModifier modifier,
|
||||||
|
triton::EvictionPolicy policy)
|
||||||
|
: PTXCpAsyncInstrBase(builder) {
|
||||||
|
o(triton::stringifyCacheModifier(modifier).str());
|
||||||
|
o("shared");
|
||||||
|
o("global");
|
||||||
|
o("L2::" + triton::stringifyEvictionPolicy(policy).str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Record the operands and context for "launching" a PtxInstr.
|
// Record the operands and context for "launching" a PtxInstr.
|
||||||
struct PTXInstrExecution {
|
struct PTXInstrExecution {
|
||||||
using Operand = PTXBuilder::Operand;
|
using Operand = PTXBuilder::Operand;
|
||||||
|
@@ -16,7 +16,7 @@ def TT_CacheModifierAttr : I32EnumAttr<
|
|||||||
def TT_EvictionPolicyAttr : I32EnumAttr<
|
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||||
"EvictionPolicy", "",
|
"EvictionPolicy", "",
|
||||||
[
|
[
|
||||||
I32EnumAttrCase<"NORMAL", 1, "normal">,
|
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
|
||||||
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
|
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
|
||||||
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
|
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
|
||||||
]> {
|
]> {
|
||||||
|
@@ -24,6 +24,8 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
|||||||
|
|
||||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||||
|
|
||||||
|
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||||
|
|
||||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||||
|
|
||||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||||
|
@@ -54,7 +54,7 @@ in memory. For example, a swizzled row-major layout could store its data
|
|||||||
as follows:
|
as follows:
|
||||||
|
|
||||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||||
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
||||||
groups of vec=2 elements
|
groups of vec=2 elements
|
||||||
are stored contiguously
|
are stored contiguously
|
||||||
_ _ _ _ /\_ _ _ _
|
_ _ _ _ /\_ _ _ _
|
||||||
|
@@ -141,11 +141,12 @@ PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
|||||||
std::string PTXInstrExecution::dump() const {
|
std::string PTXInstrExecution::dump() const {
|
||||||
std::string osStr;
|
std::string osStr;
|
||||||
llvm::raw_string_ostream os(osStr);
|
llvm::raw_string_ostream os(osStr);
|
||||||
if (pred)
|
if (pred) {
|
||||||
if (!pred->repr)
|
if (!pred->repr)
|
||||||
os << "@" << pred->dump() << " ";
|
os << "@" << pred->dump() << " ";
|
||||||
else
|
else
|
||||||
os << pred->repr(pred->idx);
|
os << pred->repr(pred->idx);
|
||||||
|
}
|
||||||
|
|
||||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||||
|
|
||||||
|
@@ -32,6 +32,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
|
|||||||
using ::mlir::triton::gpu::getOrder;
|
using ::mlir::triton::gpu::getOrder;
|
||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
using ::mlir::triton::gpu::getSizePerThread;
|
using ::mlir::triton::gpu::getSizePerThread;
|
||||||
|
using ::mlir::triton::gpu::getThreadsPerCTA;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||||
@@ -75,7 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||||
#define bit_cast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
||||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||||
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
||||||
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
||||||
@@ -86,6 +87,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||||
|
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||||
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||||
@@ -630,7 +632,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
auto layout = tensorTy.getEncoding();
|
auto layout = tensorTy.getEncoding();
|
||||||
auto srcType = typeConverter->convertType(elemType);
|
auto srcType = typeConverter->convertType(elemType);
|
||||||
auto llSrc = bit_cast(srcType, constVal);
|
auto llSrc = bitcast(srcType, constVal);
|
||||||
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
||||||
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
|
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
|
||||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||||
@@ -706,22 +708,14 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
// Get corresponding LLVM element values of \param value.
|
// Get corresponding LLVM element values of \param value.
|
||||||
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||||
const BlockedEncodingAttr &layout,
|
const BlockedEncodingAttr &layout,
|
||||||
TypeConverter *typeConverter,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Location loc) const {
|
Location loc) const {
|
||||||
if (!value)
|
if (!value)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto ty = value.getType().cast<RankedTensorType>();
|
auto shape = value.getType().cast<RankedTensorType>().getShape();
|
||||||
auto shape = ty.getShape();
|
|
||||||
// Here, we assume that all inputs should have a blockedLayout
|
// Here, we assume that all inputs should have a blockedLayout
|
||||||
|
|
||||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||||
|
|
||||||
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
|
|
||||||
auto llvmElemPtrPtrTy =
|
|
||||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
|
|
||||||
|
|
||||||
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
||||||
return valueVals;
|
return valueVals;
|
||||||
}
|
}
|
||||||
@@ -810,16 +804,13 @@ struct StoreOpConversion
|
|||||||
|
|
||||||
auto [layout, numElems] = getLayout(ptr);
|
auto [layout, numElems] = getLayout(ptr);
|
||||||
|
|
||||||
auto ptrElems =
|
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||||
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
|
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
|
||||||
auto valueElems =
|
|
||||||
getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
|
|
||||||
assert(ptrElems.size() == valueElems.size());
|
assert(ptrElems.size() == valueElems.size());
|
||||||
|
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
if (llMask) {
|
if (llMask) {
|
||||||
maskElems =
|
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||||
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
|
|
||||||
assert(valueElems.size() == maskElems.size());
|
assert(valueElems.size() == maskElems.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -848,7 +839,7 @@ struct StoreOpConversion
|
|||||||
const bool hasL2EvictPolicy = false;
|
const bool hasL2EvictPolicy = false;
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
|
||||||
|
|
||||||
llvm::SmallVector<std::string> asmArgs;
|
llvm::SmallVector<std::string> asmArgs;
|
||||||
|
|
||||||
@@ -866,7 +857,7 @@ struct StoreOpConversion
|
|||||||
Value elem = valueElems[elemOffset];
|
Value elem = valueElems[elemOffset];
|
||||||
if (elem.getType().isInteger(1))
|
if (elem.getType().isInteger(1))
|
||||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||||
elem = bit_cast(valueElemTy, elem);
|
elem = bitcast(valueElemTy, elem);
|
||||||
|
|
||||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||||
llWord =
|
llWord =
|
||||||
@@ -874,7 +865,7 @@ struct StoreOpConversion
|
|||||||
rewriter.create<LLVM::ConstantOp>(
|
rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
||||||
}
|
}
|
||||||
llWord = bit_cast(valArgTy, llWord);
|
llWord = bitcast(valArgTy, llWord);
|
||||||
std::string constraint =
|
std::string constraint =
|
||||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||||
@@ -1100,14 +1091,12 @@ struct LoadOpConversion
|
|||||||
|
|
||||||
auto [layout, numElems] = getLayout(ptr);
|
auto [layout, numElems] = getLayout(ptr);
|
||||||
|
|
||||||
auto ptrElems =
|
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||||
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
|
|
||||||
assert(ptrElems.size() == numElems);
|
assert(ptrElems.size() == numElems);
|
||||||
|
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
if (llMask) {
|
if (llMask) {
|
||||||
maskElems =
|
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||||
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
|
|
||||||
assert(ptrElems.size() == maskElems.size());
|
assert(ptrElems.size() == maskElems.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1132,8 +1121,7 @@ struct LoadOpConversion
|
|||||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto otherElems =
|
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||||
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<Value> loadedVals;
|
SmallVector<Value> loadedVals;
|
||||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||||
@@ -1153,7 +1141,7 @@ struct LoadOpConversion
|
|||||||
const bool hasL2EvictPolicy = false;
|
const bool hasL2EvictPolicy = false;
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||||
|
|
||||||
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
||||||
// the values share one predicate? Here assume all the mask values are
|
// the values share one predicate? Here assume all the mask values are
|
||||||
@@ -1198,7 +1186,6 @@ struct LoadOpConversion
|
|||||||
else
|
else
|
||||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||||
|
|
||||||
SmallVector<Value> others;
|
|
||||||
if (other) {
|
if (other) {
|
||||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||||
@@ -1214,14 +1201,13 @@ struct LoadOpConversion
|
|||||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||||
v = insert_element(vecTy, v, falseVal, sVal);
|
v = insert_element(vecTy, v, falseVal, sVal);
|
||||||
}
|
}
|
||||||
v = bit_cast(IntegerType::get(getContext(), width), v);
|
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||||
|
|
||||||
PTXInstr::Operand *opr{};
|
PTXInstr::Operand *opr{};
|
||||||
if (otherIsSplatConstInt) {
|
if (otherIsSplatConstInt) {
|
||||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||||
} else {
|
} else {
|
||||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||||
others.push_back(v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||||
@@ -1253,7 +1239,7 @@ struct LoadOpConversion
|
|||||||
} else {
|
} else {
|
||||||
curr = ret;
|
curr = ret;
|
||||||
}
|
}
|
||||||
curr = bit_cast(
|
curr = bitcast(
|
||||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||||
curr);
|
curr);
|
||||||
rets.push_back(curr);
|
rets.push_back(curr);
|
||||||
@@ -1360,9 +1346,8 @@ struct ExtractSliceOpConversion
|
|||||||
|
|
||||||
// axis > 0 will result in non-contiguous memory access if the result tensor
|
// axis > 0 will result in non-contiguous memory access if the result tensor
|
||||||
// is an alias of the source tensor.
|
// is an alias of the source tensor.
|
||||||
auto axis =
|
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||||
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt();
|
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
||||||
assert(axis == 0 && "Only axis=0 is supported for now");
|
|
||||||
|
|
||||||
// Example:
|
// Example:
|
||||||
// %dst = extract_slice %src, %index {axis = 0}
|
// %dst = extract_slice %src, %index {axis = 0}
|
||||||
@@ -1372,12 +1357,11 @@ struct ExtractSliceOpConversion
|
|||||||
auto base = product<int64_t>(dstTy.getShape());
|
auto base = product<int64_t>(dstTy.getShape());
|
||||||
auto baseVal = createIndexAttrConstant(
|
auto baseVal = createIndexAttrConstant(
|
||||||
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
||||||
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal);
|
Value offset = mul(adaptor.index(), baseVal);
|
||||||
|
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
Value resultVal =
|
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
|
||||||
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
|
|
||||||
rewriter.replaceOp(op, resultVal);
|
rewriter.replaceOp(op, resultVal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -1581,7 +1565,7 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||||
ptr = bit_cast(ptr_ty(vecTy, 3), ptr);
|
ptr = bitcast(ptr_ty(vecTy, 3), ptr);
|
||||||
if (stNotRd) {
|
if (stNotRd) {
|
||||||
Value valVec = undef(vecTy);
|
Value valVec = undef(vecTy);
|
||||||
for (unsigned v = 0; v < vec; ++v) {
|
for (unsigned v = 0; v < vec; ++v) {
|
||||||
@@ -1614,7 +1598,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
smemBase = bit_cast(elemPtrTy, smemBase);
|
smemBase = bitcast(elemPtrTy, smemBase);
|
||||||
auto shape = dstTy.getShape();
|
auto shape = dstTy.getShape();
|
||||||
unsigned rank = dstTy.getRank();
|
unsigned rank = dstTy.getRank();
|
||||||
SmallVector<unsigned> numReplicates(rank);
|
SmallVector<unsigned> numReplicates(rank);
|
||||||
@@ -1732,7 +1716,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
Value minVecVal = idx_val(minVec);
|
Value minVecVal = idx_val(minVec);
|
||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||||
smemBase = bit_cast(elemPtrTy, smemBase);
|
smemBase = bitcast(elemPtrTy, smemBase);
|
||||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||||
for (unsigned i = 0; i < numElems; ++i) {
|
for (unsigned i = 0; i < numElems; ++i) {
|
||||||
@@ -1783,7 +1767,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
|
|
||||||
// step 3: store
|
// step 3: store
|
||||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||||
smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr);
|
smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr);
|
||||||
store(wordVecs[linearWordIdx], smemAddr);
|
store(wordVecs[linearWordIdx], smemAddr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2126,7 +2110,7 @@ public:
|
|||||||
for (int e = 0; e < 4; ++e)
|
for (int e = 0; e < 4; ++e)
|
||||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||||
i8Elems[m][e], i32_val(e));
|
i8Elems[m][e], i32_val(e));
|
||||||
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
|
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
|
||||||
}
|
}
|
||||||
} else { // k first
|
} else { // k first
|
||||||
Value offset = i32_val(sOffsetElem);
|
Value offset = i32_val(sOffsetElem);
|
||||||
@@ -2144,7 +2128,7 @@ public:
|
|||||||
for (int e = 0; e < 4; ++e)
|
for (int e = 0; e < 4; ++e)
|
||||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||||
i8Elems[m][e], i32_val(e));
|
i8Elems[m][e], i32_val(e));
|
||||||
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
|
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2628,7 +2612,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|||||||
Type smemPtrTy = helper.getShemPtrTy();
|
Type smemPtrTy = helper.getShemPtrTy();
|
||||||
for (int i = 0; i < numPtrs; ++i) {
|
for (int i = 0; i < numPtrs; ++i) {
|
||||||
ptrs[i] =
|
ptrs[i] =
|
||||||
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool needTrans = kOrder != order[0];
|
bool needTrans = kOrder != order[0];
|
||||||
@@ -2777,6 +2761,229 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct AsyncWaitOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
|
||||||
|
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||||
|
asyncWaitOp(ptxBuilder.newConstantOperand(num));
|
||||||
|
|
||||||
|
auto ctx = op.getContext();
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
||||||
|
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
|
||||||
|
|
||||||
|
// Safe to remove the op since it doesn't have any return value.
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InsertSliceAsyncOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>,
|
||||||
|
public LoadStoreConversionBase {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
|
||||||
|
const Allocation *allocation, Value smem,
|
||||||
|
AxisInfoAnalysis &axisAnalysisPass,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||||
|
converter, allocation, smem, benefit),
|
||||||
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// insert_slice_async %src, %dst, %index, %mask, %other
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value src = op.src();
|
||||||
|
Value dst = op.dst();
|
||||||
|
Value res = op.result();
|
||||||
|
Value mask = op.mask();
|
||||||
|
Value other = op.other();
|
||||||
|
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||||
|
"Only support in-place insert_slice_async for now");
|
||||||
|
|
||||||
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
|
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||||
|
auto resElemTy = resTy.getElementType();
|
||||||
|
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
|
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
|
auto srcShape = srcTy.getShape();
|
||||||
|
assert(srcShape.size() == 2 &&
|
||||||
|
"insert_slice_async: Unexpected rank of %src");
|
||||||
|
|
||||||
|
Value llDst = adaptor.dst();
|
||||||
|
Value llSrc = adaptor.src();
|
||||||
|
Value llMask = adaptor.mask();
|
||||||
|
Value llOther = adaptor.other();
|
||||||
|
Value llIndex = adaptor.index();
|
||||||
|
|
||||||
|
// %src
|
||||||
|
auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc);
|
||||||
|
|
||||||
|
// %dst
|
||||||
|
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||||
|
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
|
||||||
|
auto dstBase = createIndexAttrConstant(rewriter, loc,
|
||||||
|
getTypeConverter()->getIndexType(),
|
||||||
|
product<int64_t>(resTy.getShape()));
|
||||||
|
Value offset = mul(llIndex, dstBase);
|
||||||
|
auto dstPtrTy = LLVM::LLVMPointerType::get(
|
||||||
|
getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||||
|
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
|
||||||
|
|
||||||
|
// %mask
|
||||||
|
SmallVector<Value> maskElems;
|
||||||
|
if (llMask) {
|
||||||
|
maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc);
|
||||||
|
assert(srcElems.size() == maskElems.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// %other
|
||||||
|
SmallVector<Value> otherElems;
|
||||||
|
if (llOther) {
|
||||||
|
// TODO(Keren): support "other" tensor.
|
||||||
|
// It's not necessary for now because the pipeline pass will skip
|
||||||
|
// generating insert_slice_async if the load op has any "other" tensor.
|
||||||
|
assert(false && "insert_slice_async: Other value not supported yet");
|
||||||
|
otherElems =
|
||||||
|
getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc);
|
||||||
|
assert(srcElems.size() == otherElems.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned inVec = getVectorizeSize(src, srcBlockedLayout);
|
||||||
|
unsigned outVec = resSharedLayout.getVec();
|
||||||
|
unsigned minVec = std::min(outVec, inVec);
|
||||||
|
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
|
||||||
|
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||||
|
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||||
|
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||||
|
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
|
||||||
|
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
|
||||||
|
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||||
|
|
||||||
|
auto inOrder = srcBlockedLayout.getOrder();
|
||||||
|
auto outOrder = resSharedLayout.getOrder();
|
||||||
|
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements
|
||||||
|
// across phases.
|
||||||
|
// If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd
|
||||||
|
auto numSwizzleRows = std::max<unsigned>(
|
||||||
|
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||||
|
// A sharedLayout encoding has a "vec" parameter.
|
||||||
|
// On the column dimension, if inVec > outVec, it means we have to divide
|
||||||
|
// single vector read into multiple ones
|
||||||
|
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||||
|
|
||||||
|
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
|
||||||
|
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
|
||||||
|
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
|
||||||
|
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||||
|
// minVec = 2, inVec = 4, outVec = 2
|
||||||
|
// baseOffsetCol = 0 baseOffsetCol = 0
|
||||||
|
// tileVecIdxCol = 0 tileVecIdxCol = 1
|
||||||
|
// -/\- -/\-
|
||||||
|
// [|x x| |x x| x x x x x]
|
||||||
|
// [|x x| |x x| x x x x x]
|
||||||
|
// baseOffsetRow [|x x| |x x| x x x x x]
|
||||||
|
// [|x x| |x x| x x x x x]
|
||||||
|
auto vecIdx = elemIdx / minVec;
|
||||||
|
auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec);
|
||||||
|
auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec);
|
||||||
|
auto baseOffsetCol =
|
||||||
|
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
|
||||||
|
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
|
||||||
|
threadsPerCTA[inOrder[1]];
|
||||||
|
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
|
||||||
|
auto tileVecIdxCol = vecIdxCol % numVecCols;
|
||||||
|
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
|
||||||
|
|
||||||
|
if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) {
|
||||||
|
// Swizzling
|
||||||
|
// Since the swizzling index is related to outVec, and we know minVec
|
||||||
|
// already, inVec doesn't matter
|
||||||
|
//
|
||||||
|
// (Numbers represent row indices)
|
||||||
|
// Example1:
|
||||||
|
// outVec = 2, inVec = 2, minVec = 2
|
||||||
|
// outVec = 2, inVec = 4, minVec = 2
|
||||||
|
// | [1 2] [3 4] ... [15 16] |
|
||||||
|
// | [3 4] [5 6] ... [1 2] |
|
||||||
|
// Example2:
|
||||||
|
// outVec = 4, inVec = 2, minVec = 2
|
||||||
|
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
|
||||||
|
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
|
||||||
|
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||||
|
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||||
|
i32_val(maxPhase));
|
||||||
|
Value rowOffset =
|
||||||
|
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
|
||||||
|
Value colOffset =
|
||||||
|
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
|
||||||
|
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
|
||||||
|
Value swizzleColOffset =
|
||||||
|
add(mul(xor_(swizzleIdx, phase), i32_val(outVec)),
|
||||||
|
urem(colOffset, i32_val(outVec)));
|
||||||
|
Value tileOffset = add(rowOffset, swizzleColOffset);
|
||||||
|
tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] =
|
||||||
|
gep(dstPtrTy, dstPtrBase, tileOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 16 * 8 = 128bits
|
||||||
|
auto maxBitWidth =
|
||||||
|
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||||
|
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||||
|
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
||||||
|
auto numWords = vecBitWidth / bitWidth;
|
||||||
|
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
// XXX(Keren): Tune CG and CA here.
|
||||||
|
CacheModifier srcCacheModifier =
|
||||||
|
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA;
|
||||||
|
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32);
|
||||||
|
|
||||||
|
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
auto ©AsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
|
||||||
|
srcCacheModifier, op.evict());
|
||||||
|
|
||||||
|
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||||
|
auto *dstOperand =
|
||||||
|
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset);
|
||||||
|
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l");
|
||||||
|
auto *copySize = ptxBuilder.newConstantOperand(bitWidth);
|
||||||
|
auto *srcSize = copySize;
|
||||||
|
if (op.mask()) {
|
||||||
|
// We don't use predicate in this case, setting src-size to 0
|
||||||
|
// if there's any mask. cp.async will automatically fill the
|
||||||
|
// remaining slots with 0 if cp-size > src-size.
|
||||||
|
// XXX(Keren): Always assume other = 0 for now.
|
||||||
|
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems],
|
||||||
|
i32_val(bitWidth), i32_val(0));
|
||||||
|
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
||||||
|
}
|
||||||
|
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
||||||
|
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
|
||||||
|
auto ret =
|
||||||
|
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
|
||||||
|
rewriter.replaceOp(op, ret);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns, int numWarps,
|
RewritePatternSet &patterns, int numWarps,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
@@ -2786,6 +2993,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
|
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
|
||||||
@@ -2800,6 +3008,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||||
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||||
|
@@ -72,6 +72,21 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||||
|
SmallVector<unsigned> threads;
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||||
|
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
|
||||||
|
blockedLayout.getWarpsPerCTA()[d]);
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
||||||
|
} else {
|
||||||
|
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||||
|
}
|
||||||
|
|
||||||
|
return threads;
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||||
SmallVector<unsigned> shape;
|
SmallVector<unsigned> shape;
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
@@ -333,6 +333,99 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: basic_async_wait
|
||||||
|
func @basic_async_wait() {
|
||||||
|
// CHECK: cp.async.wait_group 0x4
|
||||||
|
triton_gpu.async_wait {num = 4: i32}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||||
|
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||||
|
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||||
|
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||||
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
||||||
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1>
|
||||||
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
||||||
|
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3>
|
||||||
|
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||||
|
%cst_scalar = arith.constant 64 : i32
|
||||||
|
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||||
|
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
|
||||||
|
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
|
||||||
|
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
|
||||||
|
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||||
|
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||||
|
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x64x!tt.ptr<f32>, #AL>
|
||||||
|
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f32>, #AL>
|
||||||
|
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A>
|
||||||
|
%index = arith.constant 1 : i32
|
||||||
|
|
||||||
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
|
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80
|
||||||
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
|
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80
|
||||||
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
|
// CHECK-SAME: cp.async.commit_group
|
||||||
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||||
|
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||||
|
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: basic_insert_slice_async_v1
|
||||||
|
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||||
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
||||||
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1>
|
||||||
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
||||||
|
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3>
|
||||||
|
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2>
|
||||||
|
%cst_scalar = arith.constant 32 : i32
|
||||||
|
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2>
|
||||||
|
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x32xi32, #block2>
|
||||||
|
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x32xi32, #block3>) -> tensor<16x32xi32, #block3>
|
||||||
|
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x32xi32, #block2>) -> tensor<16x32xi32, #AL>
|
||||||
|
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x32xi32, #block3>) -> tensor<16x32xi32, #AL>
|
||||||
|
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL>
|
||||||
|
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<16x32x!tt.ptr<f32>, #AL>
|
||||||
|
%a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr<f32>, #AL>
|
||||||
|
%tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A>
|
||||||
|
%index = arith.constant 1 : i32
|
||||||
|
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.commit_group
|
||||||
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK: basic_splat
|
// CHECK: basic_splat
|
||||||
@@ -351,9 +444,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK-LABEL: basic_store
|
// CHECK-LABEL: basic_store
|
||||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
||||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
|
@@ -76,7 +76,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) {
|
|||||||
|
|
||||||
auto &ld =
|
auto &ld =
|
||||||
builder
|
builder
|
||||||
.create<PtxIOInstr>("ld") //
|
.create<PTXIOInstr>("ld") //
|
||||||
->o("volatile", isVolatile)
|
->o("volatile", isVolatile)
|
||||||
.global()
|
.global()
|
||||||
.o("ca", cache == CacheModifier::CA)
|
.o("ca", cache == CacheModifier::CA)
|
||||||
|
Reference in New Issue
Block a user