more cleaning
This commit is contained in:
@@ -146,6 +146,21 @@ private:
|
||||
Builder &builder_;
|
||||
};
|
||||
|
||||
class machine_layout_t {
|
||||
|
||||
};
|
||||
|
||||
class machine_layout_shared_t: public machine_layout_t {
|
||||
|
||||
};
|
||||
|
||||
class machine_layout_hmma_884_t: public machine_layout_t {
|
||||
|
||||
};
|
||||
|
||||
class machine_layout_scanline_t: public machine_layout_t {
|
||||
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
@@ -160,9 +175,7 @@ private:
|
||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
||||
|
||||
public:
|
||||
|
||||
generator(LLVMContext *ctx,
|
||||
Function *fn,
|
||||
Module *dst,
|
||||
Builder *builder,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
@@ -178,7 +191,7 @@ public:
|
||||
unsigned num_packs_0, unsigned num_packs_1,
|
||||
unsigned pack_size_0, unsigned pack_size_1,
|
||||
unsigned num_warps)
|
||||
: ctx_(ctx), fn_(fn), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
: ctx_(ctx), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
|
||||
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
|
||||
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
|
||||
@@ -243,6 +256,7 @@ private:
|
||||
Builder *builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<analysis::layout_t*, machine_layout_t*> machine_layouts_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
@@ -258,6 +272,8 @@ private:
|
||||
unsigned num_warps_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, Value *> vmap_t;
|
||||
@@ -266,8 +282,6 @@ class selection{
|
||||
private:
|
||||
// LLVM conversions
|
||||
Type* llvm_type(ir::type *ty, LLVMContext &ctx);
|
||||
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
|
||||
Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst);
|
||||
Value* alloc_shared(Builder &builder, Module& dst);
|
||||
|
||||
// grid construction
|
||||
|
@@ -344,16 +344,6 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
}
|
||||
|
||||
|
||||
/* convert ir::alloc_const to llvm::GlobalVariable */
|
||||
Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) {
|
||||
unsigned size = ((ir::constant_int*)v->get_operand(0))->get_value();
|
||||
Type *element_ty = llvm_type(v->get_type()->get_pointer_element_ty(), module->getContext());
|
||||
Type *array_ty = llvm::ArrayType::get(element_ty, size);
|
||||
Value *array = new llvm::GlobalVariable(*module, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
|
||||
nullptr, v->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
return builder.CreateBitCast(array, element_ty->getPointerTo(4));
|
||||
}
|
||||
|
||||
|
||||
/* -------------------
|
||||
* ---- Init Axes ----
|
||||
@@ -384,17 +374,17 @@ inline int32_t ceil(int32_t num, int32_t div){
|
||||
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
||||
if(tmap_.find(v) != tmap_.end())
|
||||
return;
|
||||
auto order = layouts_->get(v)->order;
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned pad = layouts_->get(v)->pad;
|
||||
if(pad > 0)
|
||||
shapes[order[0]] += pad;
|
||||
analysis::layout_shared_t *layout = (analysis::layout_shared_t*)layouts_->get(v);
|
||||
auto order = layout->order;
|
||||
auto shapes = layout->shapes;
|
||||
shapes[order[0]] += layout->pad;
|
||||
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layouts_->get(v)->double_buffer) {
|
||||
auto info = *layouts_->get(v)->double_buffer;
|
||||
if(layout->double_buffer) {
|
||||
auto info = *layout->double_buffer;
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
|
||||
if(parent->empty())
|
||||
@@ -461,9 +451,8 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
|
||||
if(src->get_type()->is_tile_ty()){
|
||||
builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin());
|
||||
auto *i = dynamic_cast<ir::instruction*>(src);
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(src)){
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED)
|
||||
create_shared_tile(i, builder, sh_mem_ptr_);
|
||||
}
|
||||
else
|
||||
create_distributed_tile(src, builder);
|
||||
}
|
||||
@@ -502,47 +491,6 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||
}
|
||||
|
||||
|
||||
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), ctx);
|
||||
FunctionType *dst_fn_ty = fn_ty;
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||
std::vector<Type*> dst_fn_args_ty;
|
||||
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
||||
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
||||
dst_fn_args_ty.push_back(builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(builder.getInt32Ty());
|
||||
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||
}
|
||||
Function *ret = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
||||
// set attributes
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
if(attr.is_llvm_attr())
|
||||
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||
}
|
||||
// set metadata
|
||||
tgt_->set_kernel(builder, ctx, &dst, ret);
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(builder.getInt32(num_warps_*32))
|
||||
};
|
||||
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
// map parameters
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
|
||||
// create blocks
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||
vmap_[block] = dst_block;
|
||||
}
|
||||
builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
||||
}
|
||||
|
||||
Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
|
||||
Value *ret = nullptr;
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
@@ -566,24 +514,22 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
LLVMContext &dst_ctx = dst.getContext();
|
||||
IRBuilder<> dst_builder(dst_ctx);
|
||||
|
||||
// constant memory
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
|
||||
|
||||
// allocate shared memory
|
||||
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||
|
||||
// iterate over functions
|
||||
std::set<ir::value*> seen;
|
||||
|
||||
// create tile
|
||||
generator gen(&dst_ctx, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
x->accept(&gen);
|
||||
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
|
||||
// create LLVM function
|
||||
Function *ffn = llvm_fn(fn, dst_builder, dst);
|
||||
|
||||
// create tile
|
||||
generator gen(&dst_ctx, ffn, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
fn->accept(&gen);
|
||||
|
||||
// initialize layouts
|
||||
for(auto x: layouts_->get_all())
|
||||
@@ -656,18 +602,6 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
}
|
||||
|
||||
|
||||
/* -----------------------------------------------------
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
* ------------------------------------------------------ */
|
||||
|
||||
|
||||
|
||||
@@ -1355,8 +1289,46 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_function(ir::function*) {
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type());
|
||||
FunctionType *dst_fn_ty = fn_ty;
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||
std::vector<Type*> dst_fn_args_ty;
|
||||
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
||||
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
||||
dst_fn_args_ty.push_back(builder_->getInt32Ty());
|
||||
dst_fn_args_ty.push_back(builder_->getInt32Ty());
|
||||
dst_fn_args_ty.push_back(builder_->getInt32Ty());
|
||||
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||
}
|
||||
Function *ret = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
|
||||
// set attributes
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
if(attr.is_llvm_attr())
|
||||
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||
}
|
||||
// set metadata
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(builder_->getInt32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
// map parameters
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
|
||||
// create blocks
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||
vmap_[block] = dst_block;
|
||||
}
|
||||
builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
||||
fn_ = ret;
|
||||
}
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
@@ -1469,6 +1441,8 @@ void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t();
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
@@ -1498,10 +1472,13 @@ void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
}
|
||||
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
||||
}
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t();
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t*) {
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_shared_t();
|
||||
}
|
||||
|
||||
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
|
||||
|
Reference in New Issue
Block a user