[CI] run clang-format (#24)

This commit is contained in:
Philippe Tillet
2022-07-26 17:25:03 -07:00
committed by GitHub
parent 25357083e6
commit 6d62d88d4f
62 changed files with 13673 additions and 11367 deletions

View File

@@ -8,24 +8,23 @@
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y){
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y) {
// Base Case
if (a == 0) {
*x = 0;
*y = 1;
return b;
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
int gcd = gcd_impl(b % a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*x = y1 - (b / a) * x1;
*y = x1;
return gcd;
}
@@ -35,17 +34,17 @@ static int gcd(int a, int b) {
return gcd_impl(a, b, &x, &y);
}
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
size_t rank = 1;
if(TensorType ty = value.getType().dyn_cast<TensorType>())
if (TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
int divHint = 1;
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
Operation* op = blockArg.getOwner()->getParentOp();
if(FuncOp fun = dyn_cast<FuncOp>(op)){
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if(attr)
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
}
}
@@ -55,51 +54,51 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
return AxisInfo(contiguity, divisibility, constancy);
}
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs,
const AxisInfo &rhs) {
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
ContiguityT retContiguity;
DivisibilityT retDivisibility;
ConstancyT retConstancy;
for(size_t d = 0; d < lhs.getRank(); d++){
for (size_t d = 0; d < lhs.getRank(); d++) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
}
return AxisInfo(retContiguity, retDivisibility, retConstancy);
}
//===----------------------------------------------------------------------===//
// AxisInfoAnalysis
//===----------------------------------------------------------------------===//
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity;
AxisInfo::DivisibilityT newDivisibility;
AxisInfo::ConstancyT newConstancy;
for(size_t d = 0; d < rank; d++){
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
AxisInfo AxisInfoAnalysis::visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity;
AxisInfo::DivisibilityT newDivisibility;
AxisInfo::ConstancyT newConstancy;
for (size_t d = 0; d < rank; d++) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
}
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) {
ChangeResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
AxisInfo curr;
// This preserves the input axes (e.g., cast):
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
triton::PtrToIntOp, triton::IntToPtrOp>(op))
curr = operands[0]->getValue();
// Constant ranges
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
if (triton::MakeRangeOp make_range =
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
int start = make_range.start();
int end = make_range.end();
AxisInfo::ContiguityT contiguity = {end - start};
@@ -108,61 +107,59 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Constant
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
if(intAttr){
if (intAttr) {
size_t val = intAttr.getValue().getZExtValue();
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
}
// TODO: generalize to dense attr
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
if(splatAttr && splatAttr.getElementType().isInteger(32)){
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
auto value = splatAttr.getSplatValue<int>();
TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
curr = AxisInfo(
AxisInfo::ContiguityT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
}
}
// Addition
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
};
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
newContiguity, newDivisibility, newConstancy);
}
// Multiplication
if (llvm::isa<arith::MulIOp>(op)){
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
return 1;
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
if (llvm::isa<arith::MulIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
newContiguity, newDivisibility, newConstancy);
}
// Splat
if (llvm::isa<triton::SplatOp>(op)){
if (llvm::isa<triton::SplatOp>(op)) {
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
for(size_t d = 0; d < retTy.getRank(); d++){
for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]);
@@ -171,7 +168,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
}
// Reshape
// TODO: Replace by `unsqueeze`
if (llvm::isa<triton::ReshapeOp>(op)){
if (llvm::isa<triton::ReshapeOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -184,20 +181,17 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
AxisInfo::ConstancyT constancy;
bool is_skewed = false;
size_t current = 0;
for(size_t d = 0; d < retTy.getRank(); d++){
if(retShape[d] == 1){
for (size_t d = 0; d < retTy.getRank(); d++) {
if (retShape[d] == 1) {
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(1);
}
else if(!is_skewed
&& retShape[d] == opShape[current]){
} else if (!is_skewed && retShape[d] == opShape[current]) {
contiguity.push_back(opInfo.getContiguity()[current]);
divisibility.push_back(opInfo.getDivisibility()[current]);
constancy.push_back(opInfo.getConstancy()[current]);
current++;
}
else {
} else {
is_skewed = true;
contiguity.push_back(1);
divisibility.push_back(1);
@@ -207,7 +201,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Broadcast
if (llvm::isa<triton::BroadcastOp>(op)){
if (llvm::isa<triton::BroadcastOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -218,14 +212,14 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
for(size_t d = 0; d < retTy.getRank(); d++){
for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
if(curr.getRank() == 0){
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}
// join all latice elements
@@ -236,4 +230,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
return result;
}
}
} // namespace mlir

View File

@@ -2,14 +2,16 @@
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir{
namespace triton{
namespace mlir {
namespace triton {
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
}
}
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,42 +1,42 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "../PassDetail.h"
using namespace mlir;
using namespace mlir::triton;
namespace {
template<class Op>
class ArithGenericPattern : public OpConversionPattern<Op> {
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
Op res =
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success();
}
};
template<class SrcOp, class DstOp>
template <class SrcOp, class DstOp>
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<SrcOp>::OpConversionPattern;
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
);
DstOp res =
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
@@ -45,36 +45,40 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
public:
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
assert(value);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
op, retType,
value.reshape(retType) // This is a hack. We just want to add encoding
);
return success();
}
};
class ConvertArithmeticOp: public ConversionPattern {
class ConvertArithmeticOp : public ConversionPattern {
public:
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override {
Dialect* dialect = op->getDialect();
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
return failure();
return success();
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Dialect *dialect = op->getDialect();
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
return failure();
return success();
}
};
void populateArithmeticPatternsAndLegality(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target){
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
// --------------
// Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that
@@ -91,59 +95,49 @@ void populateArithmeticPatternsAndLegality(
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithConstantPattern,
ArithGenericPattern<arith::AddIOp>,
ArithGenericPattern<arith::SubIOp>,
ArithGenericPattern<arith::MulIOp>,
ArithGenericPattern<arith::DivUIOp>,
ArithGenericPattern<arith::DivSIOp>,
ArithGenericPattern<arith::CeilDivUIOp>,
ArithGenericPattern<arith::CeilDivSIOp>,
ArithGenericPattern<arith::FloorDivSIOp>,
ArithGenericPattern<arith::RemUIOp>,
ArithGenericPattern<arith::RemSIOp>,
ArithGenericPattern<arith::AndIOp>,
ArithGenericPattern<arith::OrIOp>,
ArithGenericPattern<arith::XOrIOp>,
ArithGenericPattern<arith::ShLIOp>,
ArithGenericPattern<arith::ShRUIOp>,
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
// Floating point
ArithGenericPattern<arith::AddFOp>,
ArithGenericPattern<arith::SubFOp>,
// MaxMin
ArithGenericPattern<arith::MaxFOp>,
ArithGenericPattern<arith::MaxSIOp>,
ArithGenericPattern<arith::MaxUIOp>,
ArithGenericPattern<arith::MinFOp>,
ArithGenericPattern<arith::MinSIOp>,
ArithGenericPattern<arith::MinUIOp>,
// Floating point
ArithGenericPattern<arith::MulFOp>,
ArithGenericPattern<arith::DivFOp>,
ArithGenericPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>
>(typeConverter, context);
patterns.add<
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
ArithGenericPattern<arith::CeilDivUIOp>,
ArithGenericPattern<arith::CeilDivSIOp>,
ArithGenericPattern<arith::FloorDivSIOp>,
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
ArithGenericPattern<arith::ShRUIOp>,
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
// Floating point
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
// MaxMin
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
// Floating point
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
ArithGenericPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
}
//
// Triton patterns
//
// TODO: Do we need to put them in anonymous namespace?
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
struct TritonMakeRangePattern
: public OpConversionPattern<triton::MakeRangeOp> {
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, retType, adaptor.start(), adaptor.end()
);
op, retType, adaptor.start(), adaptor.end());
return success();
}
};
@@ -151,8 +145,9 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
// a & b must be of smem layout
auto aType = adaptor.a().getType().cast<RankedTensorType>();
@@ -165,18 +160,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value b = adaptor.b();
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
);
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
return success();
}
};
@@ -184,14 +182,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, retType,
adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
);
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
return success();
}
};
@@ -199,11 +196,11 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
);
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
return success();
}
};
@@ -212,12 +209,11 @@ template <class Op>
struct TritonGenericPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
return success();
}
};
@@ -225,30 +221,25 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
);
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
return success();
}
};
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::SplatOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>,
TritonReducePattern,
TritonMakeRangePattern,
TritonDotPattern,
TritonLoadPattern,
TritonStorePattern
>(typeConverter, context);
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
}
//
@@ -259,17 +250,19 @@ void populateTritonPatterns(
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
// Ref: ConvertForOpTypes
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
LogicalResult
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp =
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
// Now, update all the types.
// Convert the types of block arguments within the given region. This
// replaces each block with a new block containing the updated signature. The
// entry block may have a special conversion if `entryConversion` is
// replaces each block with a new block containing the updated signature.
// The entry block may have a special conversion if `entryConversion` is
// provided. On success, the new entry block to the region is returned for
// convenience. Otherwise, failure is returned.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
@@ -299,33 +292,27 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
// op.erase();
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, adaptor.getOperands()
);
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
return success();
}
};
void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern
>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
}
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
class ConvertTritonToTritonGPU
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
ConvertTritonToTritonGPU(int numWarps) {
this->numWarps = numWarps;
}
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -339,21 +326,21 @@ public:
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
// TODO: can we use
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
// update layouts
// broadcast src => multicast, dst => broadcasted
if(failed(target.refineLayouts(mod, numWarps)))
if (failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure();
}
};
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {

View File

@@ -7,7 +7,6 @@
#include "mlir/IR/DialectImplementation.h"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
using namespace mlir;
@@ -19,12 +18,13 @@ void TritonDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
>();
>();
// We can also add interface here.
}
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
}

View File

@@ -13,14 +13,16 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return Type();
}
@@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) {
return Type();
}
}
}
} // namespace triton
} // namespace mlir
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
@@ -48,50 +50,48 @@ namespace triton {
//-- StoreOp --
// Default mask
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
}
//-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
// other
Type resultType = RankedTensorType::get(shape, elementType);
::mlir::Value other = builder.create<arith::ConstantOp>(
ptr.getLoc(),
resultType,
DenseElementsAttr::get(
resultType, builder.getZeroAttr(elementType)
)
);
ptr.getLoc(), resultType,
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}

View File

@@ -1,6 +1,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir;
@@ -16,7 +16,7 @@ void TritonDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
>();
>();
}
Type PointerType::parse(AsmParser &parser) {

View File

@@ -17,21 +17,23 @@ namespace {
class CombineDotOp : public mlir::RewritePattern {
public:
CombineDotOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
if (isCandidate(op->getOperand(0)).succeeded()) {
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
op->getOperand(1), dotOp.allowTF32());
return mlir::success();
} else if (isCandidate(op->getOperand(1)).succeeded()) {
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
op->getOperand(0), dotOp.allowTF32());
return mlir::success();
}
}
@@ -54,7 +56,7 @@ private:
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
return true;
}
@@ -68,18 +70,19 @@ private:
class CombineGEPOp : public mlir::RewritePattern {
public:
CombineGEPOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::triton::GEPOp>(op)) {
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
auto loc = op->getLoc();
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
loc, op->getOperand(1), gep2->getOperand(1));
loc, op->getOperand(1), gep2->getOperand(1));
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
);
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
return mlir::success();
}
}
@@ -92,20 +95,21 @@ public:
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::SelectOp>(op)) {
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
mlir::Value cond = op->getOperand(0);
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
op, op->getResultTypes().front(),
load.ptr(), load.mask(), op->getOperand(2),
load.cache(), load.evict(), load.isVolatile()
);
op, op->getResultTypes().front(), load.ptr(), load.mask(),
op->getOperand(2), load.cache(), load.evict(),
load.isVolatile());
return mlir::success();
}
}
@@ -120,11 +124,11 @@ public:
class CombineBroadcastConstantOp : public mlir::RewritePattern {
public:
CombineBroadcastConstantOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
PatternRewriter &rewriter) const override {
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
Attribute value = cst.getValue();
@@ -132,15 +136,14 @@ public:
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
if (!denseValue.isSplat())
return failure();
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
value = DenseElementsAttr::get(resType,
denseValue.getSplatValue<Attribute>());
} else {
if (!value.isa<FloatAttr, IntegerAttr>())
return failure();
value = DenseElementsAttr::get(resType, value);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, value, resType
);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
return success();
}
}

View File

@@ -11,19 +11,18 @@ using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
/*SmallVector<unsigned, 2>*/auto &res,
StringRef desc) {
/*SmallVector<unsigned, 2>*/ auto &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ")
<< desc;
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
return failure();
}
for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
<< desc;
return failure();
}
res.push_back(intAttr.getUInt());
@@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
@@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
.failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
.failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else if (attr.getName() == "broadcastAxis") {
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed())
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
@@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
}
}
return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
threadTileSize,
warpTileSize,
blockTileSize,
order,
broadcastAxis);
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
}
static void printBlocked(AsmPrinter &printer, auto *attr) {
@@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
<< "}>";
}
Attribute
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
@@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
}
Attribute
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
parseBlocked(parser, type);
}
@@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) {
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
.failed())
return {};
} else if (attr.getName() == "shapePerWarp") {
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
.failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
.failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
fragmentPerWarp,
shapePerWarp,
warpPerTile,
shapePerTile,
repetitions,
contigPerThread,
broadcastAxis);
return parser.getChecked<TritonGPUMmaEncodingAttr>(
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}
static void printMma(AsmPrinter &printer, auto *attr) {
@@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) {
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
@@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseMma(parser, type);
}
@@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr,
unsigned &value,
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
@@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
vec,
perPhase,
maxPhase,
order);
return parser.getChecked<TritonGPUSharedEncodingAttr>(
parser.getContext(), vec, perPhase, maxPhase, order);
}
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase()
<< ", order = [" << getOrder() << "]"
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
<< "]"
<< "}>";
}
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
@@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
OpAsmDialectInterface::getAlias(attr, os);
}
private:
private:
static void printMma(const auto &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
@@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
>();
addInterfaces<TritonGPUOpAsmInterface>();
}
@@ -349,7 +340,8 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
@@ -368,8 +360,8 @@ static Type getPointeeType(Type type) {
return Type();
}
}
}
} // namespace triton
} // namespace mlir
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
@@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) {
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
// verify TritonGPU ops
LogicalResult
TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -27,8 +27,8 @@ namespace {
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();

View File

@@ -6,12 +6,11 @@
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
#define GEN_PASS_CLASSES
@@ -41,14 +40,15 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation*> depOps;
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
@@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() {
}
// for (triton::LoadOp loadOp : allLoads) {
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
// for (Value dep : loadDeps[loadOp])
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
// " values\n"; for (Value dep : loadDeps[loadOp])
// llvm::errs() << dep << "\n";
// llvm::errs() << "\n";
// }
@@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() {
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding()
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
@@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() {
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
@@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() {
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
loadOp.isVolatile());
} else
llvm_unreachable("This should be LoadOp");
} else
@@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
loopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
mask,
splatCond);
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), loopCond);
Value newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
newOp->setOperand(1, newMask);
}
@@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() {
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
}
@@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(),
newLoopArgs);
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
@@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
}
// 4. prefetch the next iteration
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() {
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx],
newForOp.getStep());
Value nextLoopCond = builder.create<arith::CmpIOp>(
nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
splatCond,
nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than once
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
}
else
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
} else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
// update mapping of results
@@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(newForOp.getRegionIterArgs()[
loadIdx + idx*(numStages-1) + stage
]);
yieldValues.push_back(
newForOp
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
@@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) {
this->numStages = numStages;
}
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;

View File

@@ -1,7 +1,7 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm>
using namespace mlir;
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
// TODO: how does MLIR pick the right conversion?
@@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
int remainingLanes = /*warp size*/ 32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
@@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
context, threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
});
@@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
StandardOpsDialect, scf::SCFDialect>(
[&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding &&
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
@@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// return true;
return false;
});
}
// %dst = tt.broadcast %src
@@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
@@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
if (auto blockedEnc =
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(), broadcastAxis);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
@@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
newSrcType, src);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(), broadcastAxis);
auto newType =
RankedTensorType::get(originDstTensorType.getShape(),
originDstTensorType.getElementType(), newEnc);
Value newBroadcast =
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.getLoc(), originDstType, newBroadcast);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);

View File

@@ -5,7 +5,6 @@
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -37,28 +36,30 @@ private:
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return dotOp.emitError() << name << " should be of shared layout";
} else
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
Attribute cLayout;
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
llvm::SmallVector<char>{'c', 'd'})) {
llvm::SmallVector<char>{'c', 'd'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
return dotOp.emitError()
<< name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;
else if (encoding != cLayout)
return dotOp.emitError() << "d & c should have the same layout";
} else
return dotOp.emitError() << name
<< "'s type should be of RankedTensorType";
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
// signalPassFailure();
@@ -89,7 +90,7 @@ private:
}
void verifyImpl(Operation *op) {
if(verifySingleOp(op).failed())
if (verifySingleOp(op).failed())
signalPassFailure();
// verify that all child regions are ok

408
lib/driver/dispatch.cc Executable file → Normal file
View File

@@ -1,107 +1,152 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
namespace triton {
namespace driver {
//Helpers for function definition
#define DEFINE0(init, hlib, ret, fname) ret dispatch::fname()\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname); }\
void* dispatch::fname ## _;
// Helpers for function definition
#define DEFINE0(init, hlib, ret, fname) \
ret dispatch::fname() { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname); \
} \
void *dispatch::fname##_;
#define DEFINE1(init, hlib, ret, fname, t1) ret dispatch::fname(t1 a)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a); }\
void* dispatch::fname ## _;
#define DEFINE1(init, hlib, ret, fname, t1) \
ret dispatch::fname(t1 a) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a); \
} \
void *dispatch::fname##_;
#define DEFINE2(init, hlib, ret, fname, t1, t2) ret dispatch::fname(t1 a, t2 b)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b); }\
void* dispatch::fname ## _;
#define DEFINE2(init, hlib, ret, fname, t1, t2) \
ret dispatch::fname(t1 a, t2 b) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b); \
} \
void *dispatch::fname##_;
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) ret dispatch::fname(t1 a, t2 b, t3 c)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c); }\
void* dispatch::fname ## _;
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) \
ret dispatch::fname(t1 a, t2 b, t3 c) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c); \
} \
void *dispatch::fname##_;
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d); }\
void* dispatch::fname ## _;
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d); \
} \
void *dispatch::fname##_;
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e); }\
void* dispatch::fname ## _;
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e); \
} \
void *dispatch::fname##_;
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f); }\
void* dispatch::fname ## _;
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f); \
} \
void *dispatch::fname##_;
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g); }\
void* dispatch::fname ## _;
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g); \
} \
void *dispatch::fname##_;
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h); }\
void* dispatch::fname ## _;
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h); \
} \
void *dispatch::fname##_;
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i); }\
void* dispatch::fname ## _;
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i); \
} \
void *dispatch::fname##_;
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j); }\
void* dispatch::fname ## _;
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j); \
} \
void *dispatch::fname##_;
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k); }\
void* dispatch::fname ## _;
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k); \
} \
void *dispatch::fname##_;
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m); }\
void* dispatch::fname ## _;
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }\
void* dispatch::fname ## _;
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11, t12, t13) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k, t12 l, t13 m) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k, l, m); \
} \
void *dispatch::fname##_;
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, \
t18 r, t19 s) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k, l, m, n, o, p, q, r, \
s); \
} \
void *dispatch::fname##_;
/* ------------------- *
* CUDA
* ------------------- */
bool dispatch::cuinit(){
if(cuda_==nullptr){
#ifdef _WIN32
bool dispatch::cuinit() {
if (cuda_ == nullptr) {
#ifdef _WIN32
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
#else
#else
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
if(!cuda_)
if (!cuda_)
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
#endif
if(!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
#endif
if (!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is "
"in your LD_LIBRARY_PATH.");
}
if(cuda_ == nullptr)
if (cuda_ == nullptr)
return false;
CUresult (*fptr)(unsigned int);
cuInit_ = dlsym(cuda_, "cuInit");
@@ -112,21 +157,33 @@ bool dispatch::cuinit(){
}
#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
#define CUDA_DEFINE2(ret, fname, t1, t2) \
DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) \
DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11) \
DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11)
// context management
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice *)
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
CUDA_DEFINE1(CUresult, cuInit, unsigned int)
CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
@@ -134,59 +191,71 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int)
CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute,
CUdevice)
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *)
// link management
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void *,
size_t, const char *, unsigned int, CUjit_option *, void **);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option *, void **,
CUlinkState *);
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*);
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void **, size_t *);
// module management
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*)
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr *, size_t *, CUmodule,
const char *)
CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **)
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *,
unsigned int, CUjit_option *, void **)
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule,
const char *)
// stream management
CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int)
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **)
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext *)
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, CUstream, void **, void **)
// function management
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int *, CUfunction_attribute,
CUfunction)
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute,
int)
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
// memory management
CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t)
CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t )
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t,
CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t,
CUstream)
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t)
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr *, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void *, CUpointer_attribute,
CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t,
CUstream)
// event management
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent)
CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream)
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
/* ------------------- *
* NVML
* ------------------- */
bool dispatch::nvmlinit(){
#ifdef _WIN32
if(nvml_==nullptr)
bool dispatch::nvmlinit() {
#ifdef _WIN32
if (nvml_ == nullptr)
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
#else
if(nvml_==nullptr)
#else
if (nvml_ == nullptr)
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
#endif
#endif
nvmlReturn_t (*fptr)();
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
@@ -197,21 +266,27 @@ bool dispatch::nvmlinit(){
#define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname)
#define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1)
#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
#define NVML_DEFINE2(ret, fname, t1, t2) \
DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
#define NVML_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *,
nvmlDevice_t *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t,
nvmlClockType_t, unsigned int *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t,
nvmlClockType_t, unsigned int *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t,
unsigned int, unsigned int)
/* ------------------- *
* HIP
* ------------------- */
bool dispatch::hipinit(){
if(hip_==nullptr)
bool dispatch::hipinit() {
if (hip_ == nullptr)
hip_ = dlopen("libamdhip64.so", RTLD_LAZY);
if(hip_ == nullptr)
if (hip_ == nullptr)
return false;
hipError_t (*fptr)();
hipInit_ = dlsym(hip_, "hipInit");
@@ -222,23 +297,34 @@ bool dispatch::hipinit(){
}
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
#define HIP_DEFINE2(ret, fname, t1, t2) \
DEFINE2(hipinit, hip_, ret, fname, t1, t2)
#define HIP_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) \
DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \
DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11)
// context management
HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t)
HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t*)
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t *)
HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t)
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t*)
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t *)
HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int)
HIP_DEFINE1(hipError_t, hipInit, unsigned int)
HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
@@ -246,56 +332,64 @@ HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int)
HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t,
hipDevice_t)
HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *)
// module management
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t*, size_t*, hipModule_t, const char*)
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t *, size_t *,
hipModule_t, const char *)
HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *)
HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t)
HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *)
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, hipJitOption *, void **)
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *)
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *,
unsigned int, hipJitOption *, void **)
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t,
const char *)
// stream management
HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int)
HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t)
HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t)
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **)
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, hipStream_t, void **, void **)
// function management
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes*, void*)
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes *, void *)
HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t)
// memory management
HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t)
HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, size_t, hipStream_t)
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t )
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t*, size_t)
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void*, CUpointer_attribute, hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t,
hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *,
size_t, hipStream_t)
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t)
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t *, size_t)
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void *, CUpointer_attribute,
hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t,
hipStream_t)
// event management
HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
/* ------------------- *
* COMMON
* ------------------- */
// Release
void dispatch::release(){
if(cuda_){
void dispatch::release() {
if (cuda_) {
dlclose(cuda_);
cuda_ = nullptr;
}
}
void* dispatch::cuda_;
void* dispatch::nvml_;
void* dispatch::nvmlInit_v2_;
void* dispatch::hip_;
void *dispatch::cuda_;
void *dispatch::nvml_;
void *dispatch::nvmlInit_v2_;
void *dispatch::hip_;
}
}
} // namespace driver
} // namespace triton

410
lib/driver/error.cc Executable file → Normal file
View File

@@ -1,166 +1,270 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/error.h"
namespace triton
{
namespace driver
{
namespace triton {
namespace driver {
void check(CUresult err)
{
void check(CUresult err) {
using namespace exception::cuda;
switch(err)
{
case CUDA_SUCCESS : break;
case CUDA_ERROR_INVALID_VALUE : throw invalid_value();
case CUDA_ERROR_OUT_OF_MEMORY : throw out_of_memory();
case CUDA_ERROR_NOT_INITIALIZED : throw not_initialized();
case CUDA_ERROR_DEINITIALIZED : throw deinitialized();
case CUDA_ERROR_PROFILER_DISABLED : throw profiler_disabled();
case CUDA_ERROR_PROFILER_NOT_INITIALIZED : throw profiler_not_initialized();
case CUDA_ERROR_PROFILER_ALREADY_STARTED : throw profiler_already_started();
case CUDA_ERROR_PROFILER_ALREADY_STOPPED : throw profiler_already_stopped();
case CUDA_ERROR_NO_DEVICE : throw no_device();
case CUDA_ERROR_INVALID_DEVICE : throw invalid_device();
case CUDA_ERROR_INVALID_IMAGE : throw invalid_image();
case CUDA_ERROR_INVALID_CONTEXT : throw invalid_context();
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT : throw context_already_current();
case CUDA_ERROR_MAP_FAILED : throw map_failed();
case CUDA_ERROR_UNMAP_FAILED : throw unmap_failed();
case CUDA_ERROR_ARRAY_IS_MAPPED : throw array_is_mapped();
case CUDA_ERROR_ALREADY_MAPPED : throw already_mapped();
case CUDA_ERROR_NO_BINARY_FOR_GPU : throw no_binary_for_gpu();
case CUDA_ERROR_ALREADY_ACQUIRED : throw already_acquired();
case CUDA_ERROR_NOT_MAPPED : throw not_mapped();
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY : throw not_mapped_as_array();
case CUDA_ERROR_NOT_MAPPED_AS_POINTER : throw not_mapped_as_pointer();
case CUDA_ERROR_ECC_UNCORRECTABLE : throw ecc_uncorrectable();
case CUDA_ERROR_UNSUPPORTED_LIMIT : throw unsupported_limit();
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE : throw context_already_in_use();
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED : throw peer_access_unsupported();
case CUDA_ERROR_INVALID_PTX : throw invalid_ptx();
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT : throw invalid_graphics_context();
case CUDA_ERROR_INVALID_SOURCE : throw invalid_source();
case CUDA_ERROR_FILE_NOT_FOUND : throw file_not_found();
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND : throw shared_object_symbol_not_found();
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED : throw shared_object_init_failed();
case CUDA_ERROR_OPERATING_SYSTEM : throw operating_system();
case CUDA_ERROR_INVALID_HANDLE : throw invalid_handle();
case CUDA_ERROR_NOT_FOUND : throw not_found();
case CUDA_ERROR_NOT_READY : throw not_ready();
case CUDA_ERROR_ILLEGAL_ADDRESS : throw illegal_address();
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES : throw launch_out_of_resources();
case CUDA_ERROR_LAUNCH_TIMEOUT : throw launch_timeout();
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING : throw launch_incompatible_texturing();
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED : throw peer_access_already_enabled();
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED : throw peer_access_not_enabled();
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE : throw primary_context_active();
case CUDA_ERROR_CONTEXT_IS_DESTROYED : throw context_is_destroyed();
case CUDA_ERROR_ASSERT : throw assert_error();
case CUDA_ERROR_TOO_MANY_PEERS : throw too_many_peers();
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED : throw host_memory_already_registered();
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED : throw host_memory_not_registered();
case CUDA_ERROR_HARDWARE_STACK_ERROR : throw hardware_stack_error();
case CUDA_ERROR_ILLEGAL_INSTRUCTION : throw illegal_instruction();
case CUDA_ERROR_MISALIGNED_ADDRESS : throw misaligned_address();
case CUDA_ERROR_INVALID_ADDRESS_SPACE : throw invalid_address_space();
case CUDA_ERROR_INVALID_PC : throw invalid_pc();
case CUDA_ERROR_LAUNCH_FAILED : throw launch_failed();
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
case CUDA_ERROR_UNKNOWN : throw unknown();
default : throw unknown();
switch (err) {
case CUDA_SUCCESS:
break;
case CUDA_ERROR_INVALID_VALUE:
throw invalid_value();
case CUDA_ERROR_OUT_OF_MEMORY:
throw out_of_memory();
case CUDA_ERROR_NOT_INITIALIZED:
throw not_initialized();
case CUDA_ERROR_DEINITIALIZED:
throw deinitialized();
case CUDA_ERROR_PROFILER_DISABLED:
throw profiler_disabled();
case CUDA_ERROR_PROFILER_NOT_INITIALIZED:
throw profiler_not_initialized();
case CUDA_ERROR_PROFILER_ALREADY_STARTED:
throw profiler_already_started();
case CUDA_ERROR_PROFILER_ALREADY_STOPPED:
throw profiler_already_stopped();
case CUDA_ERROR_NO_DEVICE:
throw no_device();
case CUDA_ERROR_INVALID_DEVICE:
throw invalid_device();
case CUDA_ERROR_INVALID_IMAGE:
throw invalid_image();
case CUDA_ERROR_INVALID_CONTEXT:
throw invalid_context();
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT:
throw context_already_current();
case CUDA_ERROR_MAP_FAILED:
throw map_failed();
case CUDA_ERROR_UNMAP_FAILED:
throw unmap_failed();
case CUDA_ERROR_ARRAY_IS_MAPPED:
throw array_is_mapped();
case CUDA_ERROR_ALREADY_MAPPED:
throw already_mapped();
case CUDA_ERROR_NO_BINARY_FOR_GPU:
throw no_binary_for_gpu();
case CUDA_ERROR_ALREADY_ACQUIRED:
throw already_acquired();
case CUDA_ERROR_NOT_MAPPED:
throw not_mapped();
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY:
throw not_mapped_as_array();
case CUDA_ERROR_NOT_MAPPED_AS_POINTER:
throw not_mapped_as_pointer();
case CUDA_ERROR_ECC_UNCORRECTABLE:
throw ecc_uncorrectable();
case CUDA_ERROR_UNSUPPORTED_LIMIT:
throw unsupported_limit();
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE:
throw context_already_in_use();
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED:
throw peer_access_unsupported();
case CUDA_ERROR_INVALID_PTX:
throw invalid_ptx();
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT:
throw invalid_graphics_context();
case CUDA_ERROR_INVALID_SOURCE:
throw invalid_source();
case CUDA_ERROR_FILE_NOT_FOUND:
throw file_not_found();
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND:
throw shared_object_symbol_not_found();
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
throw shared_object_init_failed();
case CUDA_ERROR_OPERATING_SYSTEM:
throw operating_system();
case CUDA_ERROR_INVALID_HANDLE:
throw invalid_handle();
case CUDA_ERROR_NOT_FOUND:
throw not_found();
case CUDA_ERROR_NOT_READY:
throw not_ready();
case CUDA_ERROR_ILLEGAL_ADDRESS:
throw illegal_address();
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
throw launch_out_of_resources();
case CUDA_ERROR_LAUNCH_TIMEOUT:
throw launch_timeout();
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING:
throw launch_incompatible_texturing();
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED:
throw peer_access_already_enabled();
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED:
throw peer_access_not_enabled();
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE:
throw primary_context_active();
case CUDA_ERROR_CONTEXT_IS_DESTROYED:
throw context_is_destroyed();
case CUDA_ERROR_ASSERT:
throw assert_error();
case CUDA_ERROR_TOO_MANY_PEERS:
throw too_many_peers();
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED:
throw host_memory_already_registered();
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED:
throw host_memory_not_registered();
case CUDA_ERROR_HARDWARE_STACK_ERROR:
throw hardware_stack_error();
case CUDA_ERROR_ILLEGAL_INSTRUCTION:
throw illegal_instruction();
case CUDA_ERROR_MISALIGNED_ADDRESS:
throw misaligned_address();
case CUDA_ERROR_INVALID_ADDRESS_SPACE:
throw invalid_address_space();
case CUDA_ERROR_INVALID_PC:
throw invalid_pc();
case CUDA_ERROR_LAUNCH_FAILED:
throw launch_failed();
case CUDA_ERROR_NOT_PERMITTED:
throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED:
throw not_supported();
case CUDA_ERROR_UNKNOWN:
throw unknown();
default:
throw unknown();
}
}
void check(hipError_t error) {
using namespace exception::hip;
switch(error)
{
case hipSuccess : break;
case hipErrorInvalidValue : throw invalid_value();
case hipErrorMemoryAllocation : throw out_of_memory();
case hipErrorNotInitialized : throw not_initialized();
case hipErrorDeinitialized : throw deinitialized();
case hipErrorProfilerDisabled : throw profiler_disabled();
case hipErrorProfilerNotInitialized : throw profiler_not_initialized();
case hipErrorProfilerAlreadyStarted : throw profiler_already_started();
case hipErrorProfilerAlreadyStopped : throw profiler_already_stopped();
case hipErrorNoDevice : throw no_device();
case hipErrorInvalidSymbol : throw invalid_symbol();
case hipErrorInvalidDevice : throw invalid_device();
case hipErrorInvalidImage : throw invalid_image();
case hipErrorInvalidContext : throw invalid_context();
case hipErrorContextAlreadyCurrent : throw context_already_current();
case hipErrorMapFailed : throw map_failed();
case hipErrorUnmapFailed : throw unmap_failed();
case hipErrorArrayIsMapped : throw array_is_mapped();
case hipErrorAlreadyMapped : throw already_mapped();
case hipErrorNoBinaryForGpu : throw no_binary_for_gpu();
case hipErrorAlreadyAcquired : throw already_acquired();
case hipErrorNotMapped : throw not_mapped();
case hipErrorNotMappedAsArray : throw not_mapped_as_array();
case hipErrorNotMappedAsPointer : throw not_mapped_as_pointer();
case hipErrorECCNotCorrectable : throw ecc_uncorrectable();
case hipErrorUnsupportedLimit : throw unsupported_limit();
case hipErrorContextAlreadyInUse : throw context_already_in_use();
case hipErrorPeerAccessUnsupported : throw peer_access_unsupported();
case hipErrorInvalidKernelFile : throw invalid_ptx();
case hipErrorInvalidGraphicsContext : throw invalid_graphics_context();
case hipErrorInvalidSource : throw invalid_source();
case hipErrorFileNotFound : throw file_not_found();
case hipErrorSharedObjectSymbolNotFound : throw shared_object_symbol_not_found();
case hipErrorSharedObjectInitFailed : throw shared_object_init_failed();
case hipErrorOperatingSystem : throw operating_system();
case hipErrorInvalidResourceHandle : throw invalid_handle();
case hipErrorNotFound : throw not_found();
case hipErrorNotReady : throw not_ready();
case hipErrorIllegalAddress : throw illegal_address();
case hipErrorLaunchOutOfResources : throw launch_out_of_resources();
case hipErrorLaunchTimeOut : throw launch_timeout();
// case hipErrorLaunchIncompatibleTexturing : throw launch_incompatible_texturing();
case hipErrorPeerAccessAlreadyEnabled : throw peer_access_already_enabled();
case hipErrorPeerAccessNotEnabled : throw peer_access_not_enabled();
// case hipErrorPrimaryContextActive : throw primary_context_active();
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
case hipErrorAssert : throw assert_error();
// case hipErrorTooManyPeers : throw too_many_peers();
case hipErrorHostMemoryAlreadyRegistered : throw host_memory_already_registered();
case hipErrorHostMemoryNotRegistered : throw host_memory_not_registered();
// case hipErrorHardwareStackError : throw hardware_stack_error();
// case hipErrorIllegalInstruction : throw illegal_instruction();
// case hipErrorMisalignedAddress : throw misaligned_address();
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
// case hipErrorInvalidPc : throw invalid_pc();
case hipErrorLaunchFailure : throw launch_failed();
// case hipErrorNotPermitted : throw not_permitted();
case hipErrorNotSupported : throw not_supported();
case hipErrorUnknown : throw unknown();
default : throw unknown();
}
}
}
switch (error) {
case hipSuccess:
break;
case hipErrorInvalidValue:
throw invalid_value();
case hipErrorMemoryAllocation:
throw out_of_memory();
case hipErrorNotInitialized:
throw not_initialized();
case hipErrorDeinitialized:
throw deinitialized();
case hipErrorProfilerDisabled:
throw profiler_disabled();
case hipErrorProfilerNotInitialized:
throw profiler_not_initialized();
case hipErrorProfilerAlreadyStarted:
throw profiler_already_started();
case hipErrorProfilerAlreadyStopped:
throw profiler_already_stopped();
case hipErrorNoDevice:
throw no_device();
case hipErrorInvalidSymbol:
throw invalid_symbol();
case hipErrorInvalidDevice:
throw invalid_device();
case hipErrorInvalidImage:
throw invalid_image();
case hipErrorInvalidContext:
throw invalid_context();
case hipErrorContextAlreadyCurrent:
throw context_already_current();
case hipErrorMapFailed:
throw map_failed();
case hipErrorUnmapFailed:
throw unmap_failed();
case hipErrorArrayIsMapped:
throw array_is_mapped();
case hipErrorAlreadyMapped:
throw already_mapped();
case hipErrorNoBinaryForGpu:
throw no_binary_for_gpu();
case hipErrorAlreadyAcquired:
throw already_acquired();
case hipErrorNotMapped:
throw not_mapped();
case hipErrorNotMappedAsArray:
throw not_mapped_as_array();
case hipErrorNotMappedAsPointer:
throw not_mapped_as_pointer();
case hipErrorECCNotCorrectable:
throw ecc_uncorrectable();
case hipErrorUnsupportedLimit:
throw unsupported_limit();
case hipErrorContextAlreadyInUse:
throw context_already_in_use();
case hipErrorPeerAccessUnsupported:
throw peer_access_unsupported();
case hipErrorInvalidKernelFile:
throw invalid_ptx();
case hipErrorInvalidGraphicsContext:
throw invalid_graphics_context();
case hipErrorInvalidSource:
throw invalid_source();
case hipErrorFileNotFound:
throw file_not_found();
case hipErrorSharedObjectSymbolNotFound:
throw shared_object_symbol_not_found();
case hipErrorSharedObjectInitFailed:
throw shared_object_init_failed();
case hipErrorOperatingSystem:
throw operating_system();
case hipErrorInvalidResourceHandle:
throw invalid_handle();
case hipErrorNotFound:
throw not_found();
case hipErrorNotReady:
throw not_ready();
case hipErrorIllegalAddress:
throw illegal_address();
case hipErrorLaunchOutOfResources:
throw launch_out_of_resources();
case hipErrorLaunchTimeOut:
throw launch_timeout();
// case hipErrorLaunchIncompatibleTexturing : throw
// launch_incompatible_texturing();
case hipErrorPeerAccessAlreadyEnabled:
throw peer_access_already_enabled();
case hipErrorPeerAccessNotEnabled:
throw peer_access_not_enabled();
// case hipErrorPrimaryContextActive : throw primary_context_active();
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
case hipErrorAssert:
throw assert_error();
// case hipErrorTooManyPeers : throw too_many_peers();
case hipErrorHostMemoryAlreadyRegistered:
throw host_memory_already_registered();
case hipErrorHostMemoryNotRegistered:
throw host_memory_not_registered();
// case hipErrorHardwareStackError : throw hardware_stack_error();
// case hipErrorIllegalInstruction : throw illegal_instruction();
// case hipErrorMisalignedAddress : throw misaligned_address();
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
// case hipErrorInvalidPc : throw invalid_pc();
case hipErrorLaunchFailure:
throw launch_failed();
// case hipErrorNotPermitted : throw not_permitted();
case hipErrorNotSupported:
throw not_supported();
case hipErrorUnknown:
throw unknown();
default:
throw unknown();
}
}
} // namespace driver
} // namespace triton

View File

@@ -1,73 +1,73 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <fstream>
#if __has_include(<unistd.h>)
#include <unistd.h>
#include <unistd.h>
#endif
#include <memory>
#include <regex>
#include "triton/driver/llvm.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/tools/sha1.hpp"
#include "triton/tools/sys/exec.hpp"
#include "triton/tools/sys/getenv.hpp"
#include "triton/tools/sys/mkdir.hpp"
#include "triton/tools/sys/exec.hpp"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <memory>
#include <regex>
// begin AMD stuff
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
// end AMD stuff
extern "C"{
int set_curterm(char* nterm){ return 0; }
int del_curterm(char* nterm){ return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
extern "C" {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
namespace triton{
namespace driver{
namespace triton {
namespace driver {
void init_llvm() {
LLVMInitializeNVPTXTargetInfo();
@@ -80,82 +80,93 @@ void init_llvm() {
LLVMInitializeAMDGPUAsmPrinter();
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace);
if(start_replace == std::string::npos)
if (start_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
std::string path_to_ptxas(int& version) {
std::string path_to_ptxas(int &version) {
std::vector<std::string> rets;
std::string ret;
// search pathes for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
if (!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
for (std::string prefix : ptxas_prefixes) {
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if(works) {
if (works) {
working_ptxas.push_back(ptxas);
rets.push_back(ret);
}
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
if (working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, "
"/usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
bool found = false;
// currently choosing the first ptxas. Other logics can be implemented in future
for(std::string ret : rets) {
if(std::regex_search(ret, match, version_regex)){
// currently choosing the first ptxas. Other logics can be implemented in
// future
for (std::string ret : rets) {
if (std::regex_search(ret, match, version_regex)) {
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major*1000 + minor*10;
version = major * 1000 + minor * 10;
found = true;
break;
}
}
if ( not found) {
if (not found) {
throw std::runtime_error("Error in parsing version");
}
return ptxas;
}
int vptx(int version){
if(version >= 11040) return 74;
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
if(version >= 11000) return 70;
if(version >= 10020) return 65;
if(version >= 10010) return 64;
if(version >= 10000) return 63;
int vptx(int version) {
if (version >= 11040)
return 74;
if (version >= 11030)
return 73;
if (version >= 11020)
return 72;
if (version >= 11010)
return 71;
if (version >= 11000)
return 70;
if (version >= 10020)
return 65;
if (version >= 10010)
return 64;
if (version >= 10000)
return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string llir_to_ptx(llvm::Module *module, int cc, int version) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
auto *short_ptr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
@@ -170,7 +181,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
@@ -181,16 +193,18 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
@@ -200,19 +214,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas,
int cc) {
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
@@ -221,15 +241,16 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc +
" -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if(err != 0){
if (err != 0) {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
@@ -237,7 +258,7 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
@@ -251,11 +272,11 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc) {
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create
llvm::SmallVector<char, 0> buffer;
@@ -270,17 +291,18 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive);
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
@@ -295,33 +317,37 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
std::string isabin_path =
std::string("/tmp/") + module_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
{
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
if (ec) {
std::cout << isabin_path << " was not created. error code: " << ec
<< std::endl;
}
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
llvm::CGFT_ObjectFile);
pass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
std::string amdgcn_path =
std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
std::string hsaco_path =
std::string("/tmp/") + module_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu",
"-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result)
{
if (lld_result) {
std::cout << "ld.lld execute fail: " << std::endl;
std::cout << error_message << std::endl;
std::cout << lld_result << std::endl;
@@ -330,33 +356,29 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
return hsaco_path;
}
hipModule_t amdgpu_to_hipmodule(const std::string& path) {
hipModule_t amdgpu_to_hipmodule(const std::string &path) {
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
hipJitOptionLogVerbose};
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize,
(void*)_err, (void*)(uintptr_t)logbufsize,
(void*)_log, (void*)1};
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret;
}
}
}
} // namespace driver
} // namespace triton