more cleaning

This commit is contained in:
Philippe Tillet
2019-10-17 08:17:23 -04:00
parent a157177267
commit a0182f41dd
3 changed files with 56 additions and 72 deletions

View File

@@ -178,7 +178,7 @@ public:
class machine_layout_distributed_t: public machine_layout_t {
public:
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_t* layout);
tile* create(ir::value *v);
@@ -186,7 +186,6 @@ public:
Builder *builder_;
target *tgt_;
Type *ty_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::layout_t* layout_;
};
@@ -196,7 +195,7 @@ class machine_layout_hmma_884_t: public machine_layout_distributed_t {
public:
machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
@@ -210,7 +209,7 @@ class machine_layout_scanline_t: public machine_layout_distributed_t {
public:
machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_scanline_t* layout);
};
@@ -230,22 +229,12 @@ private:
void finalize_phi_node(ir::phi_node*);
public:
generator(LLVMContext *ctx,
Module *dst,
Builder *builder,
analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap,
generator(Module *dst,
target *tgt,
analysis::layout *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
Value *sh_mem_ptr,
unsigned num_warps)
: ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
num_warps_(num_warps) { }
unsigned num_warps);
void visit_value(ir::value* v);
@@ -305,14 +294,13 @@ public:
private:
LLVMContext *ctx_;
Builder *builder_;
std::unique_ptr<Builder> builder_;
Module *mod_;
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
target *tgt_;
analysis::layout *layouts_;
analysis::align *alignment_;
@@ -329,30 +317,22 @@ class selection{
typedef std::map<ir::value *, Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
private:
// LLVM conversions
Value* alloc_shared(Builder &builder, Module& dst);
public:
selection(analysis::liveness* liveness, analysis::allocation *alloc,
analysis::align *alignment, analysis::axes *axes,
analysis::align *alignment,
analysis::layout *layouts, target *tgt, unsigned num_warps)
: liveness_(liveness), alloc_(alloc),
alignment_(alignment), a_axes_(axes), layouts_(layouts),
alignment_(alignment), layouts_(layouts),
tgt_(tgt), num_warps_(num_warps){ }
void run(ir::module &src, Module &dst);
private:
vmap_t vmap_;
tmap_t tmap_;
analysis::liveness *liveness_;
analysis::allocation *alloc_;
analysis::axes *a_axes_;
analysis::layout *layouts_;
analysis::align *alignment_;
target *tgt_;
std::map<unsigned, distributed_axis> axes_;
unsigned num_warps_;
};

View File

@@ -401,34 +401,9 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
}
Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
Value *ret = nullptr;
LLVMContext &ctx = builder.getContext();
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(ctx);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
ret = builder.CreateBitCast(sh_mem_array, ptr_ty);
}
return ret;
}
void selection::run(ir::module &src, Module &dst) {
vmap_.clear();
tmap_.clear();
LLVMContext &ctx = dst.getContext();
IRBuilder<> builder(ctx);
// allocate shared memory
Value *sh_mem_ptr = alloc_shared(builder, dst);
// create tile
generator gen(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ );
generator gen(&dst, tgt_, layouts_, alignment_, alloc_, num_warps_ );
for(ir::alloc_const *x: src.allocs())
gen.visit_value(x);
@@ -438,6 +413,32 @@ void selection::run(ir::module &src, Module &dst) {
generator::generator(Module *dst,
target *tgt,
analysis::layout *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
unsigned num_warps)
: ctx_(&dst->getContext()), mod_(dst),
builder_(new Builder(dst->getContext())),
tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc),
num_warps_(num_warps) {
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(*dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
}
}
void generator::visit_value(ir::value* v) {
if(!seen_.insert(v).second)
return;
@@ -544,11 +545,12 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
// vector loads
std::map<unsigned, Value*> packets;
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
@@ -562,6 +564,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
// extract result element
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
@@ -573,13 +576,13 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
std::map<unsigned, Value*> packets;
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
@@ -633,6 +636,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// extract result element
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
@@ -907,8 +911,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
}
void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add) {
TA->set_vector_size(axes_.at(a_axes_->get(dot, 0)).contiguous);
TB->set_vector_size(axes_.at(a_axes_->get(dot, 1)).contiguous);
TA->set_vector_size(TD->axis(0).contiguous);
TB->set_vector_size(TD->axis(1).contiguous);
for_each(dot, [&](indices_t idx){
Value *res = TD->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
@@ -1162,16 +1166,16 @@ void generator::visit_function(ir::function* fn) {
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
}
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
machine_layouts_[layout] = new machine_layout_shared_t(mod_, builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
}
void generator::visit_basic_block(ir::basic_block * block) {
@@ -1270,9 +1274,9 @@ tile* machine_layout_shared_t::create(ir::value *v) {
}
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_t *layout)
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), axes_(axes), layout_(layout) {
}
@@ -1282,7 +1286,7 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
unsigned x = layout_->axes[d];
axes[d] = axes_.at(x);
}
else{
@@ -1294,10 +1298,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
}
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty, analysis::axes *a_axes,
target *tgt, Type *ty,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
@@ -1413,9 +1417,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
std::map<unsigned, distributed_axis> &axes,
analysis::layout_scanline_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);

View File

@@ -217,7 +217,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::reassociate reassociate(&align);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::cts cts;
codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps);
codegen::selection selection(&liveness, &allocation, &align, &layouts, target.get(), opt.num_warps);
// run passes
// ir::print(module, std::cout);
peephole.run(module);