[CI] run clang-format (#24)
This commit is contained in:
@@ -8,24 +8,23 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y){
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y) {
|
||||
// Base Case
|
||||
if (a == 0) {
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
int gcd = gcd_impl(b % a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*x = y1 - (b / a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
@@ -35,17 +34,17 @@ static int gcd(int a, int b) {
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
size_t rank = 1;
|
||||
if(TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
if (TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
rank = ty.getRank();
|
||||
int divHint = 1;
|
||||
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
|
||||
Operation* op = blockArg.getOwner()->getParentOp();
|
||||
if(FuncOp fun = dyn_cast<FuncOp>(op)){
|
||||
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if(attr)
|
||||
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
||||
Attribute attr =
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
@@ -55,51 +54,51 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) {
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
ContiguityT retContiguity;
|
||||
DivisibilityT retDivisibility;
|
||||
ConstancyT retConstancy;
|
||||
for(size_t d = 0; d < lhs.getRank(); d++){
|
||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retDivisibility.push_back(
|
||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||
}
|
||||
return AxisInfo(retContiguity, retDivisibility, retConstancy);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfoAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for(size_t d = 0; d < rank; d++){
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo curr;
|
||||
// This preserves the input axes (e.g., cast):
|
||||
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||
triton::PtrToIntOp, triton::IntToPtrOp>(op))
|
||||
curr = operands[0]->getValue();
|
||||
// Constant ranges
|
||||
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
|
||||
if (triton::MakeRangeOp make_range =
|
||||
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::ContiguityT contiguity = {end - start};
|
||||
@@ -108,61 +107,59 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Constant
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
|
||||
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
|
||||
if(intAttr){
|
||||
if (intAttr) {
|
||||
size_t val = intAttr.getValue().getZExtValue();
|
||||
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if(splatAttr && splatAttr.getElementType().isInteger(32)){
|
||||
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
|
||||
curr = AxisInfo(
|
||||
AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
};
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Multiplication
|
||||
if (llvm::isa<arith::MulIOp>(op)){
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return 1;
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
if (llvm::isa<arith::MulIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Splat
|
||||
if (llvm::isa<triton::SplatOp>(op)){
|
||||
if (llvm::isa<triton::SplatOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
@@ -171,7 +168,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
}
|
||||
// Reshape
|
||||
// TODO: Replace by `unsqueeze`
|
||||
if (llvm::isa<triton::ReshapeOp>(op)){
|
||||
if (llvm::isa<triton::ReshapeOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -184,20 +181,17 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
AxisInfo::ConstancyT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
if(retShape[d] == 1){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
if (retShape[d] == 1) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
}
|
||||
else if(!is_skewed
|
||||
&& retShape[d] == opShape[current]){
|
||||
} else if (!is_skewed && retShape[d] == opShape[current]) {
|
||||
contiguity.push_back(opInfo.getContiguity()[current]);
|
||||
divisibility.push_back(opInfo.getDivisibility()[current]);
|
||||
constancy.push_back(opInfo.getConstancy()[current]);
|
||||
current++;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
is_skewed = true;
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
@@ -207,7 +201,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Broadcast
|
||||
if (llvm::isa<triton::BroadcastOp>(op)){
|
||||
if (llvm::isa<triton::BroadcastOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -218,14 +212,14 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
if(curr.getRank() == 0){
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
@@ -236,4 +230,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace mlir
|
@@ -2,14 +2,16 @@
|
||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir{
|
||||
namespace triton{
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
@@ -1,42 +1,42 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
template<class Op>
|
||||
class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
Op res =
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template<class SrcOp, class DstOp>
|
||||
template <class SrcOp, class DstOp>
|
||||
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||
public:
|
||||
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
|
||||
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||
);
|
||||
DstOp res =
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -45,36 +45,40 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
assert(value);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
op, retType,
|
||||
value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertArithmeticOp: public ConversionPattern {
|
||||
class ConvertArithmeticOp : public ConversionPattern {
|
||||
public:
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
Dialect* dialect = op->getDialect();
|
||||
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Dialect *dialect = op->getDialect();
|
||||
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateArithmeticPatternsAndLegality(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target){
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arithmetic dialect. The basic premise is that
|
||||
@@ -91,59 +95,49 @@ void populateArithmeticPatternsAndLegality(
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<ArithConstantPattern,
|
||||
ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>,
|
||||
ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>,
|
||||
ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>,
|
||||
ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>,
|
||||
ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>,
|
||||
ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>,
|
||||
ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>,
|
||||
ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>,
|
||||
ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>,
|
||||
ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>,
|
||||
ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>
|
||||
>(typeConverter, context);
|
||||
patterns.add<
|
||||
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// Triton patterns
|
||||
//
|
||||
// TODO: Do we need to put them in anonymous namespace?
|
||||
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
|
||||
struct TritonMakeRangePattern
|
||||
: public OpConversionPattern<triton::MakeRangeOp> {
|
||||
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, retType, adaptor.start(), adaptor.end()
|
||||
);
|
||||
op, retType, adaptor.start(), adaptor.end());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -151,8 +145,9 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
||||
@@ -165,18 +160,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||
);
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -184,14 +182,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, retType,
|
||||
adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
|
||||
);
|
||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -199,11 +196,11 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
||||
);
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -212,12 +209,11 @@ template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -225,30 +221,25 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
|
||||
);
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::SplatOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>,
|
||||
TritonReducePattern,
|
||||
TritonMakeRangePattern,
|
||||
TritonDotPattern,
|
||||
TritonLoadPattern,
|
||||
TritonStorePattern
|
||||
>(typeConverter, context);
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -259,17 +250,19 @@ void populateTritonPatterns(
|
||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||
// Ref: ConvertForOpTypes
|
||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp =
|
||||
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||
newOp.getLoopBody().end());
|
||||
|
||||
// Now, update all the types.
|
||||
|
||||
// Convert the types of block arguments within the given region. This
|
||||
// replaces each block with a new block containing the updated signature. The
|
||||
// entry block may have a special conversion if `entryConversion` is
|
||||
// replaces each block with a new block containing the updated signature.
|
||||
// The entry block may have a special conversion if `entryConversion` is
|
||||
// provided. On success, the new entry block to the region is returned for
|
||||
// convenience. Otherwise, failure is returned.
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||
@@ -299,33 +292,27 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||
// op.erase();
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||
op, adaptor.getOperands()
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern
|
||||
>(typeConverter, context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
class ConvertTritonToTritonGPU
|
||||
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
public:
|
||||
ConvertTritonToTritonGPU(int numWarps) {
|
||||
this->numWarps = numWarps;
|
||||
}
|
||||
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -339,21 +326,21 @@ public:
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
// TODO: can we use
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
if(failed(target.refineLayouts(mod, numWarps)))
|
||||
if (failed(target.refineLayouts(mod, numWarps)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||
|
@@ -7,7 +7,6 @@
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -19,12 +18,13 @@ void TritonDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
}
|
||||
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||||
Type type, Location loc) {
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
}
|
@@ -13,14 +13,16 @@ namespace triton {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
@@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
@@ -48,50 +50,48 @@ namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
|
||||
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
// other
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
::mlir::Value other = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
resultType,
|
||||
DenseElementsAttr::get(
|
||||
resultType, builder.getZeroAttr(elementType)
|
||||
)
|
||||
);
|
||||
ptr.getLoc(), resultType,
|
||||
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addOperands(other);
|
||||
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(
|
||||
evictAttrName(state.name),
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
||||
|
||||
using namespace mlir;
|
||||
@@ -16,7 +16,7 @@ void TritonDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
}
|
||||
|
||||
Type PointerType::parse(AsmParser &parser) {
|
||||
|
@@ -17,21 +17,23 @@ namespace {
|
||||
class CombineDotOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineDotOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
|
||||
if (isCandidate(op->getOperand(0)).succeeded()) {
|
||||
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
||||
op->getOperand(1), dotOp.allowTF32());
|
||||
return mlir::success();
|
||||
} else if (isCandidate(op->getOperand(1)).succeeded()) {
|
||||
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
||||
op->getOperand(0), dotOp.allowTF32());
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -54,7 +56,7 @@ private:
|
||||
return true;
|
||||
// broadcast(constant_0)
|
||||
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
|
||||
return true;
|
||||
}
|
||||
@@ -68,18 +70,19 @@ private:
|
||||
class CombineGEPOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineGEPOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::triton::GEPOp>(op)) {
|
||||
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
|
||||
auto loc = op->getLoc();
|
||||
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
|
||||
loc, op->getOperand(1), gep2->getOperand(1));
|
||||
loc, op->getOperand(1), gep2->getOperand(1));
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
|
||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
|
||||
);
|
||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -92,20 +95,21 @@ public:
|
||||
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::SelectOp>(op)) {
|
||||
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
|
||||
mlir::Value cond = op->getOperand(0);
|
||||
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
|
||||
op, op->getResultTypes().front(),
|
||||
load.ptr(), load.mask(), op->getOperand(2),
|
||||
load.cache(), load.evict(), load.isVolatile()
|
||||
);
|
||||
op, op->getResultTypes().front(), load.ptr(), load.mask(),
|
||||
op->getOperand(2), load.cache(), load.evict(),
|
||||
load.isVolatile());
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -120,11 +124,11 @@ public:
|
||||
class CombineBroadcastConstantOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineBroadcastConstantOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
|
||||
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
|
||||
Attribute value = cst.getValue();
|
||||
@@ -132,15 +136,14 @@ public:
|
||||
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||
if (!denseValue.isSplat())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
|
||||
value = DenseElementsAttr::get(resType,
|
||||
denseValue.getSplatValue<Attribute>());
|
||||
} else {
|
||||
if (!value.isa<FloatAttr, IntegerAttr>())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, value, resType
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@@ -11,19 +11,18 @@ using namespace mlir::triton::gpu;
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
/*SmallVector<unsigned, 2>*/auto &res,
|
||||
StringRef desc) {
|
||||
/*SmallVector<unsigned, 2>*/ auto &res,
|
||||
StringRef desc) {
|
||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
<< desc;
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
|
||||
return failure();
|
||||
}
|
||||
for (Attribute i : arrayAttr) {
|
||||
auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
<< desc;
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
res.push_back(intAttr.getUInt());
|
||||
@@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
|
||||
SmallVector<unsigned, 2> threadTileSize;
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
@@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "blockTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "broadcastAxis") {
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed())
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
@@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
|
||||
threadTileSize,
|
||||
warpTileSize,
|
||||
blockTileSize,
|
||||
order,
|
||||
broadcastAxis);
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
@@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
@@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
@@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpPerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "repetitions") {
|
||||
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "contigPerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
fragmentPerWarp,
|
||||
shapePerWarp,
|
||||
warpPerTile,
|
||||
shapePerTile,
|
||||
repetitions,
|
||||
contigPerThread,
|
||||
broadcastAxis);
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(
|
||||
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
|
||||
shapePerTile, repetitions, contigPerThread, broadcastAxis);
|
||||
}
|
||||
|
||||
static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
@@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
@@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
@@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned maxPhase = 0;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr,
|
||||
unsigned &value,
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
|
||||
StringRef desc) -> LogicalResult {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
@@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
|
||||
vec,
|
||||
perPhase,
|
||||
maxPhase,
|
||||
order);
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(
|
||||
parser.getContext(), vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "vec = " << getVec()
|
||||
<< ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase()
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
||||
<< "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
public:
|
||||
public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
@@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
static void printMma(const auto &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||
@@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
}
|
||||
|
||||
@@ -349,7 +340,8 @@ namespace triton {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
@@ -368,8 +360,8 @@ static Type getPointeeType(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
@@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) {
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
||||
|
||||
// verify TritonGPU ops
|
||||
LogicalResult
|
||||
TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
|
@@ -27,8 +27,8 @@ namespace {
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
@@ -6,12 +6,11 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
|
||||
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
|
||||
// and SCF's LoopPipelining.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -41,14 +40,15 @@ class LoopPipeliner {
|
||||
/// Block arguments that loads depend on
|
||||
DenseSet<BlockArgument> depArgs;
|
||||
/// Operations (inside the loop body) that loads depend on
|
||||
DenseSet<Operation*> depOps;
|
||||
DenseSet<Operation *> depOps;
|
||||
|
||||
/// collect values that v depends on and are defined inside the loop
|
||||
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
|
||||
|
||||
void setValueMapping(Value origin, Value newValue, int stage);
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
: forOp(forOp), numStages(numStages) {
|
||||
// cache yieldOp
|
||||
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
@@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
deps.insert(v);
|
||||
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
||||
} else { // value
|
||||
// v might be in deps, but we still need to visit v.
|
||||
// This is because v might depends on value in previous iterations
|
||||
@@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
}
|
||||
|
||||
// for (triton::LoadOp loadOp : allLoads) {
|
||||
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
|
||||
// for (Value dep : loadDeps[loadOp])
|
||||
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
|
||||
// " values\n"; for (Value dep : loadDeps[loadOp])
|
||||
// llvm::errs() << dep << "\n";
|
||||
// llvm::errs() << "\n";
|
||||
// }
|
||||
@@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
||||
isCandiate = false;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
if (auto convertLayout =
|
||||
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding()
|
||||
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
@@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
|
||||
// we have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// update depArgs & depOps
|
||||
@@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
// special handling for loop condition as there is no condition in ForOp
|
||||
Value loopCond = builder.create<arith::CmpIOp>(
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
|
||||
// rematerialize peeled values
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
@@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
// TODO: check if the hardware supports copyasync
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(),
|
||||
loadOp.ptr(), loadOp.mask(), loadOp.other(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
|
||||
);
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
|
||||
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
|
||||
loadOp.isVolatile());
|
||||
} else
|
||||
llvm_unreachable("This should be LoadOp");
|
||||
} else
|
||||
@@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() {
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(newOp);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||
mask.getType(),
|
||||
loopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||
mask,
|
||||
splatCond);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
Value newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
newOp->setOperand(1, newMask);
|
||||
}
|
||||
|
||||
@@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
// update mapping for loop-carried values (args)
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx))
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
setValueMapping(
|
||||
forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs) {
|
||||
depArgsIdx[depArg] = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
|
||||
}
|
||||
|
||||
size_t nextIVIdx = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
|
||||
|
||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||
assert(newLoopArgs[i]);
|
||||
|
||||
// 1. signature of the new ForOp
|
||||
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
||||
forOp.getLowerBound(),
|
||||
forOp.getUpperBound(),
|
||||
forOp.getStep(),
|
||||
newLoopArgs);
|
||||
auto newForOp = builder.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newLoopArgs);
|
||||
|
||||
// 2. body of the new ForOp
|
||||
builder.setInsertionPointToStart(newForOp.getBody());
|
||||
@@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
|
||||
assert(load.hasOneUse() &&
|
||||
"we assume that this load has one use (ConvertLayout)");
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
|
||||
}
|
||||
|
||||
|
||||
// 4. prefetch the next iteration
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
@@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
DenseMap<BlockArgument, Value> depArgsMapping;
|
||||
size_t argIdx = 0;
|
||||
for (BlockArgument arg : depArgs) {
|
||||
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
nextMapping.map(arg,
|
||||
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
++argIdx;
|
||||
}
|
||||
// special handling for iv & loop condition
|
||||
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx],
|
||||
newForOp.getStep());
|
||||
Value nextLoopCond = builder.create<arith::CmpIOp>(
|
||||
nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
Value nextIV = builder.create<arith::AddIOp>(
|
||||
newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
||||
Value nextLoopCond =
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||
mask.getType(),
|
||||
nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||
splatCond,
|
||||
nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than once
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
|
||||
);
|
||||
}
|
||||
else
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
|
||||
// update mapping of results
|
||||
@@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
for (int stage = 1; stage < numStages - 1; ++stage) {
|
||||
yieldValues.push_back(newForOp.getRegionIterArgs()[
|
||||
loadIdx + idx*(numStages-1) + stage
|
||||
]);
|
||||
yieldValues.push_back(
|
||||
newForOp
|
||||
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
|
||||
}
|
||||
yieldValues.push_back(nextMapping.lookup(load));
|
||||
}
|
||||
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
||||
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(
|
||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(nextIV);
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
@@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
PipelinePass() = default;
|
||||
PipelinePass(int numStages) {
|
||||
this->numStages = numStages;
|
||||
}
|
||||
PipelinePass(int numStages) { this->numStages = numStages; }
|
||||
|
||||
void runOnOperation() override {
|
||||
int numStages = this->numStages;
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
|
||||
//
|
||||
// TypeConverter
|
||||
//
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numThreads)
|
||||
: context(context), numThreads(numThreads) {
|
||||
// TODO: how does MLIR pick the right conversion?
|
||||
@@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// or assert no encoding?
|
||||
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
llvm::SmallVector<unsigned> broadcastAxis;
|
||||
int remainingThreads = numThreads;
|
||||
int remainingLanes = /*warp size*/32;
|
||||
int remainingLanes = /*warp size*/ 32;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
|
||||
@@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
|
||||
@@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
//
|
||||
// This will be called when (newArgType != origArgType)
|
||||
// This will create newArg, and map(origArg, newArg)
|
||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
addArgumentMaterialization([&](OpBuilder &builder,
|
||||
RankedTensorType tensorType, ValueRange inputs,
|
||||
Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// If the origValue still has live user(s), use this to
|
||||
// convert origValue to newValue
|
||||
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// where, desiredType = typeConverter->convertType(origType)
|
||||
// NOTE: only for remapped values.
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// TritonGPUConversion
|
||||
//
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
// TODO: we should also verify ops of TritonGPUDialect
|
||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||
triton::TritonDialect,
|
||||
StandardOpsDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
||||
scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
|
||||
StandardOpsDialect, scf::SCFDialect>(
|
||||
[&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
Attribute aEncoding =
|
||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding =
|
||||
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding &&
|
||||
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// // TODO: we should delete this
|
||||
@@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// %dst = tt.broadcast %src
|
||||
@@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// %bcst = tt.broadcast %newSrc
|
||||
// %dst = convert_layout %bcst
|
||||
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
int numThreads) {
|
||||
int numThreads) {
|
||||
// collect broadcasts
|
||||
SmallVector<triton::BroadcastOp> broadcasts;
|
||||
mod.walk([&](triton::BroadcastOp op) {
|
||||
broadcasts.push_back(op);
|
||||
});
|
||||
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto broadcast : broadcasts) {
|
||||
@@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
broadcastAxis.push_back(ax);
|
||||
|
||||
Attribute originSrcEnc = tensorType.getEncoding();
|
||||
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
if (auto blockedEnc =
|
||||
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||
blockedEnc.getContext(),
|
||||
blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(),
|
||||
blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(), broadcastAxis);
|
||||
newSrcType = RankedTensorType::get(
|
||||
tensorType.getShape(),
|
||||
tensorType.getElementType(),
|
||||
newSrcEnc
|
||||
);
|
||||
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
|
||||
} else
|
||||
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||
} else {
|
||||
@@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
|
||||
// create new src
|
||||
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
src.getLoc(), newSrcType, src
|
||||
);
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
|
||||
newSrcType, src);
|
||||
|
||||
// create new broadcast
|
||||
// compute new type (encoding)
|
||||
auto originDstEnc = originDstTensorType.getEncoding()
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||
originDstEnc.getContext(),
|
||||
originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(),
|
||||
originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
auto newType = RankedTensorType::get(
|
||||
originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(),
|
||||
newEnc
|
||||
);
|
||||
Value newBroadcast = builder.create<triton::BroadcastOp>(
|
||||
broadcast.getLoc(), newType, src
|
||||
);
|
||||
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(), broadcastAxis);
|
||||
auto newType =
|
||||
RankedTensorType::get(originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(), newEnc);
|
||||
Value newBroadcast =
|
||||
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
|
||||
// we don't want to change the encoding of the result
|
||||
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
broadcast.getLoc(), originDstType, newBroadcast
|
||||
);
|
||||
broadcast.getLoc(), originDstType, newBroadcast);
|
||||
|
||||
broadcast.replaceAllUsesWith(newDst);
|
||||
mapping.map(broadcast, newDst);
|
||||
|
@@ -5,7 +5,6 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -37,28 +36,30 @@ private:
|
||||
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of shared layout";
|
||||
} else
|
||||
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
Attribute cLayout;
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
Type type = std::get<0>(it);
|
||||
char name = std::get<1>(it);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of distributed layout";
|
||||
return dotOp.emitError()
|
||||
<< name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
cLayout = encoding;
|
||||
else if (encoding != cLayout)
|
||||
return dotOp.emitError() << "d & c should have the same layout";
|
||||
} else
|
||||
return dotOp.emitError() << name
|
||||
<< "'s type should be of RankedTensorType";
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
// signalPassFailure();
|
||||
@@ -89,7 +90,7 @@ private:
|
||||
}
|
||||
|
||||
void verifyImpl(Operation *op) {
|
||||
if(verifySingleOp(op).failed())
|
||||
if (verifySingleOp(op).failed())
|
||||
signalPassFailure();
|
||||
|
||||
// verify that all child regions are ok
|
||||
|
408
lib/driver/dispatch.cc
Executable file → Normal file
408
lib/driver/dispatch.cc
Executable file → Normal file
@@ -1,107 +1,152 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/driver/dispatch.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
//Helpers for function definition
|
||||
#define DEFINE0(init, hlib, ret, fname) ret dispatch::fname()\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname); }\
|
||||
void* dispatch::fname ## _;
|
||||
// Helpers for function definition
|
||||
#define DEFINE0(init, hlib, ret, fname) \
|
||||
ret dispatch::fname() { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE1(init, hlib, ret, fname, t1) ret dispatch::fname(t1 a)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE1(init, hlib, ret, fname, t1) \
|
||||
ret dispatch::fname(t1 a) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE2(init, hlib, ret, fname, t1, t2) ret dispatch::fname(t1 a, t2 b)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE2(init, hlib, ret, fname, t1, t2) \
|
||||
ret dispatch::fname(t1 a, t2 b) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) ret dispatch::fname(t1 a, t2 b, t3 c)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10, t11) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j, t11 k) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j, k); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m); }\
|
||||
void* dispatch::fname ## _;
|
||||
|
||||
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }\
|
||||
void* dispatch::fname ## _;
|
||||
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10, t11, t12, t13) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j, t11 k, t12 l, t13 m) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j, k, l, m); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
|
||||
t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) \
|
||||
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
|
||||
t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, \
|
||||
t18 r, t19 s) { \
|
||||
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
|
||||
e, f, g, h, i, j, k, l, m, n, o, p, q, r, \
|
||||
s); \
|
||||
} \
|
||||
void *dispatch::fname##_;
|
||||
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
|
||||
bool dispatch::cuinit(){
|
||||
if(cuda_==nullptr){
|
||||
#ifdef _WIN32
|
||||
bool dispatch::cuinit() {
|
||||
if (cuda_ == nullptr) {
|
||||
#ifdef _WIN32
|
||||
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
|
||||
#else
|
||||
#else
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
if(!cuda_)
|
||||
if (!cuda_)
|
||||
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
|
||||
#endif
|
||||
if(!cuda_)
|
||||
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
|
||||
#endif
|
||||
if (!cuda_)
|
||||
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is "
|
||||
"in your LD_LIBRARY_PATH.");
|
||||
}
|
||||
if(cuda_ == nullptr)
|
||||
if (cuda_ == nullptr)
|
||||
return false;
|
||||
CUresult (*fptr)(unsigned int);
|
||||
cuInit_ = dlsym(cuda_, "cuInit");
|
||||
@@ -112,21 +157,33 @@ bool dispatch::cuinit(){
|
||||
}
|
||||
|
||||
#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
|
||||
#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
|
||||
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
|
||||
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
|
||||
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
|
||||
#define CUDA_DEFINE2(ret, fname, t1, t2) \
|
||||
DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
|
||||
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) \
|
||||
DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
|
||||
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) \
|
||||
DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
|
||||
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
|
||||
DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
|
||||
DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
|
||||
DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
|
||||
DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
|
||||
DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
|
||||
DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
|
||||
t11) \
|
||||
DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
|
||||
t11)
|
||||
|
||||
// context management
|
||||
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
|
||||
CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
|
||||
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice *)
|
||||
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
|
||||
CUDA_DEFINE1(CUresult, cuInit, unsigned int)
|
||||
CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
|
||||
@@ -134,59 +191,71 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
|
||||
CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute,
|
||||
CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *)
|
||||
|
||||
// link management
|
||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void *,
|
||||
size_t, const char *, unsigned int, CUjit_option *, void **);
|
||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option *, void **,
|
||||
CUlinkState *);
|
||||
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
|
||||
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*);
|
||||
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void **, size_t *);
|
||||
// module management
|
||||
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*)
|
||||
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr *, size_t *, CUmodule,
|
||||
const char *)
|
||||
CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
|
||||
CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
|
||||
CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
|
||||
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **)
|
||||
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
|
||||
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *,
|
||||
unsigned int, CUjit_option *, void **)
|
||||
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule,
|
||||
const char *)
|
||||
// stream management
|
||||
CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int)
|
||||
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
|
||||
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
|
||||
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
|
||||
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **)
|
||||
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext *)
|
||||
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, CUstream, void **, void **)
|
||||
// function management
|
||||
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
|
||||
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
|
||||
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int *, CUfunction_attribute,
|
||||
CUfunction)
|
||||
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute,
|
||||
int)
|
||||
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
|
||||
// memory management
|
||||
CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t)
|
||||
CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, CUstream)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
|
||||
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t )
|
||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
|
||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
||||
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t,
|
||||
CUstream)
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t,
|
||||
CUstream)
|
||||
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t)
|
||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr *, size_t)
|
||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void *, CUpointer_attribute,
|
||||
CUdeviceptr)
|
||||
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t,
|
||||
CUstream)
|
||||
// event management
|
||||
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
|
||||
CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent)
|
||||
CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream)
|
||||
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
||||
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
bool dispatch::nvmlinit(){
|
||||
#ifdef _WIN32
|
||||
if(nvml_==nullptr)
|
||||
bool dispatch::nvmlinit() {
|
||||
#ifdef _WIN32
|
||||
if (nvml_ == nullptr)
|
||||
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
|
||||
#else
|
||||
if(nvml_==nullptr)
|
||||
#else
|
||||
if (nvml_ == nullptr)
|
||||
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
|
||||
#endif
|
||||
#endif
|
||||
nvmlReturn_t (*fptr)();
|
||||
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
|
||||
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
|
||||
@@ -197,21 +266,27 @@ bool dispatch::nvmlinit(){
|
||||
|
||||
#define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname)
|
||||
#define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1)
|
||||
#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
|
||||
#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
|
||||
#define NVML_DEFINE2(ret, fname, t1, t2) \
|
||||
DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
|
||||
#define NVML_DEFINE3(ret, fname, t1, t2, t3) \
|
||||
DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
|
||||
|
||||
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
|
||||
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *,
|
||||
nvmlDevice_t *)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t,
|
||||
nvmlClockType_t, unsigned int *)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t,
|
||||
nvmlClockType_t, unsigned int *)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t,
|
||||
unsigned int, unsigned int)
|
||||
|
||||
/* ------------------- *
|
||||
* HIP
|
||||
* ------------------- */
|
||||
bool dispatch::hipinit(){
|
||||
if(hip_==nullptr)
|
||||
bool dispatch::hipinit() {
|
||||
if (hip_ == nullptr)
|
||||
hip_ = dlopen("libamdhip64.so", RTLD_LAZY);
|
||||
if(hip_ == nullptr)
|
||||
if (hip_ == nullptr)
|
||||
return false;
|
||||
hipError_t (*fptr)();
|
||||
hipInit_ = dlsym(hip_, "hipInit");
|
||||
@@ -222,23 +297,34 @@ bool dispatch::hipinit(){
|
||||
}
|
||||
|
||||
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
|
||||
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
|
||||
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
|
||||
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
|
||||
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
|
||||
#define HIP_DEFINE2(ret, fname, t1, t2) \
|
||||
DEFINE2(hipinit, hip_, ret, fname, t1, t2)
|
||||
#define HIP_DEFINE3(ret, fname, t1, t2, t3) \
|
||||
DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
|
||||
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) \
|
||||
DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
|
||||
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
|
||||
DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
|
||||
DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
|
||||
DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
|
||||
DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
|
||||
DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
|
||||
DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
|
||||
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \
|
||||
DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
|
||||
t11)
|
||||
|
||||
// context management
|
||||
HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t)
|
||||
HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
|
||||
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t*)
|
||||
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t *)
|
||||
HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t)
|
||||
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t*)
|
||||
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t *)
|
||||
HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int)
|
||||
HIP_DEFINE1(hipError_t, hipInit, unsigned int)
|
||||
HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
|
||||
@@ -246,56 +332,64 @@ HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
|
||||
HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, hipDevice_t)
|
||||
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t,
|
||||
hipDevice_t)
|
||||
HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *)
|
||||
// module management
|
||||
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t*, size_t*, hipModule_t, const char*)
|
||||
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t *, size_t *,
|
||||
hipModule_t, const char *)
|
||||
HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *)
|
||||
HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t)
|
||||
HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *)
|
||||
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, hipJitOption *, void **)
|
||||
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *)
|
||||
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *,
|
||||
unsigned int, hipJitOption *, void **)
|
||||
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t,
|
||||
const char *)
|
||||
// stream management
|
||||
HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int)
|
||||
HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t)
|
||||
HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t)
|
||||
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **)
|
||||
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, hipStream_t, void **, void **)
|
||||
// function management
|
||||
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes*, void*)
|
||||
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes *, void *)
|
||||
HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t)
|
||||
// memory management
|
||||
HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t)
|
||||
HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, hipStream_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, size_t, hipStream_t)
|
||||
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t )
|
||||
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t*, size_t)
|
||||
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void*, CUpointer_attribute, hipDeviceptr_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, hipStream_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t,
|
||||
hipStream_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *,
|
||||
size_t, hipStream_t)
|
||||
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t)
|
||||
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t *, size_t)
|
||||
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void *, CUpointer_attribute,
|
||||
hipDeviceptr_t)
|
||||
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t,
|
||||
hipStream_t)
|
||||
// event management
|
||||
HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
|
||||
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
|
||||
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
|
||||
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* COMMON
|
||||
* ------------------- */
|
||||
|
||||
// Release
|
||||
void dispatch::release(){
|
||||
if(cuda_){
|
||||
void dispatch::release() {
|
||||
if (cuda_) {
|
||||
dlclose(cuda_);
|
||||
cuda_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void* dispatch::cuda_;
|
||||
void* dispatch::nvml_;
|
||||
void* dispatch::nvmlInit_v2_;
|
||||
void* dispatch::hip_;
|
||||
void *dispatch::cuda_;
|
||||
void *dispatch::nvml_;
|
||||
void *dispatch::nvmlInit_v2_;
|
||||
void *dispatch::hip_;
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
410
lib/driver/error.cc
Executable file → Normal file
410
lib/driver/error.cc
Executable file → Normal file
@@ -1,166 +1,270 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/driver/error.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
void check(CUresult err)
|
||||
{
|
||||
void check(CUresult err) {
|
||||
using namespace exception::cuda;
|
||||
switch(err)
|
||||
{
|
||||
case CUDA_SUCCESS : break;
|
||||
case CUDA_ERROR_INVALID_VALUE : throw invalid_value();
|
||||
case CUDA_ERROR_OUT_OF_MEMORY : throw out_of_memory();
|
||||
case CUDA_ERROR_NOT_INITIALIZED : throw not_initialized();
|
||||
case CUDA_ERROR_DEINITIALIZED : throw deinitialized();
|
||||
case CUDA_ERROR_PROFILER_DISABLED : throw profiler_disabled();
|
||||
case CUDA_ERROR_PROFILER_NOT_INITIALIZED : throw profiler_not_initialized();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STARTED : throw profiler_already_started();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STOPPED : throw profiler_already_stopped();
|
||||
case CUDA_ERROR_NO_DEVICE : throw no_device();
|
||||
case CUDA_ERROR_INVALID_DEVICE : throw invalid_device();
|
||||
case CUDA_ERROR_INVALID_IMAGE : throw invalid_image();
|
||||
case CUDA_ERROR_INVALID_CONTEXT : throw invalid_context();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT : throw context_already_current();
|
||||
case CUDA_ERROR_MAP_FAILED : throw map_failed();
|
||||
case CUDA_ERROR_UNMAP_FAILED : throw unmap_failed();
|
||||
case CUDA_ERROR_ARRAY_IS_MAPPED : throw array_is_mapped();
|
||||
case CUDA_ERROR_ALREADY_MAPPED : throw already_mapped();
|
||||
case CUDA_ERROR_NO_BINARY_FOR_GPU : throw no_binary_for_gpu();
|
||||
case CUDA_ERROR_ALREADY_ACQUIRED : throw already_acquired();
|
||||
case CUDA_ERROR_NOT_MAPPED : throw not_mapped();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY : throw not_mapped_as_array();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_POINTER : throw not_mapped_as_pointer();
|
||||
case CUDA_ERROR_ECC_UNCORRECTABLE : throw ecc_uncorrectable();
|
||||
case CUDA_ERROR_UNSUPPORTED_LIMIT : throw unsupported_limit();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE : throw context_already_in_use();
|
||||
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED : throw peer_access_unsupported();
|
||||
case CUDA_ERROR_INVALID_PTX : throw invalid_ptx();
|
||||
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT : throw invalid_graphics_context();
|
||||
case CUDA_ERROR_INVALID_SOURCE : throw invalid_source();
|
||||
case CUDA_ERROR_FILE_NOT_FOUND : throw file_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND : throw shared_object_symbol_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED : throw shared_object_init_failed();
|
||||
case CUDA_ERROR_OPERATING_SYSTEM : throw operating_system();
|
||||
case CUDA_ERROR_INVALID_HANDLE : throw invalid_handle();
|
||||
case CUDA_ERROR_NOT_FOUND : throw not_found();
|
||||
case CUDA_ERROR_NOT_READY : throw not_ready();
|
||||
case CUDA_ERROR_ILLEGAL_ADDRESS : throw illegal_address();
|
||||
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES : throw launch_out_of_resources();
|
||||
case CUDA_ERROR_LAUNCH_TIMEOUT : throw launch_timeout();
|
||||
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING : throw launch_incompatible_texturing();
|
||||
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED : throw peer_access_already_enabled();
|
||||
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED : throw peer_access_not_enabled();
|
||||
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE : throw primary_context_active();
|
||||
case CUDA_ERROR_CONTEXT_IS_DESTROYED : throw context_is_destroyed();
|
||||
case CUDA_ERROR_ASSERT : throw assert_error();
|
||||
case CUDA_ERROR_TOO_MANY_PEERS : throw too_many_peers();
|
||||
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED : throw host_memory_already_registered();
|
||||
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED : throw host_memory_not_registered();
|
||||
case CUDA_ERROR_HARDWARE_STACK_ERROR : throw hardware_stack_error();
|
||||
case CUDA_ERROR_ILLEGAL_INSTRUCTION : throw illegal_instruction();
|
||||
case CUDA_ERROR_MISALIGNED_ADDRESS : throw misaligned_address();
|
||||
case CUDA_ERROR_INVALID_ADDRESS_SPACE : throw invalid_address_space();
|
||||
case CUDA_ERROR_INVALID_PC : throw invalid_pc();
|
||||
case CUDA_ERROR_LAUNCH_FAILED : throw launch_failed();
|
||||
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
|
||||
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
|
||||
case CUDA_ERROR_UNKNOWN : throw unknown();
|
||||
default : throw unknown();
|
||||
switch (err) {
|
||||
case CUDA_SUCCESS:
|
||||
break;
|
||||
case CUDA_ERROR_INVALID_VALUE:
|
||||
throw invalid_value();
|
||||
case CUDA_ERROR_OUT_OF_MEMORY:
|
||||
throw out_of_memory();
|
||||
case CUDA_ERROR_NOT_INITIALIZED:
|
||||
throw not_initialized();
|
||||
case CUDA_ERROR_DEINITIALIZED:
|
||||
throw deinitialized();
|
||||
case CUDA_ERROR_PROFILER_DISABLED:
|
||||
throw profiler_disabled();
|
||||
case CUDA_ERROR_PROFILER_NOT_INITIALIZED:
|
||||
throw profiler_not_initialized();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STARTED:
|
||||
throw profiler_already_started();
|
||||
case CUDA_ERROR_PROFILER_ALREADY_STOPPED:
|
||||
throw profiler_already_stopped();
|
||||
case CUDA_ERROR_NO_DEVICE:
|
||||
throw no_device();
|
||||
case CUDA_ERROR_INVALID_DEVICE:
|
||||
throw invalid_device();
|
||||
case CUDA_ERROR_INVALID_IMAGE:
|
||||
throw invalid_image();
|
||||
case CUDA_ERROR_INVALID_CONTEXT:
|
||||
throw invalid_context();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT:
|
||||
throw context_already_current();
|
||||
case CUDA_ERROR_MAP_FAILED:
|
||||
throw map_failed();
|
||||
case CUDA_ERROR_UNMAP_FAILED:
|
||||
throw unmap_failed();
|
||||
case CUDA_ERROR_ARRAY_IS_MAPPED:
|
||||
throw array_is_mapped();
|
||||
case CUDA_ERROR_ALREADY_MAPPED:
|
||||
throw already_mapped();
|
||||
case CUDA_ERROR_NO_BINARY_FOR_GPU:
|
||||
throw no_binary_for_gpu();
|
||||
case CUDA_ERROR_ALREADY_ACQUIRED:
|
||||
throw already_acquired();
|
||||
case CUDA_ERROR_NOT_MAPPED:
|
||||
throw not_mapped();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY:
|
||||
throw not_mapped_as_array();
|
||||
case CUDA_ERROR_NOT_MAPPED_AS_POINTER:
|
||||
throw not_mapped_as_pointer();
|
||||
case CUDA_ERROR_ECC_UNCORRECTABLE:
|
||||
throw ecc_uncorrectable();
|
||||
case CUDA_ERROR_UNSUPPORTED_LIMIT:
|
||||
throw unsupported_limit();
|
||||
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE:
|
||||
throw context_already_in_use();
|
||||
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED:
|
||||
throw peer_access_unsupported();
|
||||
case CUDA_ERROR_INVALID_PTX:
|
||||
throw invalid_ptx();
|
||||
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT:
|
||||
throw invalid_graphics_context();
|
||||
case CUDA_ERROR_INVALID_SOURCE:
|
||||
throw invalid_source();
|
||||
case CUDA_ERROR_FILE_NOT_FOUND:
|
||||
throw file_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND:
|
||||
throw shared_object_symbol_not_found();
|
||||
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
|
||||
throw shared_object_init_failed();
|
||||
case CUDA_ERROR_OPERATING_SYSTEM:
|
||||
throw operating_system();
|
||||
case CUDA_ERROR_INVALID_HANDLE:
|
||||
throw invalid_handle();
|
||||
case CUDA_ERROR_NOT_FOUND:
|
||||
throw not_found();
|
||||
case CUDA_ERROR_NOT_READY:
|
||||
throw not_ready();
|
||||
case CUDA_ERROR_ILLEGAL_ADDRESS:
|
||||
throw illegal_address();
|
||||
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
|
||||
throw launch_out_of_resources();
|
||||
case CUDA_ERROR_LAUNCH_TIMEOUT:
|
||||
throw launch_timeout();
|
||||
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING:
|
||||
throw launch_incompatible_texturing();
|
||||
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED:
|
||||
throw peer_access_already_enabled();
|
||||
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED:
|
||||
throw peer_access_not_enabled();
|
||||
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE:
|
||||
throw primary_context_active();
|
||||
case CUDA_ERROR_CONTEXT_IS_DESTROYED:
|
||||
throw context_is_destroyed();
|
||||
case CUDA_ERROR_ASSERT:
|
||||
throw assert_error();
|
||||
case CUDA_ERROR_TOO_MANY_PEERS:
|
||||
throw too_many_peers();
|
||||
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED:
|
||||
throw host_memory_already_registered();
|
||||
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED:
|
||||
throw host_memory_not_registered();
|
||||
case CUDA_ERROR_HARDWARE_STACK_ERROR:
|
||||
throw hardware_stack_error();
|
||||
case CUDA_ERROR_ILLEGAL_INSTRUCTION:
|
||||
throw illegal_instruction();
|
||||
case CUDA_ERROR_MISALIGNED_ADDRESS:
|
||||
throw misaligned_address();
|
||||
case CUDA_ERROR_INVALID_ADDRESS_SPACE:
|
||||
throw invalid_address_space();
|
||||
case CUDA_ERROR_INVALID_PC:
|
||||
throw invalid_pc();
|
||||
case CUDA_ERROR_LAUNCH_FAILED:
|
||||
throw launch_failed();
|
||||
case CUDA_ERROR_NOT_PERMITTED:
|
||||
throw not_permitted();
|
||||
case CUDA_ERROR_NOT_SUPPORTED:
|
||||
throw not_supported();
|
||||
case CUDA_ERROR_UNKNOWN:
|
||||
throw unknown();
|
||||
default:
|
||||
throw unknown();
|
||||
}
|
||||
}
|
||||
|
||||
void check(hipError_t error) {
|
||||
using namespace exception::hip;
|
||||
switch(error)
|
||||
{
|
||||
case hipSuccess : break;
|
||||
case hipErrorInvalidValue : throw invalid_value();
|
||||
case hipErrorMemoryAllocation : throw out_of_memory();
|
||||
case hipErrorNotInitialized : throw not_initialized();
|
||||
case hipErrorDeinitialized : throw deinitialized();
|
||||
case hipErrorProfilerDisabled : throw profiler_disabled();
|
||||
case hipErrorProfilerNotInitialized : throw profiler_not_initialized();
|
||||
case hipErrorProfilerAlreadyStarted : throw profiler_already_started();
|
||||
case hipErrorProfilerAlreadyStopped : throw profiler_already_stopped();
|
||||
case hipErrorNoDevice : throw no_device();
|
||||
case hipErrorInvalidSymbol : throw invalid_symbol();
|
||||
case hipErrorInvalidDevice : throw invalid_device();
|
||||
case hipErrorInvalidImage : throw invalid_image();
|
||||
case hipErrorInvalidContext : throw invalid_context();
|
||||
case hipErrorContextAlreadyCurrent : throw context_already_current();
|
||||
case hipErrorMapFailed : throw map_failed();
|
||||
case hipErrorUnmapFailed : throw unmap_failed();
|
||||
case hipErrorArrayIsMapped : throw array_is_mapped();
|
||||
case hipErrorAlreadyMapped : throw already_mapped();
|
||||
case hipErrorNoBinaryForGpu : throw no_binary_for_gpu();
|
||||
case hipErrorAlreadyAcquired : throw already_acquired();
|
||||
case hipErrorNotMapped : throw not_mapped();
|
||||
case hipErrorNotMappedAsArray : throw not_mapped_as_array();
|
||||
case hipErrorNotMappedAsPointer : throw not_mapped_as_pointer();
|
||||
case hipErrorECCNotCorrectable : throw ecc_uncorrectable();
|
||||
case hipErrorUnsupportedLimit : throw unsupported_limit();
|
||||
case hipErrorContextAlreadyInUse : throw context_already_in_use();
|
||||
case hipErrorPeerAccessUnsupported : throw peer_access_unsupported();
|
||||
case hipErrorInvalidKernelFile : throw invalid_ptx();
|
||||
case hipErrorInvalidGraphicsContext : throw invalid_graphics_context();
|
||||
case hipErrorInvalidSource : throw invalid_source();
|
||||
case hipErrorFileNotFound : throw file_not_found();
|
||||
case hipErrorSharedObjectSymbolNotFound : throw shared_object_symbol_not_found();
|
||||
case hipErrorSharedObjectInitFailed : throw shared_object_init_failed();
|
||||
case hipErrorOperatingSystem : throw operating_system();
|
||||
case hipErrorInvalidResourceHandle : throw invalid_handle();
|
||||
case hipErrorNotFound : throw not_found();
|
||||
case hipErrorNotReady : throw not_ready();
|
||||
case hipErrorIllegalAddress : throw illegal_address();
|
||||
case hipErrorLaunchOutOfResources : throw launch_out_of_resources();
|
||||
case hipErrorLaunchTimeOut : throw launch_timeout();
|
||||
// case hipErrorLaunchIncompatibleTexturing : throw launch_incompatible_texturing();
|
||||
case hipErrorPeerAccessAlreadyEnabled : throw peer_access_already_enabled();
|
||||
case hipErrorPeerAccessNotEnabled : throw peer_access_not_enabled();
|
||||
// case hipErrorPrimaryContextActive : throw primary_context_active();
|
||||
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
|
||||
case hipErrorAssert : throw assert_error();
|
||||
// case hipErrorTooManyPeers : throw too_many_peers();
|
||||
case hipErrorHostMemoryAlreadyRegistered : throw host_memory_already_registered();
|
||||
case hipErrorHostMemoryNotRegistered : throw host_memory_not_registered();
|
||||
// case hipErrorHardwareStackError : throw hardware_stack_error();
|
||||
// case hipErrorIllegalInstruction : throw illegal_instruction();
|
||||
// case hipErrorMisalignedAddress : throw misaligned_address();
|
||||
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
|
||||
// case hipErrorInvalidPc : throw invalid_pc();
|
||||
case hipErrorLaunchFailure : throw launch_failed();
|
||||
// case hipErrorNotPermitted : throw not_permitted();
|
||||
case hipErrorNotSupported : throw not_supported();
|
||||
case hipErrorUnknown : throw unknown();
|
||||
default : throw unknown();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
switch (error) {
|
||||
case hipSuccess:
|
||||
break;
|
||||
case hipErrorInvalidValue:
|
||||
throw invalid_value();
|
||||
case hipErrorMemoryAllocation:
|
||||
throw out_of_memory();
|
||||
case hipErrorNotInitialized:
|
||||
throw not_initialized();
|
||||
case hipErrorDeinitialized:
|
||||
throw deinitialized();
|
||||
case hipErrorProfilerDisabled:
|
||||
throw profiler_disabled();
|
||||
case hipErrorProfilerNotInitialized:
|
||||
throw profiler_not_initialized();
|
||||
case hipErrorProfilerAlreadyStarted:
|
||||
throw profiler_already_started();
|
||||
case hipErrorProfilerAlreadyStopped:
|
||||
throw profiler_already_stopped();
|
||||
case hipErrorNoDevice:
|
||||
throw no_device();
|
||||
case hipErrorInvalidSymbol:
|
||||
throw invalid_symbol();
|
||||
case hipErrorInvalidDevice:
|
||||
throw invalid_device();
|
||||
case hipErrorInvalidImage:
|
||||
throw invalid_image();
|
||||
case hipErrorInvalidContext:
|
||||
throw invalid_context();
|
||||
case hipErrorContextAlreadyCurrent:
|
||||
throw context_already_current();
|
||||
case hipErrorMapFailed:
|
||||
throw map_failed();
|
||||
case hipErrorUnmapFailed:
|
||||
throw unmap_failed();
|
||||
case hipErrorArrayIsMapped:
|
||||
throw array_is_mapped();
|
||||
case hipErrorAlreadyMapped:
|
||||
throw already_mapped();
|
||||
case hipErrorNoBinaryForGpu:
|
||||
throw no_binary_for_gpu();
|
||||
case hipErrorAlreadyAcquired:
|
||||
throw already_acquired();
|
||||
case hipErrorNotMapped:
|
||||
throw not_mapped();
|
||||
case hipErrorNotMappedAsArray:
|
||||
throw not_mapped_as_array();
|
||||
case hipErrorNotMappedAsPointer:
|
||||
throw not_mapped_as_pointer();
|
||||
case hipErrorECCNotCorrectable:
|
||||
throw ecc_uncorrectable();
|
||||
case hipErrorUnsupportedLimit:
|
||||
throw unsupported_limit();
|
||||
case hipErrorContextAlreadyInUse:
|
||||
throw context_already_in_use();
|
||||
case hipErrorPeerAccessUnsupported:
|
||||
throw peer_access_unsupported();
|
||||
case hipErrorInvalidKernelFile:
|
||||
throw invalid_ptx();
|
||||
case hipErrorInvalidGraphicsContext:
|
||||
throw invalid_graphics_context();
|
||||
case hipErrorInvalidSource:
|
||||
throw invalid_source();
|
||||
case hipErrorFileNotFound:
|
||||
throw file_not_found();
|
||||
case hipErrorSharedObjectSymbolNotFound:
|
||||
throw shared_object_symbol_not_found();
|
||||
case hipErrorSharedObjectInitFailed:
|
||||
throw shared_object_init_failed();
|
||||
case hipErrorOperatingSystem:
|
||||
throw operating_system();
|
||||
case hipErrorInvalidResourceHandle:
|
||||
throw invalid_handle();
|
||||
case hipErrorNotFound:
|
||||
throw not_found();
|
||||
case hipErrorNotReady:
|
||||
throw not_ready();
|
||||
case hipErrorIllegalAddress:
|
||||
throw illegal_address();
|
||||
case hipErrorLaunchOutOfResources:
|
||||
throw launch_out_of_resources();
|
||||
case hipErrorLaunchTimeOut:
|
||||
throw launch_timeout();
|
||||
// case hipErrorLaunchIncompatibleTexturing : throw
|
||||
// launch_incompatible_texturing();
|
||||
case hipErrorPeerAccessAlreadyEnabled:
|
||||
throw peer_access_already_enabled();
|
||||
case hipErrorPeerAccessNotEnabled:
|
||||
throw peer_access_not_enabled();
|
||||
// case hipErrorPrimaryContextActive : throw primary_context_active();
|
||||
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
|
||||
case hipErrorAssert:
|
||||
throw assert_error();
|
||||
// case hipErrorTooManyPeers : throw too_many_peers();
|
||||
case hipErrorHostMemoryAlreadyRegistered:
|
||||
throw host_memory_already_registered();
|
||||
case hipErrorHostMemoryNotRegistered:
|
||||
throw host_memory_not_registered();
|
||||
// case hipErrorHardwareStackError : throw hardware_stack_error();
|
||||
// case hipErrorIllegalInstruction : throw illegal_instruction();
|
||||
// case hipErrorMisalignedAddress : throw misaligned_address();
|
||||
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
|
||||
// case hipErrorInvalidPc : throw invalid_pc();
|
||||
case hipErrorLaunchFailure:
|
||||
throw launch_failed();
|
||||
// case hipErrorNotPermitted : throw not_permitted();
|
||||
case hipErrorNotSupported:
|
||||
throw not_supported();
|
||||
case hipErrorUnknown:
|
||||
throw unknown();
|
||||
default:
|
||||
throw unknown();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
@@ -1,73 +1,73 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
#include <fstream>
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/exec.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/tools/sys/mkdir.hpp"
|
||||
#include "triton/tools/sys/exec.hpp"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/Support/CodeGen.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "llvm/Transforms/Scalar.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
|
||||
// begin AMD stuff
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/FormattedStream.h"
|
||||
#include "llvm/Support/Program.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
// end AMD stuff
|
||||
|
||||
extern "C"{
|
||||
int set_curterm(char* nterm){ return 0; }
|
||||
int del_curterm(char* nterm){ return 0; }
|
||||
int tigetnum(char *capname) { return 0; }
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
extern "C" {
|
||||
int set_curterm(char *nterm) { return 0; }
|
||||
int del_curterm(char *nterm) { return 0; }
|
||||
int tigetnum(char *capname) { return 0; }
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace driver{
|
||||
namespace triton {
|
||||
namespace driver {
|
||||
|
||||
void init_llvm() {
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
@@ -80,82 +80,93 @@ void init_llvm() {
|
||||
LLVMInitializeAMDGPUAsmPrinter();
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
|
||||
static bool find_and_replace(std::string &str, const std::string &begin,
|
||||
const std::string &end,
|
||||
const std::string &target) {
|
||||
size_t start_replace = str.find(begin);
|
||||
size_t end_replace = str.find(end, start_replace);
|
||||
if(start_replace == std::string::npos)
|
||||
if (start_replace == std::string::npos)
|
||||
return false;
|
||||
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string path_to_ptxas(int& version) {
|
||||
std::string path_to_ptxas(int &version) {
|
||||
std::vector<std::string> rets;
|
||||
std::string ret;
|
||||
// search pathes for ptxas
|
||||
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
|
||||
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
|
||||
if(!triton_ptxas.empty())
|
||||
if (!triton_ptxas.empty())
|
||||
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
||||
// see what path for ptxas are valid
|
||||
std::vector<std::string> working_ptxas;
|
||||
for(std::string prefix: ptxas_prefixes){
|
||||
for (std::string prefix : ptxas_prefixes) {
|
||||
std::string ptxas = prefix + "ptxas";
|
||||
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
|
||||
if(works) {
|
||||
if (works) {
|
||||
working_ptxas.push_back(ptxas);
|
||||
rets.push_back(ret);
|
||||
}
|
||||
}
|
||||
// error if no working ptxas was found
|
||||
if(working_ptxas.empty())
|
||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
|
||||
if (working_ptxas.empty())
|
||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, "
|
||||
"/usr/local/cuda/bin/ or PATH"
|
||||
" but a working version could not be found.");
|
||||
std::string ptxas = working_ptxas.front();
|
||||
// parse version
|
||||
std::regex version_regex("release (\\d+)\\.(\\d+)");
|
||||
std::smatch match;
|
||||
bool found = false;
|
||||
// currently choosing the first ptxas. Other logics can be implemented in future
|
||||
for(std::string ret : rets) {
|
||||
if(std::regex_search(ret, match, version_regex)){
|
||||
// currently choosing the first ptxas. Other logics can be implemented in
|
||||
// future
|
||||
for (std::string ret : rets) {
|
||||
if (std::regex_search(ret, match, version_regex)) {
|
||||
int major = std::stoi(match[1]);
|
||||
int minor = std::stoi(match[2]);
|
||||
version = major*1000 + minor*10;
|
||||
version = major * 1000 + minor * 10;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if ( not found) {
|
||||
if (not found) {
|
||||
throw std::runtime_error("Error in parsing version");
|
||||
}
|
||||
return ptxas;
|
||||
}
|
||||
|
||||
|
||||
int vptx(int version){
|
||||
if(version >= 11040) return 74;
|
||||
if(version >= 11030) return 73;
|
||||
if(version >= 11020) return 72;
|
||||
if(version >= 11010) return 71;
|
||||
if(version >= 11000) return 70;
|
||||
if(version >= 10020) return 65;
|
||||
if(version >= 10010) return 64;
|
||||
if(version >= 10000) return 63;
|
||||
int vptx(int version) {
|
||||
if (version >= 11040)
|
||||
return 74;
|
||||
if (version >= 11030)
|
||||
return 73;
|
||||
if (version >= 11020)
|
||||
return 72;
|
||||
if (version >= 11010)
|
||||
return 71;
|
||||
if (version >= 11000)
|
||||
return 70;
|
||||
if (version >= 10020)
|
||||
return 65;
|
||||
if (version >= 10010)
|
||||
return 64;
|
||||
if (version >= 10000)
|
||||
return 63;
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
}
|
||||
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
std::string llir_to_ptx(llvm::Module *module, int cc, int version) {
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
int max_nvvm_ptx = 74;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||
auto *short_ptr =
|
||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
@@ -170,7 +181,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
||||
std::string layout = "";
|
||||
std::string features = "";
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
||||
// max_nvvm_ptx));
|
||||
init_llvm();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
@@ -181,16 +193,18 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
if (layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
else
|
||||
module->setDataLayout(layout);
|
||||
@@ -200,19 +214,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
pass.run(*module);
|
||||
|
||||
// post-process
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".version", "\n",
|
||||
".version " + std::to_string(ptx_major) + "." +
|
||||
std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
|
||||
;
|
||||
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
|
||||
;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
||||
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas,
|
||||
int cc) {
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[L_tmpnam];
|
||||
char _flog[L_tmpnam];
|
||||
@@ -221,15 +241,16 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
std::string fsrc = _fsrc;
|
||||
std::string flog = _flog;
|
||||
std::string fbin = fsrc + ".o";
|
||||
const char* _fbin = fbin.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(fsrc);
|
||||
ofs << ptx << std::endl;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc +
|
||||
" -o " + fsrc + ".o 2> " + flog;
|
||||
err = system(cmd.c_str());
|
||||
if(err != 0){
|
||||
if (err != 0) {
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
unlink(_fsrc);
|
||||
@@ -237,7 +258,7 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
|
||||
}
|
||||
CUmodule ret;
|
||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
unlink(_fsrc);
|
||||
@@ -251,11 +272,11 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
// HIP //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc) {
|
||||
init_llvm();
|
||||
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
@@ -270,17 +291,18 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None,
|
||||
llvm::CodeGenOpt::Aggressive);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
if (layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
else
|
||||
module->setDataLayout(layout);
|
||||
@@ -295,33 +317,37 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
std::error_code ec;
|
||||
|
||||
// Save GCN ISA binary.
|
||||
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::string isabin_path =
|
||||
std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
|
||||
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
|
||||
if (ec)
|
||||
{
|
||||
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
|
||||
if (ec) {
|
||||
std::cout << isabin_path << " was not created. error code: " << ec
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
|
||||
llvm::CGFT_ObjectFile);
|
||||
pass.run(*module);
|
||||
// Save GCN ISA.
|
||||
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string amdgcn_path =
|
||||
std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
std::ofstream amdgcn(amdgcn_path);
|
||||
amdgcn << result;
|
||||
amdgcn.close();
|
||||
|
||||
// generate HASCO file
|
||||
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string hsaco_path =
|
||||
std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string error_message;
|
||||
int lld_result =
|
||||
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
|
||||
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
|
||||
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu",
|
||||
"-shared", "-o", hsaco_path, isabin_path},
|
||||
llvm::None, {}, 0, 0, &error_message);
|
||||
if (lld_result)
|
||||
{
|
||||
if (lld_result) {
|
||||
std::cout << "ld.lld execute fail: " << std::endl;
|
||||
std::cout << error_message << std::endl;
|
||||
std::cout << lld_result << std::endl;
|
||||
@@ -330,33 +356,29 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
return hsaco_path;
|
||||
}
|
||||
|
||||
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path) {
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &path) {
|
||||
// Read HSACO.
|
||||
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
||||
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
||||
|
||||
std::vector<unsigned char> hsaco(hsaco_file_size);
|
||||
hsaco_file.seekg(0, std::ios::beg);
|
||||
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
|
||||
hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
|
||||
hsaco_file.close();
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
|
||||
hipJitOptionLogVerbose};
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
|
||||
hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes,
|
||||
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
|
||||
const unsigned int errbufsize = 8192;
|
||||
const unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize,
|
||||
(void*)_err, (void*)(uintptr_t)logbufsize,
|
||||
(void*)_log, (void*)1};
|
||||
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
|
||||
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
|
||||
hipModule_t ret;
|
||||
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace triton
|
||||
|
Reference in New Issue
Block a user