[BUILD] Fix Warnings and Enable Warnings as Errors (#794)

This commit is contained in:
Ian Bearman
2022-10-28 12:36:09 -07:00
committed by GitHub
parent ac0f6793cc
commit f2106d0aa2
20 changed files with 205 additions and 213 deletions

View File

@@ -132,14 +132,15 @@ endif()
# Python module
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
include_directories("." ${PYTHON_SRC_PATH})
if (PYTHON_INCLUDE_DIRS)
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS})
include_directories(${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS})
link_libraries(${PYTHON_LIBRARIES})
else()
find_package(Python3 REQUIRED COMPONENTS Development)
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
include_directories("." ${PYTHON_SRC_PATH} ${Python3_INCLUDE_DIRS})
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
@@ -169,7 +170,10 @@ list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen) # required by AddMLIR
include(AddLLVM)
include(AddMLIR)
# include(HandleLLVMOptions) # human-friendly error message
include(HandleLLVMOptions) # human-friendly error message
# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default")
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
@@ -192,7 +196,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(triton
${PYTHON_LIBRARIES}
TritonAnalysis
TritonTransforms
TritonGPUTransforms

View File

@@ -142,7 +142,7 @@ private:
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), size(size), offset(offset), id(nextId++) {}
: kind(kind), id(nextId++), size(size), offset(offset) {}
bool intersects(const BufferT &other) const {
return Interval<size_t>(offset, offset + size)

View File

@@ -28,8 +28,8 @@ public:
DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank);
assert(knownConstancy.size() == rank);
assert(knownDivisibility.size() == (size_t)rank);
assert(knownConstancy.size() == (size_t)rank);
}
// Accessors

View File

@@ -15,9 +15,9 @@ class Location;
namespace triton {
using llvm::StringRef;
class PTXInstr;
class PTXInstrCommon;
class PTXInstrExecution;
struct PTXInstr;
struct PTXInstrCommon;
struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
@@ -83,7 +83,7 @@ struct PTXBuilder {
Operand() = default;
Operand(const Operation &) = delete;
Operand(Value value, StringRef constraint)
: value(value), constraint(constraint) {}
: constraint(constraint), value(value) {}
bool isList() const { return !value && constraint.empty(); }
@@ -120,7 +120,7 @@ struct PTXBuilder {
Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) {
auto *list = newOperand();
for (int i = 0; i < count; ++i) {
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint));
}
return list;
@@ -128,7 +128,7 @@ struct PTXBuilder {
Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand();
for (int i = 0; i < count; ++i) {
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint));
}
return list;
@@ -172,8 +172,8 @@ private:
return argArchive.back().get();
}
friend class PTXInstr;
friend class PTXInstrCommon;
friend struct PTXInstr;
friend struct PTXInstrCommon;
protected:
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
@@ -209,7 +209,7 @@ protected:
PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
friend class PTXInstrExecution;
friend struct PTXInstrExecution;
};
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
@@ -309,7 +309,7 @@ struct PTXInstrExecution {
PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs)
: instr(instr), argsInOrder(oprs.begin(), oprs.end()) {}
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
// Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {

View File

@@ -11,16 +11,12 @@ class ModuleOp;
template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter);
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public:
explicit TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);

View File

@@ -26,11 +26,11 @@ public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, int axis,
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
};

View File

@@ -21,7 +21,6 @@ private:
};
class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx,

View File

@@ -82,7 +82,6 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
auto rank = srcShape.size();
auto axis = op.axis();
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension

View File

@@ -66,7 +66,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
DimVectorT retContiguity;
DimVectorT retDivisibility;
DimVectorT retConstancy;
for (size_t d = 0; d < lhs.getRank(); ++d) {
for (int d = 0; d < lhs.getRank(); ++d) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
@@ -88,7 +88,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
AxisInfo::DimVectorT newContiguity;
AxisInfo::DimVectorT newDivisibility;
AxisInfo::DimVectorT newConstancy;
for (size_t d = 0; d < rank; ++d) {
for (int d = 0; d < rank; ++d) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
@@ -167,7 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); ++d) {
for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]);
@@ -176,12 +176,6 @@ ChangeResult AxisInfoAnalysis::visitOperation(
}
// expandDims
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
TensorType opTy = _opTy.cast<TensorType>();
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
@@ -203,7 +197,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); ++d) {
for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);

View File

@@ -28,11 +28,10 @@ bool maybeSharedAllocationOp(Operation *op) {
}
std::string getValueOperandName(Value value, AsmState &state) {
auto *op = value.getDefiningOp();
std::string opName;
llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state);
return std::move(opName);
return opName;
}
} // namespace mlir

View File

@@ -254,7 +254,6 @@ protected:
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
@@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
@@ -482,7 +480,6 @@ public:
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
@@ -654,7 +651,6 @@ public:
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
@@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(tensorTy);
@@ -842,7 +837,6 @@ struct LoadOpConversion
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
// original values
@@ -897,12 +891,11 @@ struct LoadOpConversion
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
@@ -921,7 +914,7 @@ struct LoadOpConversion
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr);
}
@@ -988,8 +981,8 @@ struct LoadOpConversion
: retTys[0];
// TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
// LLVM::AsmDialect::AD_ATT);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
// ---
@@ -1080,27 +1073,25 @@ struct StoreOpConversion
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems);
SmallVector<std::pair<Value, std::string>> asmArgs;
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
@@ -1220,7 +1211,6 @@ struct BroadcastOpConversion
}
unsigned srcElems = getElemsPerThread(srcTy);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, rewriter);
unsigned resultElems = getElemsPerThread(resultTy);
SmallVector<Value> resultVals(resultElems);
@@ -1282,8 +1272,6 @@ private:
LogicalResult
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto rank = srcTy.getShape().size();
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
@@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
Location loc, Value val, int i) const {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {
@@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
barrier();
SmallVector<Value> resultVals(resultElems);
for (int i = 0; i < resultElems; i++) {
for (size_t i = 0; i < resultElems; i++) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
auto srcOrder = srcLayout.getOrder();
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
@@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
barrier();
SmallVector<Value> resultVals(resultElems);
for (int i = 0; i < resultElems; i++) {
for (size_t i = 0; i < resultElems; i++) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
@@ -1698,7 +1683,6 @@ struct AddPtrOpConversion
auto resultLayout =
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTensorTy.getShape();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
@@ -1821,7 +1805,7 @@ protected:
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (int i = 0; i < elems; ++i) {
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
@@ -1931,6 +1915,7 @@ struct CmpFOpConversion
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(UGE, uge);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
@@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica(
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
for (unsigned d = 0; d < rank; ++d) {
@@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica(
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape());
} else if (sliceLayout) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape =
@@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcTy);
// unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
@@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase,
int elemBytes, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, const Location &loc)
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
: order(order.begin(), order.end()), kOrder(kOrder),
tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
ctx(rewriter.getContext()) {
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
@@ -2576,7 +2559,6 @@ public:
assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
"smem matrix load must be aligned");
int matIdx[2] = {mat0, mat1};
int k = matIdx[kOrder];
int ptrIdx{-1};
@@ -2596,7 +2578,6 @@ public:
Value ptr = getPtr(ptrIdx);
Value resV4;
if (canUseLdmatrix) {
int sOffset =
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
@@ -2727,7 +2708,6 @@ public:
}
private:
int wpt;
SmallVector<uint32_t> order;
int kOrder;
SmallVector<int64_t> tileShape;
@@ -2737,7 +2717,6 @@ private:
int maxPhase;
int elemBytes;
ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter{};
const Location &loc;
MLIRContext *ctx{};
@@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// D = A * B + C
Value A = op.a();
Value B = op.b();
Value C = op.c();
Value D = op.getResult();
MLIRContext *ctx = op->getContext();
bool allowTF32 = op.allowTF32();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
@@ -2951,8 +2925,6 @@ struct DotOpConversionHelper {
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
Type i8x4Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(4, type::i32Ty(ctx)));
switch (mmaType) {
case TensorCoreType::FP32_FP16_FP16_FP32:
@@ -3062,7 +3034,6 @@ struct DotOpConversionHelper {
auto bTy = B.getType().cast<RankedTensorType>();
// d = a*b + c
auto dTy = op.d().getType().cast<RankedTensorType>();
auto mmaLayout = dTy.getEncoding().cast<MmaEncodingAttr>();
if (dTy.getElementType().isF32()) {
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
@@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper {
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc)
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
thread(thread) {
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
ctx(mmaLayout.getContext()) {
wpt = mmaLayout.getWarpsPerCTA();
Value _32 = i32_val(32);
@@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper {
}
// step1. Perform loading.
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned k = 0; k < numRepK; ++k)
for (int m = 0; m < numRepM; ++m)
for (int k = 0; k < numRepK; ++k)
loadFn(2 * m, 2 * k);
// step2. Format the values to LLVM::Struct to passing to mma codegen.
@@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper {
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (unsigned k = 0; k < numRepK; ++k)
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (int k = 0; k < numRepK; ++k)
loadFn(2 * n, 2 * k);
}
@@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper {
helper.deduceMmaType(op);
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto dShape = dTensorTy.getShape();
int NK = aShape[1];
// shape / shape_per_cta
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
int numRepM = getNumRepM(aTensorTy, dShape[0]);
int numRepN = getNumRepN(aTensorTy, dShape[1]);
int numRepK = getNumRepK(aTensorTy, aShape[1]);
@@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper {
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
};
for (unsigned k = 0; k < numRepK; ++k)
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned n = 0; n < numRepN; ++n)
for (int k = 0; k < numRepK; ++k)
for (int m = 0; m < numRepM; ++m)
for (int n = 0; n < numRepN; ++n)
callMma(2 * m, n, 2 * k);
// replace with new packed result
@@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper {
private:
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
int wpt, int kOrder, ArrayRef<int> instrShape,
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
ArrayRef<int> matShape, Value warpId,
ValueTable &vals) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
@@ -3486,8 +3452,8 @@ private:
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
int n1) const {
std::vector<Value> elems;
for (unsigned m = 0; m < n0; ++m)
for (unsigned k = 0; k < n1; ++k) {
for (int m = 0; m < n0; ++m)
for (int k = 0; k < n1; ++k) {
elems.push_back(vals.at({2 * m, 2 * k}));
elems.push_back(vals.at({2 * m, 2 * k + 1}));
elems.push_back(vals.at({2 * m + 1, 2 * k}));
@@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr mmaLayout =
@@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = void_ty(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
@@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
auto outOrder = resSharedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
// swizzle is not allowd
@@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp =
@@ -4208,7 +4170,7 @@ namespace mlir {
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
// addIllegalDialect<triton::TritonDialect>();
@@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
// addLegalDialect<NVVM::NVVMDialect>();
addIllegalOp<mlir::FuncOp>();

View File

@@ -21,9 +21,7 @@ public:
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res =
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success();
}
};
@@ -37,9 +35,8 @@ public:
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
DstOp res =
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
@@ -129,10 +126,9 @@ public:
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
triton::gpu::SelectOp res =
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
@@ -204,9 +200,6 @@ struct TritonExpandDimsPattern
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
// convert operand to slice of return type
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
getContext(), op.axis(), retEncoding);
@@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
return success();
@@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
return success();
}
@@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
return success();
}

View File

@@ -49,7 +49,6 @@ unsigned getElemsPerThread(Type type) {
auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding();
auto shape = tensorType.getShape();
size_t rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
@@ -109,7 +108,7 @@ SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
shape.push_back(blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
@@ -117,7 +116,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(blockedParent.getSizePerThread()[d] *
@@ -258,7 +257,6 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
unsigned dim = getDim();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
@@ -512,11 +510,11 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
auto encoding = srcType.getEncoding();
auto srcShape = srcType.getShape();
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || axis > srcShape.size())
if (axis < 0 || (size_t)axis > srcShape.size())
return failure();
SmallVector<int64_t, 4> dstShape;
for (int i = 0; i < srcShape.size(); i++)
if (i != axis)
for (size_t i = 0; i < srcShape.size(); i++)
if (i != (size_t)axis)
dstShape.push_back(srcShape[i]);
auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
@@ -578,15 +576,17 @@ struct TritonGPUInferLayoutInterface
: public triton::DialectInferLayoutInterface {
using DialectInferLayoutInterface::DialectInferLayoutInterface;
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
Attribute &resultEncoding) const {
LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const override {
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
operandEncoding);
return success();
}
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
Attribute &resultEncoding) const {
LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const override {
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
if (!sliceEncoding) {
llvm::report_fatal_error(

View File

@@ -87,7 +87,6 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
@@ -219,10 +218,10 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
SmallVector<Type, 1> newType;
auto sucess = typeInfer.inferReturnTypes(
auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newType);
if (success)
if (succeeded(success))
newOp->getResult(0).setType(newType.front());
}
return newOp;
@@ -364,10 +363,6 @@ public:
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
size_t i, RankedTensorType newType,
triton::gpu::ConvertLayoutOp origConversion) const {
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
auto ctx = forOp.getContext();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
// Rewrite init argument
Type origType = forOp.getInitArgs()[i].getType();
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
@@ -418,11 +413,10 @@ public:
return newResults;
}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op);
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1)
@@ -480,7 +474,6 @@ public:
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp)
return mlir::failure();
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
SetVector<Operation *> cvtSlices;

View File

@@ -17,11 +17,6 @@ using namespace mlir;
namespace {
class LoopPipeliner {
/// comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// cache forOp we are working on
scf::ForOp forOp;
@@ -43,6 +38,11 @@ class LoopPipeliner {
///
Value loopIterIdx;
/// comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
@@ -58,9 +58,6 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage);
/// return true if this op uses any of `loads`
bool isDirectUserOfAsyncLoad(Operation &op);
/// returns a empty buffer of size <numStages, ...>
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
OpBuilder &builder);
@@ -84,7 +81,7 @@ public:
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
friend class PipelinePass;
friend struct PipelinePass;
};
// helpers
@@ -123,19 +120,6 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
}
}
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
for (Value loadOp : loads) {
assert(loadOp.hasOneUse() &&
"load should only have one use (ConvertLayout)");
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
for (Value opOperand : op.getOperands()) {
if (opOperand == loadUseResult)
return true;
}
}
return false;
}
triton::gpu::AllocTensorOp
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
@@ -356,8 +340,8 @@ void LoopPipeliner::emitPrologue() {
} // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
@@ -380,8 +364,7 @@ void LoopPipeliner::emitEpilogue() {
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
Operation *asyncWait =
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
}
scf::ForOp LoopPipeliner::createNewForOp() {
@@ -575,8 +558,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody());
auto test = builder.create<scf::YieldOp>(
forOp.getBody()->getTerminator()->getLoc(), yieldValues);
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
yieldValues);
return newForOp;
}

View File

@@ -30,7 +30,7 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
(ty.getElementType().getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
int inner = (opIdx == 0) ? 0 : 1;
size_t inner = (opIdx == 0) ? 0 : 1;
if (version == 1) {
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
// TODO: handle rep (see
@@ -67,7 +67,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = &getContext();
op->walk([&](triton::DotOp dotOp) -> void {
OpBuilder builder(dotOp);
auto _retEncoding =

View File

@@ -73,7 +73,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
: ConversionTarget(context) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
@@ -90,7 +90,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =

View File

@@ -63,7 +63,7 @@ static bool find_and_replace(std::string &str, const std::string &begin,
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr =

View File

@@ -1,11 +1,12 @@
import distutils
import distutils.spawn
import itertools
import os
import platform
import re
import shutil
import subprocess
import sys
import sysconfig
import tarfile
import tempfile
import urllib.request
@@ -16,6 +17,74 @@ from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
# Logic from https://github.com/Kitware/CMake/blob/master/Modules/FindPythonLibs.cmake
# Code from https://stackoverflow.com/questions/47423246
def get_python_library():
"""Get path to the python library associated with the current python
interpreter."""
# determine direct path to libpython
python_version = sysconfig.get_python_version()
python_library = sysconfig.get_config_var('LIBRARY')
# if static (or nonexistent), try to find a suitable dynamic libpython
if (python_library is None or os.path.splitext(python_library)[1][-2:] == '.a'):
candidate_lib_prefixes = ['', 'lib']
candidate_extensions = ['.lib', '.so', '.a']
if sysconfig.get_config_var('WITH_DYLD'):
candidate_extensions.insert(0, '.dylib')
candidate_versions = [python_version]
if python_version:
candidate_versions.append('')
candidate_versions.insert(
0, "".join(python_version.split(".")[:2]))
abiflags = getattr(sys, 'abiflags', '')
candidate_abiflags = [abiflags]
if abiflags:
candidate_abiflags.append('')
# Ensure the value injected by virtualenv is
# returned on windows.
# Because calling `sysconfig.get_config_var('multiarchsubdir')`
# returns an empty string on Linux, `du_sysconfig` is only used to
# get the value of `LIBDIR`.
libdir = distutils.sysconfig.get_config_var('LIBDIR')
if sysconfig.get_config_var('MULTIARCH'):
masd = sysconfig.get_config_var('multiarchsubdir')
if masd:
if masd.startswith(os.sep):
masd = masd[len(os.sep):]
libdir = os.path.join(libdir, masd)
if libdir is None:
libdir = os.path.abspath(os.path.join(
sysconfig.get_config_var('LIBDEST'), "..", "libs"))
candidates = (
os.path.join(
libdir,
''.join((pre, 'python', ver, abi, ext))
)
for (pre, ext, ver, abi) in itertools.product(
candidate_lib_prefixes,
candidate_extensions,
candidate_versions,
candidate_abiflags
)
)
for candidate in candidates:
if os.path.exists(candidate):
# we found a (likely alternate) libpython
python_library = candidate
break
return python_library
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
@@ -136,14 +205,19 @@ class CMakeBuild(build_ext):
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
python_include_dir = distutils.sysconfig.get_python_inc()
python_link_dir = distutils.sysconfig.get_python_lib()
python_library = get_python_library()
cmake_args = [
"-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs),
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DPYTHON_LINK_DIRS=" + python_link_dir,
"-DPYTHON_LIBRARIES=" + python_library,
"-DLLVM_EXTERNAL_LIT=" + lit_dir
] + thirdparty_cmake_args

View File

@@ -26,6 +26,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/SourceMgr.h"
@@ -1301,39 +1302,36 @@ void init_triton_translation(py::module &m) {
py::gil_scoped_release allow_threads;
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
llvm::SmallString<64> fsrc;
llvm::SmallString<64> flog;
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fbin = std::string(fsrc) + ".o";
llvm::FileRemover srcRemover(fsrc);
llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
const char *_flog = flog.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
std::ofstream ofs(_fsrc);
ofs << ptxCode << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
" " + fsrc + " -o " + fsrc + ".o 2> " + flog;
" " + _fsrc + " -o " + _fsrc + ".o 2> " + _flog;
err = system(cmd.c_str());
if (err != 0) {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
}
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
py::bytes bytes(cubin);
return bytes;
return std::move(bytes);
});
m.def("add_external_libs",
@@ -1345,8 +1343,8 @@ void init_triton_translation(py::module &m) {
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir")));
// init_triton_codegen(subm.def_submodule("code_gen"));
init_triton_runtime(subm.def_submodule("runtime"));
init_triton_ir(subm.def_submodule("ir"));
init_triton_translation(subm);
}