[BUILD] Fix Warnings and Enable Warnings as Errors (#794)
This commit is contained in:
@@ -132,14 +132,15 @@ endif()
|
|||||||
# Python module
|
# Python module
|
||||||
if(TRITON_BUILD_PYTHON_MODULE)
|
if(TRITON_BUILD_PYTHON_MODULE)
|
||||||
message(STATUS "Adding Python module")
|
message(STATUS "Adding Python module")
|
||||||
|
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||||
|
include_directories("." ${PYTHON_SRC_PATH})
|
||||||
if (PYTHON_INCLUDE_DIRS)
|
if (PYTHON_INCLUDE_DIRS)
|
||||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS})
|
|
||||||
link_directories(${PYTHON_LINK_DIRS})
|
link_directories(${PYTHON_LINK_DIRS})
|
||||||
|
link_libraries(${PYTHON_LIBRARIES})
|
||||||
else()
|
else()
|
||||||
find_package(Python3 REQUIRED COMPONENTS Development)
|
find_package(Python3 REQUIRED COMPONENTS Development)
|
||||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
include_directories(${Python3_INCLUDE_DIRS})
|
||||||
include_directories("." ${PYTHON_SRC_PATH} ${Python3_INCLUDE_DIRS})
|
|
||||||
link_directories(${Python3_LIBRARY_DIRS})
|
link_directories(${Python3_LIBRARY_DIRS})
|
||||||
link_libraries(${Python3_LIBRARIES})
|
link_libraries(${Python3_LIBRARIES})
|
||||||
add_link_options(${Python3_LINK_OPTIONS})
|
add_link_options(${Python3_LINK_OPTIONS})
|
||||||
@@ -169,7 +170,10 @@ list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
|||||||
include(TableGen) # required by AddMLIR
|
include(TableGen) # required by AddMLIR
|
||||||
include(AddLLVM)
|
include(AddLLVM)
|
||||||
include(AddMLIR)
|
include(AddMLIR)
|
||||||
# include(HandleLLVMOptions) # human-friendly error message
|
include(HandleLLVMOptions) # human-friendly error message
|
||||||
|
|
||||||
|
# Disable warnings that show up in external code (gtest;pybind11)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default")
|
||||||
|
|
||||||
include_directories(${MLIR_INCLUDE_DIRS})
|
include_directories(${MLIR_INCLUDE_DIRS})
|
||||||
include_directories(${LLVM_INCLUDE_DIRS})
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
@@ -192,7 +196,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
|||||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||||
|
|
||||||
target_link_libraries(triton
|
target_link_libraries(triton
|
||||||
${PYTHON_LIBRARIES}
|
|
||||||
TritonAnalysis
|
TritonAnalysis
|
||||||
TritonTransforms
|
TritonTransforms
|
||||||
TritonGPUTransforms
|
TritonGPUTransforms
|
||||||
|
@@ -142,7 +142,7 @@ private:
|
|||||||
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
|
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
|
||||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||||
: kind(kind), size(size), offset(offset), id(nextId++) {}
|
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||||
|
|
||||||
bool intersects(const BufferT &other) const {
|
bool intersects(const BufferT &other) const {
|
||||||
return Interval<size_t>(offset, offset + size)
|
return Interval<size_t>(offset, offset + size)
|
||||||
|
@@ -28,8 +28,8 @@ public:
|
|||||||
DimVectorT knownConstancy)
|
DimVectorT knownConstancy)
|
||||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||||
constancy(knownConstancy), rank(contiguity.size()) {
|
constancy(knownConstancy), rank(contiguity.size()) {
|
||||||
assert(knownDivisibility.size() == rank);
|
assert(knownDivisibility.size() == (size_t)rank);
|
||||||
assert(knownConstancy.size() == rank);
|
assert(knownConstancy.size() == (size_t)rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessors
|
// Accessors
|
||||||
|
@@ -15,9 +15,9 @@ class Location;
|
|||||||
namespace triton {
|
namespace triton {
|
||||||
using llvm::StringRef;
|
using llvm::StringRef;
|
||||||
|
|
||||||
class PTXInstr;
|
struct PTXInstr;
|
||||||
class PTXInstrCommon;
|
struct PTXInstrCommon;
|
||||||
class PTXInstrExecution;
|
struct PTXInstrExecution;
|
||||||
|
|
||||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||||
// instructions.
|
// instructions.
|
||||||
@@ -83,7 +83,7 @@ struct PTXBuilder {
|
|||||||
Operand() = default;
|
Operand() = default;
|
||||||
Operand(const Operation &) = delete;
|
Operand(const Operation &) = delete;
|
||||||
Operand(Value value, StringRef constraint)
|
Operand(Value value, StringRef constraint)
|
||||||
: value(value), constraint(constraint) {}
|
: constraint(constraint), value(value) {}
|
||||||
|
|
||||||
bool isList() const { return !value && constraint.empty(); }
|
bool isList() const { return !value && constraint.empty(); }
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ struct PTXBuilder {
|
|||||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||||
const std::string &constraint) {
|
const std::string &constraint) {
|
||||||
auto *list = newOperand();
|
auto *list = newOperand();
|
||||||
for (int i = 0; i < count; ++i) {
|
for (unsigned i = 0; i < count; ++i) {
|
||||||
list->listAppend(newOperand(val, constraint));
|
list->listAppend(newOperand(val, constraint));
|
||||||
}
|
}
|
||||||
return list;
|
return list;
|
||||||
@@ -128,7 +128,7 @@ struct PTXBuilder {
|
|||||||
|
|
||||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||||
auto *list = newOperand();
|
auto *list = newOperand();
|
||||||
for (int i = 0; i < count; ++i) {
|
for (unsigned i = 0; i < count; ++i) {
|
||||||
list->listAppend(newOperand(constraint));
|
list->listAppend(newOperand(constraint));
|
||||||
}
|
}
|
||||||
return list;
|
return list;
|
||||||
@@ -172,8 +172,8 @@ private:
|
|||||||
return argArchive.back().get();
|
return argArchive.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class PTXInstr;
|
friend struct PTXInstr;
|
||||||
friend class PTXInstrCommon;
|
friend struct PTXInstrCommon;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||||
@@ -209,7 +209,7 @@ protected:
|
|||||||
PTXBuilder *builder{};
|
PTXBuilder *builder{};
|
||||||
llvm::SmallVector<std::string, 4> instrParts;
|
llvm::SmallVector<std::string, 4> instrParts;
|
||||||
|
|
||||||
friend class PTXInstrExecution;
|
friend struct PTXInstrExecution;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||||
@@ -309,7 +309,7 @@ struct PTXInstrExecution {
|
|||||||
PTXInstrExecution() = default;
|
PTXInstrExecution() = default;
|
||||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||||
llvm::ArrayRef<Operand *> oprs)
|
llvm::ArrayRef<Operand *> oprs)
|
||||||
: instr(instr), argsInOrder(oprs.begin(), oprs.end()) {}
|
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
|
||||||
|
|
||||||
// Prefix a predicate to the instruction.
|
// Prefix a predicate to the instruction.
|
||||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||||
|
@@ -11,16 +11,12 @@ class ModuleOp;
|
|||||||
template <typename T> class OperationPass;
|
template <typename T> class OperationPass;
|
||||||
|
|
||||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||||
mlir::LLVMTypeConverter &typeConverter;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
||||||
mlir::LLVMTypeConverter &typeConverter);
|
mlir::LLVMTypeConverter &typeConverter);
|
||||||
};
|
};
|
||||||
|
|
||||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||||
mlir::LLVMTypeConverter &typeConverter;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit TritonLLVMFunctionConversionTarget(
|
explicit TritonLLVMFunctionConversionTarget(
|
||||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
|
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
|
||||||
|
@@ -26,11 +26,11 @@ public:
|
|||||||
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||||
|
|
||||||
virtual LogicalResult
|
virtual LogicalResult
|
||||||
inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||||
Attribute &resultEncoding) const = 0;
|
Attribute &resultEncoding) const = 0;
|
||||||
|
|
||||||
virtual LogicalResult
|
virtual LogicalResult
|
||||||
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||||
Attribute &resultEncoding) const = 0;
|
Attribute &resultEncoding) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -21,7 +21,6 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
class TritonGPUConversionTarget : public ConversionTarget {
|
class TritonGPUConversionTarget : public ConversionTarget {
|
||||||
TritonGPUTypeConverter &typeConverter;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||||
|
@@ -82,7 +82,6 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
|||||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto rank = srcShape.size();
|
|
||||||
auto axis = op.axis();
|
auto axis = op.axis();
|
||||||
|
|
||||||
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||||
|
@@ -66,7 +66,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
|||||||
DimVectorT retContiguity;
|
DimVectorT retContiguity;
|
||||||
DimVectorT retDivisibility;
|
DimVectorT retDivisibility;
|
||||||
DimVectorT retConstancy;
|
DimVectorT retConstancy;
|
||||||
for (size_t d = 0; d < lhs.getRank(); ++d) {
|
for (int d = 0; d < lhs.getRank(); ++d) {
|
||||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||||
retDivisibility.push_back(
|
retDivisibility.push_back(
|
||||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||||
@@ -88,7 +88,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
|||||||
AxisInfo::DimVectorT newContiguity;
|
AxisInfo::DimVectorT newContiguity;
|
||||||
AxisInfo::DimVectorT newDivisibility;
|
AxisInfo::DimVectorT newDivisibility;
|
||||||
AxisInfo::DimVectorT newConstancy;
|
AxisInfo::DimVectorT newConstancy;
|
||||||
for (size_t d = 0; d < rank; ++d) {
|
for (int d = 0; d < rank; ++d) {
|
||||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||||
@@ -167,7 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
AxisInfo::DimVectorT contiguity;
|
AxisInfo::DimVectorT contiguity;
|
||||||
AxisInfo::DimVectorT divisibility;
|
AxisInfo::DimVectorT divisibility;
|
||||||
AxisInfo::DimVectorT constancy;
|
AxisInfo::DimVectorT constancy;
|
||||||
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||||
contiguity.push_back(1);
|
contiguity.push_back(1);
|
||||||
divisibility.push_back(opInfo.getDivisibility(0));
|
divisibility.push_back(opInfo.getDivisibility(0));
|
||||||
constancy.push_back(retTy.getShape()[d]);
|
constancy.push_back(retTy.getShape()[d]);
|
||||||
@@ -176,12 +176,6 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
}
|
}
|
||||||
// expandDims
|
// expandDims
|
||||||
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
|
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||||
Type _retTy = *op->result_type_begin();
|
|
||||||
Type _opTy = *op->operand_type_begin();
|
|
||||||
TensorType retTy = _retTy.cast<TensorType>();
|
|
||||||
TensorType opTy = _opTy.cast<TensorType>();
|
|
||||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
|
||||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
|
||||||
AxisInfo opInfo = operands[0]->getValue();
|
AxisInfo opInfo = operands[0]->getValue();
|
||||||
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
||||||
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
||||||
@@ -203,7 +197,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
AxisInfo::DimVectorT contiguity;
|
AxisInfo::DimVectorT contiguity;
|
||||||
AxisInfo::DimVectorT divisibility;
|
AxisInfo::DimVectorT divisibility;
|
||||||
AxisInfo::DimVectorT constancy;
|
AxisInfo::DimVectorT constancy;
|
||||||
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||||
divisibility.push_back(opInfo.getDivisibility(d));
|
divisibility.push_back(opInfo.getDivisibility(d));
|
||||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||||
|
@@ -28,11 +28,10 @@ bool maybeSharedAllocationOp(Operation *op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state) {
|
std::string getValueOperandName(Value value, AsmState &state) {
|
||||||
auto *op = value.getDefiningOp();
|
|
||||||
std::string opName;
|
std::string opName;
|
||||||
llvm::raw_string_ostream ss(opName);
|
llvm::raw_string_ostream ss(opName);
|
||||||
value.printAsOperand(ss, state);
|
value.printAsOperand(ss, state);
|
||||||
return std::move(opName);
|
return opName;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -254,7 +254,6 @@ protected:
|
|||||||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||||
/// information.
|
/// information.
|
||||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
|
||||||
struct FuncOpConversion : public FuncOpConversionBase {
|
struct FuncOpConversion : public FuncOpConversionBase {
|
||||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||||
PatternBenefit benefit)
|
PatternBenefit benefit)
|
||||||
@@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op->getLoc();
|
|
||||||
unsigned numArguments = op.getNumOperands();
|
unsigned numArguments = op.getNumOperands();
|
||||||
|
|
||||||
// Currently, Triton kernel function always return nothing.
|
// Currently, Triton kernel function always return nothing.
|
||||||
@@ -482,7 +480,6 @@ public:
|
|||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
const BlockedEncodingAttr &blocked_layout,
|
const BlockedEncodingAttr &blocked_layout,
|
||||||
ArrayRef<int64_t> shape) const {
|
ArrayRef<int64_t> shape) const {
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
||||||
Value threadId = getThreadId(rewriter, loc);
|
Value threadId = getThreadId(rewriter, loc);
|
||||||
Value warpSize = idx_val(32);
|
Value warpSize = idx_val(32);
|
||||||
Value laneId = urem(threadId, warpSize);
|
Value laneId = urem(threadId, warpSize);
|
||||||
@@ -654,7 +651,6 @@ public:
|
|||||||
auto bufferId = allocation->getBufferId(value);
|
auto bufferId = allocation->getBufferId(value);
|
||||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||||
size_t offset = allocation->getOffset(bufferId);
|
size_t offset = allocation->getOffset(bufferId);
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
||||||
Value offVal = idx_val(offset);
|
Value offVal = idx_val(offset);
|
||||||
Value base = gep(ptrTy, smem, offVal);
|
Value base = gep(ptrTy, smem, offVal);
|
||||||
return base;
|
return base;
|
||||||
@@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
auto layout = tensorTy.getEncoding();
|
|
||||||
auto srcType = typeConverter->convertType(elemType);
|
auto srcType = typeConverter->convertType(elemType);
|
||||||
auto llSrc = bitcast(srcType, constVal);
|
auto llSrc = bitcast(srcType, constVal);
|
||||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||||
@@ -842,7 +837,6 @@ struct LoadOpConversion
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// original values
|
// original values
|
||||||
@@ -897,12 +891,11 @@ struct LoadOpConversion
|
|||||||
// TODO: optimization when ptr is GEP with constant offset
|
// TODO: optimization when ptr is GEP with constant offset
|
||||||
size_t in_off = 0;
|
size_t in_off = 0;
|
||||||
|
|
||||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||||
const int totalWidth = valueElemNbits * vec;
|
const size_t totalWidth = valueElemNbits * vec;
|
||||||
const int width = std::min(totalWidth, maxWordWidth);
|
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||||
const int nWords = std::max(1, totalWidth / width);
|
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||||
const int wordNElems = width / valueElemNbits;
|
const size_t wordNElems = width / valueElemNbits;
|
||||||
const int vecNElems = totalWidth / valueElemNbits;
|
|
||||||
assert(wordNElems * nWords * numVecs == numElems);
|
assert(wordNElems * nWords * numVecs == numElems);
|
||||||
|
|
||||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||||
@@ -921,7 +914,7 @@ struct LoadOpConversion
|
|||||||
|
|
||||||
// prepare asm operands
|
// prepare asm operands
|
||||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||||
dstsOpr->listAppend(opr);
|
dstsOpr->listAppend(opr);
|
||||||
}
|
}
|
||||||
@@ -988,8 +981,8 @@ struct LoadOpConversion
|
|||||||
: retTys[0];
|
: retTys[0];
|
||||||
|
|
||||||
// TODO: if (has_l2_evict_policy)
|
// TODO: if (has_l2_evict_policy)
|
||||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||||
LLVM::AsmDialect::AD_ATT);
|
// LLVM::AsmDialect::AD_ATT);
|
||||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||||
|
|
||||||
// ---
|
// ---
|
||||||
@@ -1080,27 +1073,25 @@ struct StoreOpConversion
|
|||||||
// TODO: optimization when ptr is AddPtr with constant offset
|
// TODO: optimization when ptr is AddPtr with constant offset
|
||||||
size_t in_off = 0;
|
size_t in_off = 0;
|
||||||
|
|
||||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||||
const int totalWidth = valueElemNbits * vec;
|
const size_t totalWidth = valueElemNbits * vec;
|
||||||
const int width = std::min(totalWidth, maxWordWidth);
|
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||||
const int nWords = std::max(1, totalWidth / width);
|
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||||
const int wordNElems = width / valueElemNbits;
|
const size_t wordNElems = width / valueElemNbits;
|
||||||
const int vecNElems = totalWidth / valueElemNbits;
|
|
||||||
assert(wordNElems * nWords * numVecs == numElems);
|
assert(wordNElems * nWords * numVecs == numElems);
|
||||||
|
|
||||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||||
// TODO(Superjomn) Deal with cache policy here.
|
// TODO(Superjomn) Deal with cache policy here.
|
||||||
const bool hasL2EvictPolicy = false;
|
|
||||||
|
|
||||||
Type valArgTy = IntegerType::get(ctx, width);
|
Type valArgTy = IntegerType::get(ctx, width);
|
||||||
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
||||||
|
|
||||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||||
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||||
// llWord is a width-len composition
|
// llWord is a width-len composition
|
||||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||||
// Insert each value element to the composition
|
// Insert each value element to the composition
|
||||||
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||||
assert(elemOffset < valueElems.size());
|
assert(elemOffset < valueElems.size());
|
||||||
Value elem = valueElems[elemOffset];
|
Value elem = valueElems[elemOffset];
|
||||||
@@ -1220,7 +1211,6 @@ struct BroadcastOpConversion
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned srcElems = getElemsPerThread(srcTy);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto elemTy = resultTy.getElementType();
|
|
||||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||||
unsigned resultElems = getElemsPerThread(resultTy);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
@@ -1282,8 +1272,6 @@ private:
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
||||||
auto rank = srcTy.getShape().size();
|
|
||||||
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
||||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||||
@@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
|||||||
|
|
||||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value val, int i) const {
|
Location loc, Value val, int i) const {
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
|
||||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
|
|
||||||
if (bits == 64) {
|
if (bits == 64) {
|
||||||
@@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
for (int i = 0; i < resultElems; i++) {
|
for (size_t i = 0; i < resultElems; i++) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||||
@@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto srcOrder = srcLayout.getOrder();
|
|
||||||
|
|
||||||
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||||
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||||
@@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
for (int i = 0; i < resultElems; i++) {
|
for (size_t i = 0; i < resultElems; i++) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||||
@@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
|||||||
// due to MLIR's restrictions
|
// due to MLIR's restrictions
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||||
auto resultShape = resultTy.getShape();
|
|
||||||
unsigned elems = getElemsPerThread(resultTy);
|
unsigned elems = getElemsPerThread(resultTy);
|
||||||
Type elemTy =
|
Type elemTy =
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
@@ -1698,7 +1683,6 @@ struct AddPtrOpConversion
|
|||||||
auto resultLayout =
|
auto resultLayout =
|
||||||
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||||
auto resultShape = resultTensorTy.getShape();
|
|
||||||
unsigned elems = getElemsPerThread(resultTy);
|
unsigned elems = getElemsPerThread(resultTy);
|
||||||
Type elemTy =
|
Type elemTy =
|
||||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||||
@@ -1821,7 +1805,7 @@ protected:
|
|||||||
SmallVector<SmallVector<Value>> operands(elems);
|
SmallVector<SmallVector<Value>> operands(elems);
|
||||||
for (auto operand : adaptor.getOperands()) {
|
for (auto operand : adaptor.getOperands()) {
|
||||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||||
for (int i = 0; i < elems; ++i) {
|
for (size_t i = 0; i < elems; ++i) {
|
||||||
operands[i].push_back(sub_operands[i]);
|
operands[i].push_back(sub_operands[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1931,6 +1915,7 @@ struct CmpFOpConversion
|
|||||||
__PRED_ENUM(ORD, ord);
|
__PRED_ENUM(ORD, ord);
|
||||||
__PRED_ENUM(UEQ, ueq);
|
__PRED_ENUM(UEQ, ueq);
|
||||||
__PRED_ENUM(UGT, ugt);
|
__PRED_ENUM(UGT, ugt);
|
||||||
|
__PRED_ENUM(UGE, uge);
|
||||||
__PRED_ENUM(ULT, ult);
|
__PRED_ENUM(ULT, ult);
|
||||||
__PRED_ENUM(ULE, ule);
|
__PRED_ENUM(ULE, ule);
|
||||||
__PRED_ENUM(UNE, une);
|
__PRED_ENUM(UNE, une);
|
||||||
@@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
auto rank = type.getRank();
|
auto rank = type.getRank();
|
||||||
auto sizePerThread = getSizePerThread(layout);
|
auto sizePerThread = getSizePerThread(layout);
|
||||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
|
||||||
SmallVector<unsigned> numCTAs(rank);
|
SmallVector<unsigned> numCTAs(rank);
|
||||||
auto shapePerCTA = getShapePerCTA(layout);
|
auto shapePerCTA = getShapePerCTA(layout);
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
@@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||||
loc, rewriter, blockedLayout, type.getShape());
|
loc, rewriter, blockedLayout, type.getShape());
|
||||||
} else if (sliceLayout) {
|
} else if (sliceLayout) {
|
||||||
unsigned dim = sliceLayout.getDim();
|
|
||||||
auto parent = sliceLayout.getParent();
|
auto parent = sliceLayout.getParent();
|
||||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
SmallVector<int64_t> paddedShape =
|
SmallVector<int64_t> paddedShape =
|
||||||
@@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
}
|
}
|
||||||
// Potentially we need to store for multiple CTAs in this replication
|
// Potentially we need to store for multiple CTAs in this replication
|
||||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||||
unsigned elems = getElemsPerThread(srcTy);
|
// unsigned elems = getElemsPerThread(srcTy);
|
||||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||||
unsigned inVec = 0;
|
unsigned inVec = 0;
|
||||||
unsigned outVec = 0;
|
unsigned outVec = 0;
|
||||||
@@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
// Data loader for mma.16816 instruction.
|
// Data loader for mma.16816 instruction.
|
||||||
class MMA16816SmemLoader {
|
class MMA16816SmemLoader {
|
||||||
public:
|
public:
|
||||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
|
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||||
TypeConverter *typeConverter, const Location &loc)
|
TypeConverter *typeConverter, const Location &loc)
|
||||||
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
|
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||||
tileShape(tileShape.begin(), tileShape.end()),
|
tileShape(tileShape.begin(), tileShape.end()),
|
||||||
instrShape(instrShape.begin(), instrShape.end()),
|
instrShape(instrShape.begin(), instrShape.end()),
|
||||||
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
||||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
|
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
|
||||||
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
|
ctx(rewriter.getContext()) {
|
||||||
cMatShape = matShape[order[0]];
|
cMatShape = matShape[order[0]];
|
||||||
sMatShape = matShape[order[1]];
|
sMatShape = matShape[order[1]];
|
||||||
|
|
||||||
@@ -2576,7 +2559,6 @@ public:
|
|||||||
assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
|
assert(mat0 % 2 == 0 && mat1 % 2 == 0 &&
|
||||||
"smem matrix load must be aligned");
|
"smem matrix load must be aligned");
|
||||||
int matIdx[2] = {mat0, mat1};
|
int matIdx[2] = {mat0, mat1};
|
||||||
int k = matIdx[kOrder];
|
|
||||||
|
|
||||||
int ptrIdx{-1};
|
int ptrIdx{-1};
|
||||||
|
|
||||||
@@ -2596,7 +2578,6 @@ public:
|
|||||||
|
|
||||||
Value ptr = getPtr(ptrIdx);
|
Value ptr = getPtr(ptrIdx);
|
||||||
|
|
||||||
Value resV4;
|
|
||||||
if (canUseLdmatrix) {
|
if (canUseLdmatrix) {
|
||||||
int sOffset =
|
int sOffset =
|
||||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||||
@@ -2727,7 +2708,6 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int wpt;
|
|
||||||
SmallVector<uint32_t> order;
|
SmallVector<uint32_t> order;
|
||||||
int kOrder;
|
int kOrder;
|
||||||
SmallVector<int64_t> tileShape;
|
SmallVector<int64_t> tileShape;
|
||||||
@@ -2737,7 +2717,6 @@ private:
|
|||||||
int maxPhase;
|
int maxPhase;
|
||||||
int elemBytes;
|
int elemBytes;
|
||||||
ConversionPatternRewriter &rewriter;
|
ConversionPatternRewriter &rewriter;
|
||||||
TypeConverter *typeConverter{};
|
|
||||||
const Location &loc;
|
const Location &loc;
|
||||||
MLIRContext *ctx{};
|
MLIRContext *ctx{};
|
||||||
|
|
||||||
@@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op->getLoc();
|
|
||||||
// D = A * B + C
|
// D = A * B + C
|
||||||
Value A = op.a();
|
Value A = op.a();
|
||||||
Value B = op.b();
|
|
||||||
Value C = op.c();
|
|
||||||
Value D = op.getResult();
|
Value D = op.getResult();
|
||||||
MLIRContext *ctx = op->getContext();
|
|
||||||
bool allowTF32 = op.allowTF32();
|
|
||||||
|
|
||||||
// Here we assume the DotOp's operands always comes from shared memory.
|
// Here we assume the DotOp's operands always comes from shared memory.
|
||||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||||
@@ -2951,8 +2925,6 @@ struct DotOpConversionHelper {
|
|||||||
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
||||||
Type i8x4Pack4Ty =
|
Type i8x4Pack4Ty =
|
||||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
||||||
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
|
|
||||||
ctx, SmallVector<Type>(4, type::i32Ty(ctx)));
|
|
||||||
|
|
||||||
switch (mmaType) {
|
switch (mmaType) {
|
||||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||||
@@ -3062,7 +3034,6 @@ struct DotOpConversionHelper {
|
|||||||
auto bTy = B.getType().cast<RankedTensorType>();
|
auto bTy = B.getType().cast<RankedTensorType>();
|
||||||
// d = a*b + c
|
// d = a*b + c
|
||||||
auto dTy = op.d().getType().cast<RankedTensorType>();
|
auto dTy = op.d().getType().cast<RankedTensorType>();
|
||||||
auto mmaLayout = dTy.getEncoding().cast<MmaEncodingAttr>();
|
|
||||||
|
|
||||||
if (dTy.getElementType().isF32()) {
|
if (dTy.getElementType().isF32()) {
|
||||||
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
|
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
|
||||||
@@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper {
|
|||||||
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
TypeConverter *typeConverter, Location loc)
|
TypeConverter *typeConverter, Location loc)
|
||||||
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
|
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
||||||
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
|
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
||||||
thread(thread) {
|
ctx(mmaLayout.getContext()) {
|
||||||
wpt = mmaLayout.getWarpsPerCTA();
|
wpt = mmaLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
Value _32 = i32_val(32);
|
Value _32 = i32_val(32);
|
||||||
@@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// step1. Perform loading.
|
// step1. Perform loading.
|
||||||
for (unsigned m = 0; m < numRepM; ++m)
|
for (int m = 0; m < numRepM; ++m)
|
||||||
for (unsigned k = 0; k < numRepK; ++k)
|
for (int k = 0; k < numRepK; ++k)
|
||||||
loadFn(2 * m, 2 * k);
|
loadFn(2 * m, 2 * k);
|
||||||
|
|
||||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||||
@@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper {
|
|||||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||||
|
|
||||||
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||||
for (unsigned k = 0; k < numRepK; ++k)
|
for (int k = 0; k < numRepK; ++k)
|
||||||
loadFn(2 * n, 2 * k);
|
loadFn(2 * n, 2 * k);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper {
|
|||||||
helper.deduceMmaType(op);
|
helper.deduceMmaType(op);
|
||||||
|
|
||||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
|
||||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
|
||||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto aShape = aTensorTy.getShape();
|
auto aShape = aTensorTy.getShape();
|
||||||
auto dShape = dTensorTy.getShape();
|
auto dShape = dTensorTy.getShape();
|
||||||
|
|
||||||
int NK = aShape[1];
|
|
||||||
// shape / shape_per_cta
|
// shape / shape_per_cta
|
||||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
|
||||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
|
||||||
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
||||||
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
||||||
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
||||||
@@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper {
|
|||||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||||
};
|
};
|
||||||
|
|
||||||
for (unsigned k = 0; k < numRepK; ++k)
|
for (int k = 0; k < numRepK; ++k)
|
||||||
for (unsigned m = 0; m < numRepM; ++m)
|
for (int m = 0; m < numRepM; ++m)
|
||||||
for (unsigned n = 0; n < numRepN; ++n)
|
for (int n = 0; n < numRepN; ++n)
|
||||||
callMma(2 * m, n, 2 * k);
|
callMma(2 * m, n, 2 * k);
|
||||||
|
|
||||||
// replace with new packed result
|
// replace with new packed result
|
||||||
@@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper {
|
|||||||
private:
|
private:
|
||||||
std::function<void(int, int)>
|
std::function<void(int, int)>
|
||||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||||
int wpt, int kOrder, ArrayRef<int> instrShape,
|
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
|
||||||
ArrayRef<int> matShape, Value warpId,
|
ArrayRef<int> matShape, Value warpId,
|
||||||
ValueTable &vals) const {
|
ValueTable &vals) const {
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
@@ -3486,8 +3452,8 @@ private:
|
|||||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||||
int n1) const {
|
int n1) const {
|
||||||
std::vector<Value> elems;
|
std::vector<Value> elems;
|
||||||
for (unsigned m = 0; m < n0; ++m)
|
for (int m = 0; m < n0; ++m)
|
||||||
for (unsigned k = 0; k < n1; ++k) {
|
for (int k = 0; k < n1; ++k) {
|
||||||
elems.push_back(vals.at({2 * m, 2 * k}));
|
elems.push_back(vals.at({2 * m, 2 * k}));
|
||||||
elems.push_back(vals.at({2 * m, 2 * k + 1}));
|
elems.push_back(vals.at({2 * m, 2 * k + 1}));
|
||||||
elems.push_back(vals.at({2 * m + 1, 2 * k}));
|
elems.push_back(vals.at({2 * m + 1, 2 * k}));
|
||||||
@@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
|||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value src = op.src();
|
Value src = op.src();
|
||||||
Value dst = op.result();
|
Value dst = op.result();
|
||||||
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
|
||||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
||||||
auto dotOperandLayout =
|
auto dotOperandLayout =
|
||||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
MmaEncodingAttr mmaLayout =
|
MmaEncodingAttr mmaLayout =
|
||||||
@@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion
|
|||||||
auto ctx = op.getContext();
|
auto ctx = op.getContext();
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto voidTy = void_ty(ctx);
|
auto voidTy = void_ty(ctx);
|
||||||
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
|
ptxBuilder.launch(rewriter, loc, voidTy);
|
||||||
|
|
||||||
// Safe to remove the op since it doesn't have any return value.
|
// Safe to remove the op since it doesn't have any return value.
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
@@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion
|
|||||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||||
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
|
|
||||||
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
|
|
||||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||||
|
|
||||||
auto inOrder = srcBlockedLayout.getOrder();
|
auto inOrder = srcBlockedLayout.getOrder();
|
||||||
auto outOrder = resSharedLayout.getOrder();
|
|
||||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
||||||
// swizzle is not allowd
|
// swizzle is not allowd
|
||||||
@@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion
|
|||||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||||
|
|
||||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||||
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto wordElemIdx = wordIdx * numWordElems;
|
auto wordElemIdx = wordIdx * numWordElems;
|
||||||
auto ©AsyncOp =
|
auto ©AsyncOp =
|
||||||
@@ -4208,7 +4170,7 @@ namespace mlir {
|
|||||||
|
|
||||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
: ConversionTarget(ctx) {
|
||||||
addLegalDialect<LLVM::LLVMDialect>();
|
addLegalDialect<LLVM::LLVMDialect>();
|
||||||
addLegalDialect<NVVM::NVVMDialect>();
|
addLegalDialect<NVVM::NVVMDialect>();
|
||||||
// addIllegalDialect<triton::TritonDialect>();
|
// addIllegalDialect<triton::TritonDialect>();
|
||||||
@@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
|||||||
|
|
||||||
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
||||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
: ConversionTarget(ctx) {
|
||||||
addLegalDialect<LLVM::LLVMDialect>();
|
addLegalDialect<LLVM::LLVMDialect>();
|
||||||
// addLegalDialect<NVVM::NVVMDialect>();
|
// addLegalDialect<NVVM::NVVMDialect>();
|
||||||
addIllegalOp<mlir::FuncOp>();
|
addIllegalOp<mlir::FuncOp>();
|
||||||
|
@@ -21,9 +21,7 @@ public:
|
|||||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
Op res =
|
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -37,9 +35,8 @@ public:
|
|||||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
DstOp res =
|
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
adaptor.getLhs(), adaptor.getRhs());
|
||||||
adaptor.getLhs(), adaptor.getRhs());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -129,10 +126,9 @@ public:
|
|||||||
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
triton::gpu::SelectOp res =
|
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
adaptor.getFalseValue());
|
||||||
adaptor.getFalseValue());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -204,9 +200,6 @@ struct TritonExpandDimsPattern
|
|||||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||||
retThreadsPerWarp, retWarpsPerCTA,
|
retThreadsPerWarp, retWarpsPerCTA,
|
||||||
retOrder);
|
retOrder);
|
||||||
// return type
|
|
||||||
RankedTensorType retType =
|
|
||||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
|
||||||
// convert operand to slice of return type
|
// convert operand to slice of return type
|
||||||
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||||
getContext(), op.axis(), retEncoding);
|
getContext(), op.axis(), retEncoding);
|
||||||
@@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|||||||
bType.getElementType(), encoding);
|
bType.getElementType(), encoding);
|
||||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||||
}
|
}
|
||||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||||
adaptor.transB());
|
adaptor.transB());
|
||||||
return success();
|
return success();
|
||||||
@@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||||
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@@ -49,7 +49,6 @@ unsigned getElemsPerThread(Type type) {
|
|||||||
auto tensorType = type.cast<RankedTensorType>();
|
auto tensorType = type.cast<RankedTensorType>();
|
||||||
auto layout = tensorType.getEncoding();
|
auto layout = tensorType.getEncoding();
|
||||||
auto shape = tensorType.getShape();
|
auto shape = tensorType.getShape();
|
||||||
size_t rank = shape.size();
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
return blockedLayout.getElemsPerThread(shape);
|
return blockedLayout.getElemsPerThread(shape);
|
||||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
@@ -109,7 +108,7 @@ SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
|||||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||||
SmallVector<unsigned> shape;
|
SmallVector<unsigned> shape;
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||||
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
||||||
blockedLayout.getThreadsPerWarp()[d] *
|
blockedLayout.getThreadsPerWarp()[d] *
|
||||||
blockedLayout.getWarpsPerCTA()[d]);
|
blockedLayout.getWarpsPerCTA()[d]);
|
||||||
@@ -117,7 +116,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
|||||||
unsigned dim = sliceLayout.getDim();
|
unsigned dim = sliceLayout.getDim();
|
||||||
auto parent = sliceLayout.getParent();
|
auto parent = sliceLayout.getParent();
|
||||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||||
if (d == dim)
|
if (d == dim)
|
||||||
continue;
|
continue;
|
||||||
shape.push_back(blockedParent.getSizePerThread()[d] *
|
shape.push_back(blockedParent.getSizePerThread()[d] *
|
||||||
@@ -258,7 +257,6 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
|||||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
auto parent = getParent();
|
auto parent = getParent();
|
||||||
unsigned dim = getDim();
|
|
||||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||||
@@ -512,11 +510,11 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
|||||||
auto encoding = srcType.getEncoding();
|
auto encoding = srcType.getEncoding();
|
||||||
auto srcShape = srcType.getShape();
|
auto srcShape = srcType.getShape();
|
||||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||||
if (axis < 0 || axis > srcShape.size())
|
if (axis < 0 || (size_t)axis > srcShape.size())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<int64_t, 4> dstShape;
|
SmallVector<int64_t, 4> dstShape;
|
||||||
for (int i = 0; i < srcShape.size(); i++)
|
for (size_t i = 0; i < srcShape.size(); i++)
|
||||||
if (i != axis)
|
if (i != (size_t)axis)
|
||||||
dstShape.push_back(srcShape[i]);
|
dstShape.push_back(srcShape[i]);
|
||||||
auto returnType =
|
auto returnType =
|
||||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||||
@@ -578,15 +576,17 @@ struct TritonGPUInferLayoutInterface
|
|||||||
: public triton::DialectInferLayoutInterface {
|
: public triton::DialectInferLayoutInterface {
|
||||||
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
||||||
|
|
||||||
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
LogicalResult
|
||||||
Attribute &resultEncoding) const {
|
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||||
|
Attribute &resultEncoding) const override {
|
||||||
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
||||||
operandEncoding);
|
operandEncoding);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
LogicalResult
|
||||||
Attribute &resultEncoding) const {
|
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||||
|
Attribute &resultEncoding) const override {
|
||||||
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||||
if (!sliceEncoding) {
|
if (!sliceEncoding) {
|
||||||
llvm::report_fatal_error(
|
llvm::report_fatal_error(
|
||||||
|
@@ -87,7 +87,6 @@ public:
|
|||||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
|
||||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||||
// we don't handle conversions to DotOperandEncodingAttr
|
// we don't handle conversions to DotOperandEncodingAttr
|
||||||
// this is a heuristics to accomodate fused attention
|
// this is a heuristics to accomodate fused attention
|
||||||
@@ -219,10 +218,10 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
|||||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||||
if (typeInfer) {
|
if (typeInfer) {
|
||||||
SmallVector<Type, 1> newType;
|
SmallVector<Type, 1> newType;
|
||||||
auto sucess = typeInfer.inferReturnTypes(
|
auto success = typeInfer.inferReturnTypes(
|
||||||
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||||
newOp->getAttrDictionary(), newOp->getRegions(), newType);
|
newOp->getAttrDictionary(), newOp->getRegions(), newType);
|
||||||
if (success)
|
if (succeeded(success))
|
||||||
newOp->getResult(0).setType(newType.front());
|
newOp->getResult(0).setType(newType.front());
|
||||||
}
|
}
|
||||||
return newOp;
|
return newOp;
|
||||||
@@ -364,10 +363,6 @@ public:
|
|||||||
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
|
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
|
||||||
size_t i, RankedTensorType newType,
|
size_t i, RankedTensorType newType,
|
||||||
triton::gpu::ConvertLayoutOp origConversion) const {
|
triton::gpu::ConvertLayoutOp origConversion) const {
|
||||||
|
|
||||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
|
||||||
auto ctx = forOp.getContext();
|
|
||||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
|
||||||
// Rewrite init argument
|
// Rewrite init argument
|
||||||
Type origType = forOp.getInitArgs()[i].getType();
|
Type origType = forOp.getInitArgs()[i].getType();
|
||||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||||
@@ -418,11 +413,10 @@ public:
|
|||||||
return newResults;
|
return newResults;
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
mlir::LogicalResult
|
||||||
mlir::PatternRewriter &rewriter) const {
|
matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
|
||||||
auto iterArgs = forOp.getRegionIterArgs();
|
auto iterArgs = forOp.getRegionIterArgs();
|
||||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||||
// if (iterArg.index() != 1)
|
// if (iterArg.index() != 1)
|
||||||
@@ -480,7 +474,6 @@ public:
|
|||||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||||
if (!forOp)
|
if (!forOp)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
||||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||||
|
|
||||||
SetVector<Operation *> cvtSlices;
|
SetVector<Operation *> cvtSlices;
|
||||||
|
@@ -17,11 +17,6 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LoopPipeliner {
|
class LoopPipeliner {
|
||||||
/// comments on numStages:
|
|
||||||
/// [0, numStages-1) are in the prologue
|
|
||||||
/// numStages-1 is appended after the loop body
|
|
||||||
int numStages;
|
|
||||||
|
|
||||||
/// cache forOp we are working on
|
/// cache forOp we are working on
|
||||||
scf::ForOp forOp;
|
scf::ForOp forOp;
|
||||||
|
|
||||||
@@ -43,6 +38,11 @@ class LoopPipeliner {
|
|||||||
///
|
///
|
||||||
Value loopIterIdx;
|
Value loopIterIdx;
|
||||||
|
|
||||||
|
/// comments on numStages:
|
||||||
|
/// [0, numStages-1) are in the prologue
|
||||||
|
/// numStages-1 is appended after the loop body
|
||||||
|
int numStages;
|
||||||
|
|
||||||
/// value (in loop) => value at stage N
|
/// value (in loop) => value at stage N
|
||||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||||
|
|
||||||
@@ -58,9 +58,6 @@ class LoopPipeliner {
|
|||||||
|
|
||||||
Value lookupOrDefault(Value origin, int stage);
|
Value lookupOrDefault(Value origin, int stage);
|
||||||
|
|
||||||
/// return true if this op uses any of `loads`
|
|
||||||
bool isDirectUserOfAsyncLoad(Operation &op);
|
|
||||||
|
|
||||||
/// returns a empty buffer of size <numStages, ...>
|
/// returns a empty buffer of size <numStages, ...>
|
||||||
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
|
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
|
||||||
OpBuilder &builder);
|
OpBuilder &builder);
|
||||||
@@ -84,7 +81,7 @@ public:
|
|||||||
/// create the new ForOp (add new args & insert prefetched ops)
|
/// create the new ForOp (add new args & insert prefetched ops)
|
||||||
scf::ForOp createNewForOp();
|
scf::ForOp createNewForOp();
|
||||||
|
|
||||||
friend class PipelinePass;
|
friend struct PipelinePass;
|
||||||
};
|
};
|
||||||
|
|
||||||
// helpers
|
// helpers
|
||||||
@@ -123,19 +120,6 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
|
|
||||||
for (Value loadOp : loads) {
|
|
||||||
assert(loadOp.hasOneUse() &&
|
|
||||||
"load should only have one use (ConvertLayout)");
|
|
||||||
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
|
|
||||||
for (Value opOperand : op.getOperands()) {
|
|
||||||
if (opOperand == loadUseResult)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
triton::gpu::AllocTensorOp
|
triton::gpu::AllocTensorOp
|
||||||
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
|
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
|
||||||
// allocate a buffer for each pipelined tensor
|
// allocate a buffer for each pipelined tensor
|
||||||
@@ -356,8 +340,8 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||||
|
|
||||||
// async.wait & extract_slice
|
// async.wait & extract_slice
|
||||||
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
||||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
loads.size() * (numStages - 2));
|
||||||
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||||
for (Value loadOp : loads) {
|
for (Value loadOp : loads) {
|
||||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||||
@@ -380,8 +364,7 @@ void LoopPipeliner::emitEpilogue() {
|
|||||||
OpBuilder builder(forOp);
|
OpBuilder builder(forOp);
|
||||||
OpBuilder::InsertionGuard g(builder);
|
OpBuilder::InsertionGuard g(builder);
|
||||||
builder.setInsertionPointAfter(forOp);
|
builder.setInsertionPointAfter(forOp);
|
||||||
Operation *asyncWait =
|
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||||
@@ -575,8 +558,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
yieldValues.push_back(loopIterIdx);
|
yieldValues.push_back(loopIterIdx);
|
||||||
|
|
||||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||||
auto test = builder.create<scf::YieldOp>(
|
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||||
forOp.getBody()->getTerminator()->getLoc(), yieldValues);
|
yieldValues);
|
||||||
return newForOp;
|
return newForOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -30,7 +30,7 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
|||||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||||
perPhase = std::max<int>(perPhase, 1);
|
perPhase = std::max<int>(perPhase, 1);
|
||||||
// index of the inner dimension in `order`
|
// index of the inner dimension in `order`
|
||||||
int inner = (opIdx == 0) ? 0 : 1;
|
size_t inner = (opIdx == 0) ? 0 : 1;
|
||||||
if (version == 1) {
|
if (version == 1) {
|
||||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||||
// TODO: handle rep (see
|
// TODO: handle rep (see
|
||||||
@@ -67,7 +67,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
|||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
Operation *op = getOperation();
|
Operation *op = getOperation();
|
||||||
MLIRContext *context = &getContext();
|
|
||||||
op->walk([&](triton::DotOp dotOp) -> void {
|
op->walk([&](triton::DotOp dotOp) -> void {
|
||||||
OpBuilder builder(dotOp);
|
OpBuilder builder(dotOp);
|
||||||
auto _retEncoding =
|
auto _retEncoding =
|
||||||
|
@@ -73,7 +73,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
//
|
//
|
||||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
: ConversionTarget(context) {
|
||||||
// TODO: we should also verify ops of TritonGPUDialect
|
// TODO: we should also verify ops of TritonGPUDialect
|
||||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// We have requirements for the data layouts
|
// We have requirements for the data layouts
|
||||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||||
Attribute aEncoding =
|
Attribute aEncoding =
|
||||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||||
Attribute bEncoding =
|
Attribute bEncoding =
|
||||||
|
@@ -63,7 +63,7 @@ static bool find_and_replace(std::string &str, const std::string &begin,
|
|||||||
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
||||||
// LLVM version in use may not officially support target hardware
|
// LLVM version in use may not officially support target hardware
|
||||||
int max_nvvm_cc = 75;
|
int max_nvvm_cc = 75;
|
||||||
int max_nvvm_ptx = 74;
|
// int max_nvvm_ptx = 74;
|
||||||
// options
|
// options
|
||||||
auto options = llvm::cl::getRegisteredOptions();
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
auto *short_ptr =
|
auto *short_ptr =
|
||||||
|
@@ -1,11 +1,12 @@
|
|||||||
import distutils
|
import distutils
|
||||||
import distutils.spawn
|
import itertools
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import sysconfig
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
@@ -16,6 +17,74 @@ from setuptools import Extension, setup
|
|||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
|
# Logic from https://github.com/Kitware/CMake/blob/master/Modules/FindPythonLibs.cmake
|
||||||
|
# Code from https://stackoverflow.com/questions/47423246
|
||||||
|
def get_python_library():
|
||||||
|
"""Get path to the python library associated with the current python
|
||||||
|
interpreter."""
|
||||||
|
# determine direct path to libpython
|
||||||
|
python_version = sysconfig.get_python_version()
|
||||||
|
python_library = sysconfig.get_config_var('LIBRARY')
|
||||||
|
|
||||||
|
# if static (or nonexistent), try to find a suitable dynamic libpython
|
||||||
|
if (python_library is None or os.path.splitext(python_library)[1][-2:] == '.a'):
|
||||||
|
|
||||||
|
candidate_lib_prefixes = ['', 'lib']
|
||||||
|
|
||||||
|
candidate_extensions = ['.lib', '.so', '.a']
|
||||||
|
if sysconfig.get_config_var('WITH_DYLD'):
|
||||||
|
candidate_extensions.insert(0, '.dylib')
|
||||||
|
|
||||||
|
candidate_versions = [python_version]
|
||||||
|
if python_version:
|
||||||
|
candidate_versions.append('')
|
||||||
|
candidate_versions.insert(
|
||||||
|
0, "".join(python_version.split(".")[:2]))
|
||||||
|
|
||||||
|
abiflags = getattr(sys, 'abiflags', '')
|
||||||
|
candidate_abiflags = [abiflags]
|
||||||
|
if abiflags:
|
||||||
|
candidate_abiflags.append('')
|
||||||
|
|
||||||
|
# Ensure the value injected by virtualenv is
|
||||||
|
# returned on windows.
|
||||||
|
# Because calling `sysconfig.get_config_var('multiarchsubdir')`
|
||||||
|
# returns an empty string on Linux, `du_sysconfig` is only used to
|
||||||
|
# get the value of `LIBDIR`.
|
||||||
|
libdir = distutils.sysconfig.get_config_var('LIBDIR')
|
||||||
|
if sysconfig.get_config_var('MULTIARCH'):
|
||||||
|
masd = sysconfig.get_config_var('multiarchsubdir')
|
||||||
|
if masd:
|
||||||
|
if masd.startswith(os.sep):
|
||||||
|
masd = masd[len(os.sep):]
|
||||||
|
libdir = os.path.join(libdir, masd)
|
||||||
|
|
||||||
|
if libdir is None:
|
||||||
|
libdir = os.path.abspath(os.path.join(
|
||||||
|
sysconfig.get_config_var('LIBDEST'), "..", "libs"))
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
os.path.join(
|
||||||
|
libdir,
|
||||||
|
''.join((pre, 'python', ver, abi, ext))
|
||||||
|
)
|
||||||
|
for (pre, ext, ver, abi) in itertools.product(
|
||||||
|
candidate_lib_prefixes,
|
||||||
|
candidate_extensions,
|
||||||
|
candidate_versions,
|
||||||
|
candidate_abiflags
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if os.path.exists(candidate):
|
||||||
|
# we found a (likely alternate) libpython
|
||||||
|
python_library = candidate
|
||||||
|
break
|
||||||
|
|
||||||
|
return python_library
|
||||||
|
|
||||||
|
|
||||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||||
def check_env_flag(name: str, default: str = "") -> bool:
|
def check_env_flag(name: str, default: str = "") -> bool:
|
||||||
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
||||||
@@ -136,14 +205,19 @@ class CMakeBuild(build_ext):
|
|||||||
if not os.path.exists(llvm_build_dir):
|
if not os.path.exists(llvm_build_dir):
|
||||||
os.makedirs(llvm_build_dir)
|
os.makedirs(llvm_build_dir)
|
||||||
# python directories
|
# python directories
|
||||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
python_include_dir = distutils.sysconfig.get_python_inc()
|
||||||
|
python_link_dir = distutils.sysconfig.get_python_lib()
|
||||||
|
python_library = get_python_library()
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
|
"-DLLVM_ENABLE_WERROR=ON",
|
||||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||||
"-DTRITON_BUILD_TUTORIALS=OFF",
|
"-DTRITON_BUILD_TUTORIALS=OFF",
|
||||||
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
||||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs),
|
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
||||||
|
"-DPYTHON_LINK_DIRS=" + python_link_dir,
|
||||||
|
"-DPYTHON_LIBRARIES=" + python_library,
|
||||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
||||||
] + thirdparty_cmake_args
|
] + thirdparty_cmake_args
|
||||||
|
|
||||||
|
@@ -26,6 +26,7 @@
|
|||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
#include "llvm/IRReader/IRReader.h"
|
#include "llvm/IRReader/IRReader.h"
|
||||||
|
#include "llvm/Support/FileUtilities.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
@@ -1301,39 +1302,36 @@ void init_triton_translation(py::module &m) {
|
|||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
|
|
||||||
// compile ptx with ptxas
|
// compile ptx with ptxas
|
||||||
char _fsrc[L_tmpnam];
|
llvm::SmallString<64> fsrc;
|
||||||
char _flog[L_tmpnam];
|
llvm::SmallString<64> flog;
|
||||||
std::tmpnam(_fsrc);
|
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||||
std::tmpnam(_flog);
|
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||||
std::string fsrc = _fsrc;
|
std::string fbin = std::string(fsrc) + ".o";
|
||||||
std::string flog = _flog;
|
llvm::FileRemover srcRemover(fsrc);
|
||||||
std::string fbin = fsrc + ".o";
|
llvm::FileRemover logRemover(flog);
|
||||||
|
llvm::FileRemover binRemover(fbin);
|
||||||
|
const char *_fsrc = fsrc.c_str();
|
||||||
|
const char *_flog = flog.c_str();
|
||||||
const char *_fbin = fbin.c_str();
|
const char *_fbin = fbin.c_str();
|
||||||
std::ofstream ofs(fsrc);
|
std::ofstream ofs(_fsrc);
|
||||||
ofs << ptxCode << std::endl;
|
ofs << ptxCode << std::endl;
|
||||||
ofs.close();
|
ofs.close();
|
||||||
std::string cmd;
|
std::string cmd;
|
||||||
int err;
|
int err;
|
||||||
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
|
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
|
||||||
" " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
" " + _fsrc + " -o " + _fsrc + ".o 2> " + _flog;
|
||||||
err = system(cmd.c_str());
|
err = system(cmd.c_str());
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
std::ifstream _log(_flog);
|
std::ifstream _log(_flog);
|
||||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||||
unlink(_fsrc);
|
|
||||||
unlink(_flog);
|
|
||||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||||
log);
|
log);
|
||||||
}
|
}
|
||||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||||
_cubin.close();
|
_cubin.close();
|
||||||
unlink(_fsrc);
|
|
||||||
unlink(_flog);
|
|
||||||
unlink(_fbin);
|
|
||||||
|
|
||||||
py::bytes bytes(cubin);
|
py::bytes bytes(cubin);
|
||||||
return bytes;
|
return std::move(bytes);
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("add_external_libs",
|
m.def("add_external_libs",
|
||||||
@@ -1345,8 +1343,8 @@ void init_triton_translation(py::module &m) {
|
|||||||
|
|
||||||
void init_triton(py::module &m) {
|
void init_triton(py::module &m) {
|
||||||
py::module subm = m.def_submodule("triton");
|
py::module subm = m.def_submodule("triton");
|
||||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
// init_triton_codegen(subm.def_submodule("code_gen"));
|
||||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
init_triton_runtime(subm.def_submodule("runtime"));
|
||||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
init_triton_ir(subm.def_submodule("ir"));
|
||||||
init_triton_translation(subm);
|
init_triton_translation(subm);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user