more cleaning

This commit is contained in:
Philippe Tillet
2019-10-13 14:43:17 -04:00
parent e787ce0cab
commit ee387ff567
11 changed files with 277 additions and 300 deletions

View File

@@ -343,16 +343,6 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
throw std::runtime_error("unknown conversion from ir::type to Type");
}
/* convert ir::constant to Constant */
Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx);
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
return ConstantInt::get(dst_ty, cc->get_value());
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
return ConstantFP::get(dst_ty, cc->get_value());
// unknown constant
throw std::runtime_error("unknown conversion from ir::constant to Constant");
}
/* convert ir::alloc_const to llvm::GlobalVariable */
Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) {
@@ -387,145 +377,6 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = layout.order;
const auto& shapes = layout.shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout.nts;
std::vector<int> mts = layout.mts;
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder.getInt32(nts[k]);
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = layout.shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
Value *_4 = builder.getInt32(4);
Value *_16 = builder.getInt32(16);
// fragments per warp
unsigned fpw_0 = layout.fpw.at(0);
unsigned fpw_1 = layout.fpw.at(1);
unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1;
// warps per tile
unsigned wpt_0 = layout.wpt.at(0);
unsigned wpt_1 = layout.wpt.at(1);
unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
builder.getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
builder.getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(layout.type == analysis::HMMA_884)
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
else if(layout.type == analysis::SCANLINE)
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
}
/* -------------------
* ---- Init Tiles ----
* ------------------- */
@@ -549,7 +400,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
if(parent->empty())
builder.SetInsertPoint(parent);
else
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
builder.SetInsertPoint(&*parent->getFirstNonPHI());
// create double-buffered pointer
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
@@ -587,41 +438,6 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
tmap_.insert({v, T});
}
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops())
create_tile(op, builder, seen, sh_mem_ptr);
auto *i = dynamic_cast<ir::instruction*>(v);
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(v))
create_shared_tile(i, builder, sh_mem_ptr);
else
create_distributed_tile(v, builder);
}
void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
// fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
Value *warp_size = builder.getInt32(32);
Value* u_thread_id = tgt_->get_local_id(mod, builder, 0);
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
for(auto x: layouts_->get_all())
init_axes(*x.second, builder, u_thread_warp_id, u_warp_id);
// create tile
std::set<ir::value*> seen;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
create_tile(i, builder, seen, sh_mem_ptr);
}
}
bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
@@ -641,51 +457,34 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
if(!seen.insert(src).second)
return;
BasicBlock *current = builder.GetInsertBlock();
if(src->get_type()->is_tile_ty()){
builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin());
auto *i = dynamic_cast<ir::instruction*>(src);
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(src)){
create_shared_tile(i, builder, sh_mem_ptr_);
}
else
create_distributed_tile(src, builder);
}
builder.SetInsertPoint(current);
auto *inst = dynamic_cast<ir::instruction*>(src);
if(inst && !dynamic_cast<ir::phi_node*>(src))
for(ir::value *op: inst->ops())
lower_value(op, builder, gen, seen);
BasicBlock *current = builder.GetInsertBlock();
builder.SetInsertPoint(current);
auto *phi = dynamic_cast<ir::phi_node*>(src);
bool phi_inserted = phi && !current->empty();
if(phi_inserted && current->getFirstNonPHI())
if(phi && !current->empty() && current->getFirstNonPHI())
builder.SetInsertPoint(&*current->getFirstNonPHI());
if(auto *usr = dynamic_cast<ir::user*>(src))
usr->accept(gen);
if(dynamic_cast<ir::make_range*>(src)){
distributed_tile *T = (distributed_tile *)tmap_.at(src);
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
});
}
else if(dynamic_cast<ir::make_range_sta*>(src)){
distributed_tile *T = (distributed_tile *)tmap_.at(src);
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *res = bin_add->getOperand(1);
assert(isa<Constant>(res));
T->set_value(idx, res);
});
}
else if(auto *cst = dynamic_cast<ir::constant*>(src)){
vmap_[cst] = llvm_constant(cst, builder.getContext());
}
else if(inst){
inst->accept(gen);
}
if(phi_inserted && current->getFirstNonPHI())
if(phi && !current->empty() && current->getFirstNonPHI())
builder.SetInsertPoint(current);
// if(dynamic_cast<ir::phi_node*>(src))
// for(ir::value *op: inst->ops())
// lower_value(op, builder, seen);
}
/* ----------------------------
@@ -702,12 +501,6 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
}
}
ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) {
unsigned size = 1;
for(auto shape: ty->get_tile_shapes())
size *= shape;
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
}
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
LLVMContext &ctx = builder.getContext();
@@ -777,6 +570,9 @@ void selection::run(ir::module &src, Module &dst) {
for(ir::alloc_const *x: src.allocs())
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
// allocate shared memory
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
// iterate over functions
std::set<ir::value*> seen;
@@ -785,14 +581,13 @@ void selection::run(ir::module &src, Module &dst) {
// create LLVM function
Function *ffn = llvm_fn(fn, dst_builder, dst);
// allocate shared memory
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
// create tile
generator gen(&dst_ctx, ffn, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
// initialize layouts
init_layouts(fn, dst_builder, sh_mem_ptr_);
generator gen(&dst_ctx, ffn, &dst_builder, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
for(auto x: layouts_->get_all())
x.second->accept(&gen);
// generate LLVM-IR code
std::map<ir::basic_block*, BasicBlock*> last_block;
@@ -1536,6 +1331,179 @@ Type *generator::type(ir::type *ty) {
throw std::runtime_error("unknown conversion from ir::type to Type");
}
void generator::visit_undef_value(ir::undef_value *ud) {
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type()));
}
void generator::visit_constant_int(ir::constant_int *cst){
Type *ty = type(cst->get_type()->get_scalar_ty());
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
}
void generator::visit_constant_fp(ir::constant_fp *cst){
Type *ty = type(cst->get_type()->get_scalar_ty());
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
}
void generator::visit_alloc_const(ir::alloc_const *alloc) {
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty());
Type *array_ty = llvm::ArrayType::get(element_ty, size);
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
vmap_[alloc] = builder_->CreateBitCast(array, element_ty->getPointerTo(4));
}
void generator::visit_function(ir::function*) {
}
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shapes = layout->shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->order;
const auto& shapes = layout->shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout->nts;
std::vector<int> mts = layout->mts;
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts[k]);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void generator::visit_layout_shared(analysis::layout_shared_t*) {
}
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
if(!x->get_type()->is_tile_ty())
return fn({});