more cleaning
This commit is contained in:
@@ -343,16 +343,6 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
/* convert ir::constant to Constant */
|
||||
Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
|
||||
Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx);
|
||||
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
|
||||
return ConstantInt::get(dst_ty, cc->get_value());
|
||||
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
|
||||
return ConstantFP::get(dst_ty, cc->get_value());
|
||||
// unknown constant
|
||||
throw std::runtime_error("unknown conversion from ir::constant to Constant");
|
||||
}
|
||||
|
||||
/* convert ir::alloc_const to llvm::GlobalVariable */
|
||||
Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) {
|
||||
@@ -387,145 +377,6 @@ inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
auto order = layout.order;
|
||||
const auto& shapes = layout.shapes;
|
||||
size_t dim = shapes.size();
|
||||
std::vector<int> nts = layout.nts;
|
||||
std::vector<int> mts = layout.mts;
|
||||
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder.getInt32(nts[k]);
|
||||
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts[k] * mts[k];
|
||||
unsigned per_thread = nts[k] * shapes[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts[k] * per_block + n % nts[k];
|
||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = layout.shapes;
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
|
||||
Value *_1 = builder.getInt32(1);
|
||||
Value *_2 = builder.getInt32(2);
|
||||
Value *_3 = builder.getInt32(3);
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout.fpw.at(0);
|
||||
unsigned fpw_1 = layout.fpw.at(1);
|
||||
unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout.wpt.at(0);
|
||||
unsigned wpt_1 = layout.wpt.at(1);
|
||||
unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// hmma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||
builder.getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||
builder.getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(layout.type == analysis::HMMA_884)
|
||||
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
else if(layout.type == analysis::SCANLINE)
|
||||
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
}
|
||||
|
||||
/* -------------------
|
||||
* ---- Init Tiles ----
|
||||
* ------------------- */
|
||||
@@ -549,7 +400,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
||||
if(parent->empty())
|
||||
builder.SetInsertPoint(parent);
|
||||
else
|
||||
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
||||
builder.SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create double-buffered pointer
|
||||
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
||||
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
||||
@@ -587,41 +438,6 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
tmap_.insert({v, T});
|
||||
}
|
||||
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, seen, sh_mem_ptr);
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
create_shared_tile(i, builder, sh_mem_ptr);
|
||||
else
|
||||
create_distributed_tile(v, builder);
|
||||
}
|
||||
|
||||
void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
|
||||
// fetch linear ID
|
||||
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
|
||||
Value *warp_size = builder.getInt32(32);
|
||||
Value* u_thread_id = tgt_->get_local_id(mod, builder, 0);
|
||||
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
for(auto x: layouts_->get_all())
|
||||
init_axes(*x.second, builder, u_thread_warp_id, u_warp_id);
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
create_tile(i, builder, seen, sh_mem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
@@ -641,51 +457,34 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
|
||||
if(!seen.insert(src).second)
|
||||
return;
|
||||
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
if(src->get_type()->is_tile_ty()){
|
||||
builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin());
|
||||
auto *i = dynamic_cast<ir::instruction*>(src);
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(src)){
|
||||
create_shared_tile(i, builder, sh_mem_ptr_);
|
||||
}
|
||||
else
|
||||
create_distributed_tile(src, builder);
|
||||
}
|
||||
builder.SetInsertPoint(current);
|
||||
|
||||
|
||||
auto *inst = dynamic_cast<ir::instruction*>(src);
|
||||
if(inst && !dynamic_cast<ir::phi_node*>(src))
|
||||
for(ir::value *op: inst->ops())
|
||||
lower_value(op, builder, gen, seen);
|
||||
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
builder.SetInsertPoint(current);
|
||||
auto *phi = dynamic_cast<ir::phi_node*>(src);
|
||||
bool phi_inserted = phi && !current->empty();
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
if(phi && !current->empty() && current->getFirstNonPHI())
|
||||
builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
|
||||
if(auto *usr = dynamic_cast<ir::user*>(src))
|
||||
usr->accept(gen);
|
||||
|
||||
if(dynamic_cast<ir::make_range*>(src)){
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(src);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
else if(dynamic_cast<ir::make_range_sta*>(src)){
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(src);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(res));
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else if(auto *cst = dynamic_cast<ir::constant*>(src)){
|
||||
vmap_[cst] = llvm_constant(cst, builder.getContext());
|
||||
}
|
||||
else if(inst){
|
||||
inst->accept(gen);
|
||||
}
|
||||
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
if(phi && !current->empty() && current->getFirstNonPHI())
|
||||
builder.SetInsertPoint(current);
|
||||
|
||||
// if(dynamic_cast<ir::phi_node*>(src))
|
||||
// for(ir::value *op: inst->ops())
|
||||
// lower_value(op, builder, seen);
|
||||
|
||||
|
||||
}
|
||||
|
||||
/* ----------------------------
|
||||
@@ -702,12 +501,6 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||
}
|
||||
}
|
||||
|
||||
ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) {
|
||||
unsigned size = 1;
|
||||
for(auto shape: ty->get_tile_shapes())
|
||||
size *= shape;
|
||||
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
|
||||
}
|
||||
|
||||
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
@@ -777,6 +570,9 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
|
||||
|
||||
// allocate shared memory
|
||||
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||
|
||||
// iterate over functions
|
||||
std::set<ir::value*> seen;
|
||||
|
||||
@@ -785,14 +581,13 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
// create LLVM function
|
||||
Function *ffn = llvm_fn(fn, dst_builder, dst);
|
||||
|
||||
// allocate shared memory
|
||||
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||
// create tile
|
||||
generator gen(&dst_ctx, ffn, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
|
||||
// initialize layouts
|
||||
init_layouts(fn, dst_builder, sh_mem_ptr_);
|
||||
|
||||
generator gen(&dst_ctx, ffn, &dst_builder, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
for(auto x: layouts_->get_all())
|
||||
x.second->accept(&gen);
|
||||
|
||||
// generate LLVM-IR code
|
||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||
@@ -1536,6 +1331,179 @@ Type *generator::type(ir::type *ty) {
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
void generator::visit_undef_value(ir::undef_value *ud) {
|
||||
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type()));
|
||||
}
|
||||
|
||||
void generator::visit_constant_int(ir::constant_int *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty());
|
||||
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_constant_fp(ir::constant_fp *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty());
|
||||
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
|
||||
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty());
|
||||
Type *array_ty = llvm::ArrayType::get(element_ty, size);
|
||||
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
|
||||
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
vmap_[alloc] = builder_->CreateBitCast(array, element_ty->getPointerTo(4));
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_function(ir::function*) {
|
||||
|
||||
}
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shapes = layout->shapes;
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw.at(0);
|
||||
unsigned fpw_1 = layout->fpw.at(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt.at(0);
|
||||
unsigned wpt_1 = layout->wpt.at(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// hmma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->order;
|
||||
const auto& shapes = layout->shapes;
|
||||
size_t dim = shapes.size();
|
||||
std::vector<int> nts = layout->nts;
|
||||
std::vector<int> mts = layout->mts;
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts[k]);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts[k] * mts[k];
|
||||
unsigned per_thread = nts[k] * shapes[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts[k] * per_block + n % nts[k];
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t*) {
|
||||
|
||||
}
|
||||
|
||||
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
|
||||
if(!x->get_type()->is_tile_ty())
|
||||
return fn({});
|
||||
|
Reference in New Issue
Block a user