some more cleaning
This commit is contained in:
@@ -21,16 +21,13 @@ namespace codegen{
|
|||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
class align {
|
class align {
|
||||||
|
private:
|
||||||
struct cst_info {
|
struct cst_info {
|
||||||
unsigned num_cst;
|
unsigned num_cst;
|
||||||
unsigned value;
|
unsigned value;
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
|
||||||
// helpers
|
// helpers
|
||||||
bool is_first_axis_unit(ir::value *v);
|
|
||||||
std::vector<unsigned> get_shapes(ir::value *v);
|
std::vector<unsigned> get_shapes(ir::value *v);
|
||||||
|
|
||||||
// populate is_constant
|
// populate is_constant
|
||||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||||
@@ -61,10 +58,8 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
unsigned get_starting_multiple(ir::value* v) const;
|
unsigned get(ir::value* v, unsigned ax) const;
|
||||||
unsigned get_max_contiguous(ir::value* v) const;
|
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||||
std::vector<unsigned> get_max_contiguous_vec(ir::value* v) const;
|
|
||||||
void copy(ir::value *dst, ir::value *src);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||||
|
@@ -49,20 +49,20 @@ private:
|
|||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
grids(size_t num_warps, transform::coalesce* reorder);
|
grids(size_t num_warps, transform::coalesce* coalesce);
|
||||||
fragment_t get_fragment(ir::value *value, unsigned ax);
|
|
||||||
void copy(ir::value *dst, ir::value *src);
|
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
const std::vector<ir::value*> get() const { return grids_; }
|
||||||
const std::vector<ir::value*> get_grids() const { return grids_; }
|
fragment_t fragment_of(ir::value *value, unsigned ax);
|
||||||
|
unsigned group_of(ir::value *value, unsigned ax);
|
||||||
int mts(ir::value *value, unsigned ax);
|
int mts(ir::value *value, unsigned ax);
|
||||||
int nts(ir::value *value, unsigned ax);
|
int nts(ir::value *value, unsigned ax);
|
||||||
int fpw(ir::value *value, unsigned ax);
|
int fpw(ir::value *value, unsigned ax);
|
||||||
int wpt(ir::value *value, unsigned ax);
|
int wpt(ir::value *value, unsigned ax);
|
||||||
|
void copy(ir::value *dst, ir::value *src);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
transform::coalesce* reorder_;
|
transform::coalesce* coalesce_;
|
||||||
// number of warps
|
// number of warps
|
||||||
size_t num_warps_;
|
size_t num_warps_;
|
||||||
// grids
|
// grids
|
||||||
|
@@ -45,11 +45,9 @@ public:
|
|||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
liveness(meminfo *info): info_(info){ }
|
liveness(meminfo *info): info_(info){ }
|
||||||
|
|
||||||
// accessors
|
// accessors
|
||||||
const intervals_map_t& intervals() const { return intervals_; }
|
const intervals_map_t& intervals() const { return intervals_; }
|
||||||
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
||||||
|
|
||||||
// run
|
// run
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
|
@@ -24,15 +24,12 @@ class memalloc {
|
|||||||
public:
|
public:
|
||||||
memalloc(liveness *live, meminfo *buffer_info, grids *params)
|
memalloc(liveness *live, meminfo *buffer_info, grids *params)
|
||||||
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
|
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
|
||||||
|
|
||||||
// utilities
|
// utilities
|
||||||
unsigned get_num_bytes(ir::value *x);
|
unsigned num_bytes(ir::value *x);
|
||||||
unsigned is_ld_padded(ir::value* x);
|
unsigned is_ld_padded(ir::value* x);
|
||||||
|
|
||||||
// accessors
|
// accessors
|
||||||
unsigned get_offset(ir::value *x) const { return offsets_.at(x); }
|
unsigned offset(ir::value *x) const { return offsets_.at(x); }
|
||||||
unsigned get_allocated_size() const { return allocated_size_; }
|
unsigned allocated_size() const { return allocated_size_; }
|
||||||
|
|
||||||
// run
|
// run
|
||||||
void run();
|
void run();
|
||||||
|
|
||||||
|
@@ -30,14 +30,6 @@ inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
|||||||
return map[i] = value;
|
return map[i] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool align::is_first_axis_unit(ir::value *x){
|
|
||||||
if(x->get_type()->is_tile_ty())
|
|
||||||
return x->get_type()->get_tile_shapes()[0] == 1;
|
|
||||||
else
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* is constant
|
* is constant
|
||||||
*/
|
*/
|
||||||
@@ -471,26 +463,19 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
|||||||
return populate_starting_multiple_default(v);
|
return populate_starting_multiple_default(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned align::get_starting_multiple(ir::value* v) const {
|
|
||||||
return starting_multiple_.at(v)[0];
|
unsigned align::get(ir::value *v, unsigned ax) const {
|
||||||
|
unsigned starting_multiple = starting_multiple_.at(v)[ax];
|
||||||
|
unsigned max_contiguous = max_contiguous_.at(v)[ax];
|
||||||
|
return std::min(starting_multiple, max_contiguous);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned align::get_max_contiguous(ir::value* v) const {
|
std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||||
return max_contiguous_.at(v)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<unsigned> align::get_max_contiguous_vec(ir::value* v) const {
|
|
||||||
return max_contiguous_.at(v);
|
return max_contiguous_.at(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
void align::copy(ir::value *dst, ir::value *src) {
|
|
||||||
starting_multiple_[dst] = starting_multiple_[src];
|
|
||||||
max_contiguous_[dst] = max_contiguous_[src];
|
|
||||||
is_constant_[dst] = is_constant_[src];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void align::run(ir::module &mod) {
|
void align::run(ir::module &mod) {
|
||||||
|
|
||||||
// populate constant
|
// populate constant
|
||||||
for(ir::function *fn: mod.get_function_list())
|
for(ir::function *fn: mod.get_function_list())
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
@@ -16,7 +16,7 @@ namespace triton{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), reorder_(reorder)
|
grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), coalesce_(reorder)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
bool is_hmma(ir::value *v){
|
bool is_hmma(ir::value *v){
|
||||||
@@ -168,12 +168,12 @@ void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_v
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned grids::get_param_group(ir::value *value, unsigned ax) {
|
unsigned grids::group_of(ir::value *value, unsigned ax) {
|
||||||
unsigned result = groups_.at(value).at(ax);
|
unsigned result = groups_.at(value).at(ax);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
grids::fragment_t grids::get_fragment(ir::value *value, unsigned ax) {
|
grids::fragment_t grids::fragment_of(ir::value *value, unsigned ax) {
|
||||||
return fragments_.at({value, ax});
|
return fragments_.at({value, ax});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -233,7 +233,7 @@ void grids::run(ir::module &mod) {
|
|||||||
for(ir::value *i: grids_){
|
for(ir::value *i: grids_){
|
||||||
if(!i->get_type()->is_tile_ty())
|
if(!i->get_type()->is_tile_ty())
|
||||||
continue;
|
continue;
|
||||||
auto order = reorder_->get_order(i);
|
auto order = coalesce_->get_order(i);
|
||||||
auto shapes = i->get_type()->get_tile_shapes();
|
auto shapes = i->get_type()->get_tile_shapes();
|
||||||
unsigned size = i->get_type()->get_tile_num_elements();
|
unsigned size = i->get_type()->get_tile_num_elements();
|
||||||
/* HMMA parameters*/
|
/* HMMA parameters*/
|
||||||
@@ -329,7 +329,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
|||||||
for(size_t d = 0; d < shapes.size(); d++){
|
for(size_t d = 0; d < shapes.size(); d++){
|
||||||
if(shapes[d] == 1)
|
if(shapes[d] == 1)
|
||||||
continue;
|
continue;
|
||||||
unsigned x = get_param_group(v, d);
|
unsigned x = group_of(v, d);
|
||||||
ir::value *&r = references[x];
|
ir::value *&r = references[x];
|
||||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||||
r = v;
|
r = v;
|
||||||
|
@@ -20,7 +20,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) {
|
|||||||
}
|
}
|
||||||
for(ir::user* user: x->get_users())
|
for(ir::user* user: x->get_users())
|
||||||
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
||||||
bool is_hmma = params_->get_fragment(user, 0) == grids::HMMA_FRAGMENT_C;
|
bool is_hmma = params_->fragment_of(user, 0) == grids::HMMA_FRAGMENT_C;
|
||||||
bool is_op_0 = x == dot->get_operand(0);
|
bool is_op_0 = x == dot->get_operand(0);
|
||||||
bool is_op_1 = x == dot->get_operand(1);
|
bool is_op_1 = x == dot->get_operand(1);
|
||||||
if(is_hmma && is_op_0){
|
if(is_hmma && is_op_0){
|
||||||
@@ -45,7 +45,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned memalloc::get_num_bytes(ir::value *x) {
|
unsigned memalloc::num_bytes(ir::value *x) {
|
||||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||||
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
size_t axis = red->get_axis();
|
size_t axis = red->get_axis();
|
||||||
@@ -56,7 +56,7 @@ unsigned memalloc::get_num_bytes(ir::value *x) {
|
|||||||
for(auto x: shapes)
|
for(auto x: shapes)
|
||||||
num_elements *= x;
|
num_elements *= x;
|
||||||
size_t depth;
|
size_t depth;
|
||||||
if(params_->get_fragment(x, 0) == grids::HMMA_FRAGMENT_C)
|
if(params_->fragment_of(x, 0) == grids::HMMA_FRAGMENT_C)
|
||||||
depth = params_->wpt(op, axis);
|
depth = params_->wpt(op, axis);
|
||||||
else
|
else
|
||||||
depth = params_->mts(op, axis);
|
depth = params_->mts(op, axis);
|
||||||
@@ -102,7 +102,7 @@ void memalloc::run(){
|
|||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
if(j_it != J.end()){
|
if(j_it != J.end()){
|
||||||
unsigned size = get_num_bytes(*j_it);
|
unsigned size = num_bytes(*j_it);
|
||||||
segment xj = liveness_->get_interval(*j_it);
|
segment xj = liveness_->get_interval(*j_it);
|
||||||
starts[*j_it] = w;
|
starts[*j_it] = w;
|
||||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||||
@@ -123,8 +123,8 @@ void memalloc::run(){
|
|||||||
if(x == y)
|
if(x == y)
|
||||||
continue;
|
continue;
|
||||||
unsigned X0 = starts[x], Y0 = starts[y];
|
unsigned X0 = starts[x], Y0 = starts[y];
|
||||||
unsigned NX = get_num_bytes(x);
|
unsigned NX = num_bytes(x);
|
||||||
unsigned NY = get_num_bytes(y);
|
unsigned NY = num_bytes(y);
|
||||||
segment XS = {X0, X0 + NX};
|
segment XS = {X0, X0 + NX};
|
||||||
segment YS = {Y0, Y0 + NY};
|
segment YS = {Y0, Y0 + NY};
|
||||||
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
|
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
|
||||||
@@ -156,7 +156,7 @@ void memalloc::run(){
|
|||||||
for(ir::value *x: V){
|
for(ir::value *x: V){
|
||||||
unsigned Adj = 0;
|
unsigned Adj = 0;
|
||||||
for(ir::value *y: interferences[x])
|
for(ir::value *y: interferences[x])
|
||||||
Adj = std::max(Adj, starts[y] + get_num_bytes(y));
|
Adj = std::max(Adj, starts[y] + num_bytes(y));
|
||||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||||
if(buffer_info_->is_double(x)){
|
if(buffer_info_->is_double(x)){
|
||||||
ir::phi_node *phi = (ir::phi_node*)x;
|
ir::phi_node *phi = (ir::phi_node*)x;
|
||||||
@@ -170,7 +170,7 @@ void memalloc::run(){
|
|||||||
// Save maximum size of induced memory space
|
// Save maximum size of induced memory space
|
||||||
allocated_size_ = 0;
|
allocated_size_ = 0;
|
||||||
for(auto &x: offsets_){
|
for(auto &x: offsets_){
|
||||||
allocated_size_ = std::max<size_t>(allocated_size_, x.second + get_num_bytes(x.first));
|
allocated_size_ = std::max<size_t>(allocated_size_, x.second + num_bytes(x.first));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -430,7 +430,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
|
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
|
||||||
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
|
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
|
||||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
|
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
|
||||||
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
|
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->offset(ii)));
|
||||||
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
||||||
tgt_->add_memfence(module, builder);
|
tgt_->add_memfence(module, builder);
|
||||||
tgt_->add_barrier(module, builder);
|
tgt_->add_barrier(module, builder);
|
||||||
@@ -538,6 +538,10 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
|
|||||||
throw std::runtime_error("unknown conversion from ir::value to Value");
|
throw std::runtime_error("unknown conversion from ir::value to Value");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* -------------------
|
||||||
|
* ---- Init Axes ----
|
||||||
|
* ------------------- */
|
||||||
|
|
||||||
// Grid construction
|
// Grid construction
|
||||||
std::vector<Value*> delinearize(Value *trailing, const std::vector<unsigned>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
std::vector<Value*> delinearize(Value *trailing, const std::vector<unsigned>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
||||||
size_t dim = shapes.size();
|
size_t dim = shapes.size();
|
||||||
@@ -600,7 +604,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
|
|||||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||||
}
|
}
|
||||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
axes_[params_->group_of(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -705,27 +709,23 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
|||||||
|
|
||||||
|
|
||||||
/* axes */
|
/* axes */
|
||||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
axes_[params_->group_of(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
axes_[params_->group_of(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||||
if(is_batched)
|
if(is_batched)
|
||||||
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
axes_[params_->group_of(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||||
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN)
|
if(params_->fragment_of(v, 0) == analysis::grids::STRIDED_SCAN)
|
||||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||||
else
|
else
|
||||||
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool static inline has_phi_user(ir::value *v) {
|
/* -------------------
|
||||||
for(ir::user *usr: v->get_users()){
|
* ---- Init Tiles ----
|
||||||
if(dynamic_cast<ir::phi_node*>(usr))
|
* ------------------- */
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
||||||
auto shapes = v->get_type()->get_tile_shapes();
|
auto shapes = v->get_type()->get_tile_shapes();
|
||||||
@@ -748,7 +748,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
|||||||
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
||||||
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
||||||
// next pointer
|
// next pointer
|
||||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(phi)));
|
||||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||||
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
||||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||||
@@ -761,8 +761,12 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if(!has_phi_user(v)){
|
bool has_phi_user = false;
|
||||||
size_t offset = alloc_->get_offset(v);
|
for(ir::user *usr: v->get_users())
|
||||||
|
if(dynamic_cast<ir::phi_node*>(usr))
|
||||||
|
has_phi_user = true;
|
||||||
|
if(has_phi_user){
|
||||||
|
size_t offset = alloc_->offset(v);
|
||||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||||
@@ -776,7 +780,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
|||||||
std::vector<distributed_axis> axes(shapes.size());
|
std::vector<distributed_axis> axes(shapes.size());
|
||||||
for(size_t d = 0; d < shapes.size(); d++){
|
for(size_t d = 0; d < shapes.size(); d++){
|
||||||
if(shapes[d] > 1){
|
if(shapes[d] > 1){
|
||||||
unsigned x = params_->get_param_group(v, d);
|
unsigned x = params_->group_of(v, d);
|
||||||
axes[d] = axes_.at(x);
|
axes[d] = axes_.at(x);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
@@ -827,7 +831,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
|||||||
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
||||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||||
// create grid
|
// create grid
|
||||||
for(ir::value* i: params_->get_grids())
|
for(ir::value* i: params_->get())
|
||||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||||
// create tile
|
// create tile
|
||||||
std::set<ir::value*> seen;
|
std::set<ir::value*> seen;
|
||||||
@@ -839,6 +843,10 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ----------------------------
|
||||||
|
* ---- Lower Instructions ----
|
||||||
|
* ---------------------------- */
|
||||||
|
|
||||||
void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||||
distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
|
distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
|
||||||
@@ -907,7 +915,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||||
for(auto& x: partial) {
|
for(auto& x: partial) {
|
||||||
// current element being computed
|
// current element being computed
|
||||||
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
|
Value *lane = axes_.at(params_->group_of(op, axis)).thread_id;
|
||||||
Value *&result = x.second;
|
Value *&result = x.second;
|
||||||
indices_t write_idx = x.first;
|
indices_t write_idx = x.first;
|
||||||
write_idx.insert(write_idx.begin() + axis, lane);
|
write_idx.insert(write_idx.begin() + axis, lane);
|
||||||
@@ -1233,7 +1241,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
|||||||
if(NK != 1) {
|
if(NK != 1) {
|
||||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||||
if(params_->get_fragment(dot, 0) == analysis::grids::STRIDED_SCAN)
|
if(params_->fragment_of(dot, 0) == analysis::grids::STRIDED_SCAN)
|
||||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||||
else
|
else
|
||||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||||
@@ -1249,9 +1257,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
|||||||
// find vector size
|
// find vector size
|
||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
unsigned starting_multiple = alignment_->get_starting_multiple(ptr);
|
unsigned alignment = alignment_->get(ptr, 0);
|
||||||
unsigned max_contiguous = alignment_->get_max_contiguous(ptr);
|
|
||||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||||
@@ -1322,9 +1328,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
|
|||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
unsigned starting_multiple = alignment_->get_starting_multiple(ptr);
|
unsigned alignment = alignment_->get(ptr, 0);
|
||||||
unsigned max_contiguous = alignment_->get_max_contiguous(ptr);
|
|
||||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||||
// vector loads
|
// vector loads
|
||||||
@@ -1408,6 +1412,10 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ----------------------------
|
||||||
|
* ---- Generate LLVM code ----
|
||||||
|
* ---------------------------- */
|
||||||
|
|
||||||
inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||||
switch(attr.get_kind()){
|
switch(attr.get_kind()){
|
||||||
case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias);
|
case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias);
|
||||||
@@ -1487,7 +1495,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
// allocate shared memory
|
// allocate shared memory
|
||||||
Value *sh_mem_ptr = nullptr;
|
Value *sh_mem_ptr = nullptr;
|
||||||
if(tgt_->is_gpu())
|
if(tgt_->is_gpu())
|
||||||
if(unsigned alloc_size = alloc_->get_allocated_size()){
|
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||||
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
||||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||||
@@ -1540,7 +1548,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*num_bytes)), llvm_inc_block);
|
offset->addIncoming(dst_builder.getInt32(alloc_->num_bytes(phi)/(2*num_bytes)), llvm_inc_block);
|
||||||
}
|
}
|
||||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||||
}
|
}
|
||||||
|
@@ -57,7 +57,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
std::map<ir::value*, ir::value*> replaced;
|
std::map<ir::value*, ir::value*> replaced;
|
||||||
for(ir::io_inst *i: io) {
|
for(ir::io_inst *i: io) {
|
||||||
ir::value *ptr = i->get_pointer_operand();
|
ir::value *ptr = i->get_pointer_operand();
|
||||||
auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
auto max_contiguous = align_->contiguous(ptr);
|
||||||
std::vector<unsigned> order(max_contiguous.size());
|
std::vector<unsigned> order(max_contiguous.size());
|
||||||
std::iota(order.begin(), order.end(), 0);
|
std::iota(order.begin(), order.end(), 0);
|
||||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||||
@@ -102,7 +102,6 @@ void coalesce::run(ir::module &mod) {
|
|||||||
n_op = builder.insert(n_op);
|
n_op = builder.insert(n_op);
|
||||||
replaced.insert({i_op, n_op});
|
replaced.insert({i_op, n_op});
|
||||||
order_[n_op] = order;
|
order_[n_op] = order;
|
||||||
align_->copy(n_op, i_op);
|
|
||||||
mem_->copy(n_op, i_op);
|
mem_->copy(n_op, i_op);
|
||||||
if(original)
|
if(original)
|
||||||
n_op->erase_use(original);
|
n_op->erase_use(original);
|
||||||
|
@@ -32,8 +32,8 @@ bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
|||||||
|
|
||||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
|
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
|
||||||
unsigned offset = alloc_->get_offset(v);
|
unsigned offset = alloc_->offset(v);
|
||||||
unsigned num_bytes = alloc_->get_num_bytes(v);
|
unsigned num_bytes = alloc_->num_bytes(v);
|
||||||
res.push_back(interval_t(offset, offset + num_bytes));
|
res.push_back(interval_t(offset, offset + num_bytes));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -94,9 +94,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
|||||||
params_->copy(new_value, old_value);
|
params_->copy(new_value, old_value);
|
||||||
params_->copy(new_lhs, old_value);
|
params_->copy(new_lhs, old_value);
|
||||||
params_->copy(new_rhs, old_value);
|
params_->copy(new_rhs, old_value);
|
||||||
align_->copy(new_value, old_value);
|
|
||||||
align_->copy(new_lhs, old_value);
|
|
||||||
align_->copy(new_rhs, old_value);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -134,9 +131,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
|||||||
params_->copy(new_value, old_value);
|
params_->copy(new_value, old_value);
|
||||||
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
||||||
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
||||||
align_->copy(new_value, old_value);
|
|
||||||
align_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
|
||||||
align_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,9 +186,6 @@ void reassociate::run(ir::module &mod) {
|
|||||||
params_->copy(dyn_range, old_range);
|
params_->copy(dyn_range, old_range);
|
||||||
params_->copy(static_range, old_range);
|
params_->copy(static_range, old_range);
|
||||||
params_->copy(new_range, old_range);
|
params_->copy(new_range, old_range);
|
||||||
align_->copy(dyn_range, old_range);
|
|
||||||
align_->copy(static_range, old_range);
|
|
||||||
align_->copy(new_range, old_range);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,9 +217,6 @@ void reassociate::run(ir::module &mod) {
|
|||||||
params_->copy(ndyn, rt);
|
params_->copy(ndyn, rt);
|
||||||
params_->copy(nsta, rt);
|
params_->copy(nsta, rt);
|
||||||
params_->copy(broadcast, rt);
|
params_->copy(broadcast, rt);
|
||||||
align_->copy(ndyn, rt);
|
|
||||||
align_->copy(nsta, rt);
|
|
||||||
align_->copy(broadcast, rt);
|
|
||||||
infos[rt] = cst_info{ndyn, nsta};
|
infos[rt] = cst_info{ndyn, nsta};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -250,8 +238,6 @@ void reassociate::run(ir::module &mod) {
|
|||||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||||
params_->copy(dyn_ptr, pz);
|
params_->copy(dyn_ptr, pz);
|
||||||
params_->copy(sta_ptr, pz);
|
params_->copy(sta_ptr, pz);
|
||||||
align_->copy(dyn_ptr, pz);
|
|
||||||
align_->copy(sta_ptr, pz);
|
|
||||||
pz->replace_all_uses_with(sta_ptr);
|
pz->replace_all_uses_with(sta_ptr);
|
||||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||||
@@ -268,8 +254,6 @@ void reassociate::run(ir::module &mod) {
|
|||||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||||
params_->copy(pz_dyn, pz);
|
params_->copy(pz_dyn, pz);
|
||||||
params_->copy(pz_sta, pz);
|
params_->copy(pz_sta, pz);
|
||||||
align_->copy(pz_dyn, pz);
|
|
||||||
align_->copy(pz_sta, pz);
|
|
||||||
pz->replace_all_uses_with(pz_sta);
|
pz->replace_all_uses_with(pz_sta);
|
||||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||||
@@ -320,11 +304,6 @@ void reassociate::run(ir::module &mod) {
|
|||||||
params_->copy(neg_off, off);
|
params_->copy(neg_off, off);
|
||||||
params_->copy(phi_dyn, phi);
|
params_->copy(phi_dyn, phi);
|
||||||
params_->copy(phi_sta, phi);
|
params_->copy(phi_sta, phi);
|
||||||
align_->copy(pz_dyn, pz);
|
|
||||||
align_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
|
|
||||||
align_->copy(neg_off, off);
|
|
||||||
align_->copy(phi_dyn, phi);
|
|
||||||
align_->copy(phi_sta, phi);
|
|
||||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||||
replaced.insert(phi);
|
replaced.insert(phi);
|
||||||
|
@@ -197,24 +197,25 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
codegen::analysis::meminfo shmem_info;
|
codegen::analysis::meminfo shmem_info;
|
||||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||||
codegen::analysis::align alignment_info;
|
codegen::analysis::align alignment_info;
|
||||||
codegen::transform::coalesce reorder(&alignment_info, &shmem_info);
|
codegen::transform::coalesce coalesce(&alignment_info, &shmem_info);
|
||||||
codegen::analysis::grids grids(opt.num_warps, &reorder);
|
codegen::analysis::grids grids(opt.num_warps, &coalesce);
|
||||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||||
codegen::transform::vectorize vectorize(&grids);
|
codegen::transform::vectorize vectorize(&grids);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole;
|
||||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &reorder, target.get(), opt.num_warps);
|
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &coalesce, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
alignment_info.run(module);
|
alignment_info.run(module);
|
||||||
if(target->is_gpu())
|
if(target->is_gpu())
|
||||||
shmem_info.run(module);
|
shmem_info.run(module);
|
||||||
reorder.run(module);
|
coalesce.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
grids.run(module);
|
grids.run(module);
|
||||||
|
alignment_info.run(module);
|
||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
@@ -222,13 +223,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
shmem_info.run(module);
|
shmem_info.run(module);
|
||||||
shmem_liveness.run(module);
|
shmem_liveness.run(module);
|
||||||
shmem_allocation.run();
|
shmem_allocation.run();
|
||||||
if(shmem_allocation.get_allocated_size() > context->device()->max_shared_memory())
|
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
return std::unique_ptr<driver::module>();
|
return std::unique_ptr<driver::module>();
|
||||||
shmem_barriers.run(module);
|
shmem_barriers.run(module);
|
||||||
}
|
}
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
alignment_info.run(module);
|
||||||
// ir::print(module, std::cout);
|
// ir::print(module, std::cout);
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
|
Reference in New Issue
Block a user