more cleaning
This commit is contained in:
@@ -178,7 +178,7 @@ public:
|
||||
class machine_layout_distributed_t: public machine_layout_t {
|
||||
public:
|
||||
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_t* layout);
|
||||
|
||||
tile* create(ir::value *v);
|
||||
@@ -186,7 +186,6 @@ public:
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
Type *ty_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
analysis::layout_t* layout_;
|
||||
};
|
||||
@@ -196,7 +195,7 @@ class machine_layout_hmma_884_t: public machine_layout_distributed_t {
|
||||
public:
|
||||
machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_hmma_884_t* layout);
|
||||
Value *offset_a_i_, *offset_a_k_;
|
||||
Value *offset_b_j_, *offset_b_k_;
|
||||
@@ -210,7 +209,7 @@ class machine_layout_scanline_t: public machine_layout_distributed_t {
|
||||
public:
|
||||
machine_layout_scanline_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_scanline_t* layout);
|
||||
};
|
||||
|
||||
@@ -230,22 +229,12 @@ private:
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
public:
|
||||
generator(LLVMContext *ctx,
|
||||
Module *dst,
|
||||
Builder *builder,
|
||||
analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap,
|
||||
generator(Module *dst,
|
||||
target *tgt,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
Value *sh_mem_ptr,
|
||||
unsigned num_warps)
|
||||
: ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
|
||||
num_warps_(num_warps) { }
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
|
||||
@@ -305,14 +294,13 @@ public:
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder *builder_;
|
||||
std::unique_ptr<Builder> builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
std::map<ir::value *, Value *> vmap_;
|
||||
std::map<ir::value *, tile *> tmap_;
|
||||
target *tgt_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
@@ -329,30 +317,22 @@ class selection{
|
||||
typedef std::map<ir::value *, Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
|
||||
private:
|
||||
// LLVM conversions
|
||||
Value* alloc_shared(Builder &builder, Module& dst);
|
||||
|
||||
public:
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc,
|
||||
analysis::align *alignment, analysis::axes *axes,
|
||||
analysis::align *alignment,
|
||||
analysis::layout *layouts, target *tgt, unsigned num_warps)
|
||||
: liveness_(liveness), alloc_(alloc),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
alignment_(alignment), layouts_(layouts),
|
||||
tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
analysis::liveness *liveness_;
|
||||
analysis::allocation *alloc_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
target *tgt_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
unsigned num_warps_;
|
||||
};
|
||||
|
||||
|
@@ -401,34 +401,9 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||
}
|
||||
|
||||
|
||||
Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
|
||||
Value *ret = nullptr;
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(ctx);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
ret = builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void selection::run(ir::module &src, Module &dst) {
|
||||
vmap_.clear();
|
||||
tmap_.clear();
|
||||
|
||||
LLVMContext &ctx = dst.getContext();
|
||||
IRBuilder<> builder(ctx);
|
||||
|
||||
// allocate shared memory
|
||||
Value *sh_mem_ptr = alloc_shared(builder, dst);
|
||||
|
||||
// create tile
|
||||
generator gen(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ );
|
||||
generator gen(&dst, tgt_, layouts_, alignment_, alloc_, num_warps_ );
|
||||
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
gen.visit_value(x);
|
||||
@@ -438,6 +413,32 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
|
||||
|
||||
|
||||
generator::generator(Module *dst,
|
||||
target *tgt,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
unsigned num_warps)
|
||||
: ctx_(&dst->getContext()), mod_(dst),
|
||||
builder_(new Builder(dst->getContext())),
|
||||
tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc),
|
||||
num_warps_(num_warps) {
|
||||
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(*dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_value(ir::value* v) {
|
||||
if(!seen_.insert(v).second)
|
||||
return;
|
||||
@@ -544,11 +545,12 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
|
||||
// vector loads
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
@@ -562,6 +564,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
// extract result element
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
@@ -573,13 +576,13 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
@@ -633,6 +636,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// extract result element
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
@@ -907,8 +911,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
}
|
||||
void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add) {
|
||||
TA->set_vector_size(axes_.at(a_axes_->get(dot, 0)).contiguous);
|
||||
TB->set_vector_size(axes_.at(a_axes_->get(dot, 1)).contiguous);
|
||||
TA->set_vector_size(TD->axis(0).contiguous);
|
||||
TB->set_vector_size(TD->axis(1).contiguous);
|
||||
for_each(dot, [&](indices_t idx){
|
||||
Value *res = TD->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
@@ -1162,16 +1166,16 @@ void generator::visit_function(ir::function* fn) {
|
||||
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_shared_t(mod_, builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||
}
|
||||
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
@@ -1270,9 +1274,9 @@ tile* machine_layout_shared_t::create(ir::value *v) {
|
||||
}
|
||||
|
||||
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_t *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
@@ -1282,7 +1286,7 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
unsigned x = layout_->axes[d];
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
@@ -1294,10 +1298,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
||||
}
|
||||
|
||||
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty, analysis::axes *a_axes,
|
||||
target *tgt, Type *ty,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_hmma_884_t* layout)
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
@@ -1413,9 +1417,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
|
||||
|
||||
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::layout_scanline_t* layout)
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
|
@@ -217,7 +217,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::cts cts;
|
||||
codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps);
|
||||
codegen::selection selection(&liveness, &allocation, &align, &layouts, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
// ir::print(module, std::cout);
|
||||
peephole.run(module);
|
||||
|
Reference in New Issue
Block a user