[general] removed LLVM #include's in all Triton headers
This commit is contained in:
@@ -154,8 +154,6 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
stream->synchronize();
|
||||
// run
|
||||
rt::function function(src(AT, BT, ty, ty, ty, 8, 8));
|
||||
std::cout << function.make_tensorflow_src({2}, "(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1") << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [&](const rt::params_t& x) { return rt::grid_t{ceil(M, x.at("TM")), ceil(N, x.at("TN")), 1}; };
|
||||
|
@@ -1,7 +1,6 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_SELECTION_H
|
||||
#define TDL_INCLUDE_CODEGEN_SELECTION_H
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -16,6 +15,28 @@ namespace llvm{
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
@@ -35,12 +56,12 @@ class info;
|
||||
}
|
||||
class target;
|
||||
|
||||
typedef std::vector<llvm::Value*> indices_t;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
|
||||
struct distributed_axis {
|
||||
size_t contiguous;
|
||||
std::vector<llvm::Value*> values;
|
||||
llvm::Value* thread_id;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class tile {
|
||||
@@ -48,40 +69,40 @@ protected:
|
||||
typedef std::vector<unsigned> shapes_t;
|
||||
|
||||
public:
|
||||
tile(llvm::Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
|
||||
virtual void set_value(indices_t idx, llvm::Value *v) = 0;
|
||||
virtual llvm::Value* get_value(indices_t idx) = 0;
|
||||
llvm::Type *get_ty() const { return ty_; }
|
||||
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
|
||||
virtual void set_value(indices_t idx, Value *v) = 0;
|
||||
virtual Value* get_value(indices_t idx) = 0;
|
||||
Type *get_ty() const { return ty_; }
|
||||
shapes_t get_shapes() const { return shapes_; }
|
||||
|
||||
protected:
|
||||
llvm::Type *ty_;
|
||||
Type *ty_;
|
||||
shapes_t shapes_;
|
||||
};
|
||||
|
||||
class shared_tile: public tile {
|
||||
private:
|
||||
void extract_constant(llvm::Value *arg, llvm::Value *&non_cst, llvm::Value *&cst);
|
||||
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
|
||||
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
|
||||
|
||||
|
||||
public:
|
||||
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
|
||||
shared_tile(Type* ty, const shapes_t &shapes, Value* ptr, Builder &builder, Value* offset = nullptr);
|
||||
void set_vector_size(unsigned vector_size);
|
||||
void set_return_mode(bool return_vector);
|
||||
void set_value(indices_t, llvm::Value *);
|
||||
llvm::Value* get_ptr_to(indices_t idx);
|
||||
llvm::Value* get_value(indices_t idx);
|
||||
llvm::Value* get_pointer() { return ptr_; }
|
||||
llvm::Value* get_offset() { return offset_; }
|
||||
static llvm::Value* shared_offset(llvm::IRBuilder<>& builder, const shapes_t& shapes, indices_t idx);
|
||||
void set_value(indices_t, Value *);
|
||||
Value* get_ptr_to(indices_t idx);
|
||||
Value* get_value(indices_t idx);
|
||||
Value* get_pointer() { return ptr_; }
|
||||
Value* get_offset() { return offset_; }
|
||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, indices_t idx);
|
||||
|
||||
private:
|
||||
llvm::Value *ptr_;
|
||||
Value *ptr_;
|
||||
bool return_vector_;
|
||||
llvm::Value *offset_;
|
||||
llvm::IRBuilder<> &builder_;
|
||||
std::map<indices_t, llvm::Value*> ptr_cache_;
|
||||
Value *offset_;
|
||||
Builder &builder_;
|
||||
std::map<indices_t, Value*> ptr_cache_;
|
||||
unsigned vector_size_;
|
||||
};
|
||||
|
||||
@@ -90,16 +111,16 @@ class distributed_tile: public tile{
|
||||
typedef std::vector<distributed_axis> axes_t;
|
||||
typedef std::vector<indices_t> ordered_indices_vec_t;
|
||||
typedef std::map<indices_t, unsigned> indices_map_t;
|
||||
typedef std::map<indices_t, llvm::Value*> values_map_t;
|
||||
typedef std::map<indices_t, Value*> values_map_t;
|
||||
|
||||
private:
|
||||
void init_indices();
|
||||
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
|
||||
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||
|
||||
public:
|
||||
distributed_tile(llvm::Type *ty, const shapes_t& shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize);
|
||||
void set_value(indices_t idx, llvm::Value *v);
|
||||
llvm::Value* get_value(indices_t idx);
|
||||
distributed_tile(Type *ty, const shapes_t& shapes, const axes_t &axes, Builder &builder, bool vectorize);
|
||||
void set_value(indices_t idx, Value *v);
|
||||
Value* get_value(indices_t idx);
|
||||
unsigned get_linear_index(indices_t idx);
|
||||
indices_t get_ordered_indices(unsigned id);
|
||||
void for_each(std::function<void(indices_t)> fn);
|
||||
@@ -111,25 +132,15 @@ private:
|
||||
values_map_t values_;
|
||||
ordered_indices_vec_t ordered_indices_;
|
||||
size_t vector_size_;
|
||||
llvm::IRBuilder<> &builder_;
|
||||
Builder &builder_;
|
||||
};
|
||||
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
||||
typedef std::map<ir::value *, Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::IRBuilder<> Builder;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
|
||||
private:
|
||||
// utils
|
||||
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||
|
@@ -4,14 +4,36 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
||||
namespace llvm{
|
||||
class Instruction;
|
||||
class Value;
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
class Function;
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
@@ -21,13 +43,13 @@ class target {
|
||||
public:
|
||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||
virtual ~target() {}
|
||||
virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0;
|
||||
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
||||
virtual llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
||||
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
|
||||
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
|
||||
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
|
||||
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -37,37 +59,37 @@ private:
|
||||
class amd_cl_target: public target {
|
||||
public:
|
||||
amd_cl_target(): target(true){}
|
||||
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
public:
|
||||
nvidia_cu_target(): target(true){}
|
||||
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
public:
|
||||
cpu_target(): target(false){}
|
||||
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -59,7 +59,7 @@ public:
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
|
||||
@@ -71,7 +71,7 @@ public:
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
|
||||
// Binary instructions
|
||||
value *create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
|
||||
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
|
||||
value *create_fmul(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fdiv(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_frem(value *lhs, value *rhs, const std::string &name = "");
|
||||
@@ -89,7 +89,7 @@ public:
|
||||
// GEP
|
||||
value *create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name = "");
|
||||
// Comparison (int)
|
||||
value *create_icmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = "");
|
||||
@@ -101,7 +101,7 @@ public:
|
||||
value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpNE(value *lhs, value *rhs, const std::string &name = "");
|
||||
// Comparison (float)
|
||||
value *create_fcmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = "");
|
||||
|
@@ -1,9 +1,9 @@
|
||||
#ifndef TDL_INCLUDE_IR_CONSTANT_H
|
||||
#define TDL_INCLUDE_IR_CONSTANT_H
|
||||
|
||||
#include "enums.h"
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
#include "llvm/IR/Instructions.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -65,8 +65,7 @@ private:
|
||||
};
|
||||
|
||||
class constant_expression: public constant_int {
|
||||
typedef llvm::BinaryOperator::BinaryOps op_t;
|
||||
using llop = llvm::BinaryOperator::BinaryOps;
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
private:
|
||||
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
|
||||
|
@@ -2,15 +2,19 @@
|
||||
#define TDL_INCLUDE_IR_INSTRUCTIONS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class constant_int;
|
||||
class constant;
|
||||
class constant_range;
|
||||
class basic_block;
|
||||
class context;
|
||||
|
||||
@@ -95,19 +99,18 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
class binary_operator: public instruction{
|
||||
public:
|
||||
typedef llvm::BinaryOperator::BinaryOps op_t;
|
||||
using llop = llvm::BinaryOperator::BinaryOps;
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
// Constructors
|
||||
binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
|
||||
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// Get operand
|
||||
op_t get_op() const { return op_; }
|
||||
binary_op_t get_op() const { return op_; }
|
||||
|
||||
// Bool
|
||||
bool is_terminator() const;
|
||||
@@ -127,14 +130,14 @@ public:
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
|
||||
// Factory methods
|
||||
static binary_operator *create(op_t op, value *lhs, value *rhs,
|
||||
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
op_t op_;
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
};
|
||||
@@ -146,30 +149,28 @@ public:
|
||||
|
||||
class cmp_inst: public instruction{
|
||||
public:
|
||||
typedef llvm::CmpInst::Predicate pred_t;
|
||||
using llop = llvm::CmpInst;
|
||||
|
||||
typedef cmp_pred_t pred_t;
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cmp_inst(type *ty, pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
static bool is_fp_predicate(pred_t pred);
|
||||
static bool is_int_predicate(pred_t pred);
|
||||
cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
static bool is_fp_predicate(cmp_pred_t pred);
|
||||
static bool is_int_predicate(cmp_pred_t pred);
|
||||
static type* make_cmp_result_type(type *ty);
|
||||
|
||||
public:
|
||||
pred_t get_pred() const { return pred_; }
|
||||
cmp_pred_t get_pred() const { return pred_; }
|
||||
|
||||
private:
|
||||
pred_t pred_;
|
||||
cmp_pred_t pred_;
|
||||
};
|
||||
|
||||
class icmp_inst: public cmp_inst{
|
||||
using cmp_inst::cmp_inst;
|
||||
|
||||
public:
|
||||
static icmp_inst* create(pred_t pred, value *lhs, value *rhs,
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
@@ -177,7 +178,7 @@ class fcmp_inst: public cmp_inst{
|
||||
using cmp_inst::cmp_inst;
|
||||
|
||||
public:
|
||||
static fcmp_inst* create(pred_t pred, value *lhs, value *rhs,
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
@@ -196,33 +197,28 @@ protected:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cast_inst: public unary_inst{
|
||||
using ic = llvm::Instruction::CastOps;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
public:
|
||||
typedef llvm::CastInst::CastOps op_t;
|
||||
|
||||
protected:
|
||||
cast_inst(type *ty, value *v, const std::string &name, instruction *next, op_t op)
|
||||
cast_inst(type *ty, value *v, const std::string &name, instruction *next, cast_op_t op)
|
||||
: unary_inst(ty, v, name, next), op_(op) { }
|
||||
|
||||
private:
|
||||
static bool is_valid(op_t op, value *arg, type *ty);
|
||||
static bool is_valid(cast_op_t op, value *arg, type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
op_t get_op() const { return op_; }
|
||||
cast_op_t get_op() const { return op_; }
|
||||
|
||||
// factory methods
|
||||
static cast_inst *create(op_t op, value *arg, type *ty,
|
||||
static cast_inst *create(cast_op_t op, value *arg, type *ty,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
op_t op_;
|
||||
cast_op_t op_;
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
|
||||
@@ -232,19 +228,19 @@ class name : public cast_inst{ \
|
||||
: cast_inst(ty, v, name, next, op){ } \
|
||||
};
|
||||
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, llvm::Instruction::CastOps::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, llvm::Instruction::CastOps::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, llvm::Instruction::CastOps::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, llvm::Instruction::CastOps::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, llvm::Instruction::CastOps::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, cast_op_t::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, cast_op_t::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, cast_op_t::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, cast_op_t::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, cast_op_t::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, cast_op_t::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, cast_op_t::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, cast_op_t::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, cast_op_t::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, cast_op_t::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, cast_op_t::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, cast_op_t::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, cast_op_t::AddrSpaceCast)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
@@ -591,8 +587,8 @@ private:
|
||||
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const {
|
||||
std::string res = "trans<";
|
||||
for(ir::constant_int *x: perm_)
|
||||
res += x->repr() + ",";
|
||||
//for(ir::constant_int *x: perm_)
|
||||
// res += x->repr() + ",";
|
||||
res[res.size()-1] = '>';
|
||||
return res;
|
||||
}
|
||||
|
@@ -25,6 +25,7 @@
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
@@ -1,3 +1,5 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
@@ -7,7 +9,6 @@
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/driver/device.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
|
||||
namespace triton{
|
||||
|
@@ -203,6 +203,88 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
|
||||
using llop = llvm::Instruction::BinaryOps;
|
||||
using ttop = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case ttop::Add: return llop::Add;
|
||||
case ttop::FAdd: return llop::FAdd;
|
||||
case ttop::Sub: return llop::Sub;
|
||||
case ttop::FSub: return llop::FSub;
|
||||
case ttop::Mul: return llop::Mul;
|
||||
case ttop::FMul: return llop::FMul;
|
||||
case ttop::UDiv: return llop::UDiv;
|
||||
case ttop::SDiv: return llop::SDiv;
|
||||
case ttop::FDiv: return llop::FDiv;
|
||||
case ttop::URem: return llop::URem;
|
||||
case ttop::SRem: return llop::SRem;
|
||||
case ttop::FRem: return llop::FRem;
|
||||
case ttop::Shl: return llop::Shl;
|
||||
case ttop::LShr: return llop::LShr;
|
||||
case ttop::AShr: return llop::AShr;
|
||||
case ttop::And: return llop::And;
|
||||
case ttop::Or: return llop::Or;
|
||||
case ttop::Xor: return llop::Xor;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) {
|
||||
using llop = llvm::Instruction::CastOps;
|
||||
using ttop = ir::cast_op_t;
|
||||
switch(op){
|
||||
case ttop::Trunc: return llop::Trunc;
|
||||
case ttop::ZExt: return llop::ZExt;
|
||||
case ttop::SExt: return llop::SExt;
|
||||
case ttop::FPTrunc: return llop::FPTrunc;
|
||||
case ttop::FPExt: return llop::FPExt;
|
||||
case ttop::UIToFP: return llop::UIToFP;
|
||||
case ttop::SIToFP: return llop::SIToFP;
|
||||
case ttop::FPToUI: return llop::FPToUI;
|
||||
case ttop::FPToSI: return llop::FPToSI;
|
||||
case ttop::PtrToInt: return llop::PtrToInt;
|
||||
case ttop::IntToPtr: return llop::IntToPtr;
|
||||
case ttop::BitCast: return llop::BitCast;
|
||||
case ttop::AddrSpaceCast: return llop::AddrSpaceCast;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) {
|
||||
using llop = llvm::CmpInst::Predicate;
|
||||
using ttop = ir::cmp_pred_t;
|
||||
switch(pred){
|
||||
case ttop::FIRST_FCMP_PREDICATE: return llop::FIRST_FCMP_PREDICATE;
|
||||
case ttop::FCMP_FALSE: return llop::FCMP_FALSE;
|
||||
case ttop::FCMP_OEQ: return llop::FCMP_OEQ;
|
||||
case ttop::FCMP_OGT: return llop::FCMP_OGT;
|
||||
case ttop::FCMP_OGE: return llop::FCMP_OGE;
|
||||
case ttop::FCMP_OLT: return llop::FCMP_OLT;
|
||||
case ttop::FCMP_OLE: return llop::FCMP_OLE;
|
||||
case ttop::FCMP_ONE: return llop::FCMP_ONE;
|
||||
case ttop::FCMP_ORD: return llop::FCMP_ORD;
|
||||
case ttop::FCMP_UNO: return llop::FCMP_UNO;
|
||||
case ttop::FCMP_UEQ: return llop::FCMP_UEQ;
|
||||
case ttop::FCMP_UGT: return llop::FCMP_UGT;
|
||||
case ttop::FCMP_UGE: return llop::FCMP_UGE;
|
||||
case ttop::FCMP_ULT: return llop::FCMP_ULT;
|
||||
case ttop::FCMP_ULE: return llop::FCMP_ULE;
|
||||
case ttop::FCMP_UNE: return llop::FCMP_UNE;
|
||||
case ttop::FCMP_TRUE: return llop::FCMP_TRUE;
|
||||
case ttop::LAST_FCMP_PREDICATE: return llop::LAST_FCMP_PREDICATE;
|
||||
case ttop::FIRST_ICMP_PREDICATE: return llop::FIRST_ICMP_PREDICATE;
|
||||
case ttop::ICMP_EQ: return llop::ICMP_EQ;
|
||||
case ttop::ICMP_NE: return llop::ICMP_NE;
|
||||
case ttop::ICMP_UGT: return llop::ICMP_UGT;
|
||||
case ttop::ICMP_UGE: return llop::ICMP_UGE;
|
||||
case ttop::ICMP_ULT: return llop::ICMP_ULT;
|
||||
case ttop::ICMP_ULE: return llop::ICMP_ULE;
|
||||
case ttop::ICMP_SGT: return llop::ICMP_SGT;
|
||||
case ttop::ICMP_SGE: return llop::ICMP_SGE;
|
||||
case ttop::ICMP_SLT: return llop::ICMP_SLT;
|
||||
case ttop::ICMP_SLE: return llop::ICMP_SLE;
|
||||
case ttop::LAST_ICMP_PREDICATE: return llop::LAST_ICMP_PREDICATE;
|
||||
}
|
||||
}
|
||||
|
||||
/* convert ir::type to Type */
|
||||
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
@@ -283,24 +365,24 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
|
||||
Value *lhs = value(ii->get_operand(0));
|
||||
Value *rhs = value(ii->get_operand(1));
|
||||
return builder.Insert(BinaryOperator::Create(ii->get_op(), lhs, rhs));
|
||||
return builder.Insert(BinaryOperator::Create(llvm_op(ii->get_op()), lhs, rhs));
|
||||
}
|
||||
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
|
||||
CmpInst::Predicate pred = ii->get_pred();
|
||||
ir::cmp_pred_t pred = ii->get_pred();
|
||||
Value *lhs = value(ii->get_operand(0));
|
||||
Value *rhs = value(ii->get_operand(1));
|
||||
return builder.Insert(CmpInst::Create(Instruction::ICmp, pred, lhs, rhs));
|
||||
return builder.Insert(CmpInst::Create(Instruction::ICmp, llvm_pred(pred), lhs, rhs));
|
||||
}
|
||||
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
|
||||
CmpInst::Predicate pred = ii->get_pred();
|
||||
ir::cmp_pred_t pred = ii->get_pred();
|
||||
Value *lhs = value(ii->get_operand(0));
|
||||
Value *rhs = value(ii->get_operand(1));
|
||||
return builder.Insert(FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs));
|
||||
return builder.Insert(FCmpInst::Create(Instruction::FCmp, llvm_pred(pred), lhs, rhs));
|
||||
}
|
||||
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
|
||||
Value *arg = value(ii->get_operand(0));
|
||||
Type *dst_ty = type(ii->get_type()->get_scalar_ty());
|
||||
return builder.Insert(CastInst::Create(ii->get_op(), arg, dst_ty));
|
||||
return builder.Insert(CastInst::Create(llvm_op(ii->get_op()), arg, dst_ty));
|
||||
}
|
||||
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
|
||||
// get pointer
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
@@ -187,7 +188,7 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil
|
||||
auto z = dynamic_cast<ir::binary_operator*>(idx);
|
||||
if(!z)
|
||||
return false;
|
||||
bool is_sub = z->get_op() == ir::binary_operator::llop::Sub;
|
||||
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
|
||||
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
|
||||
bool is_lhs_0 = lhs && (lhs->get_value()==0);
|
||||
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
|
||||
|
@@ -36,7 +36,7 @@ namespace transform{
|
||||
|
||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
||||
bool is_bin_add = bin_op && bin_op->get_op()==llvm::Instruction::Add;
|
||||
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
|
||||
if(is_bin_add)
|
||||
return (ir::instruction*)x;
|
||||
return nullptr;
|
||||
|
@@ -1,10 +1,10 @@
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/Instruction.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -93,14 +93,14 @@ value *builder::create_ret_void() {
|
||||
return create_cast(OPCODE, src, dst_ty, name);\
|
||||
}
|
||||
|
||||
DEFINE_CAST_INSTR(si_to_fp, llvm::Instruction::SIToFP)
|
||||
DEFINE_CAST_INSTR(ui_to_fp, llvm::Instruction::UIToFP)
|
||||
DEFINE_CAST_INSTR(fp_to_si, llvm::Instruction::FPToSI)
|
||||
DEFINE_CAST_INSTR(fp_to_ui, llvm::Instruction::FPToUI)
|
||||
DEFINE_CAST_INSTR(fp_ext, llvm::Instruction::FPExt)
|
||||
DEFINE_CAST_INSTR(fp_trunc, llvm::Instruction::FPTrunc)
|
||||
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
|
||||
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
|
||||
DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
|
||||
DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI)
|
||||
DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt)
|
||||
DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc)
|
||||
|
||||
value* builder::create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name){
|
||||
value* builder::create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name){
|
||||
return insert(cast_inst::create(op, v, dst_ty), name);
|
||||
}
|
||||
|
||||
@@ -131,11 +131,11 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
|
||||
}
|
||||
|
||||
// Binary
|
||||
DEFINE_BINARY_FLOAT(fmul, llvm::Instruction::FMul)
|
||||
DEFINE_BINARY_FLOAT(fdiv, llvm::Instruction::FDiv)
|
||||
DEFINE_BINARY_FLOAT(frem, llvm::Instruction::FRem)
|
||||
DEFINE_BINARY_FLOAT(fadd, llvm::Instruction::FAdd)
|
||||
DEFINE_BINARY_FLOAT(fsub, llvm::Instruction::FSub)
|
||||
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
|
||||
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
|
||||
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
|
||||
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
|
||||
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
|
||||
// Unary
|
||||
DEFINE_UNARY_FLOAT(fneg)
|
||||
|
||||
@@ -145,7 +145,7 @@ DEFINE_UNARY_FLOAT(fneg)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
|
||||
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||
value *rhs, const std::string &name,
|
||||
bool has_nuw, bool has_nsw) {
|
||||
auto *clhs = dynamic_cast<constant_int*>(lhs);
|
||||
@@ -180,18 +180,18 @@ value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs
|
||||
}
|
||||
|
||||
// Binary
|
||||
DEFINE_NOWRAP_BINARY(mul, llvm::Instruction::Mul)
|
||||
DEFINE_NOWRAP_BINARY(add, llvm::Instruction::Add)
|
||||
DEFINE_NOWRAP_BINARY(sub, llvm::Instruction::Sub)
|
||||
DEFINE_NOWRAP_BINARY(shl, llvm::Instruction::Shl)
|
||||
DEFINE_NOWRAP_BINARY(ashr, llvm::Instruction::AShr)
|
||||
DEFINE_BINARY_INT(sdiv, llvm::Instruction::SDiv)
|
||||
DEFINE_BINARY_INT(udiv, llvm::Instruction::UDiv)
|
||||
DEFINE_BINARY_INT(srem, llvm::Instruction::SRem)
|
||||
DEFINE_BINARY_INT(urem, llvm::Instruction::URem)
|
||||
DEFINE_BINARY_INT(and, llvm::Instruction::And)
|
||||
DEFINE_BINARY_INT(or, llvm::Instruction::Or)
|
||||
DEFINE_BINARY_INT(xor, llvm::Instruction::Xor)
|
||||
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
|
||||
DEFINE_NOWRAP_BINARY(add, binary_op_t::Add)
|
||||
DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub)
|
||||
DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl)
|
||||
DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr)
|
||||
DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv)
|
||||
DEFINE_BINARY_INT(udiv, binary_op_t::UDiv)
|
||||
DEFINE_BINARY_INT(srem, binary_op_t::SRem)
|
||||
DEFINE_BINARY_INT(urem, binary_op_t::URem)
|
||||
DEFINE_BINARY_INT(and, binary_op_t::And)
|
||||
DEFINE_BINARY_INT(or, binary_op_t::Or)
|
||||
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
|
||||
// Unary
|
||||
DEFINE_UNARY_INT(neg)
|
||||
DEFINE_UNARY_INT(not)
|
||||
@@ -209,7 +209,7 @@ value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list, cons
|
||||
// icmp instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_icmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name){
|
||||
value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){
|
||||
return insert(icmp_inst::create(pred, lhs, rhs), name);
|
||||
}
|
||||
|
||||
@@ -219,25 +219,25 @@ value *builder::create_icmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const
|
||||
}
|
||||
|
||||
// Signed
|
||||
DEFINE_ICMP_INSTR(SLE, llvm::ICmpInst::ICMP_SLE)
|
||||
DEFINE_ICMP_INSTR(SLT, llvm::ICmpInst::ICMP_SLT)
|
||||
DEFINE_ICMP_INSTR(SGE, llvm::ICmpInst::ICMP_SGE)
|
||||
DEFINE_ICMP_INSTR(SGT, llvm::ICmpInst::ICMP_SGT)
|
||||
DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE)
|
||||
DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT)
|
||||
DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE)
|
||||
DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT)
|
||||
// Unsigned
|
||||
DEFINE_ICMP_INSTR(ULE, llvm::ICmpInst::ICMP_ULE)
|
||||
DEFINE_ICMP_INSTR(ULT, llvm::ICmpInst::ICMP_ULT)
|
||||
DEFINE_ICMP_INSTR(UGE, llvm::ICmpInst::ICMP_UGE)
|
||||
DEFINE_ICMP_INSTR(UGT, llvm::ICmpInst::ICMP_UGT)
|
||||
DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE)
|
||||
DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT)
|
||||
DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE)
|
||||
DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT)
|
||||
// General
|
||||
DEFINE_ICMP_INSTR(EQ, llvm::ICmpInst::ICMP_EQ)
|
||||
DEFINE_ICMP_INSTR(NE, llvm::ICmpInst::ICMP_NE)
|
||||
DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ)
|
||||
DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// fcmp instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_fcmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name){
|
||||
value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){
|
||||
return insert(fcmp_inst::create(pred, lhs, rhs), name);
|
||||
}
|
||||
|
||||
@@ -247,12 +247,12 @@ value *builder::create_fcmp(cmp_inst::pred_t pred, value *lhs, value *rhs, const
|
||||
}
|
||||
|
||||
// Ordered
|
||||
DEFINE_FCMP_INSTR(OLE, llvm::FCmpInst::FCMP_OLE)
|
||||
DEFINE_FCMP_INSTR(OLT, llvm::FCmpInst::FCMP_OLT)
|
||||
DEFINE_FCMP_INSTR(OGE, llvm::FCmpInst::FCMP_OGE)
|
||||
DEFINE_FCMP_INSTR(OGT, llvm::FCmpInst::FCMP_OGT)
|
||||
DEFINE_FCMP_INSTR(OEQ, llvm::FCmpInst::FCMP_OEQ)
|
||||
DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE)
|
||||
DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE)
|
||||
DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT)
|
||||
DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE)
|
||||
DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
|
||||
DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
|
||||
DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
|
||||
|
||||
|
||||
|
@@ -145,19 +145,19 @@ uint64_t constant_expression::get_value() const {
|
||||
uint64_t lhs = lhs_->get_value();
|
||||
uint64_t rhs = rhs_->get_value();
|
||||
switch(op_) {
|
||||
case llop::Add : return lhs + rhs;
|
||||
case llop::Sub : return lhs - rhs;
|
||||
case llop::Mul : return lhs * rhs;
|
||||
case llop::UDiv : return lhs / rhs;
|
||||
case llop::SDiv : return lhs / rhs;
|
||||
case llop::URem : return lhs % rhs;
|
||||
case llop::SRem : return lhs % rhs;
|
||||
case llop::Shl : return lhs << rhs;
|
||||
case llop::LShr : return lhs >> rhs;
|
||||
case llop::AShr : return lhs >> rhs;
|
||||
case llop::And : return lhs && rhs;
|
||||
case llop::Or : return lhs || rhs;
|
||||
case llop::Xor : return lhs ^ rhs;
|
||||
case op_t::Add : return lhs + rhs;
|
||||
case op_t::Sub : return lhs - rhs;
|
||||
case op_t::Mul : return lhs * rhs;
|
||||
case op_t::UDiv : return lhs / rhs;
|
||||
case op_t::SDiv : return lhs / rhs;
|
||||
case op_t::URem : return lhs % rhs;
|
||||
case op_t::SRem : return lhs % rhs;
|
||||
case op_t::Shl : return lhs << rhs;
|
||||
case op_t::LShr : return lhs >> rhs;
|
||||
case op_t::AShr : return lhs >> rhs;
|
||||
case op_t::And : return lhs && rhs;
|
||||
case op_t::Or : return lhs || rhs;
|
||||
case op_t::Xor : return lhs ^ rhs;
|
||||
default: throw std::runtime_error("unsupported constexpr binary operator");
|
||||
}
|
||||
}
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
@@ -87,60 +88,60 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
|
||||
|
||||
std::string binary_operator::repr_impl() const {
|
||||
switch(op_) {
|
||||
case llop::Add : return "add";
|
||||
case llop::FAdd : return "fadd";
|
||||
case llop::Sub : return "sub";
|
||||
case llop::FSub : return "fsub";
|
||||
case llop::Mul : return "mul";
|
||||
case llop::FMul : return "fmul";
|
||||
case llop::UDiv : return "udiv";
|
||||
case llop::SDiv : return "sdiv";
|
||||
case llop::FDiv : return "fdiv";
|
||||
case llop::URem : return "urem";
|
||||
case llop::SRem : return "srem";
|
||||
case llop::FRem : return "frem";
|
||||
case llop::Shl : return "shl";
|
||||
case llop::LShr : return "lshr";
|
||||
case llop::AShr : return "ashr";
|
||||
case llop::And : return "and";
|
||||
case llop::Or : return "or";
|
||||
case llop::Xor : return "xor";
|
||||
case Add : return "add";
|
||||
case FAdd : return "fadd";
|
||||
case Sub : return "sub";
|
||||
case FSub : return "fsub";
|
||||
case Mul : return "mul";
|
||||
case FMul : return "fmul";
|
||||
case UDiv : return "udiv";
|
||||
case SDiv : return "sdiv";
|
||||
case FDiv : return "fdiv";
|
||||
case URem : return "urem";
|
||||
case SRem : return "srem";
|
||||
case FRem : return "frem";
|
||||
case Shl : return "shl";
|
||||
case LShr : return "lshr";
|
||||
case AShr : return "ashr";
|
||||
case And : return "and";
|
||||
case Or : return "or";
|
||||
case Xor : return "xor";
|
||||
default: throw std::runtime_error("unknown binary operator");
|
||||
}
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_div() const {
|
||||
return op_ == llop::UDiv || op_ == llop::SDiv;
|
||||
return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_rem() const {
|
||||
return op_ == llop::URem || op_ == llop::SRem;
|
||||
return op_ == binary_op_t::URem || op_ == binary_op_t::SRem;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shl() const {
|
||||
return op_ == llop::Shl;
|
||||
return op_ == binary_op_t::Shl;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shr() const {
|
||||
return op_ == llop::LShr || op_ == llop::AShr;
|
||||
return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_mult() const {
|
||||
return op_ == llop::Mul;
|
||||
return op_ == binary_op_t::Mul;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_add_sub() const {
|
||||
return op_ == llop::Add || op_ == llop::Sub;
|
||||
return op_ == binary_op_t::Add || op_ == binary_op_t::Sub;
|
||||
}
|
||||
|
||||
|
||||
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, 2, 1, name, next), op_(op){
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create(op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(lhs->get_type() == rhs->get_type() &&
|
||||
"Cannot create binary operator with two operands of differing type!");
|
||||
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
|
||||
@@ -149,19 +150,19 @@ binary_operator *binary_operator::create(op_t op, value *lhs, value *rhs, const
|
||||
binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
||||
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
return binary_operator::create(llvm::Instruction::FSub, zero, arg, name, next);
|
||||
return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
||||
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
return binary_operator::create(llvm::Instruction::Sub, zero, arg, name, next);
|
||||
return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->is_integer_ty());
|
||||
constant *mask = constant::get_all_ones_value(arg->get_type());
|
||||
return binary_operator::create(llvm::Instruction::Xor, arg, mask, name, next);
|
||||
return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -171,37 +172,37 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name
|
||||
// cmp_inst
|
||||
std::string cmp_inst::repr_impl() const {
|
||||
switch (pred_) {
|
||||
case llop::FCMP_FALSE : return "false";
|
||||
case llop::FCMP_OEQ : return "fcmp_oeq";
|
||||
case llop::FCMP_OGT : return "fcmp_ogt";
|
||||
case llop::FCMP_OGE : return "fcmp_oge";
|
||||
case llop::FCMP_OLT : return "fcmp_olt";
|
||||
case llop::FCMP_OLE : return "fcmp_ole";
|
||||
case llop::FCMP_ONE : return "fcmp_one";
|
||||
case llop::FCMP_ORD : return "fcmp_ord";
|
||||
case llop::FCMP_UNO : return "fcmp_uno";
|
||||
case llop::FCMP_UEQ : return "fcmp_ueq";
|
||||
case llop::FCMP_UGT : return "fcmp_ugt";
|
||||
case llop::FCMP_UGE : return "fcmp_uge";
|
||||
case llop::FCMP_ULT : return "fcmp_ult";
|
||||
case llop::FCMP_ULE : return "fcmp_ule";
|
||||
case llop::FCMP_UNE : return "fcmp_une";
|
||||
case llop::FCMP_TRUE : return "true";
|
||||
case llop::ICMP_EQ : return "icmp_eq";
|
||||
case llop::ICMP_NE : return "icmp_ne";
|
||||
case llop::ICMP_UGT : return "icmp_ugt";
|
||||
case llop::ICMP_UGE : return "icmp_uge";
|
||||
case llop::ICMP_ULT : return "icmp_ult";
|
||||
case llop::ICMP_ULE : return "icmp_ule";
|
||||
case llop::ICMP_SGT : return "icmp_sgt";
|
||||
case llop::ICMP_SGE : return "icmp_sge";
|
||||
case llop::ICMP_SLT : return "icmp_slt";
|
||||
case llop::ICMP_SLE : return "icmp_sle";
|
||||
case FCMP_FALSE : return "false";
|
||||
case FCMP_OEQ : return "fcmp_oeq";
|
||||
case FCMP_OGT : return "fcmp_ogt";
|
||||
case FCMP_OGE : return "fcmp_oge";
|
||||
case FCMP_OLT : return "fcmp_olt";
|
||||
case FCMP_OLE : return "fcmp_ole";
|
||||
case FCMP_ONE : return "fcmp_one";
|
||||
case FCMP_ORD : return "fcmp_ord";
|
||||
case FCMP_UNO : return "fcmp_uno";
|
||||
case FCMP_UEQ : return "fcmp_ueq";
|
||||
case FCMP_UGT : return "fcmp_ugt";
|
||||
case FCMP_UGE : return "fcmp_uge";
|
||||
case FCMP_ULT : return "fcmp_ult";
|
||||
case FCMP_ULE : return "fcmp_ule";
|
||||
case FCMP_UNE : return "fcmp_une";
|
||||
case FCMP_TRUE : return "true";
|
||||
case ICMP_EQ : return "icmp_eq";
|
||||
case ICMP_NE : return "icmp_ne";
|
||||
case ICMP_UGT : return "icmp_ugt";
|
||||
case ICMP_UGE : return "icmp_uge";
|
||||
case ICMP_ULT : return "icmp_ult";
|
||||
case ICMP_ULE : return "icmp_ule";
|
||||
case ICMP_SGT : return "icmp_sgt";
|
||||
case ICMP_SGE : return "icmp_sge";
|
||||
case ICMP_SLT : return "icmp_slt";
|
||||
case ICMP_SLE : return "icmp_sle";
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
cmp_inst::cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: instruction(ty, 2, 1, name, next), pred_(pred) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
@@ -215,23 +216,23 @@ type* cmp_inst::make_cmp_result_type(type *ty){
|
||||
}
|
||||
|
||||
|
||||
bool cmp_inst::is_fp_predicate(pred_t pred) {
|
||||
return pred >= llop::FIRST_FCMP_PREDICATE && pred <= llop::LAST_FCMP_PREDICATE;
|
||||
bool cmp_inst::is_fp_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE;
|
||||
}
|
||||
|
||||
bool cmp_inst::is_int_predicate(pred_t pred) {
|
||||
return pred >= llop::FIRST_ICMP_PREDICATE && pred <= llop::LAST_ICMP_PREDICATE;
|
||||
bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
|
||||
}
|
||||
|
||||
// icmp_inst
|
||||
icmp_inst* icmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_int_predicate(pred));
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// fcmp_inst
|
||||
fcmp_inst* fcmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_fp_predicate(pred));
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||
@@ -252,45 +253,45 @@ unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction
|
||||
|
||||
std::string cast_inst::repr_impl() const {
|
||||
switch (op_){
|
||||
case ic::Trunc: return "trunc";
|
||||
case ic::ZExt: return "zext";
|
||||
case ic::SExt: return "sext";
|
||||
case ic::FPTrunc: return "fp_trunc";
|
||||
case ic::FPExt: return "fp_ext";
|
||||
case ic::UIToFP: return "ui_to_fp";
|
||||
case ic::SIToFP: return "si_to_fp";
|
||||
case ic::FPToUI: return "fp_to_ui";
|
||||
case ic::FPToSI: return "fp_to_si";
|
||||
case ic::PtrToInt: return "ptr_to_int";
|
||||
case ic::IntToPtr: return "int_to_ptr";
|
||||
case ic::BitCast: return "bitcast";
|
||||
case ic::AddrSpaceCast: return "addr_space_cast";
|
||||
case cast_op_t::Trunc: return "trunc";
|
||||
case cast_op_t::ZExt: return "zext";
|
||||
case cast_op_t::SExt: return "sext";
|
||||
case cast_op_t::FPTrunc: return "fp_trunc";
|
||||
case cast_op_t::FPExt: return "fp_ext";
|
||||
case cast_op_t::UIToFP: return "ui_to_fp";
|
||||
case cast_op_t::SIToFP: return "si_to_fp";
|
||||
case cast_op_t::FPToUI: return "fp_to_ui";
|
||||
case cast_op_t::FPToSI: return "fp_to_si";
|
||||
case cast_op_t::PtrToInt: return "ptr_to_int";
|
||||
case cast_op_t::IntToPtr: return "int_to_ptr";
|
||||
case cast_op_t::BitCast: return "bitcast";
|
||||
case cast_op_t::AddrSpaceCast: return "addr_space_cast";
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
// TODO
|
||||
bool cast_inst::is_valid(op_t op, value *arg, type *ty) {
|
||||
bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
|
||||
assert(arg->get_type()->is_tile_ty() == ty->is_tile_ty());
|
||||
return true;
|
||||
}
|
||||
|
||||
cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &name, instruction *next){
|
||||
cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){
|
||||
assert(is_valid(op, arg, ty) && "Invalid cast!");
|
||||
// Construct and return the appropriate CastInst subclass
|
||||
switch (op) {
|
||||
case ic::Trunc: return new trunc_inst (ty, arg, name, next);
|
||||
case ic::ZExt: return new z_ext_inst (ty, arg, name, next);
|
||||
case ic::SExt: return new s_ext_inst (ty, arg, name, next);
|
||||
case ic::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
|
||||
case ic::FPExt: return new fp_ext_inst (ty, arg, name, next);
|
||||
case ic::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
|
||||
case ic::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
|
||||
case ic::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
|
||||
case ic::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
|
||||
case ic::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
|
||||
case ic::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
|
||||
case ic::BitCast: return new bit_cast_inst (ty, arg, name, next);
|
||||
case ic::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
|
||||
case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next);
|
||||
case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
|
||||
case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
|
||||
case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
|
||||
case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
|
||||
case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next);
|
||||
case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
@@ -300,9 +301,9 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
|
||||
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
op_t op = (arg_bits == dst_bits ? ic::BitCast :
|
||||
(arg_bits > dst_bits ? ic::Trunc :
|
||||
(is_signed ? ic::SExt : ic::ZExt)));
|
||||
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
|
||||
(arg_bits > dst_bits ? cast_op_t::Trunc :
|
||||
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
|
||||
return create(op, arg, ty, name, next);
|
||||
}
|
||||
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include "triton/lang/statement.h"
|
||||
#include "triton/lang/declaration.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
BIN
python/dist/triton-0.1-py3.6-linux-x86_64.egg
vendored
BIN
python/dist/triton-0.1-py3.6-linux-x86_64.egg
vendored
Binary file not shown.
@@ -1,4 +1,10 @@
|
||||
import libtriton
|
||||
import tensorflow as tf
|
||||
import distutils
|
||||
import distutils.log
|
||||
import setuptools.command.build_ext
|
||||
import setuptools
|
||||
import os
|
||||
|
||||
src = """
|
||||
const tunable int TM = {128};
|
||||
@@ -9,7 +15,7 @@ void matmul(restrict read_only align(16) half *A,
|
||||
restrict read_only align(16) half *B,
|
||||
restrict read_only align(16) half *C,
|
||||
int M, int N, int K,
|
||||
multiple_of(8) int lda, multiple_of(8)" int ldb, int ldc) {
|
||||
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
@@ -39,4 +45,51 @@ void matmul(restrict read_only align(16) half *A,
|
||||
}
|
||||
"""
|
||||
|
||||
print(libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1'))
|
||||
with open('test.cpp', 'w+') as test:
|
||||
src = libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1')
|
||||
test.writelines(src)
|
||||
|
||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
||||
llvm_include_dirs = ['/usr/include/llvm-8/', '/usr/include/llvm-c-8/']
|
||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
||||
|
||||
triton_library_dirs = [os.path.realpath(libtriton.__file__)]
|
||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
||||
|
||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
||||
extra_compile_args = []
|
||||
extra_link_args = []
|
||||
library_dirs = tensorflow_library_dirs
|
||||
libraries = ['tensorflow_framework']
|
||||
|
||||
ext = setuptools.Extension(
|
||||
name = 'test',
|
||||
language = 'c++',
|
||||
sources = ['/home/philippe/development/triton/python/examples/test.cpp'],
|
||||
include_dirs = include_dirs,
|
||||
extra_compile_args = extra_compile_args,
|
||||
extra_link_args = extra_link_args,
|
||||
library_dirs = library_dirs,
|
||||
libraries = libraries
|
||||
)
|
||||
|
||||
build_path = '.'
|
||||
args = ['build_ext']
|
||||
#args.append('--build-temp=' + build_path)
|
||||
#args.append('--build-lib=' + build_path)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name = 'test',
|
||||
ext_modules = [ext],
|
||||
script_args = args,
|
||||
cmdclass = {
|
||||
'build_ext': setuptools.command.build_ext.build_ext
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
setuptools.setup(**args)
|
||||
library_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
module = tf.load_op_library(os.path.join(library_dir, 'build/lib.linux-x86_64-3.6/test.cpython-36m-x86_64-linux-gnu.so'))
|
||||
print(module.matmul)
|
Reference in New Issue
Block a user