more cleaning
This commit is contained in:
@@ -401,34 +401,9 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||
}
|
||||
|
||||
|
||||
Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
|
||||
Value *ret = nullptr;
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(ctx);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
ret = builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void selection::run(ir::module &src, Module &dst) {
|
||||
vmap_.clear();
|
||||
tmap_.clear();
|
||||
|
||||
LLVMContext &ctx = dst.getContext();
|
||||
IRBuilder<> builder(ctx);
|
||||
|
||||
// allocate shared memory
|
||||
Value *sh_mem_ptr = alloc_shared(builder, dst);
|
||||
|
||||
// create tile
|
||||
generator gen(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ );
|
||||
generator gen(&dst, tgt_, layouts_, alignment_, alloc_, num_warps_ );
|
||||
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
gen.visit_value(x);
|
||||
@@ -438,6 +413,32 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
|
||||
|
||||
|
||||
generator::generator(Module *dst,
|
||||
target *tgt,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
unsigned num_warps)
|
||||
: ctx_(&dst->getContext()), mod_(dst),
|
||||
builder_(new Builder(dst->getContext())),
|
||||
tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc),
|
||||
num_warps_(num_warps) {
|
||||
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(*dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_value(ir::value* v) {
|
||||
if(!seen_.insert(v).second)
|
||||
return;
|
||||
@@ -544,11 +545,12 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
|
||||
// vector loads
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
@@ -562,6 +564,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
// extract result element
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
@@ -573,13 +576,13 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(axes_.at(a_axes_->get(x, ld)).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
@@ -633,6 +636,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// extract result element
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
@@ -907,8 +911,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
}
|
||||
void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add) {
|
||||
TA->set_vector_size(axes_.at(a_axes_->get(dot, 0)).contiguous);
|
||||
TB->set_vector_size(axes_.at(a_axes_->get(dot, 1)).contiguous);
|
||||
TA->set_vector_size(TD->axis(0).contiguous);
|
||||
TB->set_vector_size(TD->axis(1).contiguous);
|
||||
for_each(dot, [&](indices_t idx){
|
||||
Value *res = TD->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
@@ -1162,16 +1166,16 @@ void generator::visit_function(ir::function* fn) {
|
||||
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_shared_t(mod_, builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||
}
|
||||
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
@@ -1270,9 +1274,9 @@ tile* machine_layout_shared_t::create(ir::value *v) {
|
||||
}
|
||||
|
||||
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_t *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
@@ -1282,7 +1286,7 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
unsigned x = layout_->axes[d];
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
@@ -1294,10 +1298,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
||||
}
|
||||
|
||||
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty, analysis::axes *a_axes,
|
||||
target *tgt, Type *ty,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_hmma_884_t* layout)
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
@@ -1413,9 +1417,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
|
||||
|
||||
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::layout_scanline_t* layout)
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
|
Reference in New Issue
Block a user