[intermediate representation] transitioning towards more flexible tile
shapes
This commit is contained in:
@@ -44,6 +44,7 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
|
||||
return VectorType::get(ty, vector_size);
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
||||
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
|
||||
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
||||
@@ -149,6 +150,16 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
return builder_.CreateLoad(ptr);
|
||||
}
|
||||
|
||||
/* Utils */
|
||||
std::vector<unsigned> selection::extract_shapes(ir::value *v) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> result(shapes.size());
|
||||
for(ir::constant_int* cst: shapes)
|
||||
result.push_back(cst->get_value());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/* convert ir::type to Type */
|
||||
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
@@ -299,11 +310,12 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
std::cout << v->get_name() << " " << typeid(*v).name() << std::endl;
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
|
||||
@@ -336,7 +348,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
||||
for(unsigned shape: extract_shapes(v)) {
|
||||
result += (shape > 1)?shape:0;
|
||||
}
|
||||
return result;
|
||||
@@ -353,7 +365,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
for(ir::value *op: user->ops())
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
||||
return;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
@@ -385,7 +397,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, references, seen, sh_mem_ptr);
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||
@@ -429,7 +441,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
const auto &shapes = extract_shapes(v);
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
@@ -530,7 +542,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
distributed_tile* result = (distributed_tile*)ti;
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
return;
|
||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(ins);
|
||||
// global_range
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||
static std::array<Intrinsic::ID, 3> ctaid = {
|
||||
@@ -568,7 +580,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
||||
const auto& in_shapes = extract_shapes(in);
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx = out_idx;
|
||||
@@ -615,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = tmap_.at(C)->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
||||
unsigned NK = extract_shapes(A)[1];
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
|
Reference in New Issue
Block a user