more cleaning

This commit is contained in:
Philippe Tillet
2019-10-17 00:36:46 -04:00
parent 4bfe998cc8
commit ae24621825
10 changed files with 205 additions and 322 deletions

View File

@@ -36,6 +36,7 @@ struct double_buffer_info_t {
};
class layout_visitor;
class layout_t;
class layout_hmma_884_t;
class layout_scanline_t;
class layout_shared_t;
@@ -43,6 +44,7 @@ class layout_shared_t;
class layout_visitor {
public:
virtual void visit_layout(layout_t *);
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0;
virtual void visit_layout_scanline(layout_scanline_t*) = 0;
virtual void visit_layout_shared(layout_shared_t*) = 0;

View File

@@ -197,16 +197,13 @@ public:
machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k,
unsigned &pack_size_0, unsigned &pack_size_1,
unsigned &num_packs_0, unsigned &num_packs_1,
analysis::layout_hmma_884_t* layout);
Value *&offset_a_i_, *&offset_a_k_;
Value *&offset_b_j_, *&offset_b_k_;
unsigned &pack_size_0_;
unsigned& pack_size_1_;
unsigned &num_packs_0_;
unsigned& num_packs_1_;
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_;
unsigned pack_size_1_;
unsigned num_packs_0_;
unsigned num_packs_1_;
};
class machine_layout_scanline_t: public machine_layout_distributed_t {
@@ -219,15 +216,18 @@ public:
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
public:
generator(LLVMContext *ctx,
Module *dst,
@@ -241,18 +241,12 @@ public:
analysis::align *alignment,
analysis::allocation *alloc,
Value *sh_mem_ptr,
Value *offset_a_i, Value *offset_a_k,
Value *offset_b_j, Value *offset_b_k,
unsigned num_packs_0, unsigned num_packs_1,
unsigned pack_size_0, unsigned pack_size_1,
unsigned num_warps)
: ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
num_warps_(num_warps) { }
machine_layout_t *get_machine_layout(const analysis::layout_t *layout) { return machine_layouts_.at(layout); }
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
@@ -301,6 +295,8 @@ public:
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
void visit_layout_scanline(analysis::layout_scanline_t*);
@@ -308,7 +304,6 @@ public:
private:
LLVMContext *ctx_;
Function *fn_;
Builder *builder_;
Module *mod_;
@@ -322,78 +317,9 @@ private:
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *sh_mem_ptr_;
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned num_packs_0_, num_packs_1_;
unsigned pack_size_0_, pack_size_1_;
unsigned num_warps_;
};
class finalizer: public ir::visitor, public analysis::layout_visitor {
private:
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
public:
finalizer(Builder *builder, std::map<ir::value *, Value *>& vmap, std::map<ir::value *, tile *>& tmap);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*) { }
void visit_getelementptr_inst(ir::getelementptr_inst*) { }
void visit_icmp_inst(ir::icmp_inst*) { }
void visit_fcmp_inst(ir::fcmp_inst*) { }
void visit_cast_inst(ir::cast_inst*) { }
void visit_return_inst(ir::return_inst*) { }
void visit_cond_branch_inst(ir::cond_branch_inst*) { }
void visit_uncond_branch_inst(ir::uncond_branch_inst*) { }
void visit_unmasked_load_inst(ir::unmasked_load_inst*) { }
void visit_masked_load_inst(ir::masked_load_inst*) { }
void visit_unmasked_store_inst(ir::unmasked_store_inst*) { }
void visit_masked_store_inst(ir::masked_store_inst*) { }
void visit_reshape_inst(ir::reshape_inst*) { }
void visit_splat_inst(ir::splat_inst*) { }
void visit_broadcast_inst(ir::broadcast_inst*) { }
void visit_downcast_inst(ir::downcast_inst*) { }
void visit_get_program_id_inst(ir::get_program_id_inst*) { }
void visit_get_num_program_inst(ir::get_num_program_inst*) { }
void visit_atomic_cas_inst(ir::atomic_cas_inst*) { }
void visit_atomic_exch_inst(ir::atomic_exch_inst*) { }
void visit_atomic_add_inst(ir::atomic_add_inst*) { }
void visit_dot_inst(ir::dot_inst*) { }
void visit_trans_inst(ir::trans_inst*) { }
void visit_sqrt_inst(ir::sqrt_inst*) { }
void visit_reduce_inst(ir::reduce_inst*) { }
void visit_select_inst(ir::select_inst*) { }
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*) { }
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { }
void visit_barrier_inst(ir::barrier_inst*) { }
void visit_make_range_dyn(ir::make_range_dyn*) { }
void visit_make_range(ir::make_range*) { }
void visit_make_range_sta(ir::make_range_sta*) { }
void visit_undef_value(ir::undef_value*) { }
void visit_constant_int(ir::constant_int*) { }
void visit_constant_fp(ir::constant_fp*) { }
void visit_alloc_const(ir::alloc_const*) { }
void visit_function(ir::function*) { }
void visit_layout_hmma_884(analysis::layout_hmma_884_t*) { }
void visit_layout_scanline(analysis::layout_scanline_t*) { }
void visit_layout_shared(analysis::layout_shared_t*);
private:
Builder *builder_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
std::set<ir::value*> seen_;
};
// Selection pass
@@ -405,9 +331,6 @@ private:
// LLVM conversions
Value* alloc_shared(Builder &builder, Module& dst);
// lower scalar instruction
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);
public:
selection(analysis::liveness* liveness, analysis::allocation *alloc,
analysis::align *alignment, analysis::axes *axes,
@@ -428,11 +351,6 @@ private:
analysis::align *alignment_;
target *tgt_;
std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_;
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned num_packs_0_, num_packs_1_;
unsigned pack_size_0_, pack_size_1_;
unsigned num_warps_;
};

View File

@@ -6,6 +6,7 @@
#include <string>
#include <list>
#include "value.h"
#include "visitor.h"
namespace triton{
namespace ir{
@@ -66,6 +67,9 @@ public:
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent);
// visitor
void accept(visitor *v) { v->visit_basic_block(this); }
private:
context &ctx_;
std::string name_;

View File

@@ -26,6 +26,8 @@ public:
function* get_parent() const;
unsigned get_arg_no() const;
void accept(visitor *v);
private:
function *parent_;
unsigned arg_no_;

View File

@@ -33,6 +33,8 @@ public:
void set_name(const std::string &name);
const std::string &get_name() const { return name_; }
type* get_type() const { return ty_; }
// visitor
virtual void accept(visitor *v) = 0;
private:
std::string name_;
@@ -75,8 +77,6 @@ public:
void replace_all_uses_with(value *target);
void replace_uses_of_with(value *before, value *after);
// Visitor
virtual void accept(visitor *v) = 0;
private:
ops_t ops_;

View File

@@ -7,6 +7,8 @@
namespace triton{
namespace ir{
class value;
class instruction;
class phi_node;
@@ -81,10 +83,18 @@ class alloc_const;
class function;
class basic_block;
class argument;
class visitor {
public:
virtual ~visitor() {}
virtual void visit_value(ir::value*);
virtual void visit_basic_block(basic_block*) = 0;
virtual void visit_argument(argument*) = 0;
virtual void visit_phi_node(phi_node*) = 0;
virtual void visit_binary_operator(binary_operator*) = 0;
virtual void visit_getelementptr_inst(getelementptr_inst*) = 0;

View File

@@ -124,6 +124,10 @@ inline bool is_trans(ir::value *v) {
}
void layout_visitor::visit_layout(layout_t *layout) {
layout->accept(this);
}
layout_t::layout_t(layout_type_t _type,
const std::vector<int> &_axes,
@@ -145,6 +149,7 @@ layout_t::layout_t(layout_type_t _type,
}
}
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}

View File

@@ -385,31 +385,6 @@ bool is_trans(ir::value *v) {
}
void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen, std::set<ir::value*>& seen) {
if(!seen.insert(src).second)
return;
if(src->get_type()->is_tile_ty())
tmap_[src] = gen->get_machine_layout(layouts_->get(src))->create(src);
BasicBlock *current = builder.GetInsertBlock();
auto *inst = dynamic_cast<ir::instruction*>(src);
if(inst && !dynamic_cast<ir::phi_node*>(src))
for(ir::value *op: inst->ops())
lower_value(op, builder, gen, seen);
builder.SetInsertPoint(current);
auto *phi = dynamic_cast<ir::phi_node*>(src);
if(phi && !current->empty() && current->getFirstNonPHI())
builder.SetInsertPoint(&*current->getFirstNonPHI());
if(auto *usr = dynamic_cast<ir::user*>(src))
usr->accept(gen);
if(phi && !current->empty() && current->getFirstNonPHI())
builder.SetInsertPoint(current);
}
/* ----------------------------
* ---- Generate LLVM code ----
@@ -445,57 +420,44 @@ Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
void selection::run(ir::module &src, Module &dst) {
vmap_.clear();
tmap_.clear();
LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx);
LLVMContext &ctx = dst.getContext();
IRBuilder<> builder(ctx);
// allocate shared memory
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
// iterate over functions
std::set<ir::value*> seen;
// create tile
generator gen(&dst_ctx, &dst, &dst_builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
finalizer fin(&dst_builder, vmap_, tmap_);
Value *sh_mem_ptr = alloc_shared(builder, dst);
// visit
generator visitor(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ );
for(ir::alloc_const *x: src.allocs())
x->accept(&gen);
for(ir::function *fn: src.get_function_list()) {
fn->accept(&gen);
// initialize layouts
for(auto x: layouts_->get_all())
x.second->accept(&gen);
// generate LLVM-IR code
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list())
lower_value(i, dst_builder, &gen, seen);
vmap_[block] = dst_builder.GetInsertBlock();
}
// finalize double-buffering
for(const auto& x: layouts_->get_all())
x.second->accept(&fin);
// finalize phi
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
inst->accept(&fin);
}
visitor.visit_value(x);
for(ir::function *fn: src.get_function_list())
visitor.visit_value(fn);
}
void generator::visit_value(ir::value* v) {
if(!seen_.insert(v).second)
return;
// create machine tile
if(v->get_type()->is_tile_ty())
tmap_[v] = machine_layouts_.at(layouts_->get(v))->create(v);
// visit operands
BasicBlock *current = builder_->GetInsertBlock();
auto *inst = dynamic_cast<ir::instruction*>(v);
if(inst && !dynamic_cast<ir::phi_node*>(v))
for(ir::value *op: inst->ops())
visit_value(op);
// change insert point for phi node
builder_->SetInsertPoint(current);
auto *phi = dynamic_cast<ir::phi_node*>(v);
if(phi && !current->empty() && current->getFirstNonPHI())
builder_->SetInsertPoint(&*current->getFirstNonPHI());
// visit user
if(auto *usr = dynamic_cast<ir::user*>(v))
usr->accept(this);
// revert insert point
if(phi && !current->empty() && current->getFirstNonPHI())
builder_->SetInsertPoint(current);
}
void generator::visit_phi_node(ir::phi_node* phi) {
Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_);
@@ -574,19 +536,19 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
// vector loads
std::map<unsigned, Value*> packets;
result->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
Value *ptr = pointers->get_value(idx);
ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
@@ -594,25 +556,26 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
}
});
// extract result element
result->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
result->set_value(idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
});
}
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
std::map<unsigned, Value*> packets;
result->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
@@ -664,7 +627,8 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
}
});
// extract result element
result->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
@@ -714,13 +678,13 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
void generator::visit_reshape_inst(ir::reshape_inst* reshape) {
distributed_tile* result = (distributed_tile*)tmap_.at(reshape);
ir::value* in = reshape->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
for_each(reshape, [&](indices_t out_idx){
distributed_tile* result = (distributed_tile*)tmap_.at(reshape);
unsigned pos = result->get_linear_index(out_idx);
ir::value* in = reshape->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
indices_t in_idx = in_tile->get_ordered_indices(pos);
result->set_value(out_idx, in_tile->get_value(in_idx));
set_value(reshape, out_idx, get_value(in, in_idx));
});
}
@@ -732,17 +696,16 @@ void generator::visit_splat_inst(ir::splat_inst* splat) {
}
void generator::visit_broadcast_inst(ir::broadcast_inst* bcast) {
distributed_tile* result = (distributed_tile*)tmap_.at(bcast);
ir::value* in = bcast->get_operand(0);
const auto& in_shapes = in->get_type()->get_tile_shapes();
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
for_each(bcast, [&](indices_t out_idx){
indices_t in_idx = out_idx;
for(size_t k = 0; k < in_idx.size(); k++){
if(in_shapes[k] == 1)
in_idx[k] = builder_->getInt32(0);
}
result->set_value(out_idx, in_tile->get_value(in_idx));
set_value(bcast, out_idx, in_tile->get_value(in_idx));
});
}
@@ -812,17 +775,17 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
throw std::runtime_error("unsupported");
}
void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
const auto& shapes = dot->get_type()->get_tile_shapes();
TA->set_vector_size(4*pack_size_0_);
TB->set_vector_size(4*pack_size_1_);
machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot));
TA->set_vector_size(4*hmma->pack_size_0_);
TB->set_vector_size(4*hmma->pack_size_1_);
TA->set_return_mode(true);
TB->set_return_mode(true);
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
TC->for_each([&](indices_t idx){
for_each(dot, [&](indices_t idx){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
fcs[key].push_back(TD->get_value(idx));
@@ -833,10 +796,6 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t
Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Value *offset_a_i = offset_a_i_;
Value *offset_a_k = offset_a_k_;
Value *offset_b_j = offset_b_j_;
Value *offset_b_k = offset_b_k_;
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
@@ -849,10 +808,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
Value *offset_a_i = hmma->offset_a_i_;
Value *offset_a_k = hmma->offset_a_k_;
if(is_a_row){
offset_a_i = builder_->CreateAdd(offset_a_i, builder_->CreateURem(u_thread_id, builder_->getInt32(4)));
offset_a_k = builder_->getInt32(0);
}
Value *offset_b_j = hmma->offset_b_j_;
Value *offset_b_k = hmma->offset_b_k_;
if(!is_b_row){
offset_b_j = builder_->CreateAdd(offset_b_j, builder_->CreateURem(u_thread_id, builder_->getInt32(4)));
offset_b_k = builder_->getInt32(0);
@@ -881,33 +845,33 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t
for(auto& x: fcs){
std::vector<Value *>& fc = x.second;
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned pack_i = 0; pack_i < hmma->num_packs_0_; pack_i++)
for(unsigned pack_j = 0; pack_j < hmma->num_packs_1_; pack_j++){
for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder_->getInt32(K);
Value *current_offset_a_i = builder_->CreateAdd(offset_a_i, builder_->getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder_->CreateAdd(offset_b_j, builder_->getInt32(pack_j*stride_rep_j*pack_size_1_));
Value *current_offset_a_i = builder_->CreateAdd(offset_a_i, builder_->getInt32(pack_i*stride_rep_i*hmma->pack_size_0_));
Value *current_offset_b_i = builder_->CreateAdd(offset_b_j, builder_->getInt32(pack_j*stride_rep_j*hmma->pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder_->CreateAdd(offset_a_k, _K)};
indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i};
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++)
for(unsigned jj = 0; jj < hmma->pack_size_1_; jj++){
Value *ha0 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*hmma->pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*hmma->pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*hmma->pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*hmma->pack_size_0_ + 1)), fp16x2_ty);
std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
(pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder_->CreateExtractValue(nc, {0});
@@ -925,23 +889,23 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t
// write back
unsigned i = 0;
TC->for_each([&](indices_t idx){
for_each(dot, [&](indices_t idx){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
if(i >= fcs.at(key).size())
i = 0;
TC->set_value(idx, fcs.at(key)[i++]);
set_value(dot, idx, fcs.at(key)[i++]);
});
TA->set_return_mode(false);
TB->set_return_mode(false);
}
void generator::visit_scanline_dot(ir::dot_inst* dot, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add) {
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
TC->for_each([&](indices_t idx){
TA->set_vector_size(axes_.at(a_axes_->get(dot, 0)).contiguous);
TB->set_vector_size(axes_.at(a_axes_->get(dot, 1)).contiguous);
for_each(dot, [&](indices_t idx){
Value *res = TD->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
// input indices
@@ -961,13 +925,13 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, distributed_tile *TC, shar
b = builder_->CreateFPCast(b, c_ty);
res = builder_->CreateCall(f_mul_add, {a, b, res});
}
TC->set_value(idx, res);
set_value(dot, idx, res);
});
}
void generator::visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add) {
TC->for_each([&](indices_t idx){
for_each(dot, [&](indices_t idx){
Value *res = TD->get_value(idx);
indices_t a_idx = {idx[0], builder_->getInt32(0)};
indices_t b_idx = {builder_->getInt32(0), idx[1]};
@@ -980,14 +944,13 @@ void generator::visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed
if(b->getType() != c_ty)
b = builder_->CreateFPCast(b, c_ty);
res = builder_->CreateCall(f_mul_add, {a, b, res});
TC->set_value(idx, res);
set_value(dot, idx, res);
});
}
void generator::visit_dot_inst(ir::dot_inst* dot) {
Function *fn = builder_->GetInsertBlock()->getParent();
distributed_tile* TC = (distributed_tile*)tmap_.at(dot);
Module *module = fn->getParent();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
@@ -1004,14 +967,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(layouts_->get(dot)->type == analysis::HMMA_884)
visit_hmma_dot(dot, TC, TA, TB, TD, NK);
visit_hmma_dot(dot, TA, TB, TD, NK);
else
visit_scanline_dot(dot, TC, TA, TB, TD, NK, c_ty, f_mul_add);
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
}
else {
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
visit_outer_dot(dot, TC, TA, TB, TD, NK, c_ty, f_mul_add);
visit_outer_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
}
}
@@ -1052,15 +1015,14 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
ir::value *arg = cts->get_operand(0);
auto arg_order = layouts_->get(arg)->order;
// tiles
shared_tile* result = (shared_tile*)tmap_.at(cts);
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
if(x_order == arg_order){
size_t ld = arg_order[0];
vector_size = layouts_->get(arg)->nts.at(ld);
}
std::map<unsigned, Value*> packets;
in->for_each([&](indices_t idx){
for_each(arg, [&](indices_t idx){
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
unsigned linear = in->get_linear_index(idx);
unsigned id = linear / vector_size;
Value *in_value = in->get_value(idx);
@@ -1068,19 +1030,19 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size);
});
in->for_each([&](indices_t idx){
for_each(arg, [&](indices_t idx){
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
shared_tile* result = (shared_tile*)tmap_.at(cts);
unsigned linear = in->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0)
result->set_value(idx, packets[id]);
});
}
void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst* cfs) {
distributed_tile* result = (distributed_tile*)tmap_.at(cfs);
shared_tile* arg = (shared_tile*)tmap_.at(cfs->get_operand(0));
result->for_each([&](indices_t idx){
result->set_value(idx, arg->get_value(idx));
for_each(cfs, [&](indices_t idx){
set_value(cfs, idx, get_value(cfs->get_operand(0), idx));
});
}
@@ -1090,33 +1052,30 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
}
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
result->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *res = bin_add->getOperand(0);
result->set_value(idx, res);
set_value(x, idx, res);
});
}
void generator::visit_make_range_sta(ir::make_range_sta* x) {
distributed_tile *T = (distributed_tile *)tmap_.at(x);
T->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *res = bin_add->getOperand(1);
assert(isa<Constant>(res));
T->set_value(idx, res);
set_value(x, idx, res);
});
}
void generator::visit_make_range(ir::make_range* x) {
distributed_tile *T = (distributed_tile *)tmap_.at(x);
T->for_each([&](indices_t idx){
for_each(x, [&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
set_value(x, idx, idx[0]);
});
}
@@ -1149,18 +1108,17 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
void generator::visit_function(ir::function* fn) {
LLVMContext &ctx = builder_->getContext();
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_);
FunctionType *dst_fn_ty = fn_ty;
if(!tgt_->is_gpu()){
Type *dst_fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> dst_fn_args_ty;
Type *fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
dst_fn_args_ty.push_back(builder_->getInt32Ty());
dst_fn_args_ty.push_back(builder_->getInt32Ty());
dst_fn_args_ty.push_back(builder_->getInt32Ty());
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
fn_args_ty.push_back(fn_ty->getParamType(i));
fn_args_ty.push_back(builder_->getInt32Ty());
fn_args_ty.push_back(builder_->getInt32Ty());
fn_args_ty.push_back(builder_->getInt32Ty());
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
}
Function *ret = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
// set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
@@ -1176,7 +1134,7 @@ void generator::visit_function(ir::function* fn) {
ValueAsMetadata::get(builder_->getInt32(num_warps_*32))
};
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
// map parameters
// set arguments
for(unsigned i = 0; i < fn->args().size(); i++)
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
// create blocks
@@ -1185,15 +1143,22 @@ void generator::visit_function(ir::function* fn) {
vmap_[block] = dst_block;
}
builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
fn_ = ret;
// initialize layouts
for(auto x: layouts_->get_all())
visit_layout(x.second);
// generate LLVM-IR code
for(ir::basic_block *block: fn->blocks())
visit_basic_block(block);
// finalize
finalize_function(fn);
}
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_,
pack_size_0_, pack_size_1_,
num_packs_0_, num_packs_1_,
layout);
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
@@ -1205,10 +1170,24 @@ void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
machine_layouts_[layout] = new machine_layout_shared_t(mod_, builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
}
void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list())
visit_value(i);
vmap_[block] = builder_->GetInsertBlock();
}
void generator::visit_argument(ir::argument* arg) {
}
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
if(!x->get_type()->is_tile_ty())
return fn({});
else {
// if(tmap_.find(x) == tmap_.end())
// tmap_[x] = machine_layouts_.at(layouts_->get(x))->create(x);
if(auto *dt = dynamic_cast<distributed_tile*>(tmap_.at(x)))
dt->for_each(fn);
}
@@ -1313,13 +1292,8 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k,
unsigned &pack_size_0, unsigned &pack_size_1,
unsigned &num_packs_0, unsigned &num_packs_1,
analysis::layout_hmma_884_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout),
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1) {
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
@@ -1467,34 +1441,18 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build
}
}
finalizer::finalizer(Builder *builder, std::map<ir::value *, Value *>& vmap, std::map<ir::value *, tile *>& tmap)
: builder_(builder), vmap_(vmap), tmap_(tmap) {
void generator::finalize_function(ir::function* fn) {
// finalize double-buffering
for(const auto& x: layouts_->get_all())
visit_layout(x.second);
// finalize phi
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
finalize_phi_node(phi);
}
void finalizer::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
if(!x->get_type()->is_tile_ty())
return fn({});
else {
if(auto *dt = dynamic_cast<distributed_tile*>(tmap_.at(x)))
dt->for_each(fn);
}
}
Value* finalizer::get_value(ir::value *x, const indices_t& idx) {
if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
return vmap_.at(x);
}
void finalizer::set_value(ir::value *x, const indices_t& idx, Value* v) {
if(x->get_type()->is_tile_ty())
tmap_.at(x)->set_value(idx, v);
else
vmap_[x] = v;
}
void finalizer::visit_phi_node(ir::phi_node* phi) {
void generator::finalize_phi_node(ir::phi_node* phi) {
auto it = tmap_.find(phi);
if(it != tmap_.end() && dynamic_cast<shared_tile*>(it->second))
return;
@@ -1510,32 +1468,6 @@ void finalizer::visit_phi_node(ir::phi_node* phi) {
}
void finalizer::visit_layout_shared(analysis::layout_shared_t* layout) {
if(layout->double_buffer) {
auto info = *layout->double_buffer;
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block* inc_block = phi->get_incoming_block(n);
ir::value* inc_val = phi->get_incoming_value(n);
BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block);
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
if(inc_val == info.latch){
builder_->SetInsertPoint(llvm_inc_block->getTerminator());
Value *next_offset = builder_->CreateNeg(offset);
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
unsigned num_bytes = layout->ty->get_primitive_size_in_bits() / 8;
offset->addIncoming(builder_->getInt32(layout->size / (2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}
}
}
}
}

View File

@@ -25,6 +25,10 @@ unsigned argument::get_arg_no() const {
return arg_no_;
}
void argument::accept(visitor *v) {
v->visit_argument(this);
}
/* function */
function::function(function_type *ty, linkage_types_t linkage,

View File

@@ -32,6 +32,10 @@ void value::replace_all_uses_with(value *target){
throw std::runtime_error("not implemented");
}
void visitor::visit_value(ir::value* v) {
v->accept(this);
}
//===----------------------------------------------------------------------===//
// user class
@@ -69,5 +73,7 @@ void user::replace_uses_of_with(value *before, value *after) {
before->erase_use(this);
}
}
}