more cleaning

This commit is contained in:
Philippe Tillet
2019-10-13 14:43:17 -04:00
parent e787ce0cab
commit ee387ff567
11 changed files with 277 additions and 300 deletions

View File

@@ -35,6 +35,18 @@ struct double_buffer_info_t {
ir::phi_node* phi; ir::phi_node* phi;
}; };
class layout_visitor;
class layout_hmma_884_t;
class layout_scanline_t;
class layout_shared_t;
class layout_visitor {
public:
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0;
virtual void visit_layout_scanline(layout_scanline_t*) = 0;
virtual void visit_layout_shared(layout_shared_t*) = 0;
};
struct layout_t { struct layout_t {
layout_t(layout_type_t _type, layout_t(layout_type_t _type,
@@ -43,6 +55,9 @@ struct layout_t {
const std::vector<ir::value *> &_values, const std::vector<ir::value *> &_values,
size_t _id, size_t _id,
analysis::align* align); analysis::align* align);
virtual void accept(layout_visitor* vst) = 0;
layout_type_t type; layout_type_t type;
std::vector<int> axes; std::vector<int> axes;
std::vector<unsigned> shapes; std::vector<unsigned> shapes;
@@ -66,6 +81,7 @@ struct layout_hmma_884_t: public layout_t {
const std::vector<ir::value *> &_values, const std::vector<ir::value *> &_values,
size_t _id, size_t _id,
analysis::align* align); analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
}; };
struct layout_scanline_t: public layout_t { struct layout_scanline_t: public layout_t {
@@ -75,6 +91,7 @@ struct layout_scanline_t: public layout_t {
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
size_t _id, size_t _id,
analysis::align* align); analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
}; };
struct layout_shared_t: public layout_t { struct layout_shared_t: public layout_t {
@@ -85,9 +102,11 @@ struct layout_shared_t: public layout_t {
ir::type *ty, ir::type *ty,
size_t _id, size_t _id,
analysis::align* align); analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
}; };
class layout { class layout {
typedef ir::value* node_t; typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t; typedef std::map <node_t, std::set<node_t>> graph_t;

View File

@@ -147,7 +147,7 @@ private:
}; };
class generator: public ir::visitor { class generator: public ir::visitor, public analysis::layout_visitor {
private: private:
void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
@@ -163,7 +163,9 @@ public:
generator(LLVMContext *ctx, generator(LLVMContext *ctx,
Function *fn, Function *fn,
Module *dst,
Builder *builder, Builder *builder,
std::map<unsigned, distributed_axis>& axes,
std::map<ir::value *, Value *>& vmap, std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap, std::map<ir::value *, tile *>& tmap,
target *tgt, target *tgt,
@@ -176,7 +178,7 @@ public:
unsigned num_packs_0, unsigned num_packs_1, unsigned num_packs_0, unsigned num_packs_1,
unsigned pack_size_0, unsigned pack_size_1, unsigned pack_size_0, unsigned pack_size_1,
unsigned num_warps) unsigned num_warps)
: ctx_(ctx), fn_(fn), builder_(builder), vmap_(vmap), tmap_(tmap), tgt_(tgt), : ctx_(ctx), fn_(fn), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k), offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
@@ -221,14 +223,27 @@ public:
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*); void visit_barrier_inst(ir::barrier_inst*);
void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range_sta(ir::make_range_sta*);
void visit_make_range(ir::make_range*); void visit_make_range(ir::make_range*);
void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
void visit_layout_scanline(analysis::layout_scanline_t*);
void visit_layout_shared(analysis::layout_shared_t*);
private: private:
LLVMContext *ctx_; LLVMContext *ctx_;
Function *fn_; Function *fn_;
Builder *builder_; Builder *builder_;
Module *mod_;
std::map<unsigned, distributed_axis>& axes_;
std::map<ir::value *, Value *>& vmap_; std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_; std::map<ir::value *, tile *>& tmap_;
target *tgt_; target *tgt_;
@@ -249,29 +264,15 @@ class selection{
typedef std::map<ir::value *, tile *> tmap_t; typedef std::map<ir::value *, tile *> tmap_t;
private: private:
// utils
Type *make_vector_ty(Type *ty, size_t vector_size);
std::vector<unsigned> extract_shapes(ir::value *v);
// LLVM conversions // LLVM conversions
Type* llvm_type(ir::type *ty, LLVMContext &ctx); Type* llvm_type(ir::type *ty, LLVMContext &ctx);
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder); Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst); Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst);
Value* alloc_shared(Builder &builder, Module& dst); Value* alloc_shared(Builder &builder, Module& dst);
// grid construction // grid construction
void create_grids(std::vector<ir::value *> &grids,
std::map<unsigned, ir::value *> &references,
ir::function *fn);
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
void create_distributed_tile(ir::value *v, Builder &builder); void create_distributed_tile(ir::value *v, Builder &builder);
void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr);
void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
// lower scalar instruction // lower scalar instruction
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen); void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);

View File

@@ -6,6 +6,7 @@
#include "enums.h" #include "enums.h"
#include "value.h" #include "value.h"
#include <cassert> #include <cassert>
#include "visitor.h"
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -32,6 +33,7 @@ private:
public: public:
static undef_value* get(type* ty); static undef_value* get(type* ty);
std::string repr() const { return "undef"; } std::string repr() const { return "undef"; }
void accept(visitor* vst) { vst->visit_undef_value(this); }
}; };
@@ -44,31 +46,13 @@ public:
virtual uint64_t get_value() const { return value_; } virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value); static constant_int *get(type *ty, uint64_t value);
std::string repr() const { return std::to_string(value_); } std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_int(this); }
protected: protected:
uint64_t value_; uint64_t value_;
}; };
/* Metaparameter (int) */ /* Constant fp */
class metaparameter: public constant_int {
private:
metaparameter(type *ty, const std::vector<unsigned>& space);
public:
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
static metaparameter *create(context &ctx, type *ty, const std::vector<unsigned>& space);
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
bool has_value() { return has_value_; }
const std::vector<unsigned>& get_space() { return space_; }
void set_space(const std::vector<unsigned> &space) { space_ = space; }
uint64_t get_value() const { assert(has_value_); return value_; }
std::string repr() const { return has_value_? std::to_string(value_) : "?" ;}
private:
std::vector<unsigned> space_;
bool has_value_;
};
/* constant fp */
class constant_fp: public constant{ class constant_fp: public constant{
constant_fp(type *ty, double value); constant_fp(type *ty, double value);
@@ -79,13 +63,14 @@ public:
static constant* get(context &ctx, double v); static constant* get(context &ctx, double v);
static constant* get(type *ty, double v); static constant* get(type *ty, double v);
std::string repr() const { return std::to_string(value_); } std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_fp(this); }
private: private:
double value_; double value_;
}; };
/* global value */ /* Global Value */
class global_value: public constant { class global_value: public constant {
public: public:
enum linkage_types_t { enum linkage_types_t {
@@ -109,7 +94,6 @@ public:
linkage_types_t linkage, const std::string &name, linkage_types_t linkage, const std::string &name,
unsigned addr_space = 0); unsigned addr_space = 0);
std::string repr() const { return get_name(); } std::string repr() const { return get_name(); }
}; };
/* global variable */ /* global variable */
@@ -118,6 +102,8 @@ public:
alloc_const(type *ty, constant_int *size, alloc_const(type *ty, constant_int *size,
const std::string &name = ""); const std::string &name = "");
std::string repr() const { return get_name(); } std::string repr() const { return get_name(); }
void accept(visitor* vst) { vst->visit_alloc_const(this); }
}; };

View File

@@ -14,7 +14,6 @@ class constant;
class constant_int; class constant_int;
class constant_fp; class constant_fp;
class undef_value; class undef_value;
class metaparameter;
/* Context impl */ /* Context impl */
class context_impl { class context_impl {
@@ -36,8 +35,6 @@ public:
std::map<std::pair<type*, double>, constant_fp*> fp_constants_; std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
// undef values // undef values
std::map<type*, undef_value*> uv_constants_; std::map<type*, undef_value*> uv_constants_;
// Metaparameters
std::vector<metaparameter*> mp_constants_;
}; };
} }

View File

@@ -112,6 +112,9 @@ public:
const attr_map_t &attrs() { return attrs_; } const attr_map_t &attrs() { return attrs_; }
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; } std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
// visitor
void accept(visitor *v) { v->visit_function(this); }
private: private:
module *parent_; module *parent_;
bool init_; bool init_;

View File

@@ -71,8 +71,6 @@ public:
} }
// instruction id // instruction id
value_id_t get_id() const { return id_; } value_id_t get_id() const { return id_; }
// visit
virtual void accept(visitor *v) = 0;
private: private:
basic_block *parent_; basic_block *parent_;
@@ -759,6 +757,7 @@ public:
static make_range_sta *get(make_range* range); static make_range_sta *get(make_range* range);
make_range* get_range() const; make_range* get_range() const;
std::string repr() const { return "nv_static_program_idx"; } std::string repr() const { return "nv_static_program_idx"; }
_TRITON_DEFINE_ACCEPT(make_range_sta)
private: private:
make_range *range_; make_range *range_;

View File

@@ -13,6 +13,7 @@ namespace ir{
class type; class type;
class use; class use;
class user; class user;
class visitor;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// value class // value class
@@ -74,6 +75,9 @@ public:
void replace_all_uses_with(value *target); void replace_all_uses_with(value *target);
void replace_uses_of_with(value *before, value *after); void replace_uses_of_with(value *before, value *after);
// Visitor
virtual void accept(visitor *v) = 0;
private: private:
ops_t ops_; ops_t ops_;
unsigned num_ops_; unsigned num_ops_;

View File

@@ -61,10 +61,25 @@ class copy_to_shared_inst;
class copy_from_shared_inst; class copy_from_shared_inst;
class barrier_inst; class barrier_inst;
class make_range_dyn; class make_range_dyn;
class make_range_sta;
class make_range; class make_range;
class make_range_sta;
class undef_value;
class constant_int;
class constant_fp;
class global_value;
class global_object;
class alloc_const;
class constant_fp;
class undef_value;
class constant_int;
class constant_fp;
class global_value;
class global_object;
class alloc_const;
class function;
class visitor { class visitor {
public: public:
@@ -108,8 +123,15 @@ public:
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_make_range_dyn(make_range_dyn*) = 0; virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range_sta(make_range_sta*) = 0;
virtual void visit_make_range(make_range*) = 0; virtual void visit_make_range(make_range*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_make_range_sta(make_range_sta*) = 0;
virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0;
virtual void visit_alloc_const(alloc_const*) = 0;
}; };
} }

View File

@@ -43,7 +43,6 @@ namespace ir {
class module; class module;
class function; class function;
class context; class context;
class metaparameter;
} }
namespace runtime{ namespace runtime{

View File

@@ -343,16 +343,6 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
throw std::runtime_error("unknown conversion from ir::type to Type"); throw std::runtime_error("unknown conversion from ir::type to Type");
} }
/* convert ir::constant to Constant */
Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx);
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
return ConstantInt::get(dst_ty, cc->get_value());
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
return ConstantFP::get(dst_ty, cc->get_value());
// unknown constant
throw std::runtime_error("unknown conversion from ir::constant to Constant");
}
/* convert ir::alloc_const to llvm::GlobalVariable */ /* convert ir::alloc_const to llvm::GlobalVariable */
Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) { Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) {
@@ -387,145 +377,6 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div; return (num + div - 1)/div;
} }
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = layout.order;
const auto& shapes = layout.shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout.nts;
std::vector<int> mts = layout.mts;
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder.getInt32(nts[k]);
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = layout.shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
Value *_4 = builder.getInt32(4);
Value *_16 = builder.getInt32(16);
// fragments per warp
unsigned fpw_0 = layout.fpw.at(0);
unsigned fpw_1 = layout.fpw.at(1);
unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1;
// warps per tile
unsigned wpt_0 = layout.wpt.at(0);
unsigned wpt_1 = layout.wpt.at(1);
unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
builder.getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
builder.getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(layout.type == analysis::HMMA_884)
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
else if(layout.type == analysis::SCANLINE)
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
}
/* ------------------- /* -------------------
* ---- Init Tiles ---- * ---- Init Tiles ----
* ------------------- */ * ------------------- */
@@ -549,7 +400,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
if(parent->empty()) if(parent->empty())
builder.SetInsertPoint(parent); builder.SetInsertPoint(parent);
else else
builder.SetInsertPoint(&*parent->getFirstInsertionPt()); builder.SetInsertPoint(&*parent->getFirstNonPHI());
// create double-buffered pointer // create double-buffered pointer
PHINode *ptr = builder.CreatePHI(ptr_ty, 2); PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
@@ -587,41 +438,6 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
tmap_.insert({v, T}); tmap_.insert({v, T});
} }
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops())
create_tile(op, builder, seen, sh_mem_ptr);
auto *i = dynamic_cast<ir::instruction*>(v);
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(v))
create_shared_tile(i, builder, sh_mem_ptr);
else
create_distributed_tile(v, builder);
}
void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
// fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
Value *warp_size = builder.getInt32(32);
Value* u_thread_id = tgt_->get_local_id(mod, builder, 0);
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
for(auto x: layouts_->get_all())
init_axes(*x.second, builder, u_thread_warp_id, u_warp_id);
// create tile
std::set<ir::value*> seen;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
create_tile(i, builder, seen, sh_mem_ptr);
}
}
bool is_trans(ir::value *v) { bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) { if(dynamic_cast<ir::trans_inst *>(v)) {
@@ -641,51 +457,34 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
if(!seen.insert(src).second) if(!seen.insert(src).second)
return; return;
BasicBlock *current = builder.GetInsertBlock();
if(src->get_type()->is_tile_ty()){
builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin());
auto *i = dynamic_cast<ir::instruction*>(src);
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(src)){
create_shared_tile(i, builder, sh_mem_ptr_);
}
else
create_distributed_tile(src, builder);
}
builder.SetInsertPoint(current);
auto *inst = dynamic_cast<ir::instruction*>(src); auto *inst = dynamic_cast<ir::instruction*>(src);
if(inst && !dynamic_cast<ir::phi_node*>(src)) if(inst && !dynamic_cast<ir::phi_node*>(src))
for(ir::value *op: inst->ops()) for(ir::value *op: inst->ops())
lower_value(op, builder, gen, seen); lower_value(op, builder, gen, seen);
BasicBlock *current = builder.GetInsertBlock(); builder.SetInsertPoint(current);
auto *phi = dynamic_cast<ir::phi_node*>(src); auto *phi = dynamic_cast<ir::phi_node*>(src);
bool phi_inserted = phi && !current->empty(); if(phi && !current->empty() && current->getFirstNonPHI())
if(phi_inserted && current->getFirstNonPHI())
builder.SetInsertPoint(&*current->getFirstNonPHI()); builder.SetInsertPoint(&*current->getFirstNonPHI());
if(auto *usr = dynamic_cast<ir::user*>(src))
usr->accept(gen);
if(dynamic_cast<ir::make_range*>(src)){ if(phi && !current->empty() && current->getFirstNonPHI())
distributed_tile *T = (distributed_tile *)tmap_.at(src);
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
});
}
else if(dynamic_cast<ir::make_range_sta*>(src)){
distributed_tile *T = (distributed_tile *)tmap_.at(src);
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *res = bin_add->getOperand(1);
assert(isa<Constant>(res));
T->set_value(idx, res);
});
}
else if(auto *cst = dynamic_cast<ir::constant*>(src)){
vmap_[cst] = llvm_constant(cst, builder.getContext());
}
else if(inst){
inst->accept(gen);
}
if(phi_inserted && current->getFirstNonPHI())
builder.SetInsertPoint(current); builder.SetInsertPoint(current);
// if(dynamic_cast<ir::phi_node*>(src))
// for(ir::value *op: inst->ops())
// lower_value(op, builder, seen);
} }
/* ---------------------------- /* ----------------------------
@@ -702,12 +501,6 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
} }
} }
ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) {
unsigned size = 1;
for(auto shape: ty->get_tile_shapes())
size *= shape;
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
}
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) { Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
LLVMContext &ctx = builder.getContext(); LLVMContext &ctx = builder.getContext();
@@ -777,6 +570,9 @@ void selection::run(ir::module &src, Module &dst) {
for(ir::alloc_const *x: src.allocs()) for(ir::alloc_const *x: src.allocs())
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder); vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
// allocate shared memory
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
// iterate over functions // iterate over functions
std::set<ir::value*> seen; std::set<ir::value*> seen;
@@ -785,14 +581,13 @@ void selection::run(ir::module &src, Module &dst) {
// create LLVM function // create LLVM function
Function *ffn = llvm_fn(fn, dst_builder, dst); Function *ffn = llvm_fn(fn, dst_builder, dst);
// allocate shared memory // create tile
sh_mem_ptr_ = alloc_shared(dst_builder, dst); generator gen(&dst_ctx, ffn, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
// initialize layouts // initialize layouts
init_layouts(fn, dst_builder, sh_mem_ptr_); for(auto x: layouts_->get_all())
x.second->accept(&gen);
generator gen(&dst_ctx, ffn, &dst_builder, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
// generate LLVM-IR code // generate LLVM-IR code
std::map<ir::basic_block*, BasicBlock*> last_block; std::map<ir::basic_block*, BasicBlock*> last_block;
@@ -1536,6 +1331,179 @@ Type *generator::type(ir::type *ty) {
throw std::runtime_error("unknown conversion from ir::type to Type"); throw std::runtime_error("unknown conversion from ir::type to Type");
} }
void generator::visit_undef_value(ir::undef_value *ud) {
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type()));
}
void generator::visit_constant_int(ir::constant_int *cst){
Type *ty = type(cst->get_type()->get_scalar_ty());
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
}
void generator::visit_constant_fp(ir::constant_fp *cst){
Type *ty = type(cst->get_type()->get_scalar_ty());
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
}
void generator::visit_alloc_const(ir::alloc_const *alloc) {
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty());
Type *array_ty = llvm::ArrayType::get(element_ty, size);
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
vmap_[alloc] = builder_->CreateBitCast(array, element_ty->getPointerTo(4));
}
void generator::visit_function(ir::function*) {
}
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shapes = layout->shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->order;
const auto& shapes = layout->shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout->nts;
std::vector<int> mts = layout->mts;
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts[k]);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void generator::visit_layout_shared(analysis::layout_shared_t*) {
}
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) { void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
if(!x->get_type()->is_tile_ty()) if(!x->get_type()->is_tile_ty())
return fn({}); return fn({});

View File

@@ -76,27 +76,6 @@ constant *constant_fp::get(type *ty, double v){
return result; return result;
} }
// metaparameter
metaparameter::metaparameter(type *ty, const std::vector<unsigned> &space)
: constant_int(ty, 0), space_(space), has_value_(false){ }
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
context_impl *impl = ctx.p_impl.get();
std::vector<unsigned> space;
for(unsigned i = lo; i <= hi; i *= 2)
space.push_back(i);
metaparameter *result = new metaparameter(ty, space);
impl->mp_constants_.push_back(result);
return result;
}
metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<unsigned> &space) {
context_impl *impl = ctx.p_impl.get();
metaparameter *result = new metaparameter(ty, space);
impl->mp_constants_.push_back(result);
return result;
}
// undef value // undef value
undef_value::undef_value(type *ty) undef_value::undef_value(type *ty)