[codegen] [selection] machine layouts now create machine tiles
This commit is contained in:
@@ -53,6 +53,7 @@ struct layout_t {
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned> &_shapes,
|
||||
const std::vector<ir::value *> &_values,
|
||||
ir::type *_ty,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
|
||||
@@ -79,6 +80,7 @@ struct layout_hmma_884_t: public layout_t {
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &_values,
|
||||
ir::type *_ty,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
||||
@@ -89,6 +91,7 @@ struct layout_scanline_t: public layout_t {
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *_ty,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
|
@@ -147,44 +147,54 @@ private:
|
||||
};
|
||||
|
||||
class machine_layout_t {
|
||||
|
||||
virtual tile* create(ir::value *v) = 0;
|
||||
};
|
||||
|
||||
class machine_layout_shared_t: public machine_layout_t {
|
||||
|
||||
public:
|
||||
shared_tile* create(ir::value *v);
|
||||
};
|
||||
|
||||
class machine_layout_hmma_884_t: public machine_layout_t {
|
||||
class machine_layout_distributed_t: public machine_layout_t {
|
||||
public:
|
||||
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_t* layout);
|
||||
|
||||
distributed_tile* create(ir::value *v);
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
Type *ty_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
analysis::layout_t* layout_;
|
||||
};
|
||||
|
||||
|
||||
class machine_layout_hmma_884_t: public machine_layout_distributed_t {
|
||||
public:
|
||||
machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
target *tgt, std::map<unsigned, distributed_axis>& axes,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k,
|
||||
unsigned &pack_size_0, unsigned &pack_size_1,
|
||||
unsigned &num_packs_0, unsigned &num_packs_1,
|
||||
analysis::layout_hmma_884_t* layout);
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
Value *&offset_a_i_, *&offset_a_k_;
|
||||
Value *&offset_b_j_, *&offset_b_k_;
|
||||
unsigned &pack_size_0_;
|
||||
unsigned& pack_size_1_;
|
||||
unsigned &num_packs_0_;
|
||||
unsigned& num_packs_1_;
|
||||
analysis::layout_hmma_884_t* layout_;
|
||||
};
|
||||
|
||||
class machine_layout_scanline_t: public machine_layout_t {
|
||||
class machine_layout_scanline_t: public machine_layout_distributed_t {
|
||||
public:
|
||||
machine_layout_scanline_t(Module *mod, Builder *builder,
|
||||
target *tgt, std::map<unsigned, distributed_axis>& axes,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_scanline_t* layout);
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
analysis::layout_scanline_t* layout_;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
@@ -194,7 +204,6 @@ private:
|
||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
|
||||
Type *type(ir::type *ty);
|
||||
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
|
||||
Value* get_value(ir::value *x, const indices_t& idx);
|
||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
||||
@@ -203,6 +212,7 @@ public:
|
||||
generator(LLVMContext *ctx,
|
||||
Module *dst,
|
||||
Builder *builder,
|
||||
analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap,
|
||||
@@ -216,12 +226,13 @@ public:
|
||||
unsigned num_packs_0, unsigned num_packs_1,
|
||||
unsigned pack_size_0, unsigned pack_size_1,
|
||||
unsigned num_warps)
|
||||
: ctx_(ctx), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
: ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
|
||||
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
|
||||
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
|
||||
num_warps_(num_warps) { }
|
||||
|
||||
machine_layout_t *get_machine_layout(const analysis::layout_t *layout) { return machine_layouts_.at(layout); }
|
||||
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
@@ -281,7 +292,8 @@ private:
|
||||
Builder *builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<analysis::layout_t*, machine_layout_t*> machine_layouts_;
|
||||
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
@@ -311,7 +323,6 @@ private:
|
||||
|
||||
// grid construction
|
||||
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
|
||||
void create_distributed_tile(ir::value *v, Builder &builder);
|
||||
|
||||
// lower scalar instruction
|
||||
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);
|
||||
|
@@ -128,9 +128,9 @@ inline bool is_trans(ir::value *v) {
|
||||
layout_t::layout_t(layout_type_t _type,
|
||||
const std::vector<int> &_axes,
|
||||
const std::vector<unsigned> &_shapes,
|
||||
const std::vector<ir::value *> &_values,
|
||||
const std::vector<ir::value *> &_values, ir::type *_ty,
|
||||
size_t _id,
|
||||
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id) {
|
||||
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id), ty(_ty) {
|
||||
// io pointer
|
||||
std::set<ir::value*> ptr;
|
||||
for(ir::value* v: values)
|
||||
@@ -152,8 +152,8 @@ inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values, size_t _id,
|
||||
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _id, align) {
|
||||
const std::vector<ir::value *> &values, ir::type *_ty, size_t _id,
|
||||
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) {
|
||||
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
@@ -194,9 +194,9 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
|
||||
layout_scanline_t::layout_scanline_t(size_t num_warps,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
const std::vector<ir::value *> &values, ir::type *_ty,
|
||||
size_t _id,
|
||||
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _id, align){
|
||||
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, _id, align){
|
||||
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = num_warps * 32;
|
||||
nts.resize(shapes.size());
|
||||
@@ -263,9 +263,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
size_t _id,
|
||||
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, _id, align) {
|
||||
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, _id, align) {
|
||||
|
||||
this->ty = ty;
|
||||
size = 0;
|
||||
|
||||
// double-buffering
|
||||
@@ -333,7 +332,7 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end())
|
||||
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, id, align_);
|
||||
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
|
||||
else if(it_cts != values.end()){
|
||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
@@ -341,7 +340,7 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
|
||||
}
|
||||
else
|
||||
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, id, align_);
|
||||
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
|
||||
}
|
||||
|
||||
void layout::run(ir::module &mod) {
|
||||
|
@@ -343,6 +343,42 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
Type *type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
|
||||
/* -------------------
|
||||
@@ -410,24 +446,6 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
||||
}
|
||||
}
|
||||
|
||||
void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder.getInt32(0)};
|
||||
}
|
||||
}
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v)->order, axes, builder, false);
|
||||
tmap_.insert({v, T});
|
||||
}
|
||||
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
@@ -454,7 +472,7 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED)
|
||||
create_shared_tile(i, builder, sh_mem_ptr_);
|
||||
else
|
||||
create_distributed_tile(src, builder);
|
||||
tmap_[src] = ((machine_layout_distributed_t*)gen->get_machine_layout(layouts_->get(src)))->create(src);
|
||||
}
|
||||
builder.SetInsertPoint(current);
|
||||
|
||||
@@ -521,7 +539,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
std::set<ir::value*> seen;
|
||||
|
||||
// create tile
|
||||
generator gen(&dst_ctx, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
generator gen(&dst_ctx, &dst, &dst_builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
@@ -606,7 +624,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
|
||||
|
||||
void generator::visit_phi_node(ir::phi_node* phi) {
|
||||
Type *ty = type(phi->get_type()->get_scalar_ty());
|
||||
Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_);
|
||||
unsigned num_ops = phi->get_num_operands();
|
||||
for_each(phi, [&](indices_t idx){
|
||||
set_value(phi, idx, builder_->Insert(PHINode::Create(ty, num_ops)));
|
||||
@@ -628,7 +646,7 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) {
|
||||
std::vector<Value*> idx_vals;
|
||||
std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals),
|
||||
[&](ir::value* x){ return get_value(x, idx);});
|
||||
Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty());
|
||||
Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_);
|
||||
Value *ret = builder_->Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals));
|
||||
set_value(gep, idx, ret);
|
||||
});
|
||||
@@ -657,7 +675,7 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) {
|
||||
void generator::visit_cast_inst(ir::cast_inst* cast) {
|
||||
for_each(cast, [&](indices_t idx){
|
||||
Value *arg = get_value(cast->get_operand(0), idx);
|
||||
Type *dst_ty = type(cast->get_type()->get_scalar_ty());
|
||||
Type *dst_ty = type(cast->get_type()->get_scalar_ty(), *ctx_);
|
||||
Value *ret = builder_->Insert(CastInst::Create(llvm_op(cast->get_op()), arg, dst_ty));
|
||||
set_value(cast, idx, ret);
|
||||
});
|
||||
@@ -1102,7 +1120,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
ir::value *D = dot->get_operand(2);
|
||||
|
||||
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
|
||||
Type *c_ty = type(D->get_type()->get_scalar_ty());
|
||||
Type *c_ty = type(D->get_type()->get_scalar_ty(), *ctx_);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = 1;
|
||||
@@ -1228,60 +1246,25 @@ void generator::visit_make_range(ir::make_range* x) {
|
||||
});
|
||||
}
|
||||
|
||||
Type *generator::type(ir::type *ty) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = type(tt->get_return_ty());
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[this](ir::type* t){ return type(t);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = type(ty->get_pointer_element_ty());
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(*ctx_, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_undef_value(ir::undef_value *ud) {
|
||||
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type()));
|
||||
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type(), *ctx_));
|
||||
}
|
||||
|
||||
void generator::visit_constant_int(ir::constant_int *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty());
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_);
|
||||
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_constant_fp(ir::constant_fp *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty());
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_);
|
||||
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
|
||||
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty());
|
||||
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty(), *ctx_);
|
||||
Type *array_ty = llvm::ArrayType::get(element_ty, size);
|
||||
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
|
||||
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
@@ -1291,7 +1274,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type());
|
||||
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_);
|
||||
FunctionType *dst_fn_ty = fn_ty;
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||
@@ -1331,16 +1314,86 @@ void generator::visit_function(ir::function* fn) {
|
||||
fn_ = ret;
|
||||
}
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_,
|
||||
pack_size_0_, pack_size_1_,
|
||||
num_packs_0_, num_packs_1_,
|
||||
layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_shared_t();
|
||||
}
|
||||
|
||||
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
|
||||
if(!x->get_type()->is_tile_ty())
|
||||
return fn({});
|
||||
else {
|
||||
if(auto *dt = dynamic_cast<distributed_tile*>(tmap_.at(x)))
|
||||
dt->for_each(fn);
|
||||
}
|
||||
}
|
||||
|
||||
Value* generator::get_value(ir::value *x, const indices_t& idx) {
|
||||
if(x->get_type()->is_tile_ty())
|
||||
return tmap_.at(x)->get_value(idx);
|
||||
return vmap_.at(x);
|
||||
}
|
||||
|
||||
void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
|
||||
if(x->get_type()->is_tile_ty())
|
||||
tmap_.at(x)->set_value(idx, v);
|
||||
else
|
||||
vmap_[x] = v;
|
||||
}
|
||||
|
||||
|
||||
|
||||
shared_tile* machine_layout_shared_t::create(ir::value *v) {
|
||||
|
||||
}
|
||||
|
||||
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_t *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
distributed_tile* machine_layout_distributed_t::create(ir::value *v) {
|
||||
Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false);
|
||||
}
|
||||
|
||||
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
target *tgt, std::map<unsigned, distributed_axis>& axes,
|
||||
target *tgt, Type *ty, analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k,
|
||||
unsigned &pack_size_0, unsigned &pack_size_1,
|
||||
unsigned &num_packs_0, unsigned &num_packs_1,
|
||||
analysis::layout_hmma_884_t* layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), axes_(axes),
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout),
|
||||
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
|
||||
pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1),
|
||||
layout_(layout) {
|
||||
pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
@@ -1454,10 +1507,11 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
|
||||
|
||||
|
||||
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
|
||||
target *tgt, std::map<unsigned, distributed_axis> &axes,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::layout_scanline_t* layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), axes_(axes), layout_(layout)
|
||||
{
|
||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
@@ -1486,44 +1540,6 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, axes_, offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_,
|
||||
pack_size_0_, pack_size_1_,
|
||||
num_packs_0_, num_packs_1_,
|
||||
layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_shared_t();
|
||||
}
|
||||
|
||||
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
|
||||
if(!x->get_type()->is_tile_ty())
|
||||
return fn({});
|
||||
else {
|
||||
if(auto *dt = dynamic_cast<distributed_tile*>(tmap_.at(x)))
|
||||
dt->for_each(fn);
|
||||
}
|
||||
}
|
||||
|
||||
Value* generator::get_value(ir::value *x, const indices_t& idx) {
|
||||
if(x->get_type()->is_tile_ty())
|
||||
return tmap_.at(x)->get_value(idx);
|
||||
return vmap_.at(x);
|
||||
}
|
||||
|
||||
void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
|
||||
if(x->get_type()->is_tile_ty())
|
||||
tmap_.at(x)->set_value(idx, v);
|
||||
else
|
||||
vmap_[x] = v;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user