[general] removed LLVM #include's in all Triton headers

This commit is contained in:
Philippe Tillet
2019-08-16 15:56:58 -07:00
parent 4de22df930
commit c7cb5f82ad
21 changed files with 454 additions and 284 deletions

View File

@@ -154,8 +154,6 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
stream->synchronize();
// run
rt::function function(src(AT, BT, ty, ty, ty, 8, 8));
std::cout << function.make_tensorflow_src({2}, "(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1") << std::endl;
exit(EXIT_FAILURE);
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [&](const rt::params_t& x) { return rt::grid_t{ceil(M, x.at("TM")), ceil(N, x.at("TN")), 1}; };

View File

@@ -1,7 +1,6 @@
#ifndef TDL_INCLUDE_CODEGEN_SELECTION_H
#define TDL_INCLUDE_CODEGEN_SELECTION_H
#include "llvm/IR/IRBuilder.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
@@ -16,6 +15,28 @@ namespace llvm{
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
@@ -35,12 +56,12 @@ class info;
}
class target;
typedef std::vector<llvm::Value*> indices_t;
typedef std::vector<Value*> indices_t;
struct distributed_axis {
size_t contiguous;
std::vector<llvm::Value*> values;
llvm::Value* thread_id;
std::vector<Value*> values;
Value* thread_id;
};
class tile {
@@ -48,40 +69,40 @@ protected:
typedef std::vector<unsigned> shapes_t;
public:
tile(llvm::Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
virtual void set_value(indices_t idx, llvm::Value *v) = 0;
virtual llvm::Value* get_value(indices_t idx) = 0;
llvm::Type *get_ty() const { return ty_; }
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
virtual void set_value(indices_t idx, Value *v) = 0;
virtual Value* get_value(indices_t idx) = 0;
Type *get_ty() const { return ty_; }
shapes_t get_shapes() const { return shapes_; }
protected:
llvm::Type *ty_;
Type *ty_;
shapes_t shapes_;
};
class shared_tile: public tile {
private:
void extract_constant(llvm::Value *arg, llvm::Value *&non_cst, llvm::Value *&cst);
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
public:
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
shared_tile(Type* ty, const shapes_t &shapes, Value* ptr, Builder &builder, Value* offset = nullptr);
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, llvm::Value *);
llvm::Value* get_ptr_to(indices_t idx);
llvm::Value* get_value(indices_t idx);
llvm::Value* get_pointer() { return ptr_; }
llvm::Value* get_offset() { return offset_; }
static llvm::Value* shared_offset(llvm::IRBuilder<>& builder, const shapes_t& shapes, indices_t idx);
void set_value(indices_t, Value *);
Value* get_ptr_to(indices_t idx);
Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, indices_t idx);
private:
llvm::Value *ptr_;
Value *ptr_;
bool return_vector_;
llvm::Value *offset_;
llvm::IRBuilder<> &builder_;
std::map<indices_t, llvm::Value*> ptr_cache_;
Value *offset_;
Builder &builder_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
};
@@ -90,16 +111,16 @@ class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
typedef std::vector<indices_t> ordered_indices_vec_t;
typedef std::map<indices_t, unsigned> indices_map_t;
typedef std::map<indices_t, llvm::Value*> values_map_t;
typedef std::map<indices_t, Value*> values_map_t;
private:
void init_indices();
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
Type *make_vector_ty(Type *ty, size_t vector_size);
public:
distributed_tile(llvm::Type *ty, const shapes_t& shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize);
void set_value(indices_t idx, llvm::Value *v);
llvm::Value* get_value(indices_t idx);
distributed_tile(Type *ty, const shapes_t& shapes, const axes_t &axes, Builder &builder, bool vectorize);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn);
@@ -111,25 +132,15 @@ private:
values_map_t values_;
ordered_indices_vec_t ordered_indices_;
size_t vector_size_;
llvm::IRBuilder<> &builder_;
Builder &builder_;
};
// Selection pass
class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::value *, Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::IRBuilder<> Builder;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
private:
// utils
Type *make_vector_ty(Type *ty, size_t vector_size);

View File

@@ -4,14 +4,36 @@
#include <map>
#include <set>
#include <vector>
#include "llvm/IR/IRBuilder.h"
namespace llvm{
class Instruction;
class Value;
class Module;
class LLVMContext;
class Function;
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
@@ -21,13 +43,13 @@ class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
virtual ~target() {}
virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0;
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
virtual llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
virtual llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
bool is_gpu() const;
private:
@@ -37,37 +59,37 @@ private:
class amd_cl_target: public target {
public:
amd_cl_target(): target(true){}
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
};
class nvidia_cu_target: public target {
public:
nvidia_cu_target(): target(true){}
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
};
class cpu_target: public target {
public:
cpu_target(): target(false){}
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
};
}

View File

@@ -59,7 +59,7 @@ public:
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Cast instructions
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
@@ -71,7 +71,7 @@ public:
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
// Binary instructions
value *create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs, const std::string &name = "");
value *create_fdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_frem(value *lhs, value *rhs, const std::string &name = "");
@@ -89,7 +89,7 @@ public:
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name = "");
// Comparison (int)
value *create_icmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = "");
@@ -101,7 +101,7 @@ public:
value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpNE(value *lhs, value *rhs, const std::string &name = "");
// Comparison (float)
value *create_fcmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = "");

View File

@@ -1,9 +1,9 @@
#ifndef TDL_INCLUDE_IR_CONSTANT_H
#define TDL_INCLUDE_IR_CONSTANT_H
#include "enums.h"
#include "value.h"
#include <cassert>
#include "llvm/IR/Instructions.h"
namespace triton{
namespace ir{
@@ -65,8 +65,7 @@ private:
};
class constant_expression: public constant_int {
typedef llvm::BinaryOperator::BinaryOps op_t;
using llop = llvm::BinaryOperator::BinaryOps;
typedef binary_op_t op_t;
private:
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);

View File

@@ -2,15 +2,19 @@
#define TDL_INCLUDE_IR_INSTRUCTIONS_H
#include <vector>
#include <map>
#include "triton/ir/enums.h"
#include "triton/ir/constant.h"
#include "triton/ir/value.h"
#include "triton/ir/type.h"
#include "triton/ir/metadata.h"
#include "llvm/IR/Instructions.h"
namespace triton{
namespace ir{
class constant_int;
class constant;
class constant_range;
class basic_block;
class context;
@@ -95,19 +99,18 @@ private:
//===----------------------------------------------------------------------===//
class binary_operator: public instruction{
public:
typedef llvm::BinaryOperator::BinaryOps op_t;
using llop = llvm::BinaryOperator::BinaryOps;
typedef binary_op_t op_t;
private:
std::string repr_impl() const;
protected:
// Constructors
binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
public:
// Get operand
op_t get_op() const { return op_; }
binary_op_t get_op() const { return op_; }
// Bool
bool is_terminator() const;
@@ -127,14 +130,14 @@ public:
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
// Factory methods
static binary_operator *create(op_t op, value *lhs, value *rhs,
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
public:
op_t op_;
binary_op_t op_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
};
@@ -146,30 +149,28 @@ public:
class cmp_inst: public instruction{
public:
typedef llvm::CmpInst::Predicate pred_t;
using llop = llvm::CmpInst;
typedef cmp_pred_t pred_t;
private:
std::string repr_impl() const;
protected:
cmp_inst(type *ty, pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
static bool is_fp_predicate(pred_t pred);
static bool is_int_predicate(pred_t pred);
cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
static bool is_fp_predicate(cmp_pred_t pred);
static bool is_int_predicate(cmp_pred_t pred);
static type* make_cmp_result_type(type *ty);
public:
pred_t get_pred() const { return pred_; }
cmp_pred_t get_pred() const { return pred_; }
private:
pred_t pred_;
cmp_pred_t pred_;
};
class icmp_inst: public cmp_inst{
using cmp_inst::cmp_inst;
public:
static icmp_inst* create(pred_t pred, value *lhs, value *rhs,
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
};
@@ -177,7 +178,7 @@ class fcmp_inst: public cmp_inst{
using cmp_inst::cmp_inst;
public:
static fcmp_inst* create(pred_t pred, value *lhs, value *rhs,
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
};
@@ -196,33 +197,28 @@ protected:
//===----------------------------------------------------------------------===//
class cast_inst: public unary_inst{
using ic = llvm::Instruction::CastOps;
private:
std::string repr_impl() const;
public:
typedef llvm::CastInst::CastOps op_t;
protected:
cast_inst(type *ty, value *v, const std::string &name, instruction *next, op_t op)
cast_inst(type *ty, value *v, const std::string &name, instruction *next, cast_op_t op)
: unary_inst(ty, v, name, next), op_(op) { }
private:
static bool is_valid(op_t op, value *arg, type *ty);
static bool is_valid(cast_op_t op, value *arg, type *ty);
public:
// accessors
op_t get_op() const { return op_; }
cast_op_t get_op() const { return op_; }
// factory methods
static cast_inst *create(op_t op, value *arg, type *ty,
static cast_inst *create(cast_op_t op, value *arg, type *ty,
const std::string &name = "", instruction *next = nullptr);
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
const std::string &name = "", instruction *next = nullptr);
private:
op_t op_;
cast_op_t op_;
};
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
@@ -232,19 +228,19 @@ class name : public cast_inst{ \
: cast_inst(ty, v, name, next, op){ } \
};
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, llvm::Instruction::CastOps::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, llvm::Instruction::CastOps::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, llvm::Instruction::CastOps::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, llvm::Instruction::CastOps::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, llvm::Instruction::CastOps::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, cast_op_t::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, cast_op_t::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, cast_op_t::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, cast_op_t::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, cast_op_t::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, cast_op_t::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, cast_op_t::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, cast_op_t::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, cast_op_t::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, cast_op_t::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, cast_op_t::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, cast_op_t::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, cast_op_t::AddrSpaceCast)
//===----------------------------------------------------------------------===//
// terminator_inst classes
@@ -591,8 +587,8 @@ private:
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
std::string repr_impl() const {
std::string res = "trans<";
for(ir::constant_int *x: perm_)
res += x->repr() + ",";
//for(ir::constant_int *x: perm_)
// res += x->repr() + ",";
res[res.size()-1] = '>';
return res;
}

View File

@@ -25,6 +25,7 @@
namespace llvm {
class Module;
}
namespace triton {

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"

View File

@@ -1,3 +1,5 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/analysis/tune.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
@@ -7,7 +9,6 @@
#include "triton/ir/constant.h"
#include "triton/driver/device.h"
#include <cstdlib>
namespace triton{

View File

@@ -203,6 +203,88 @@ Value* shared_tile::get_value(indices_t idx) {
return result;
}
llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
using llop = llvm::Instruction::BinaryOps;
using ttop = ir::binary_op_t;
switch(op) {
case ttop::Add: return llop::Add;
case ttop::FAdd: return llop::FAdd;
case ttop::Sub: return llop::Sub;
case ttop::FSub: return llop::FSub;
case ttop::Mul: return llop::Mul;
case ttop::FMul: return llop::FMul;
case ttop::UDiv: return llop::UDiv;
case ttop::SDiv: return llop::SDiv;
case ttop::FDiv: return llop::FDiv;
case ttop::URem: return llop::URem;
case ttop::SRem: return llop::SRem;
case ttop::FRem: return llop::FRem;
case ttop::Shl: return llop::Shl;
case ttop::LShr: return llop::LShr;
case ttop::AShr: return llop::AShr;
case ttop::And: return llop::And;
case ttop::Or: return llop::Or;
case ttop::Xor: return llop::Xor;
}
}
llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) {
using llop = llvm::Instruction::CastOps;
using ttop = ir::cast_op_t;
switch(op){
case ttop::Trunc: return llop::Trunc;
case ttop::ZExt: return llop::ZExt;
case ttop::SExt: return llop::SExt;
case ttop::FPTrunc: return llop::FPTrunc;
case ttop::FPExt: return llop::FPExt;
case ttop::UIToFP: return llop::UIToFP;
case ttop::SIToFP: return llop::SIToFP;
case ttop::FPToUI: return llop::FPToUI;
case ttop::FPToSI: return llop::FPToSI;
case ttop::PtrToInt: return llop::PtrToInt;
case ttop::IntToPtr: return llop::IntToPtr;
case ttop::BitCast: return llop::BitCast;
case ttop::AddrSpaceCast: return llop::AddrSpaceCast;
}
}
llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) {
using llop = llvm::CmpInst::Predicate;
using ttop = ir::cmp_pred_t;
switch(pred){
case ttop::FIRST_FCMP_PREDICATE: return llop::FIRST_FCMP_PREDICATE;
case ttop::FCMP_FALSE: return llop::FCMP_FALSE;
case ttop::FCMP_OEQ: return llop::FCMP_OEQ;
case ttop::FCMP_OGT: return llop::FCMP_OGT;
case ttop::FCMP_OGE: return llop::FCMP_OGE;
case ttop::FCMP_OLT: return llop::FCMP_OLT;
case ttop::FCMP_OLE: return llop::FCMP_OLE;
case ttop::FCMP_ONE: return llop::FCMP_ONE;
case ttop::FCMP_ORD: return llop::FCMP_ORD;
case ttop::FCMP_UNO: return llop::FCMP_UNO;
case ttop::FCMP_UEQ: return llop::FCMP_UEQ;
case ttop::FCMP_UGT: return llop::FCMP_UGT;
case ttop::FCMP_UGE: return llop::FCMP_UGE;
case ttop::FCMP_ULT: return llop::FCMP_ULT;
case ttop::FCMP_ULE: return llop::FCMP_ULE;
case ttop::FCMP_UNE: return llop::FCMP_UNE;
case ttop::FCMP_TRUE: return llop::FCMP_TRUE;
case ttop::LAST_FCMP_PREDICATE: return llop::LAST_FCMP_PREDICATE;
case ttop::FIRST_ICMP_PREDICATE: return llop::FIRST_ICMP_PREDICATE;
case ttop::ICMP_EQ: return llop::ICMP_EQ;
case ttop::ICMP_NE: return llop::ICMP_NE;
case ttop::ICMP_UGT: return llop::ICMP_UGT;
case ttop::ICMP_UGE: return llop::ICMP_UGE;
case ttop::ICMP_ULT: return llop::ICMP_ULT;
case ttop::ICMP_ULE: return llop::ICMP_ULE;
case ttop::ICMP_SGT: return llop::ICMP_SGT;
case ttop::ICMP_SGE: return llop::ICMP_SGE;
case ttop::ICMP_SLT: return llop::ICMP_SLT;
case ttop::ICMP_SLE: return llop::ICMP_SLE;
case ttop::LAST_ICMP_PREDICATE: return llop::LAST_ICMP_PREDICATE;
}
}
/* convert ir::type to Type */
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
@@ -283,24 +365,24 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return builder.Insert(BinaryOperator::Create(ii->get_op(), lhs, rhs));
return builder.Insert(BinaryOperator::Create(llvm_op(ii->get_op()), lhs, rhs));
}
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred();
ir::cmp_pred_t pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return builder.Insert(CmpInst::Create(Instruction::ICmp, pred, lhs, rhs));
return builder.Insert(CmpInst::Create(Instruction::ICmp, llvm_pred(pred), lhs, rhs));
}
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred();
ir::cmp_pred_t pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return builder.Insert(FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs));
return builder.Insert(FCmpInst::Create(Instruction::FCmp, llvm_pred(pred), lhs, rhs));
}
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
Value *arg = value(ii->get_operand(0));
Type *dst_ty = type(ii->get_type()->get_scalar_ty());
return builder.Insert(CastInst::Create(ii->get_op(), arg, dst_ty));
return builder.Insert(CastInst::Create(llvm_op(ii->get_op()), arg, dst_ty));
}
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
// get pointer

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
@@ -187,7 +188,7 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil
auto z = dynamic_cast<ir::binary_operator*>(idx);
if(!z)
return false;
bool is_sub = z->get_op() == ir::binary_operator::llop::Sub;
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
bool is_lhs_0 = lhs && (lhs->get_value()==0);
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();

View File

@@ -36,7 +36,7 @@ namespace transform{
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
bool is_bin_add = bin_op && bin_op->get_op()==llvm::Instruction::Add;
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
if(is_bin_add)
return (ir::instruction*)x;
return nullptr;

View File

@@ -1,10 +1,10 @@
#include <string>
#include <algorithm>
#include "triton/ir/basic_block.h"
#include "triton/ir/builder.h"
#include "triton/ir/constant.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "llvm/IR/Instruction.h"
namespace triton{
namespace ir{
@@ -93,14 +93,14 @@ value *builder::create_ret_void() {
return create_cast(OPCODE, src, dst_ty, name);\
}
DEFINE_CAST_INSTR(si_to_fp, llvm::Instruction::SIToFP)
DEFINE_CAST_INSTR(ui_to_fp, llvm::Instruction::UIToFP)
DEFINE_CAST_INSTR(fp_to_si, llvm::Instruction::FPToSI)
DEFINE_CAST_INSTR(fp_to_ui, llvm::Instruction::FPToUI)
DEFINE_CAST_INSTR(fp_ext, llvm::Instruction::FPExt)
DEFINE_CAST_INSTR(fp_trunc, llvm::Instruction::FPTrunc)
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI)
DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt)
DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc)
value* builder::create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name){
value* builder::create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name){
return insert(cast_inst::create(op, v, dst_ty), name);
}
@@ -131,11 +131,11 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
}
// Binary
DEFINE_BINARY_FLOAT(fmul, llvm::Instruction::FMul)
DEFINE_BINARY_FLOAT(fdiv, llvm::Instruction::FDiv)
DEFINE_BINARY_FLOAT(frem, llvm::Instruction::FRem)
DEFINE_BINARY_FLOAT(fadd, llvm::Instruction::FAdd)
DEFINE_BINARY_FLOAT(fsub, llvm::Instruction::FSub)
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
// Unary
DEFINE_UNARY_FLOAT(fneg)
@@ -145,7 +145,7 @@ DEFINE_UNARY_FLOAT(fneg)
//===----------------------------------------------------------------------===//
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
value *rhs, const std::string &name,
bool has_nuw, bool has_nsw) {
auto *clhs = dynamic_cast<constant_int*>(lhs);
@@ -180,18 +180,18 @@ value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs
}
// Binary
DEFINE_NOWRAP_BINARY(mul, llvm::Instruction::Mul)
DEFINE_NOWRAP_BINARY(add, llvm::Instruction::Add)
DEFINE_NOWRAP_BINARY(sub, llvm::Instruction::Sub)
DEFINE_NOWRAP_BINARY(shl, llvm::Instruction::Shl)
DEFINE_NOWRAP_BINARY(ashr, llvm::Instruction::AShr)
DEFINE_BINARY_INT(sdiv, llvm::Instruction::SDiv)
DEFINE_BINARY_INT(udiv, llvm::Instruction::UDiv)
DEFINE_BINARY_INT(srem, llvm::Instruction::SRem)
DEFINE_BINARY_INT(urem, llvm::Instruction::URem)
DEFINE_BINARY_INT(and, llvm::Instruction::And)
DEFINE_BINARY_INT(or, llvm::Instruction::Or)
DEFINE_BINARY_INT(xor, llvm::Instruction::Xor)
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
DEFINE_NOWRAP_BINARY(add, binary_op_t::Add)
DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub)
DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl)
DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr)
DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv)
DEFINE_BINARY_INT(udiv, binary_op_t::UDiv)
DEFINE_BINARY_INT(srem, binary_op_t::SRem)
DEFINE_BINARY_INT(urem, binary_op_t::URem)
DEFINE_BINARY_INT(and, binary_op_t::And)
DEFINE_BINARY_INT(or, binary_op_t::Or)
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
// Unary
DEFINE_UNARY_INT(neg)
DEFINE_UNARY_INT(not)
@@ -209,7 +209,7 @@ value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list, cons
// icmp instructions
//===----------------------------------------------------------------------===//
value *builder::create_icmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name){
value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){
return insert(icmp_inst::create(pred, lhs, rhs), name);
}
@@ -219,25 +219,25 @@ value *builder::create_icmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const
}
// Signed
DEFINE_ICMP_INSTR(SLE, llvm::ICmpInst::ICMP_SLE)
DEFINE_ICMP_INSTR(SLT, llvm::ICmpInst::ICMP_SLT)
DEFINE_ICMP_INSTR(SGE, llvm::ICmpInst::ICMP_SGE)
DEFINE_ICMP_INSTR(SGT, llvm::ICmpInst::ICMP_SGT)
DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE)
DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT)
DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE)
DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT)
// Unsigned
DEFINE_ICMP_INSTR(ULE, llvm::ICmpInst::ICMP_ULE)
DEFINE_ICMP_INSTR(ULT, llvm::ICmpInst::ICMP_ULT)
DEFINE_ICMP_INSTR(UGE, llvm::ICmpInst::ICMP_UGE)
DEFINE_ICMP_INSTR(UGT, llvm::ICmpInst::ICMP_UGT)
DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE)
DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT)
DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE)
DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT)
// General
DEFINE_ICMP_INSTR(EQ, llvm::ICmpInst::ICMP_EQ)
DEFINE_ICMP_INSTR(NE, llvm::ICmpInst::ICMP_NE)
DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ)
DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE)
//===----------------------------------------------------------------------===//
// fcmp instructions
//===----------------------------------------------------------------------===//
value *builder::create_fcmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name){
value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){
return insert(fcmp_inst::create(pred, lhs, rhs), name);
}
@@ -247,12 +247,12 @@ value *builder::create_fcmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const
}
// Ordered
DEFINE_FCMP_INSTR(OLE, llvm::FCmpInst::FCMP_OLE)
DEFINE_FCMP_INSTR(OLT, llvm::FCmpInst::FCMP_OLT)
DEFINE_FCMP_INSTR(OGE, llvm::FCmpInst::FCMP_OGE)
DEFINE_FCMP_INSTR(OGT, llvm::FCmpInst::FCMP_OGT)
DEFINE_FCMP_INSTR(OEQ, llvm::FCmpInst::FCMP_OEQ)
DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE)
DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE)
DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT)
DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE)
DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)

View File

@@ -145,19 +145,19 @@ uint64_t constant_expression::get_value() const {
uint64_t lhs = lhs_->get_value();
uint64_t rhs = rhs_->get_value();
switch(op_) {
case llop::Add : return lhs + rhs;
case llop::Sub : return lhs - rhs;
case llop::Mul : return lhs * rhs;
case llop::UDiv : return lhs / rhs;
case llop::SDiv : return lhs / rhs;
case llop::URem : return lhs % rhs;
case llop::SRem : return lhs % rhs;
case llop::Shl : return lhs << rhs;
case llop::LShr : return lhs >> rhs;
case llop::AShr : return lhs >> rhs;
case llop::And : return lhs && rhs;
case llop::Or : return lhs || rhs;
case llop::Xor : return lhs ^ rhs;
case op_t::Add : return lhs + rhs;
case op_t::Sub : return lhs - rhs;
case op_t::Mul : return lhs * rhs;
case op_t::UDiv : return lhs / rhs;
case op_t::SDiv : return lhs / rhs;
case op_t::URem : return lhs % rhs;
case op_t::SRem : return lhs % rhs;
case op_t::Shl : return lhs << rhs;
case op_t::LShr : return lhs >> rhs;
case op_t::AShr : return lhs >> rhs;
case op_t::And : return lhs && rhs;
case op_t::Or : return lhs || rhs;
case op_t::Xor : return lhs ^ rhs;
default: throw std::runtime_error("unsupported constexpr binary operator");
}
}

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/ir/context.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
@@ -87,60 +88,60 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
std::string binary_operator::repr_impl() const {
switch(op_) {
case llop::Add : return "add";
case llop::FAdd : return "fadd";
case llop::Sub : return "sub";
case llop::FSub : return "fsub";
case llop::Mul : return "mul";
case llop::FMul : return "fmul";
case llop::UDiv : return "udiv";
case llop::SDiv : return "sdiv";
case llop::FDiv : return "fdiv";
case llop::URem : return "urem";
case llop::SRem : return "srem";
case llop::FRem : return "frem";
case llop::Shl : return "shl";
case llop::LShr : return "lshr";
case llop::AShr : return "ashr";
case llop::And : return "and";
case llop::Or : return "or";
case llop::Xor : return "xor";
case Add : return "add";
case FAdd : return "fadd";
case Sub : return "sub";
case FSub : return "fsub";
case Mul : return "mul";
case FMul : return "fmul";
case UDiv : return "udiv";
case SDiv : return "sdiv";
case FDiv : return "fdiv";
case URem : return "urem";
case SRem : return "srem";
case FRem : return "frem";
case Shl : return "shl";
case LShr : return "lshr";
case AShr : return "ashr";
case And : return "and";
case Or : return "or";
case Xor : return "xor";
default: throw std::runtime_error("unknown binary operator");
}
}
bool binary_operator::is_int_div() const {
return op_ == llop::UDiv || op_ == llop::SDiv;
return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv;
}
bool binary_operator::is_int_rem() const {
return op_ == llop::URem || op_ == llop::SRem;
return op_ == binary_op_t::URem || op_ == binary_op_t::SRem;
}
bool binary_operator::is_shl() const {
return op_ == llop::Shl;
return op_ == binary_op_t::Shl;
}
bool binary_operator::is_shr() const {
return op_ == llop::LShr || op_ == llop::AShr;
return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr;
}
bool binary_operator::is_int_mult() const {
return op_ == llop::Mul;
return op_ == binary_op_t::Mul;
}
bool binary_operator::is_int_add_sub() const {
return op_ == llop::Add || op_ == llop::Sub;
return op_ == binary_op_t::Add || op_ == binary_op_t::Sub;
}
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, 2, 1, name, next), op_(op){
set_operand(0, lhs);
set_operand(1, rhs);
}
binary_operator *binary_operator::create(op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(lhs->get_type() == rhs->get_type() &&
"Cannot create binary operator with two operands of differing type!");
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
@@ -149,19 +150,19 @@ binary_operator *binary_operator::create(op_t op, value *lhs, value *rhs, const
binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
return binary_operator::create(llvm::Instruction::FSub, zero, arg, name, next);
return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
}
binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
return binary_operator::create(llvm::Instruction::Sub, zero, arg, name, next);
return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
}
binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->is_integer_ty());
constant *mask = constant::get_all_ones_value(arg->get_type());
return binary_operator::create(llvm::Instruction::Xor, arg, mask, name, next);
return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
}
//===----------------------------------------------------------------------===//
@@ -171,37 +172,37 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name
// cmp_inst
std::string cmp_inst::repr_impl() const {
switch (pred_) {
case llop::FCMP_FALSE : return "false";
case llop::FCMP_OEQ : return "fcmp_oeq";
case llop::FCMP_OGT : return "fcmp_ogt";
case llop::FCMP_OGE : return "fcmp_oge";
case llop::FCMP_OLT : return "fcmp_olt";
case llop::FCMP_OLE : return "fcmp_ole";
case llop::FCMP_ONE : return "fcmp_one";
case llop::FCMP_ORD : return "fcmp_ord";
case llop::FCMP_UNO : return "fcmp_uno";
case llop::FCMP_UEQ : return "fcmp_ueq";
case llop::FCMP_UGT : return "fcmp_ugt";
case llop::FCMP_UGE : return "fcmp_uge";
case llop::FCMP_ULT : return "fcmp_ult";
case llop::FCMP_ULE : return "fcmp_ule";
case llop::FCMP_UNE : return "fcmp_une";
case llop::FCMP_TRUE : return "true";
case llop::ICMP_EQ : return "icmp_eq";
case llop::ICMP_NE : return "icmp_ne";
case llop::ICMP_UGT : return "icmp_ugt";
case llop::ICMP_UGE : return "icmp_uge";
case llop::ICMP_ULT : return "icmp_ult";
case llop::ICMP_ULE : return "icmp_ule";
case llop::ICMP_SGT : return "icmp_sgt";
case llop::ICMP_SGE : return "icmp_sge";
case llop::ICMP_SLT : return "icmp_slt";
case llop::ICMP_SLE : return "icmp_sle";
case FCMP_FALSE : return "false";
case FCMP_OEQ : return "fcmp_oeq";
case FCMP_OGT : return "fcmp_ogt";
case FCMP_OGE : return "fcmp_oge";
case FCMP_OLT : return "fcmp_olt";
case FCMP_OLE : return "fcmp_ole";
case FCMP_ONE : return "fcmp_one";
case FCMP_ORD : return "fcmp_ord";
case FCMP_UNO : return "fcmp_uno";
case FCMP_UEQ : return "fcmp_ueq";
case FCMP_UGT : return "fcmp_ugt";
case FCMP_UGE : return "fcmp_uge";
case FCMP_ULT : return "fcmp_ult";
case FCMP_ULE : return "fcmp_ule";
case FCMP_UNE : return "fcmp_une";
case FCMP_TRUE : return "true";
case ICMP_EQ : return "icmp_eq";
case ICMP_NE : return "icmp_ne";
case ICMP_UGT : return "icmp_ugt";
case ICMP_UGE : return "icmp_uge";
case ICMP_ULT : return "icmp_ult";
case ICMP_ULE : return "icmp_ule";
case ICMP_SGT : return "icmp_sgt";
case ICMP_SGE : return "icmp_sge";
case ICMP_SLT : return "icmp_slt";
case ICMP_SLE : return "icmp_sle";
default: throw std::runtime_error("unreachable");
}
}
cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
cmp_inst::cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
: instruction(ty, 2, 1, name, next), pred_(pred) {
set_operand(0, lhs);
set_operand(1, rhs);
@@ -215,23 +216,23 @@ type* cmp_inst::make_cmp_result_type(type *ty){
}
bool cmp_inst::is_fp_predicate(pred_t pred) {
return pred >= llop::FIRST_FCMP_PREDICATE && pred <= llop::LAST_FCMP_PREDICATE;
bool cmp_inst::is_fp_predicate(cmp_pred_t pred) {
return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE;
}
bool cmp_inst::is_int_predicate(pred_t pred) {
return pred >= llop::FIRST_ICMP_PREDICATE && pred <= llop::LAST_ICMP_PREDICATE;
bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
}
// icmp_inst
icmp_inst* icmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_int_predicate(pred));
type *res_ty = make_cmp_result_type(lhs->get_type());
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
}
// fcmp_inst
fcmp_inst* fcmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_fp_predicate(pred));
type *res_ty = make_cmp_result_type(lhs->get_type());
return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
@@ -252,45 +253,45 @@ unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction
std::string cast_inst::repr_impl() const {
switch (op_){
case ic::Trunc: return "trunc";
case ic::ZExt: return "zext";
case ic::SExt: return "sext";
case ic::FPTrunc: return "fp_trunc";
case ic::FPExt: return "fp_ext";
case ic::UIToFP: return "ui_to_fp";
case ic::SIToFP: return "si_to_fp";
case ic::FPToUI: return "fp_to_ui";
case ic::FPToSI: return "fp_to_si";
case ic::PtrToInt: return "ptr_to_int";
case ic::IntToPtr: return "int_to_ptr";
case ic::BitCast: return "bitcast";
case ic::AddrSpaceCast: return "addr_space_cast";
case cast_op_t::Trunc: return "trunc";
case cast_op_t::ZExt: return "zext";
case cast_op_t::SExt: return "sext";
case cast_op_t::FPTrunc: return "fp_trunc";
case cast_op_t::FPExt: return "fp_ext";
case cast_op_t::UIToFP: return "ui_to_fp";
case cast_op_t::SIToFP: return "si_to_fp";
case cast_op_t::FPToUI: return "fp_to_ui";
case cast_op_t::FPToSI: return "fp_to_si";
case cast_op_t::PtrToInt: return "ptr_to_int";
case cast_op_t::IntToPtr: return "int_to_ptr";
case cast_op_t::BitCast: return "bitcast";
case cast_op_t::AddrSpaceCast: return "addr_space_cast";
default: throw std::runtime_error("unreachable");
}
}
// TODO
bool cast_inst::is_valid(op_t op, value *arg, type *ty) {
bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
assert(arg->get_type()->is_tile_ty() == ty->is_tile_ty());
return true;
}
cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &name, instruction *next){
cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){
assert(is_valid(op, arg, ty) && "Invalid cast!");
// Construct and return the appropriate CastInst subclass
switch (op) {
case ic::Trunc: return new trunc_inst (ty, arg, name, next);
case ic::ZExt: return new z_ext_inst (ty, arg, name, next);
case ic::SExt: return new s_ext_inst (ty, arg, name, next);
case ic::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
case ic::FPExt: return new fp_ext_inst (ty, arg, name, next);
case ic::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
case ic::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
case ic::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
case ic::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
case ic::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
case ic::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
case ic::BitCast: return new bit_cast_inst (ty, arg, name, next);
case ic::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next);
case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next);
case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next);
case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next);
case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next);
case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
default: throw std::runtime_error("unreachable");
}
}
@@ -300,9 +301,9 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
op_t op = (arg_bits == dst_bits ? ic::BitCast :
(arg_bits > dst_bits ? ic::Trunc :
(is_signed ? ic::SExt : ic::ZExt)));
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
(arg_bits > dst_bits ? cast_op_t::Trunc :
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
return create(op, arg, ty, name, next);
}

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/type.h"

View File

@@ -1,3 +1,4 @@
#include <algorithm>
#include "triton/lang/statement.h"
#include "triton/lang/declaration.h"
#include "triton/ir/function.h"

Binary file not shown.

View File

@@ -1,4 +1,10 @@
import libtriton
import tensorflow as tf
import distutils
import distutils.log
import setuptools.command.build_ext
import setuptools
import os
src = """
const tunable int TM = {128};
@@ -9,7 +15,7 @@ void matmul(restrict read_only align(16) half *A,
restrict read_only align(16) half *B,
restrict read_only align(16) half *C,
int M, int N, int K,
multiple_of(8) int lda, multiple_of(8)" int ldb, int ldc) {
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);
@@ -39,4 +45,51 @@ void matmul(restrict read_only align(16) half *A,
}
"""
print(libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1'))
with open('test.cpp', 'w+') as test:
src = libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1')
test.writelines(src)
triton_include_dirs = ['/home/philippe/development/triton/include']
tensorflow_include_dirs = [tf.sysconfig.get_include()]
llvm_include_dirs = ['/usr/include/llvm-8/', '/usr/include/llvm-c-8/']
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
triton_library_dirs = [os.path.realpath(libtriton.__file__)]
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
extra_compile_args = []
extra_link_args = []
library_dirs = tensorflow_library_dirs
libraries = ['tensorflow_framework']
ext = setuptools.Extension(
name = 'test',
language = 'c++',
sources = ['/home/philippe/development/triton/python/examples/test.cpp'],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args,
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries
)
build_path = '.'
args = ['build_ext']
#args.append('--build-temp=' + build_path)
#args.append('--build-lib=' + build_path)
args.append('-q')
args = dict(
name = 'test',
ext_modules = [ext],
script_args = args,
cmdclass = {
'build_ext': setuptools.command.build_ext.build_ext
}
)
setuptools.setup(**args)
library_dir = os.path.dirname(os.path.realpath(__file__))
module = tf.load_op_library(os.path.join(library_dir, 'build/lib.linux-x86_64-3.6/test.cpython-36m-x86_64-linux-gnu.so'))
print(module.matmul)