[BUILD] Fix Warnings and Enable Warnings as Errors (#794)

This commit is contained in:
Ian Bearman
2022-10-28 12:36:09 -07:00
committed by GitHub
parent ac0f6793cc
commit f2106d0aa2
20 changed files with 205 additions and 213 deletions

View File

@@ -132,14 +132,15 @@ endif()
# Python module # Python module
if(TRITON_BUILD_PYTHON_MODULE) if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module") message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
include_directories("." ${PYTHON_SRC_PATH})
if (PYTHON_INCLUDE_DIRS) if (PYTHON_INCLUDE_DIRS)
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) include_directories(${PYTHON_INCLUDE_DIRS})
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS}) link_directories(${PYTHON_LINK_DIRS})
link_libraries(${PYTHON_LIBRARIES})
else() else()
find_package(Python3 REQUIRED COMPONENTS Development) find_package(Python3 REQUIRED COMPONENTS Development)
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) include_directories(${Python3_INCLUDE_DIRS})
include_directories("." ${PYTHON_SRC_PATH} ${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS}) link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES}) link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS}) add_link_options(${Python3_LINK_OPTIONS})
@@ -169,7 +170,10 @@ list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen) # required by AddMLIR include(TableGen) # required by AddMLIR
include(AddLLVM) include(AddLLVM)
include(AddMLIR) include(AddMLIR)
# include(HandleLLVMOptions) # human-friendly error message include(HandleLLVMOptions) # human-friendly error message
# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default")
include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS})
@@ -192,7 +196,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(triton target_link_libraries(triton
${PYTHON_LIBRARIES}
TritonAnalysis TritonAnalysis
TritonTransforms TritonTransforms
TritonGPUTransforms TritonGPUTransforms

View File

@@ -142,7 +142,7 @@ private:
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {} BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {} BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset) BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), size(size), offset(offset), id(nextId++) {} : kind(kind), id(nextId++), size(size), offset(offset) {}
bool intersects(const BufferT &other) const { bool intersects(const BufferT &other) const {
return Interval<size_t>(offset, offset + size) return Interval<size_t>(offset, offset + size)

View File

@@ -28,8 +28,8 @@ public:
DimVectorT knownConstancy) DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility), : contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) { constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank); assert(knownDivisibility.size() == (size_t)rank);
assert(knownConstancy.size() == rank); assert(knownConstancy.size() == (size_t)rank);
} }
// Accessors // Accessors

View File

@@ -15,9 +15,9 @@ class Location;
namespace triton { namespace triton {
using llvm::StringRef; using llvm::StringRef;
class PTXInstr; struct PTXInstr;
class PTXInstrCommon; struct PTXInstrCommon;
class PTXInstrExecution; struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple // PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions. // instructions.
@@ -83,7 +83,7 @@ struct PTXBuilder {
Operand() = default; Operand() = default;
Operand(const Operation &) = delete; Operand(const Operation &) = delete;
Operand(Value value, StringRef constraint) Operand(Value value, StringRef constraint)
: value(value), constraint(constraint) {} : constraint(constraint), value(value) {}
bool isList() const { return !value && constraint.empty(); } bool isList() const { return !value && constraint.empty(); }
@@ -120,7 +120,7 @@ struct PTXBuilder {
Operand *newListOperand(unsigned count, mlir::Value val, Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) { const std::string &constraint) {
auto *list = newOperand(); auto *list = newOperand();
for (int i = 0; i < count; ++i) { for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint)); list->listAppend(newOperand(val, constraint));
} }
return list; return list;
@@ -128,7 +128,7 @@ struct PTXBuilder {
Operand *newListOperand(unsigned count, const std::string &constraint) { Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand(); auto *list = newOperand();
for (int i = 0; i < count; ++i) { for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint)); list->listAppend(newOperand(constraint));
} }
return list; return list;
@@ -172,8 +172,8 @@ private:
return argArchive.back().get(); return argArchive.back().get();
} }
friend class PTXInstr; friend struct PTXInstr;
friend class PTXInstrCommon; friend struct PTXInstrCommon;
protected: protected:
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive; llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
@@ -209,7 +209,7 @@ protected:
PTXBuilder *builder{}; PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts; llvm::SmallVector<std::string, 4> instrParts;
friend class PTXInstrExecution; friend struct PTXInstrExecution;
}; };
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon { template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
@@ -309,7 +309,7 @@ struct PTXInstrExecution {
PTXInstrExecution() = default; PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr, explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs) llvm::ArrayRef<Operand *> oprs)
: instr(instr), argsInOrder(oprs.begin(), oprs.end()) {} : argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
// Prefix a predicate to the instruction. // Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {

View File

@@ -11,16 +11,12 @@ class ModuleOp;
template <typename T> class OperationPass; template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget { class TritonLLVMConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public: public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx, explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter); mlir::LLVMTypeConverter &typeConverter);
}; };
class TritonLLVMFunctionConversionTarget : public ConversionTarget { class TritonLLVMFunctionConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public: public:
explicit TritonLLVMFunctionConversionTarget( explicit TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter); MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);

View File

@@ -26,11 +26,11 @@ public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, int axis, inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0; Attribute &resultEncoding) const = 0;
virtual LogicalResult virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis, inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0; Attribute &resultEncoding) const = 0;
}; };

View File

@@ -21,7 +21,6 @@ private:
}; };
class TritonGPUConversionTarget : public ConversionTarget { class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public: public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, explicit TritonGPUConversionTarget(MLIRContext &ctx,

View File

@@ -82,7 +82,6 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>(); auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto rank = srcShape.size();
auto axis = op.axis(); auto axis = op.axis();
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension

View File

@@ -66,7 +66,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
DimVectorT retContiguity; DimVectorT retContiguity;
DimVectorT retDivisibility; DimVectorT retDivisibility;
DimVectorT retConstancy; DimVectorT retConstancy;
for (size_t d = 0; d < lhs.getRank(); ++d) { for (int d = 0; d < lhs.getRank(); ++d) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back( retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
@@ -88,7 +88,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
AxisInfo::DimVectorT newContiguity; AxisInfo::DimVectorT newContiguity;
AxisInfo::DimVectorT newDivisibility; AxisInfo::DimVectorT newDivisibility;
AxisInfo::DimVectorT newConstancy; AxisInfo::DimVectorT newConstancy;
for (size_t d = 0; d < rank; ++d) { for (int d = 0; d < rank; ++d) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
@@ -167,7 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy; AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); ++d) { for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(1); contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0)); divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]); constancy.push_back(retTy.getShape()[d]);
@@ -176,12 +176,6 @@ ChangeResult AxisInfoAnalysis::visitOperation(
} }
// expandDims // expandDims
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) { if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
TensorType opTy = _opTy.cast<TensorType>();
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue(); AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
@@ -203,7 +197,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy; AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); ++d) { for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d)); divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);

View File

@@ -28,11 +28,10 @@ bool maybeSharedAllocationOp(Operation *op) {
} }
std::string getValueOperandName(Value value, AsmState &state) { std::string getValueOperandName(Value value, AsmState &state) {
auto *op = value.getDefiningOp();
std::string opName; std::string opName;
llvm::raw_string_ostream ss(opName); llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state); value.printAsOperand(ss, state);
return std::move(opName); return opName;
} }
} // namespace mlir } // namespace mlir

View File

@@ -254,7 +254,6 @@ protected:
/// FuncOp legalization pattern that converts MemRef arguments to pointers to /// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information. /// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase { struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps, FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit) PatternBenefit benefit)
@@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
LogicalResult LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor, matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
unsigned numArguments = op.getNumOperands(); unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing. // Currently, Triton kernel function always return nothing.
@@ -482,7 +480,6 @@ public:
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout, const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const { ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value threadId = getThreadId(rewriter, loc); Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32); Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize); Value laneId = urem(threadId, warpSize);
@@ -654,7 +651,6 @@ public:
auto bufferId = allocation->getBufferId(value); auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId); size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = idx_val(offset); Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal); Value base = gep(ptrTy, smem, offVal);
return base; return base;
@@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) { if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType); auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal); auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(tensorTy); size_t elemsPerThread = getElemsPerThread(tensorTy);
@@ -842,7 +837,6 @@ struct LoadOpConversion
LogicalResult LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc(); auto loc = op->getLoc();
// original values // original values
@@ -897,12 +891,11 @@ struct LoadOpConversion
// TODO: optimization when ptr is GEP with constant offset // TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0; size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits); const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec; const size_t totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth); const size_t width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width); const size_t nWords = std::max<size_t>(1, totalWidth / width);
const int wordNElems = width / valueElemNbits; const size_t wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems); assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Add cache policy fields to StoreOp.
@@ -921,7 +914,7 @@ struct LoadOpConversion
// prepare asm operands // prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand(); auto *dstsOpr = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr); dstsOpr->listAppend(opr);
} }
@@ -988,8 +981,8 @@ struct LoadOpConversion
: retTys[0]; : retTys[0];
// TODO: if (has_l2_evict_policy) // TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), // auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT); // LLVM::AsmDialect::AD_ATT);
Value ret = ptxBuilder.launch(rewriter, loc, retTy); Value ret = ptxBuilder.launch(rewriter, loc, retTy);
// --- // ---
@@ -1080,27 +1073,25 @@ struct StoreOpConversion
// TODO: optimization when ptr is AddPtr with constant offset // TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0; size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits); const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec; const size_t totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth); const size_t width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width); const size_t nWords = std::max<size_t>(1, totalWidth / width);
const int wordNElems = width / valueElemNbits; const size_t wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems); assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here. // TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
Type valArgTy = IntegerType::get(ctx, width); Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems); auto wordTy = vec_ty(valueElemTy, wordNElems);
SmallVector<std::pair<Value, std::string>> asmArgs; SmallVector<std::pair<Value, std::string>> asmArgs;
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition // llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy); Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition // Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size()); assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset]; Value elem = valueElems[elemOffset];
@@ -1220,7 +1211,6 @@ struct BroadcastOpConversion
} }
unsigned srcElems = getElemsPerThread(srcTy); unsigned srcElems = getElemsPerThread(srcTy);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, rewriter); auto srcVals = getElementsFromStruct(loc, src, rewriter);
unsigned resultElems = getElemsPerThread(resultTy); unsigned resultElems = getElemsPerThread(resultTy);
SmallVector<Value> resultVals(resultElems); SmallVector<Value> resultVals(resultElems);
@@ -1282,8 +1272,6 @@ private:
LogicalResult LogicalResult
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto rank = srcTy.getShape().size();
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
return matchAndRewriteFast(op, adaptor, rewriter); return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter); return matchAndRewriteBasic(op, adaptor, rewriter);
@@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
Location loc, Value val, int i) const { Location loc, Value val, int i) const {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth(); unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) { if (bits == 64) {
@@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
barrier(); barrier();
SmallVector<Value> resultVals(resultElems); SmallVector<Value> resultVals(resultElems);
for (int i = 0; i < resultElems; i++) { for (size_t i = 0; i < resultElems; i++) {
SmallVector<Value> readIdx = resultIndices[i]; SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]); readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape); Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto srcTy = op.operand().getType().cast<RankedTensorType>(); auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto srcOrder = srcLayout.getOrder();
auto threadsPerWarp = srcLayout.getThreadsPerWarp(); auto threadsPerWarp = srcLayout.getThreadsPerWarp();
auto warpsPerCTA = srcLayout.getWarpsPerCTA(); auto warpsPerCTA = srcLayout.getWarpsPerCTA();
@@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
barrier(); barrier();
SmallVector<Value> resultVals(resultElems); SmallVector<Value> resultVals(resultElems);
for (int i = 0; i < resultElems; i++) { for (size_t i = 0; i < resultElems; i++) {
SmallVector<Value> readIdx = resultIndices[i]; SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0)); readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape); Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
// due to MLIR's restrictions // due to MLIR's restrictions
Location loc = op->getLoc(); Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>(); auto resultTy = op.getType().template cast<RankedTensorType>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultTy); unsigned elems = getElemsPerThread(resultTy);
Type elemTy = Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType()); this->getTypeConverter()->convertType(resultTy.getElementType());
@@ -1698,7 +1683,6 @@ struct AddPtrOpConversion
auto resultLayout = auto resultLayout =
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>(); resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion"); assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTensorTy.getShape();
unsigned elems = getElemsPerThread(resultTy); unsigned elems = getElemsPerThread(resultTy);
Type elemTy = Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType()); getTypeConverter()->convertType(resultTensorTy.getElementType());
@@ -1821,7 +1805,7 @@ protected:
SmallVector<SmallVector<Value>> operands(elems); SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) { for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (int i = 0; i < elems; ++i) { for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]); operands[i].push_back(sub_operands[i]);
} }
} }
@@ -1931,6 +1915,7 @@ struct CmpFOpConversion
__PRED_ENUM(ORD, ord); __PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq); __PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt); __PRED_ENUM(UGT, ugt);
__PRED_ENUM(UGE, uge);
__PRED_ENUM(ULT, ult); __PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule); __PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une); __PRED_ENUM(UNE, une);
@@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica(
auto rank = type.getRank(); auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout); auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread); auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank); SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout); auto shapePerCTA = getShapePerCTA(layout);
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
@@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica(
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape()); loc, rewriter, blockedLayout, type.getShape());
} else if (sliceLayout) { } else if (sliceLayout) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent(); auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape = SmallVector<int64_t> paddedShape =
@@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
} }
// Potentially we need to store for multiple CTAs in this replication // Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates); unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcTy); // unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0; unsigned inVec = 0;
unsigned outVec = 0; unsigned outVec = 0;
@@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
// Data loader for mma.16816 instruction. // Data loader for mma.16816 instruction.
class MMA16816SmemLoader { class MMA16816SmemLoader {
public: public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder, MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape, ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase, ArrayRef<int> matShape, int perPhase, int maxPhase,
int elemBytes, ConversionPatternRewriter &rewriter, int elemBytes, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, const Location &loc) TypeConverter *typeConverter, const Location &loc)
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder), : order(order.begin(), order.end()), kOrder(kOrder),
tileShape(tileShape.begin(), tileShape.end()), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()), instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()), perPhase(perPhase), matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) { ctx(rewriter.getContext()) {
cMatShape = matShape[order[0]]; cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]]; sMatShape = matShape[order[1]];
@@ -2576,7 +2559,6 @@ public:
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
"smem matrix load must be aligned"); "smem matrix load must be aligned");
int matIdx[2] = {mat0, mat1}; int matIdx[2] = {mat0, mat1};
int k = matIdx[kOrder];
int ptrIdx{-1}; int ptrIdx{-1};
@@ -2596,7 +2578,6 @@ public:
Value ptr = getPtr(ptrIdx); Value ptr = getPtr(ptrIdx);
Value resV4;
if (canUseLdmatrix) { if (canUseLdmatrix) {
int sOffset = int sOffset =
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
@@ -2727,7 +2708,6 @@ public:
} }
private: private:
int wpt;
SmallVector<uint32_t> order; SmallVector<uint32_t> order;
int kOrder; int kOrder;
SmallVector<int64_t> tileShape; SmallVector<int64_t> tileShape;
@@ -2737,7 +2717,6 @@ private:
int maxPhase; int maxPhase;
int elemBytes; int elemBytes;
ConversionPatternRewriter &rewriter; ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter{};
const Location &loc; const Location &loc;
MLIRContext *ctx{}; MLIRContext *ctx{};
@@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
LogicalResult LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// D = A * B + C // D = A * B + C
Value A = op.a(); Value A = op.a();
Value B = op.b();
Value C = op.c();
Value D = op.getResult(); Value D = op.getResult();
MLIRContext *ctx = op->getContext();
bool allowTF32 = op.allowTF32();
// Here we assume the DotOp's operands always comes from shared memory. // Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape(); auto AShape = A.getType().cast<RankedTensorType>().getShape();
@@ -2951,8 +2925,6 @@ struct DotOpConversionHelper {
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
Type i8x4Pack4Ty = Type i8x4Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty)); LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(4, type::i32Ty(ctx)));
switch (mmaType) { switch (mmaType) {
case TensorCoreType::FP32_FP16_FP16_FP32: case TensorCoreType::FP32_FP16_FP16_FP32:
@@ -3062,7 +3034,6 @@ struct DotOpConversionHelper {
auto bTy = B.getType().cast<RankedTensorType>(); auto bTy = B.getType().cast<RankedTensorType>();
// d = a*b + c // d = a*b + c
auto dTy = op.d().getType().cast<RankedTensorType>(); auto dTy = op.d().getType().cast<RankedTensorType>();
auto mmaLayout = dTy.getEncoding().cast<MmaEncodingAttr>();
if (dTy.getElementType().isF32()) { if (dTy.getElementType().isF32()) {
if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
@@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper {
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread, MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc) TypeConverter *typeConverter, Location loc)
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter), : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()), rewriter(rewriter), typeConverter(typeConverter), loc(loc),
thread(thread) { ctx(mmaLayout.getContext()) {
wpt = mmaLayout.getWarpsPerCTA(); wpt = mmaLayout.getWarpsPerCTA();
Value _32 = i32_val(32); Value _32 = i32_val(32);
@@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper {
} }
// step1. Perform loading. // step1. Perform loading.
for (unsigned m = 0; m < numRepM; ++m) for (int m = 0; m < numRepM; ++m)
for (unsigned k = 0; k < numRepK; ++k) for (int k = 0; k < numRepK; ++k)
loadFn(2 * m, 2 * k); loadFn(2 * m, 2 * k);
// step2. Format the values to LLVM::Struct to passing to mma codegen. // step2. Format the values to LLVM::Struct to passing to mma codegen.
@@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper {
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) { for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (unsigned k = 0; k < numRepK; ++k) for (int k = 0; k < numRepK; ++k)
loadFn(2 * n, 2 * k); loadFn(2 * n, 2 * k);
} }
@@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper {
helper.deduceMmaType(op); helper.deduceMmaType(op);
auto aTensorTy = a.getType().cast<RankedTensorType>(); auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>(); auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape(); auto aShape = aTensorTy.getShape();
auto dShape = dTensorTy.getShape(); auto dShape = dTensorTy.getShape();
int NK = aShape[1];
// shape / shape_per_cta // shape / shape_per_cta
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
int numRepM = getNumRepM(aTensorTy, dShape[0]); int numRepM = getNumRepM(aTensorTy, dShape[0]);
int numRepN = getNumRepN(aTensorTy, dShape[1]); int numRepN = getNumRepN(aTensorTy, dShape[1]);
int numRepK = getNumRepK(aTensorTy, aShape[1]); int numRepK = getNumRepK(aTensorTy, aShape[1]);
@@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper {
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
}; };
for (unsigned k = 0; k < numRepK; ++k) for (int k = 0; k < numRepK; ++k)
for (unsigned m = 0; m < numRepM; ++m) for (int m = 0; m < numRepM; ++m)
for (unsigned n = 0; n < numRepN; ++n) for (int n = 0; n < numRepN; ++n)
callMma(2 * m, n, 2 * k); callMma(2 * m, n, 2 * k);
// replace with new packed result // replace with new packed result
@@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper {
private: private:
std::function<void(int, int)> std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout, getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
int wpt, int kOrder, ArrayRef<int> instrShape, int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
ArrayRef<int> matShape, Value warpId, ArrayRef<int> matShape, Value warpId,
ValueTable &vals) const { ValueTable &vals) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
@@ -3486,8 +3452,8 @@ private:
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
int n1) const { int n1) const {
std::vector<Value> elems; std::vector<Value> elems;
for (unsigned m = 0; m < n0; ++m) for (int m = 0; m < n0; ++m)
for (unsigned k = 0; k < n1; ++k) { for (int k = 0; k < n1; ++k) {
elems.push_back(vals.at({2 * m, 2 * k})); elems.push_back(vals.at({2 * m, 2 * k}));
elems.push_back(vals.at({2 * m, 2 * k + 1})); elems.push_back(vals.at({2 * m, 2 * k + 1}));
elems.push_back(vals.at({2 * m + 1, 2 * k})); elems.push_back(vals.at({2 * m + 1, 2 * k}));
@@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
auto loc = op.getLoc(); auto loc = op.getLoc();
Value src = op.src(); Value src = op.src();
Value dst = op.result(); Value dst = op.result();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dstTensorTy = dst.getType().cast<RankedTensorType>(); auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto dotOperandLayout = auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>(); dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr mmaLayout = MmaEncodingAttr mmaLayout =
@@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion
auto ctx = op.getContext(); auto ctx = op.getContext();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto voidTy = void_ty(ctx); auto voidTy = void_ty(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy); ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value. // Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion
unsigned perPhase = resSharedLayout.getPerPhase(); unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread(); auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder(); auto inOrder = srcBlockedLayout.getOrder();
auto outOrder = resSharedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase == threadsPerCTA, // elements across phases. If perPhase * maxPhase == threadsPerCTA,
// swizzle is not allowd // swizzle is not allowd
@@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) { for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems; auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp = auto &copyAsyncOp =
@@ -4208,7 +4170,7 @@ namespace mlir {
TritonLLVMConversionTarget::TritonLLVMConversionTarget( TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) { : ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>(); addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>(); addLegalDialect<NVVM::NVVMDialect>();
// addIllegalDialect<triton::TritonDialect>(); // addIllegalDialect<triton::TritonDialect>();
@@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget( TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) { : ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>(); addLegalDialect<LLVM::LLVMDialect>();
// addLegalDialect<NVVM::NVVMDialect>(); // addLegalDialect<NVVM::NVVMDialect>();
addIllegalOp<mlir::FuncOp>(); addIllegalOp<mlir::FuncOp>();

View File

@@ -21,9 +21,7 @@ public:
matchAndRewrite(Op op, typename Op::Adaptor adaptor, matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType()); Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success(); return success();
} }
}; };
@@ -37,9 +35,8 @@ public:
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType()); Type retType = this->getTypeConverter()->convertType(op.getType());
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs());
adaptor.getLhs(), adaptor.getRhs());
return success(); return success();
} }
}; };
@@ -129,10 +126,9 @@ public:
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType()); Type retType = this->getTypeConverter()->convertType(op.getType());
triton::gpu::SelectOp res = rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>( op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
op, retType, adaptor.getCondition(), adaptor.getTrueValue(), adaptor.getFalseValue());
adaptor.getFalseValue());
return success(); return success();
} }
}; };
@@ -204,9 +200,6 @@ struct TritonExpandDimsPattern
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
retThreadsPerWarp, retWarpsPerCTA, retThreadsPerWarp, retWarpsPerCTA,
retOrder); retOrder);
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
// convert operand to slice of return type // convert operand to slice of return type
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
getContext(), op.axis(), retEncoding); getContext(), op.axis(), retEncoding);
@@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding); bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b); b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
} }
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>( rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(), op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB()); adaptor.transB());
return success(); return success();
@@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
LogicalResult LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>( rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()); op, adaptor.ptr(), adaptor.value(), adaptor.mask());
return success(); return success();
} }
@@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
LogicalResult LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>( rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, adaptor.redOp(), adaptor.operand(), adaptor.axis()); op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
return success(); return success();
} }

View File

@@ -49,7 +49,6 @@ unsigned getElemsPerThread(Type type) {
auto tensorType = type.cast<RankedTensorType>(); auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding(); auto layout = tensorType.getEncoding();
auto shape = tensorType.getShape(); auto shape = tensorType.getShape();
size_t rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape); return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) { } else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
@@ -109,7 +108,7 @@ SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) { SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape; SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
shape.push_back(blockedLayout.getSizePerThread()[d] * shape.push_back(blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]); blockedLayout.getWarpsPerCTA()[d]);
@@ -117,7 +116,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
unsigned dim = sliceLayout.getDim(); unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent(); auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) { for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
if (d == dim) if (d == dim)
continue; continue;
shape.push_back(blockedParent.getSizePerThread()[d] * shape.push_back(blockedParent.getSizePerThread()[d] *
@@ -258,7 +257,6 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size(); size_t rank = shape.size();
auto parent = getParent(); auto parent = getParent();
unsigned dim = getDim();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 && assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread"); "unexpected rank in SliceEncodingAttr::getElemsPerThread");
@@ -512,11 +510,11 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
auto encoding = srcType.getEncoding(); auto encoding = srcType.getEncoding();
auto srcShape = srcType.getShape(); auto srcShape = srcType.getShape();
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt(); auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || axis > srcShape.size()) if (axis < 0 || (size_t)axis > srcShape.size())
return failure(); return failure();
SmallVector<int64_t, 4> dstShape; SmallVector<int64_t, 4> dstShape;
for (int i = 0; i < srcShape.size(); i++) for (size_t i = 0; i < srcShape.size(); i++)
if (i != axis) if (i != (size_t)axis)
dstShape.push_back(srcShape[i]); dstShape.push_back(srcShape[i]);
auto returnType = auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding); RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
@@ -578,15 +576,17 @@ struct TritonGPUInferLayoutInterface
: public triton::DialectInferLayoutInterface { : public triton::DialectInferLayoutInterface {
using DialectInferLayoutInterface::DialectInferLayoutInterface; using DialectInferLayoutInterface::DialectInferLayoutInterface;
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis, LogicalResult
Attribute &resultEncoding) const { inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const override {
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis, resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
operandEncoding); operandEncoding);
return success(); return success();
} }
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis, LogicalResult
Attribute &resultEncoding) const { inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const override {
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>(); auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
if (!sliceEncoding) { if (!sliceEncoding) {
llvm::report_fatal_error( llvm::report_fatal_error(

View File

@@ -87,7 +87,6 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op)) if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure(); return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op); auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>(); auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr // we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention // this is a heuristics to accomodate fused attention
@@ -219,10 +218,10 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp); auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) { if (typeInfer) {
SmallVector<Type, 1> newType; SmallVector<Type, 1> newType;
auto sucess = typeInfer.inferReturnTypes( auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(), newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newType); newOp->getAttrDictionary(), newOp->getRegions(), newType);
if (success) if (succeeded(success))
newOp->getResult(0).setType(newType.front()); newOp->getResult(0).setType(newType.front());
} }
return newOp; return newOp;
@@ -364,10 +363,6 @@ public:
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
size_t i, RankedTensorType newType, size_t i, RankedTensorType newType,
triton::gpu::ConvertLayoutOp origConversion) const { triton::gpu::ConvertLayoutOp origConversion) const {
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
auto ctx = forOp.getContext();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
// Rewrite init argument // Rewrite init argument
Type origType = forOp.getInitArgs()[i].getType(); Type origType = forOp.getInitArgs()[i].getType();
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs(); SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
@@ -418,11 +413,10 @@ public:
return newResults; return newResults;
} }
mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::LogicalResult
mlir::PatternRewriter &rewriter) const { matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op); auto forOp = cast<scf::ForOp>(op);
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
auto iterArgs = forOp.getRegionIterArgs(); auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) { for (auto iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1) // if (iterArg.index() != 1)
@@ -480,7 +474,6 @@ public:
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp()); auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp) if (!forOp)
return mlir::failure(); return mlir::failure();
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
SetVector<Operation *> cvtSlices; SetVector<Operation *> cvtSlices;

View File

@@ -17,11 +17,6 @@ using namespace mlir;
namespace { namespace {
class LoopPipeliner { class LoopPipeliner {
/// comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// cache forOp we are working on /// cache forOp we are working on
scf::ForOp forOp; scf::ForOp forOp;
@@ -43,6 +38,11 @@ class LoopPipeliner {
/// ///
Value loopIterIdx; Value loopIterIdx;
/// comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// value (in loop) => value at stage N /// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping; DenseMap<Value, SmallVector<Value>> valueMapping;
@@ -58,9 +58,6 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage); Value lookupOrDefault(Value origin, int stage);
/// return true if this op uses any of `loads`
bool isDirectUserOfAsyncLoad(Operation &op);
/// returns a empty buffer of size <numStages, ...> /// returns a empty buffer of size <numStages, ...>
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op, triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
OpBuilder &builder); OpBuilder &builder);
@@ -84,7 +81,7 @@ public:
/// create the new ForOp (add new args & insert prefetched ops) /// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp(); scf::ForOp createNewForOp();
friend class PipelinePass; friend struct PipelinePass;
}; };
// helpers // helpers
@@ -123,19 +120,6 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
} }
} }
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
for (Value loadOp : loads) {
assert(loadOp.hasOneUse() &&
"load should only have one use (ConvertLayout)");
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
for (Value opOperand : op.getOperands()) {
if (opOperand == loadUseResult)
return true;
}
}
return false;
}
triton::gpu::AllocTensorOp triton::gpu::AllocTensorOp
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
// allocate a buffer for each pipelined tensor // allocate a buffer for each pipelined tensor
@@ -356,8 +340,8 @@ void LoopPipeliner::emitPrologue() {
} // for (int stage = 0; stage < numStages - 1; ++stage) } // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice // async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>( builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads[0].getLoc(), loads.size() * (numStages - 2)); loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32); loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) { for (Value loadOp : loads) {
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>( Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
@@ -380,8 +364,7 @@ void LoopPipeliner::emitEpilogue() {
OpBuilder builder(forOp); OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder); OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp); builder.setInsertionPointAfter(forOp);
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
} }
scf::ForOp LoopPipeliner::createNewForOp() { scf::ForOp LoopPipeliner::createNewForOp() {
@@ -575,8 +558,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
yieldValues.push_back(loopIterIdx); yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody()); builder.setInsertionPointToEnd(newForOp.getBody());
auto test = builder.create<scf::YieldOp>( builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
forOp.getBody()->getTerminator()->getLoc(), yieldValues); yieldValues);
return newForOp; return newForOp;
} }

View File

@@ -30,7 +30,7 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
(ty.getElementType().getIntOrFloatBitWidth() / 8)); (ty.getElementType().getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1); perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order` // index of the inner dimension in `order`
int inner = (opIdx == 0) ? 0 : 1; size_t inner = (opIdx == 0) ? 0 : 1;
if (version == 1) { if (version == 1) {
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
// TODO: handle rep (see // TODO: handle rep (see
@@ -67,7 +67,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
void runOnOperation() override { void runOnOperation() override {
Operation *op = getOperation(); Operation *op = getOperation();
MLIRContext *context = &getContext();
op->walk([&](triton::DotOp dotOp) -> void { op->walk([&](triton::DotOp dotOp) -> void {
OpBuilder builder(dotOp); OpBuilder builder(dotOp);
auto _retEncoding = auto _retEncoding =

View File

@@ -73,7 +73,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// //
TritonGPUConversionTarget::TritonGPUConversionTarget( TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter) MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) { : ConversionTarget(context) {
// TODO: we should also verify ops of TritonGPUDialect // TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>(); addLegalDialect<triton::gpu::TritonGPUDialect>();
@@ -90,7 +90,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
}); });
// We have requirements for the data layouts // We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool { addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding = Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding(); dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = Attribute bEncoding =

View File

@@ -63,7 +63,7 @@ static bool find_and_replace(std::string &str, const std::string &begin,
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
// LLVM version in use may not officially support target hardware // LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75; int max_nvvm_cc = 75;
int max_nvvm_ptx = 74; // int max_nvvm_ptx = 74;
// options // options
auto options = llvm::cl::getRegisteredOptions(); auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr = auto *short_ptr =

View File

@@ -1,11 +1,12 @@
import distutils import distutils
import distutils.spawn import itertools
import os import os
import platform import platform
import re import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
import sysconfig
import tarfile import tarfile
import tempfile import tempfile
import urllib.request import urllib.request
@@ -16,6 +17,74 @@ from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
# Logic from https://github.com/Kitware/CMake/blob/master/Modules/FindPythonLibs.cmake
# Code from https://stackoverflow.com/questions/47423246
def get_python_library():
"""Get path to the python library associated with the current python
interpreter."""
# determine direct path to libpython
python_version = sysconfig.get_python_version()
python_library = sysconfig.get_config_var('LIBRARY')
# if static (or nonexistent), try to find a suitable dynamic libpython
if (python_library is None or os.path.splitext(python_library)[1][-2:] == '.a'):
candidate_lib_prefixes = ['', 'lib']
candidate_extensions = ['.lib', '.so', '.a']
if sysconfig.get_config_var('WITH_DYLD'):
candidate_extensions.insert(0, '.dylib')
candidate_versions = [python_version]
if python_version:
candidate_versions.append('')
candidate_versions.insert(
0, "".join(python_version.split(".")[:2]))
abiflags = getattr(sys, 'abiflags', '')
candidate_abiflags = [abiflags]
if abiflags:
candidate_abiflags.append('')
# Ensure the value injected by virtualenv is
# returned on windows.
# Because calling `sysconfig.get_config_var('multiarchsubdir')`
# returns an empty string on Linux, `du_sysconfig` is only used to
# get the value of `LIBDIR`.
libdir = distutils.sysconfig.get_config_var('LIBDIR')
if sysconfig.get_config_var('MULTIARCH'):
masd = sysconfig.get_config_var('multiarchsubdir')
if masd:
if masd.startswith(os.sep):
masd = masd[len(os.sep):]
libdir = os.path.join(libdir, masd)
if libdir is None:
libdir = os.path.abspath(os.path.join(
sysconfig.get_config_var('LIBDEST'), "..", "libs"))
candidates = (
os.path.join(
libdir,
''.join((pre, 'python', ver, abi, ext))
)
for (pre, ext, ver, abi) in itertools.product(
candidate_lib_prefixes,
candidate_extensions,
candidate_versions,
candidate_abiflags
)
)
for candidate in candidates:
if os.path.exists(candidate):
# we found a (likely alternate) libpython
python_library = candidate
break
return python_library
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool: def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
@@ -136,14 +205,19 @@ class CMakeBuild(build_ext):
if not os.path.exists(llvm_build_dir): if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir) os.makedirs(llvm_build_dir)
# python directories # python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] python_include_dir = distutils.sysconfig.get_python_inc()
python_link_dir = distutils.sysconfig.get_python_lib()
python_library = get_python_library()
cmake_args = [ cmake_args = [
"-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTRITON_BUILD_TUTORIALS=OFF", "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable, # '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs), "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DPYTHON_LINK_DIRS=" + python_link_dir,
"-DPYTHON_LIBRARIES=" + python_library,
"-DLLVM_EXTERNAL_LIT=" + lit_dir "-DLLVM_EXTERNAL_LIT=" + lit_dir
] + thirdparty_cmake_args ] + thirdparty_cmake_args

View File

@@ -26,6 +26,7 @@
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h" #include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h" #include "llvm/IRReader/IRReader.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
@@ -1301,39 +1302,36 @@ void init_triton_translation(py::module &m) {
py::gil_scoped_release allow_threads; py::gil_scoped_release allow_threads;
// compile ptx with ptxas // compile ptx with ptxas
char _fsrc[L_tmpnam]; llvm::SmallString<64> fsrc;
char _flog[L_tmpnam]; llvm::SmallString<64> flog;
std::tmpnam(_fsrc); llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
std::tmpnam(_flog); llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fsrc = _fsrc; std::string fbin = std::string(fsrc) + ".o";
std::string flog = _flog; llvm::FileRemover srcRemover(fsrc);
std::string fbin = fsrc + ".o"; llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
const char *_flog = flog.c_str();
const char *_fbin = fbin.c_str(); const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc); std::ofstream ofs(_fsrc);
ofs << ptxCode << std::endl; ofs << ptxCode << std::endl;
ofs.close(); ofs.close();
std::string cmd; std::string cmd;
int err; int err;
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) + cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
" " + fsrc + " -o " + fsrc + ".o 2> " + flog; " " + _fsrc + " -o " + _fsrc + ".o 2> " + _flog;
err = system(cmd.c_str()); err = system(cmd.c_str());
if (err != 0) { if (err != 0) {
std::ifstream _log(_flog); std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {}); std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log); log);
} }
std::ifstream _cubin(_fbin, std::ios::binary); std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {}); std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close(); _cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
py::bytes bytes(cubin); py::bytes bytes(cubin);
return bytes; return std::move(bytes);
}); });
m.def("add_external_libs", m.def("add_external_libs",
@@ -1345,8 +1343,8 @@ void init_triton_translation(py::module &m) {
void init_triton(py::module &m) { void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton"); py::module subm = m.def_submodule("triton");
// init_triton_codegen(std::move(subm.def_submodule("code_gen"))); // init_triton_codegen(subm.def_submodule("code_gen"));
init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_runtime(subm.def_submodule("runtime"));
init_triton_ir(std::move(subm.def_submodule("ir"))); init_triton_ir(subm.def_submodule("ir"));
init_triton_translation(subm); init_triton_translation(subm);
} }