[codegen] [selection] machine layouts now create machine tiles

This commit is contained in:
Philippe Tillet
2019-10-15 16:12:08 -04:00
parent 3d5ab4bc0d
commit 1b5b76b629
4 changed files with 169 additions and 140 deletions

View File

@@ -53,6 +53,7 @@ struct layout_t {
const std::vector<int>& _axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &_values,
ir::type *_ty,
size_t _id,
analysis::align* align);
@@ -79,6 +80,7 @@ struct layout_hmma_884_t: public layout_t {
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &_values,
ir::type *_ty,
size_t _id,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
@@ -89,6 +91,7 @@ struct layout_scanline_t: public layout_t {
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
ir::type *_ty,
size_t _id,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }

View File

@@ -147,44 +147,54 @@ private:
};
class machine_layout_t {
virtual tile* create(ir::value *v) = 0;
};
class machine_layout_shared_t: public machine_layout_t {
public:
shared_tile* create(ir::value *v);
};
class machine_layout_hmma_884_t: public machine_layout_t {
class machine_layout_distributed_t: public machine_layout_t {
public:
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t* layout);
distributed_tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
Type *ty_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::layout_t* layout_;
};
class machine_layout_hmma_884_t: public machine_layout_distributed_t {
public:
machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, std::map<unsigned, distributed_axis>& axes,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k,
unsigned &pack_size_0, unsigned &pack_size_1,
unsigned &num_packs_0, unsigned &num_packs_1,
analysis::layout_hmma_884_t* layout);
Module *mod_;
Builder *builder_;
target *tgt_;
std::map<unsigned, distributed_axis>& axes_;
Value *&offset_a_i_, *&offset_a_k_;
Value *&offset_b_j_, *&offset_b_k_;
unsigned &pack_size_0_;
unsigned& pack_size_1_;
unsigned &num_packs_0_;
unsigned& num_packs_1_;
analysis::layout_hmma_884_t* layout_;
};
class machine_layout_scanline_t: public machine_layout_t {
class machine_layout_scanline_t: public machine_layout_distributed_t {
public:
machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, std::map<unsigned, distributed_axis>& axes,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_scanline_t* layout);
Module *mod_;
Builder *builder_;
target *tgt_;
std::map<unsigned, distributed_axis>& axes_;
analysis::layout_scanline_t* layout_;
};
class generator: public ir::visitor, public analysis::layout_visitor {
@@ -194,7 +204,6 @@ private:
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
Type *type(ir::type *ty);
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
@@ -203,6 +212,7 @@ public:
generator(LLVMContext *ctx,
Module *dst,
Builder *builder,
analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap,
@@ -216,12 +226,13 @@ public:
unsigned num_packs_0, unsigned num_packs_1,
unsigned pack_size_0, unsigned pack_size_1,
unsigned num_warps)
: ctx_(ctx), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
: ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
num_warps_(num_warps) { }
machine_layout_t *get_machine_layout(const analysis::layout_t *layout) { return machine_layouts_.at(layout); }
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
@@ -281,7 +292,8 @@ private:
Builder *builder_;
Module *mod_;
std::map<analysis::layout_t*, machine_layout_t*> machine_layouts_;
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
@@ -311,7 +323,6 @@ private:
// grid construction
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
void create_distributed_tile(ir::value *v, Builder &builder);
// lower scalar instruction
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);

View File

@@ -128,9 +128,9 @@ inline bool is_trans(ir::value *v) {
layout_t::layout_t(layout_type_t _type,
const std::vector<int> &_axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &_values,
const std::vector<ir::value *> &_values, ir::type *_ty,
size_t _id,
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id) {
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id), ty(_ty) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values)
@@ -152,8 +152,8 @@ inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, size_t _id,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _id, align) {
const std::vector<ir::value *> &values, ir::type *_ty, size_t _id,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) {
unsigned shape_0 = shapes[order[0]];
unsigned shape_1 = shapes[order[1]];
@@ -194,9 +194,9 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
layout_scanline_t::layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
const std::vector<ir::value *> &values, ir::type *_ty,
size_t _id,
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _id, align){
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, _id, align){
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32;
nts.resize(shapes.size());
@@ -263,9 +263,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
const std::vector<ir::value *> &values,
ir::type *ty,
size_t _id,
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, _id, align) {
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, _id, align) {
this->ty = ty;
size = 0;
// double-buffering
@@ -333,7 +332,7 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, id, align_);
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
@@ -341,7 +340,7 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
}
else
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, id, align_);
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
}
void layout::run(ir::module &mod) {

View File

@@ -343,6 +343,42 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
throw std::runtime_error("unknown conversion from ir::type to Type");
}
Type *type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return type(t, ctx);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
/* -------------------
@@ -410,24 +446,6 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
}
}
void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder.getInt32(0)};
}
}
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v)->order, axes, builder, false);
tmap_.insert({v, T});
}
bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
@@ -454,7 +472,7 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
if(i && layouts_->get(i)->type == analysis::SHARED)
create_shared_tile(i, builder, sh_mem_ptr_);
else
create_distributed_tile(src, builder);
tmap_[src] = ((machine_layout_distributed_t*)gen->get_machine_layout(layouts_->get(src)))->create(src);
}
builder.SetInsertPoint(current);
@@ -521,7 +539,7 @@ void selection::run(ir::module &src, Module &dst) {
std::set<ir::value*> seen;
// create tile
generator gen(&dst_ctx, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
generator gen(&dst_ctx, &dst, &dst_builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
for(ir::alloc_const *x: src.allocs())
@@ -606,7 +624,7 @@ void selection::run(ir::module &src, Module &dst) {
void generator::visit_phi_node(ir::phi_node* phi) {
Type *ty = type(phi->get_type()->get_scalar_ty());
Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_);
unsigned num_ops = phi->get_num_operands();
for_each(phi, [&](indices_t idx){
set_value(phi, idx, builder_->Insert(PHINode::Create(ty, num_ops)));
@@ -628,7 +646,7 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) {
std::vector<Value*> idx_vals;
std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals),
[&](ir::value* x){ return get_value(x, idx);});
Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty());
Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_);
Value *ret = builder_->Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals));
set_value(gep, idx, ret);
});
@@ -657,7 +675,7 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) {
void generator::visit_cast_inst(ir::cast_inst* cast) {
for_each(cast, [&](indices_t idx){
Value *arg = get_value(cast->get_operand(0), idx);
Type *dst_ty = type(cast->get_type()->get_scalar_ty());
Type *dst_ty = type(cast->get_type()->get_scalar_ty(), *ctx_);
Value *ret = builder_->Insert(CastInst::Create(llvm_op(cast->get_op()), arg, dst_ty));
set_value(cast, idx, ret);
});
@@ -1102,7 +1120,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
ir::value *D = dot->get_operand(2);
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
Type *c_ty = type(D->get_type()->get_scalar_ty());
Type *c_ty = type(D->get_type()->get_scalar_ty(), *ctx_);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = 1;
@@ -1228,60 +1246,25 @@ void generator::visit_make_range(ir::make_range* x) {
});
}
Type *generator::type(ir::type *ty) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = type(tt->get_return_ty());
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[this](ir::type* t){ return type(t);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = type(ty->get_pointer_element_ty());
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(*ctx_, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
void generator::visit_undef_value(ir::undef_value *ud) {
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type()));
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type(), *ctx_));
}
void generator::visit_constant_int(ir::constant_int *cst){
Type *ty = type(cst->get_type()->get_scalar_ty());
Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_);
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
}
void generator::visit_constant_fp(ir::constant_fp *cst){
Type *ty = type(cst->get_type()->get_scalar_ty());
Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_);
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
}
void generator::visit_alloc_const(ir::alloc_const *alloc) {
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty());
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty(), *ctx_);
Type *array_ty = llvm::ArrayType::get(element_ty, size);
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
@@ -1291,7 +1274,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
void generator::visit_function(ir::function* fn) {
LLVMContext &ctx = builder_->getContext();
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type());
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_);
FunctionType *dst_fn_ty = fn_ty;
if(!tgt_->is_gpu()){
Type *dst_fn_ret_ty = fn_ty->getReturnType();
@@ -1331,16 +1314,86 @@ void generator::visit_function(ir::function* fn) {
fn_ = ret;
}
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_,
pack_size_0_, pack_size_1_,
num_packs_0_, num_packs_1_,
layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
}
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
machine_layouts_[layout] = new machine_layout_shared_t();
}
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
if(!x->get_type()->is_tile_ty())
return fn({});
else {
if(auto *dt = dynamic_cast<distributed_tile*>(tmap_.at(x)))
dt->for_each(fn);
}
}
Value* generator::get_value(ir::value *x, const indices_t& idx) {
if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
return vmap_.at(x);
}
void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
if(x->get_type()->is_tile_ty())
tmap_.at(x)->set_value(idx, v);
else
vmap_[x] = v;
}
shared_tile* machine_layout_shared_t::create(ir::value *v) {
}
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t *layout)
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
distributed_tile* machine_layout_distributed_t::create(ir::value *v) {
Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false);
}
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, std::map<unsigned, distributed_axis>& axes,
target *tgt, Type *ty, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k,
unsigned &pack_size_0, unsigned &pack_size_1,
unsigned &num_packs_0, unsigned &num_packs_1,
analysis::layout_hmma_884_t* layout)
: mod_(mod), builder_(builder), tgt_(tgt), axes_(axes),
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout),
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1),
layout_(layout) {
pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
@@ -1454,10 +1507,11 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, std::map<unsigned, distributed_axis> &axes,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::layout_scanline_t* layout)
: mod_(mod), builder_(builder), tgt_(tgt), axes_(axes), layout_(layout)
{
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
@@ -1486,44 +1540,6 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build
}
}
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, axes_, offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_,
pack_size_0_, pack_size_1_,
num_packs_0_, num_packs_1_,
layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, axes_, layout);
}
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
machine_layouts_[layout] = new machine_layout_shared_t();
}
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
if(!x->get_type()->is_tile_ty())
return fn({});
else {
if(auto *dt = dynamic_cast<distributed_tile*>(tmap_.at(x)))
dt->for_each(fn);
}
}
Value* generator::get_value(ir::value *x, const indices_t& idx) {
if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
return vmap_.at(x);
}
void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
if(x->get_type()->is_tile_ty())
tmap_.at(x)->set_value(idx, v);
else
vmap_[x] = v;
}
}