[BUILD] Fix Warnings and Enable Warnings as Errors (#794)
This commit is contained in:
@@ -132,14 +132,15 @@ endif()
|
||||
# Python module
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
include_directories("." ${PYTHON_SRC_PATH})
|
||||
if (PYTHON_INCLUDE_DIRS)
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS})
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
link_libraries(${PYTHON_LIBRARIES})
|
||||
else()
|
||||
find_package(Python3 REQUIRED COMPONENTS Development)
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${Python3_INCLUDE_DIRS})
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
link_directories(${Python3_LIBRARY_DIRS})
|
||||
link_libraries(${Python3_LIBRARIES})
|
||||
add_link_options(${Python3_LINK_OPTIONS})
|
||||
@@ -169,7 +170,10 @@ list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||
include(TableGen) # required by AddMLIR
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
# include(HandleLLVMOptions) # human-friendly error message
|
||||
include(HandleLLVMOptions) # human-friendly error message
|
||||
|
||||
# Disable warnings that show up in external code (gtest;pybind11)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default")
|
||||
|
||||
include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
@@ -192,7 +196,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
target_link_libraries(triton
|
||||
${PYTHON_LIBRARIES}
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
|
@@ -142,7 +142,7 @@ private:
|
||||
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), size(size), offset(offset), id(nextId++) {}
|
||||
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||
|
||||
bool intersects(const BufferT &other) const {
|
||||
return Interval<size_t>(offset, offset + size)
|
||||
|
@@ -28,8 +28,8 @@ public:
|
||||
DimVectorT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == rank);
|
||||
assert(knownConstancy.size() == rank);
|
||||
assert(knownDivisibility.size() == (size_t)rank);
|
||||
assert(knownConstancy.size() == (size_t)rank);
|
||||
}
|
||||
|
||||
// Accessors
|
||||
|
@@ -15,9 +15,9 @@ class Location;
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
class PTXInstr;
|
||||
class PTXInstrCommon;
|
||||
class PTXInstrExecution;
|
||||
struct PTXInstr;
|
||||
struct PTXInstrCommon;
|
||||
struct PTXInstrExecution;
|
||||
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
@@ -83,7 +83,7 @@ struct PTXBuilder {
|
||||
Operand() = default;
|
||||
Operand(const Operation &) = delete;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: value(value), constraint(constraint) {}
|
||||
: constraint(constraint), value(value) {}
|
||||
|
||||
bool isList() const { return !value && constraint.empty(); }
|
||||
|
||||
@@ -120,7 +120,7 @@ struct PTXBuilder {
|
||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||
const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(val, constraint));
|
||||
}
|
||||
return list;
|
||||
@@ -128,7 +128,7 @@ struct PTXBuilder {
|
||||
|
||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(constraint));
|
||||
}
|
||||
return list;
|
||||
@@ -172,8 +172,8 @@ private:
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
friend class PTXInstr;
|
||||
friend class PTXInstrCommon;
|
||||
friend struct PTXInstr;
|
||||
friend struct PTXInstrCommon;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
@@ -209,7 +209,7 @@ protected:
|
||||
PTXBuilder *builder{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
|
||||
friend class PTXInstrExecution;
|
||||
friend struct PTXInstrExecution;
|
||||
};
|
||||
|
||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
@@ -309,7 +309,7 @@ struct PTXInstrExecution {
|
||||
PTXInstrExecution() = default;
|
||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs)
|
||||
: instr(instr), argsInOrder(oprs.begin(), oprs.end()) {}
|
||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||
|
@@ -11,16 +11,12 @@ class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
mlir::LLVMTypeConverter &typeConverter;
|
||||
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
||||
mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
mlir::LLVMTypeConverter &typeConverter;
|
||||
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
|
||||
|
@@ -26,11 +26,11 @@ public:
|
||||
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
virtual LogicalResult
|
||||
inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
|
||||
virtual LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
};
|
||||
|
||||
|
@@ -21,7 +21,6 @@ private:
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
TritonGPUTypeConverter &typeConverter;
|
||||
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||
|
@@ -82,7 +82,6 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto rank = srcShape.size();
|
||||
auto axis = op.axis();
|
||||
|
||||
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||
|
@@ -66,7 +66,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
DimVectorT retContiguity;
|
||||
DimVectorT retDivisibility;
|
||||
DimVectorT retConstancy;
|
||||
for (size_t d = 0; d < lhs.getRank(); ++d) {
|
||||
for (int d = 0; d < lhs.getRank(); ++d) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(
|
||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
@@ -88,7 +88,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
AxisInfo::DimVectorT newContiguity;
|
||||
AxisInfo::DimVectorT newDivisibility;
|
||||
AxisInfo::DimVectorT newConstancy;
|
||||
for (size_t d = 0; d < rank; ++d) {
|
||||
for (int d = 0; d < rank; ++d) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
@@ -167,7 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
||||
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
@@ -176,12 +176,6 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
// expandDims
|
||||
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
TensorType opTy = _opTy.cast<TensorType>();
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
||||
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
||||
@@ -203,7 +197,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
||||
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
|
@@ -28,11 +28,10 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
}
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state) {
|
||||
auto *op = value.getDefiningOp();
|
||||
std::string opName;
|
||||
llvm::raw_string_ostream ss(opName);
|
||||
value.printAsOperand(ss, state);
|
||||
return std::move(opName);
|
||||
return opName;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -254,7 +254,6 @@ protected:
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||
/// information.
|
||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
PatternBenefit benefit)
|
||||
@@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
// Currently, Triton kernel function always return nothing.
|
||||
@@ -482,7 +480,6 @@ public:
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
@@ -654,7 +651,6 @@ public:
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
Value offVal = idx_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
@@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(srcType, constVal);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
@@ -842,7 +837,6 @@ struct LoadOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// original values
|
||||
@@ -897,12 +891,11 @@ struct LoadOpConversion
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
@@ -921,7 +914,7 @@ struct LoadOpConversion
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
}
|
||||
@@ -988,8 +981,8 @@ struct LoadOpConversion
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
// LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
// ---
|
||||
@@ -1080,27 +1073,25 @@ struct StoreOpConversion
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
||||
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||
assert(elemOffset < valueElems.size());
|
||||
Value elem = valueElems[elemOffset];
|
||||
@@ -1220,7 +1211,6 @@ struct BroadcastOpConversion
|
||||
}
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto elemTy = resultTy.getElementType();
|
||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
@@ -1282,8 +1272,6 @@ private:
|
||||
LogicalResult
|
||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto rank = srcTy.getShape().size();
|
||||
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
@@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||
|
||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value val, int i) const {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
@@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (int i = 0; i < resultElems; i++) {
|
||||
for (size_t i = 0; i < resultElems; i++) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
@@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcOrder = srcLayout.getOrder();
|
||||
|
||||
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||
@@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (int i = 0; i < resultElems; i++) {
|
||||
for (size_t i = 0; i < resultElems; i++) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
@@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
@@ -1698,7 +1683,6 @@ struct AddPtrOpConversion
|
||||
auto resultLayout =
|
||||
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||
auto resultShape = resultTensorTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
@@ -1821,7 +1805,7 @@ protected:
|
||||
SmallVector<SmallVector<Value>> operands(elems);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||
for (int i = 0; i < elems; ++i) {
|
||||
for (size_t i = 0; i < elems; ++i) {
|
||||
operands[i].push_back(sub_operands[i]);
|
||||
}
|
||||
}
|
||||
@@ -1931,6 +1915,7 @@ struct CmpFOpConversion
|
||||
__PRED_ENUM(ORD, ord);
|
||||
__PRED_ENUM(UEQ, ueq);
|
||||
__PRED_ENUM(UGT, ugt);
|
||||
__PRED_ENUM(UGE, uge);
|
||||
__PRED_ENUM(ULT, ult);
|
||||
__PRED_ENUM(ULE, ule);
|
||||
__PRED_ENUM(UNE, une);
|
||||
@@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
auto shapePerCTA = getShapePerCTA(layout);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
@@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedLayout, type.getShape());
|
||||
} else if (sliceLayout) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape =
|
||||
@@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
// unsigned elems = getElemsPerThread(srcTy);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
@@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
// Data loader for mma.16816 instruction.
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, const Location &loc)
|
||||
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
|
||||
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||
tileShape(tileShape.begin(), tileShape.end()),
|
||||
instrShape(instrShape.begin(), instrShape.end()),
|
||||
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
|
||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
|
||||
ctx(rewriter.getContext()) {
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
@@ -2576,7 +2559,6 @@ public:
|
||||
assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
|
||||
"smem matrix load must be aligned");
|
||||
int matIdx[2] = {mat0, mat1};
|
||||
int k = matIdx[kOrder];
|
||||
|
||||
int ptrIdx{-1};
|
||||
|
||||
@@ -2596,7 +2578,6 @@ public:
|
||||
|
||||
Value ptr = getPtr(ptrIdx);
|
||||
|
||||
Value resV4;
|
||||
if (canUseLdmatrix) {
|
||||
int sOffset =
|
||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||
@@ -2727,7 +2708,6 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
int wpt;
|
||||
SmallVector<uint32_t> order;
|
||||
int kOrder;
|
||||
SmallVector<int64_t> tileShape;
|
||||
@@ -2737,7 +2717,6 @@ private:
|
||||
int maxPhase;
|
||||
int elemBytes;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TypeConverter *typeConverter{};
|
||||
const Location &loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
@@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
// D = A * B + C
|
||||
Value A = op.a();
|
||||
Value B = op.b();
|
||||
Value C = op.c();
|
||||
Value D = op.getResult();
|
||||
MLIRContext *ctx = op->getContext();
|
||||
bool allowTF32 = op.allowTF32();
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
@@ -2951,8 +2925,6 @@ struct DotOpConversionHelper {
|
||||
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
||||
Type i8x4Pack4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
||||
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(4, type::i32Ty(ctx)));
|
||||
|
||||
switch (mmaType) {
|
||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||
@@ -3062,7 +3034,6 @@ struct DotOpConversionHelper {
|
||||
auto bTy = B.getType().cast<RankedTensorType>();
|
||||
// d = a*b + c
|
||||
auto dTy = op.d().getType().cast<RankedTensorType>();
|
||||
auto mmaLayout = dTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
|
||||
if (dTy.getElementType().isF32()) {
|
||||
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
|
||||
@@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper {
|
||||
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
|
||||
thread(thread) {
|
||||
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
||||
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
||||
ctx(mmaLayout.getContext()) {
|
||||
wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
Value _32 = i32_val(32);
|
||||
@@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// step1. Perform loading.
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (int m = 0; m < numRepM; ++m)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||
@@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper {
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * n, 2 * k);
|
||||
}
|
||||
|
||||
@@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper {
|
||||
helper.deduceMmaType(op);
|
||||
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto dShape = dTensorTy.getShape();
|
||||
|
||||
int NK = aShape[1];
|
||||
// shape / shape_per_cta
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
||||
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
||||
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
||||
@@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper {
|
||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned n = 0; n < numRepN; ++n)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
for (int m = 0; m < numRepM; ++m)
|
||||
for (int n = 0; n < numRepN; ++n)
|
||||
callMma(2 * m, n, 2 * k);
|
||||
|
||||
// replace with new packed result
|
||||
@@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper {
|
||||
private:
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||
int wpt, int kOrder, ArrayRef<int> instrShape,
|
||||
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, Value warpId,
|
||||
ValueTable &vals) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
@@ -3486,8 +3452,8 @@ private:
|
||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||
int n1) const {
|
||||
std::vector<Value> elems;
|
||||
for (unsigned m = 0; m < n0; ++m)
|
||||
for (unsigned k = 0; k < n1; ++k) {
|
||||
for (int m = 0; m < n0; ++m)
|
||||
for (int k = 0; k < n1; ++k) {
|
||||
elems.push_back(vals.at({2 * m, 2 * k}));
|
||||
elems.push_back(vals.at({2 * m, 2 * k + 1}));
|
||||
elems.push_back(vals.at({2 * m + 1, 2 * k}));
|
||||
@@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
|
||||
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto dotOperandLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
MmaEncodingAttr mmaLayout =
|
||||
@@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = void_ty(ctx);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
@@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
auto outOrder = resSharedLayout.getOrder();
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
||||
// swizzle is not allowd
|
||||
@@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion
|
||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||
|
||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto wordElemIdx = wordIdx * numWordElems;
|
||||
auto ©AsyncOp =
|
||||
@@ -4208,7 +4170,7 @@ namespace mlir {
|
||||
|
||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
// addIllegalDialect<triton::TritonDialect>();
|
||||
@@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
|
||||
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
// addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalOp<mlir::FuncOp>();
|
||||
|
@@ -21,9 +21,7 @@ public:
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res =
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -37,9 +35,8 @@ public:
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
DstOp res =
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -129,10 +126,9 @@ public:
|
||||
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
triton::gpu::SelectOp res =
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -204,9 +200,6 @@ struct TritonExpandDimsPattern
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
// convert operand to slice of return type
|
||||
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||
getContext(), op.axis(), retEncoding);
|
||||
@@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
return success();
|
||||
@@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
@@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
|
@@ -49,7 +49,6 @@ unsigned getElemsPerThread(Type type) {
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
auto layout = tensorType.getEncoding();
|
||||
auto shape = tensorType.getShape();
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
@@ -109,7 +108,7 @@ SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> shape;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d]);
|
||||
@@ -117,7 +116,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(blockedParent.getSizePerThread()[d] *
|
||||
@@ -258,7 +257,6 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
unsigned dim = getDim();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||
@@ -512,11 +510,11 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
auto encoding = srcType.getEncoding();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (axis < 0 || axis > srcShape.size())
|
||||
if (axis < 0 || (size_t)axis > srcShape.size())
|
||||
return failure();
|
||||
SmallVector<int64_t, 4> dstShape;
|
||||
for (int i = 0; i < srcShape.size(); i++)
|
||||
if (i != axis)
|
||||
for (size_t i = 0; i < srcShape.size(); i++)
|
||||
if (i != (size_t)axis)
|
||||
dstShape.push_back(srcShape[i]);
|
||||
auto returnType =
|
||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||
@@ -578,15 +576,17 @@ struct TritonGPUInferLayoutInterface
|
||||
: public triton::DialectInferLayoutInterface {
|
||||
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
||||
|
||||
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
LogicalResult
|
||||
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const override {
|
||||
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
||||
operandEncoding);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const override {
|
||||
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||
if (!sliceEncoding) {
|
||||
llvm::report_fatal_error(
|
||||
|
@@ -87,7 +87,6 @@ public:
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
@@ -219,10 +218,10 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||
if (typeInfer) {
|
||||
SmallVector<Type, 1> newType;
|
||||
auto sucess = typeInfer.inferReturnTypes(
|
||||
auto success = typeInfer.inferReturnTypes(
|
||||
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||
newOp->getAttrDictionary(), newOp->getRegions(), newType);
|
||||
if (success)
|
||||
if (succeeded(success))
|
||||
newOp->getResult(0).setType(newType.front());
|
||||
}
|
||||
return newOp;
|
||||
@@ -364,10 +363,6 @@ public:
|
||||
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
|
||||
size_t i, RankedTensorType newType,
|
||||
triton::gpu::ConvertLayoutOp origConversion) const {
|
||||
|
||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
||||
auto ctx = forOp.getContext();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
// Rewrite init argument
|
||||
Type origType = forOp.getInitArgs()[i].getType();
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
@@ -418,11 +413,10 @@ public:
|
||||
return newResults;
|
||||
}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
// if (iterArg.index() != 1)
|
||||
@@ -480,7 +474,6 @@ public:
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp)
|
||||
return mlir::failure();
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
|
||||
SetVector<Operation *> cvtSlices;
|
||||
|
@@ -17,11 +17,6 @@ using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class LoopPipeliner {
|
||||
/// comments on numStages:
|
||||
/// [0, numStages-1) are in the prologue
|
||||
/// numStages-1 is appended after the loop body
|
||||
int numStages;
|
||||
|
||||
/// cache forOp we are working on
|
||||
scf::ForOp forOp;
|
||||
|
||||
@@ -43,6 +38,11 @@ class LoopPipeliner {
|
||||
///
|
||||
Value loopIterIdx;
|
||||
|
||||
/// comments on numStages:
|
||||
/// [0, numStages-1) are in the prologue
|
||||
/// numStages-1 is appended after the loop body
|
||||
int numStages;
|
||||
|
||||
/// value (in loop) => value at stage N
|
||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||
|
||||
@@ -58,9 +58,6 @@ class LoopPipeliner {
|
||||
|
||||
Value lookupOrDefault(Value origin, int stage);
|
||||
|
||||
/// return true if this op uses any of `loads`
|
||||
bool isDirectUserOfAsyncLoad(Operation &op);
|
||||
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder);
|
||||
@@ -84,7 +81,7 @@ public:
|
||||
/// create the new ForOp (add new args & insert prefetched ops)
|
||||
scf::ForOp createNewForOp();
|
||||
|
||||
friend class PipelinePass;
|
||||
friend struct PipelinePass;
|
||||
};
|
||||
|
||||
// helpers
|
||||
@@ -123,19 +120,6 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
}
|
||||
}
|
||||
|
||||
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
|
||||
for (Value loadOp : loads) {
|
||||
assert(loadOp.hasOneUse() &&
|
||||
"load should only have one use (ConvertLayout)");
|
||||
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
|
||||
for (Value opOperand : op.getOperands()) {
|
||||
if (opOperand == loadUseResult)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
triton::gpu::AllocTensorOp
|
||||
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
|
||||
// allocate a buffer for each pipelined tensor
|
||||
@@ -356,8 +340,8 @@ void LoopPipeliner::emitPrologue() {
|
||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||
|
||||
// async.wait & extract_slice
|
||||
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
||||
loads.size() * (numStages - 2));
|
||||
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (Value loadOp : loads) {
|
||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
@@ -380,8 +364,7 @@ void LoopPipeliner::emitEpilogue() {
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Operation *asyncWait =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
@@ -575,8 +558,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
yieldValues.push_back(loopIterIdx);
|
||||
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
auto test = builder.create<scf::YieldOp>(
|
||||
forOp.getBody()->getTerminator()->getLoc(), yieldValues);
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
yieldValues);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
|
@@ -30,7 +30,7 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
// index of the inner dimension in `order`
|
||||
int inner = (opIdx == 0) ? 0 : 1;
|
||||
size_t inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
@@ -67,7 +67,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
op->walk([&](triton::DotOp dotOp) -> void {
|
||||
OpBuilder builder(dotOp);
|
||||
auto _retEncoding =
|
||||
|
@@ -73,7 +73,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
//
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
: ConversionTarget(context) {
|
||||
// TODO: we should also verify ops of TritonGPUDialect
|
||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
|
||||
@@ -90,7 +90,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
});
|
||||
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||
Attribute aEncoding =
|
||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding =
|
||||
|
@@ -63,7 +63,7 @@ static bool find_and_replace(std::string &str, const std::string &begin,
|
||||
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
int max_nvvm_ptx = 74;
|
||||
// int max_nvvm_ptx = 74;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto *short_ptr =
|
||||
|
@@ -1,11 +1,12 @@
|
||||
import distutils
|
||||
import distutils.spawn
|
||||
import itertools
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
@@ -16,6 +17,74 @@ from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
# Logic from https://github.com/Kitware/CMake/blob/master/Modules/FindPythonLibs.cmake
|
||||
# Code from https://stackoverflow.com/questions/47423246
|
||||
def get_python_library():
|
||||
"""Get path to the python library associated with the current python
|
||||
interpreter."""
|
||||
# determine direct path to libpython
|
||||
python_version = sysconfig.get_python_version()
|
||||
python_library = sysconfig.get_config_var('LIBRARY')
|
||||
|
||||
# if static (or nonexistent), try to find a suitable dynamic libpython
|
||||
if (python_library is None or os.path.splitext(python_library)[1][-2:] == '.a'):
|
||||
|
||||
candidate_lib_prefixes = ['', 'lib']
|
||||
|
||||
candidate_extensions = ['.lib', '.so', '.a']
|
||||
if sysconfig.get_config_var('WITH_DYLD'):
|
||||
candidate_extensions.insert(0, '.dylib')
|
||||
|
||||
candidate_versions = [python_version]
|
||||
if python_version:
|
||||
candidate_versions.append('')
|
||||
candidate_versions.insert(
|
||||
0, "".join(python_version.split(".")[:2]))
|
||||
|
||||
abiflags = getattr(sys, 'abiflags', '')
|
||||
candidate_abiflags = [abiflags]
|
||||
if abiflags:
|
||||
candidate_abiflags.append('')
|
||||
|
||||
# Ensure the value injected by virtualenv is
|
||||
# returned on windows.
|
||||
# Because calling `sysconfig.get_config_var('multiarchsubdir')`
|
||||
# returns an empty string on Linux, `du_sysconfig` is only used to
|
||||
# get the value of `LIBDIR`.
|
||||
libdir = distutils.sysconfig.get_config_var('LIBDIR')
|
||||
if sysconfig.get_config_var('MULTIARCH'):
|
||||
masd = sysconfig.get_config_var('multiarchsubdir')
|
||||
if masd:
|
||||
if masd.startswith(os.sep):
|
||||
masd = masd[len(os.sep):]
|
||||
libdir = os.path.join(libdir, masd)
|
||||
|
||||
if libdir is None:
|
||||
libdir = os.path.abspath(os.path.join(
|
||||
sysconfig.get_config_var('LIBDEST'), "..", "libs"))
|
||||
|
||||
candidates = (
|
||||
os.path.join(
|
||||
libdir,
|
||||
''.join((pre, 'python', ver, abi, ext))
|
||||
)
|
||||
for (pre, ext, ver, abi) in itertools.product(
|
||||
candidate_lib_prefixes,
|
||||
candidate_extensions,
|
||||
candidate_versions,
|
||||
candidate_abiflags
|
||||
)
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if os.path.exists(candidate):
|
||||
# we found a (likely alternate) libpython
|
||||
python_library = candidate
|
||||
break
|
||||
|
||||
return python_library
|
||||
|
||||
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||
def check_env_flag(name: str, default: str = "") -> bool:
|
||||
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
||||
@@ -136,14 +205,19 @@ class CMakeBuild(build_ext):
|
||||
if not os.path.exists(llvm_build_dir):
|
||||
os.makedirs(llvm_build_dir)
|
||||
# python directories
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
python_include_dir = distutils.sysconfig.get_python_inc()
|
||||
python_link_dir = distutils.sysconfig.get_python_lib()
|
||||
python_library = get_python_library()
|
||||
cmake_args = [
|
||||
"-DLLVM_ENABLE_WERROR=ON",
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DTRITON_BUILD_TUTORIALS=OFF",
|
||||
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs),
|
||||
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
||||
"-DPYTHON_LINK_DIRS=" + python_link_dir,
|
||||
"-DPYTHON_LIBRARIES=" + python_library,
|
||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
||||
] + thirdparty_cmake_args
|
||||
|
||||
|
@@ -26,6 +26,7 @@
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
@@ -1301,39 +1302,36 @@ void init_triton_translation(py::module &m) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[L_tmpnam];
|
||||
char _flog[L_tmpnam];
|
||||
std::tmpnam(_fsrc);
|
||||
std::tmpnam(_flog);
|
||||
std::string fsrc = _fsrc;
|
||||
std::string flog = _flog;
|
||||
std::string fbin = fsrc + ".o";
|
||||
llvm::SmallString<64> fsrc;
|
||||
llvm::SmallString<64> flog;
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||
std::string fbin = std::string(fsrc) + ".o";
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
llvm::FileRemover logRemover(flog);
|
||||
llvm::FileRemover binRemover(fbin);
|
||||
const char *_fsrc = fsrc.c_str();
|
||||
const char *_flog = flog.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(fsrc);
|
||||
std::ofstream ofs(_fsrc);
|
||||
ofs << ptxCode << std::endl;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
|
||||
" " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
" " + _fsrc + " -o " + _fsrc + ".o 2> " + _flog;
|
||||
err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
unlink(_fsrc);
|
||||
unlink(_flog);
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||
log);
|
||||
}
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
unlink(_fsrc);
|
||||
unlink(_flog);
|
||||
unlink(_fbin);
|
||||
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
return std::move(bytes);
|
||||
});
|
||||
|
||||
m.def("add_external_libs",
|
||||
@@ -1345,8 +1343,8 @@ void init_triton_translation(py::module &m) {
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
// init_triton_codegen(subm.def_submodule("code_gen"));
|
||||
init_triton_runtime(subm.def_submodule("runtime"));
|
||||
init_triton_ir(subm.def_submodule("ir"));
|
||||
init_triton_translation(subm);
|
||||
}
|
||||
|
Reference in New Issue
Block a user