153 lines
3.9 KiB
C++
153 lines
3.9 KiB
C++
#pragma once
|
|
|
|
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
|
|
#define _TRITON_SELECTION_MACHINE_VALUE_H_
|
|
|
|
#include <vector>
|
|
#include <map>
|
|
#include <functional>
|
|
|
|
namespace llvm{
|
|
class Type;
|
|
class Value;
|
|
class Instruction;
|
|
class Constant;
|
|
class LLVMContext;
|
|
class Module;
|
|
class ConstantFolder;
|
|
class IRBuilderDefaultInserter;
|
|
template <typename T, typename Inserter>
|
|
class IRBuilder;
|
|
class ArrayType;
|
|
class Function;
|
|
}
|
|
|
|
namespace triton{
|
|
namespace codegen{
|
|
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
|
llvm::IRBuilderDefaultInserter> Builder;
|
|
typedef llvm::LLVMContext LLVMContext;
|
|
typedef llvm::Type Type;
|
|
typedef llvm::Value Value;
|
|
typedef llvm::Module Module;
|
|
typedef llvm::Instruction Instruction;
|
|
typedef llvm::Constant Constant;
|
|
typedef llvm::ArrayType ArrayType;
|
|
typedef llvm::Function Function;
|
|
}
|
|
}
|
|
|
|
namespace triton{
|
|
namespace codegen{
|
|
|
|
namespace analysis{
|
|
class liveness;
|
|
class tiles;
|
|
class align;
|
|
class allocation;
|
|
class cts;
|
|
class axes;
|
|
class layouts;
|
|
}
|
|
|
|
class distributed_axis;
|
|
class machine_data_layout;
|
|
class tile;
|
|
class shared_tile;
|
|
class distributed_tile;
|
|
class target;
|
|
typedef std::vector<Value*> indices_t;
|
|
|
|
}
|
|
}
|
|
|
|
namespace triton{
|
|
namespace codegen{
|
|
|
|
struct distributed_axis {
|
|
int contiguous;
|
|
std::vector<Value*> values;
|
|
Value* thread_id;
|
|
};
|
|
|
|
class tile {
|
|
protected:
|
|
typedef std::vector<unsigned> shapes_t;
|
|
|
|
public:
|
|
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
|
|
virtual void set_value(indices_t idx, Value *v) = 0;
|
|
virtual Value* get_value(indices_t idx) = 0;
|
|
Type *get_ty() const { return ty_; }
|
|
shapes_t get_shapes() const { return shapes_; }
|
|
|
|
protected:
|
|
Type *ty_;
|
|
shapes_t shapes_;
|
|
};
|
|
|
|
class shared_tile: public tile {
|
|
private:
|
|
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
|
|
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
|
|
|
|
|
|
public:
|
|
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
|
|
void set_vector_size(unsigned vector_size);
|
|
void set_return_mode(bool return_vector);
|
|
void set_value(indices_t, Value *);
|
|
Value* get_ptr_to(indices_t idx);
|
|
Value* get_value(indices_t idx);
|
|
Value* get_pointer() { return ptr_; }
|
|
Value* get_offset() { return offset_; }
|
|
const std::vector<int>& get_perm() { return perm_; }
|
|
const std::vector<int>& get_order() { return order_; }
|
|
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
|
|
|
|
private:
|
|
Value *ptr_;
|
|
bool return_vector_;
|
|
Builder &builder_;
|
|
Value *offset_;
|
|
std::map<indices_t, Value*> ptr_cache_;
|
|
unsigned vector_size_;
|
|
std::vector<int> order_;
|
|
std::vector<int> perm_;
|
|
};
|
|
|
|
// Distribtued tile
|
|
class distributed_tile: public tile{
|
|
typedef std::vector<distributed_axis> axes_t;
|
|
typedef std::vector<indices_t> ordered_indices_vec_t;
|
|
typedef std::map<indices_t, unsigned> indices_map_t;
|
|
typedef std::map<indices_t, Value*> values_map_t;
|
|
|
|
private:
|
|
void init_indices();
|
|
|
|
public:
|
|
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
|
|
void set_value(indices_t idx, Value *v);
|
|
Value* get_value(indices_t idx);
|
|
const std::vector<int>& get_order() { return order_; }
|
|
unsigned get_linear_index(indices_t idx);
|
|
indices_t get_ordered_indices(unsigned id);
|
|
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
|
|
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
|
|
|
|
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
|
private:
|
|
axes_t axes_;
|
|
std::vector<int> order_;
|
|
indices_map_t indices_;
|
|
values_map_t values_;
|
|
ordered_indices_vec_t ordered_indices_;
|
|
Builder &builder_;
|
|
};
|
|
|
|
}
|
|
}
|
|
|
|
#endif
|