more cleaning
This commit is contained in:
@@ -35,6 +35,18 @@ struct double_buffer_info_t {
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
class layout_visitor;
|
||||
class layout_hmma_884_t;
|
||||
class layout_scanline_t;
|
||||
class layout_shared_t;
|
||||
|
||||
|
||||
class layout_visitor {
|
||||
public:
|
||||
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0;
|
||||
virtual void visit_layout_scanline(layout_scanline_t*) = 0;
|
||||
virtual void visit_layout_shared(layout_shared_t*) = 0;
|
||||
};
|
||||
|
||||
struct layout_t {
|
||||
layout_t(layout_type_t _type,
|
||||
@@ -43,6 +55,9 @@ struct layout_t {
|
||||
const std::vector<ir::value *> &_values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
|
||||
virtual void accept(layout_visitor* vst) = 0;
|
||||
|
||||
layout_type_t type;
|
||||
std::vector<int> axes;
|
||||
std::vector<unsigned> shapes;
|
||||
@@ -66,6 +81,7 @@ struct layout_hmma_884_t: public layout_t {
|
||||
const std::vector<ir::value *> &_values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
||||
};
|
||||
|
||||
struct layout_scanline_t: public layout_t {
|
||||
@@ -75,6 +91,7 @@ struct layout_scanline_t: public layout_t {
|
||||
const std::vector<ir::value *> &values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
};
|
||||
|
||||
struct layout_shared_t: public layout_t {
|
||||
@@ -85,9 +102,11 @@ struct layout_shared_t: public layout_t {
|
||||
ir::type *ty,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
};
|
||||
|
||||
|
||||
|
||||
class layout {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
@@ -147,7 +147,7 @@ private:
|
||||
};
|
||||
|
||||
|
||||
class generator: public ir::visitor {
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
||||
void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
@@ -163,7 +163,9 @@ public:
|
||||
|
||||
generator(LLVMContext *ctx,
|
||||
Function *fn,
|
||||
Module *dst,
|
||||
Builder *builder,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap,
|
||||
target *tgt,
|
||||
@@ -176,7 +178,7 @@ public:
|
||||
unsigned num_packs_0, unsigned num_packs_1,
|
||||
unsigned pack_size_0, unsigned pack_size_1,
|
||||
unsigned num_warps)
|
||||
: ctx_(ctx), fn_(fn), builder_(builder), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
: ctx_(ctx), fn_(fn), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
|
||||
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
|
||||
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
|
||||
@@ -221,14 +223,27 @@ public:
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
|
||||
void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
|
||||
void visit_function(ir::function*);
|
||||
|
||||
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
|
||||
void visit_layout_scanline(analysis::layout_scanline_t*);
|
||||
void visit_layout_shared(analysis::layout_shared_t*);
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Function *fn_;
|
||||
Builder *builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
target *tgt_;
|
||||
@@ -249,29 +264,15 @@ class selection{
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
|
||||
private:
|
||||
// utils
|
||||
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||
std::vector<unsigned> extract_shapes(ir::value *v);
|
||||
|
||||
// LLVM conversions
|
||||
Type* llvm_type(ir::type *ty, LLVMContext &ctx);
|
||||
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
|
||||
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
|
||||
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
|
||||
Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst);
|
||||
Value* alloc_shared(Builder &builder, Module& dst);
|
||||
|
||||
// grid construction
|
||||
void create_grids(std::vector<ir::value *> &grids,
|
||||
std::map<unsigned, ir::value *> &references,
|
||||
ir::function *fn);
|
||||
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
|
||||
void create_distributed_tile(ir::value *v, Builder &builder);
|
||||
void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr);
|
||||
void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||
|
||||
// lower scalar instruction
|
||||
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "enums.h"
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -32,6 +33,7 @@ private:
|
||||
public:
|
||||
static undef_value* get(type* ty);
|
||||
std::string repr() const { return "undef"; }
|
||||
void accept(visitor* vst) { vst->visit_undef_value(this); }
|
||||
};
|
||||
|
||||
|
||||
@@ -44,31 +46,13 @@ public:
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_int(this); }
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* Metaparameter (int) */
|
||||
class metaparameter: public constant_int {
|
||||
private:
|
||||
metaparameter(type *ty, const std::vector<unsigned>& space);
|
||||
|
||||
public:
|
||||
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
||||
static metaparameter *create(context &ctx, type *ty, const std::vector<unsigned>& space);
|
||||
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
|
||||
bool has_value() { return has_value_; }
|
||||
const std::vector<unsigned>& get_space() { return space_; }
|
||||
void set_space(const std::vector<unsigned> &space) { space_ = space; }
|
||||
uint64_t get_value() const { assert(has_value_); return value_; }
|
||||
std::string repr() const { return has_value_? std::to_string(value_) : "?" ;}
|
||||
private:
|
||||
std::vector<unsigned> space_;
|
||||
bool has_value_;
|
||||
};
|
||||
|
||||
/* constant fp */
|
||||
/* Constant fp */
|
||||
class constant_fp: public constant{
|
||||
constant_fp(type *ty, double value);
|
||||
|
||||
@@ -79,13 +63,14 @@ public:
|
||||
static constant* get(context &ctx, double v);
|
||||
static constant* get(type *ty, double v);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_fp(this); }
|
||||
|
||||
private:
|
||||
double value_;
|
||||
};
|
||||
|
||||
|
||||
/* global value */
|
||||
/* Global Value */
|
||||
class global_value: public constant {
|
||||
public:
|
||||
enum linkage_types_t {
|
||||
@@ -109,7 +94,6 @@ public:
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space = 0);
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
};
|
||||
|
||||
/* global variable */
|
||||
@@ -118,6 +102,8 @@ public:
|
||||
alloc_const(type *ty, constant_int *size,
|
||||
const std::string &name = "");
|
||||
std::string repr() const { return get_name(); }
|
||||
void accept(visitor* vst) { vst->visit_alloc_const(this); }
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
@@ -14,7 +14,6 @@ class constant;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
class metaparameter;
|
||||
|
||||
/* Context impl */
|
||||
class context_impl {
|
||||
@@ -36,8 +35,6 @@ public:
|
||||
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, undef_value*> uv_constants_;
|
||||
// Metaparameters
|
||||
std::vector<metaparameter*> mp_constants_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -112,6 +112,9 @@ public:
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_function(this); }
|
||||
|
||||
private:
|
||||
module *parent_;
|
||||
bool init_;
|
||||
|
@@ -71,8 +71,6 @@ public:
|
||||
}
|
||||
// instruction id
|
||||
value_id_t get_id() const { return id_; }
|
||||
// visit
|
||||
virtual void accept(visitor *v) = 0;
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
@@ -759,6 +757,7 @@ public:
|
||||
static make_range_sta *get(make_range* range);
|
||||
make_range* get_range() const;
|
||||
std::string repr() const { return "nv_static_program_idx"; }
|
||||
_TRITON_DEFINE_ACCEPT(make_range_sta)
|
||||
|
||||
private:
|
||||
make_range *range_;
|
||||
|
@@ -13,6 +13,7 @@ namespace ir{
|
||||
class type;
|
||||
class use;
|
||||
class user;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// value class
|
||||
@@ -74,6 +75,9 @@ public:
|
||||
void replace_all_uses_with(value *target);
|
||||
void replace_uses_of_with(value *before, value *after);
|
||||
|
||||
// Visitor
|
||||
virtual void accept(visitor *v) = 0;
|
||||
|
||||
private:
|
||||
ops_t ops_;
|
||||
unsigned num_ops_;
|
||||
|
@@ -61,10 +61,25 @@ class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class barrier_inst;
|
||||
class make_range_dyn;
|
||||
class make_range_sta;
|
||||
class make_range;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class function;
|
||||
|
||||
class visitor {
|
||||
public:
|
||||
@@ -108,8 +123,15 @@ public:
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -43,7 +43,6 @@ namespace ir {
|
||||
class module;
|
||||
class function;
|
||||
class context;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
namespace runtime{
|
||||
|
@@ -343,16 +343,6 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
/* convert ir::constant to Constant */
|
||||
Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
|
||||
Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx);
|
||||
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
|
||||
return ConstantInt::get(dst_ty, cc->get_value());
|
||||
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
|
||||
return ConstantFP::get(dst_ty, cc->get_value());
|
||||
// unknown constant
|
||||
throw std::runtime_error("unknown conversion from ir::constant to Constant");
|
||||
}
|
||||
|
||||
/* convert ir::alloc_const to llvm::GlobalVariable */
|
||||
Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) {
|
||||
@@ -387,145 +377,6 @@ inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
auto order = layout.order;
|
||||
const auto& shapes = layout.shapes;
|
||||
size_t dim = shapes.size();
|
||||
std::vector<int> nts = layout.nts;
|
||||
std::vector<int> mts = layout.mts;
|
||||
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder.getInt32(nts[k]);
|
||||
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts[k] * mts[k];
|
||||
unsigned per_thread = nts[k] * shapes[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts[k] * per_block + n % nts[k];
|
||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = layout.shapes;
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
|
||||
Value *_1 = builder.getInt32(1);
|
||||
Value *_2 = builder.getInt32(2);
|
||||
Value *_3 = builder.getInt32(3);
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout.fpw.at(0);
|
||||
unsigned fpw_1 = layout.fpw.at(1);
|
||||
unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout.wpt.at(0);
|
||||
unsigned wpt_1 = layout.wpt.at(1);
|
||||
unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// hmma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||
builder.getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||
builder.getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(layout.type == analysis::HMMA_884)
|
||||
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
else if(layout.type == analysis::SCANLINE)
|
||||
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
}
|
||||
|
||||
/* -------------------
|
||||
* ---- Init Tiles ----
|
||||
* ------------------- */
|
||||
@@ -549,7 +400,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
||||
if(parent->empty())
|
||||
builder.SetInsertPoint(parent);
|
||||
else
|
||||
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
||||
builder.SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create double-buffered pointer
|
||||
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
||||
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
||||
@@ -587,41 +438,6 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
tmap_.insert({v, T});
|
||||
}
|
||||
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, seen, sh_mem_ptr);
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
create_shared_tile(i, builder, sh_mem_ptr);
|
||||
else
|
||||
create_distributed_tile(v, builder);
|
||||
}
|
||||
|
||||
void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
|
||||
// fetch linear ID
|
||||
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
|
||||
Value *warp_size = builder.getInt32(32);
|
||||
Value* u_thread_id = tgt_->get_local_id(mod, builder, 0);
|
||||
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
for(auto x: layouts_->get_all())
|
||||
init_axes(*x.second, builder, u_thread_warp_id, u_warp_id);
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
create_tile(i, builder, seen, sh_mem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
@@ -641,51 +457,34 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen
|
||||
if(!seen.insert(src).second)
|
||||
return;
|
||||
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
if(src->get_type()->is_tile_ty()){
|
||||
builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin());
|
||||
auto *i = dynamic_cast<ir::instruction*>(src);
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(src)){
|
||||
create_shared_tile(i, builder, sh_mem_ptr_);
|
||||
}
|
||||
else
|
||||
create_distributed_tile(src, builder);
|
||||
}
|
||||
builder.SetInsertPoint(current);
|
||||
|
||||
|
||||
auto *inst = dynamic_cast<ir::instruction*>(src);
|
||||
if(inst && !dynamic_cast<ir::phi_node*>(src))
|
||||
for(ir::value *op: inst->ops())
|
||||
lower_value(op, builder, gen, seen);
|
||||
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
builder.SetInsertPoint(current);
|
||||
auto *phi = dynamic_cast<ir::phi_node*>(src);
|
||||
bool phi_inserted = phi && !current->empty();
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
if(phi && !current->empty() && current->getFirstNonPHI())
|
||||
builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
|
||||
if(auto *usr = dynamic_cast<ir::user*>(src))
|
||||
usr->accept(gen);
|
||||
|
||||
if(dynamic_cast<ir::make_range*>(src)){
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(src);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
else if(dynamic_cast<ir::make_range_sta*>(src)){
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(src);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(res));
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else if(auto *cst = dynamic_cast<ir::constant*>(src)){
|
||||
vmap_[cst] = llvm_constant(cst, builder.getContext());
|
||||
}
|
||||
else if(inst){
|
||||
inst->accept(gen);
|
||||
}
|
||||
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
if(phi && !current->empty() && current->getFirstNonPHI())
|
||||
builder.SetInsertPoint(current);
|
||||
|
||||
// if(dynamic_cast<ir::phi_node*>(src))
|
||||
// for(ir::value *op: inst->ops())
|
||||
// lower_value(op, builder, seen);
|
||||
|
||||
|
||||
}
|
||||
|
||||
/* ----------------------------
|
||||
@@ -702,12 +501,6 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||
}
|
||||
}
|
||||
|
||||
ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) {
|
||||
unsigned size = 1;
|
||||
for(auto shape: ty->get_tile_shapes())
|
||||
size *= shape;
|
||||
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
|
||||
}
|
||||
|
||||
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
@@ -777,6 +570,9 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
|
||||
|
||||
// allocate shared memory
|
||||
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||
|
||||
// iterate over functions
|
||||
std::set<ir::value*> seen;
|
||||
|
||||
@@ -785,14 +581,13 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
// create LLVM function
|
||||
Function *ffn = llvm_fn(fn, dst_builder, dst);
|
||||
|
||||
// allocate shared memory
|
||||
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||
// create tile
|
||||
generator gen(&dst_ctx, ffn, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
|
||||
// initialize layouts
|
||||
init_layouts(fn, dst_builder, sh_mem_ptr_);
|
||||
|
||||
generator gen(&dst_ctx, ffn, &dst_builder, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_,
|
||||
offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ );
|
||||
for(auto x: layouts_->get_all())
|
||||
x.second->accept(&gen);
|
||||
|
||||
// generate LLVM-IR code
|
||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||
@@ -1536,6 +1331,179 @@ Type *generator::type(ir::type *ty) {
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
void generator::visit_undef_value(ir::undef_value *ud) {
|
||||
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type()));
|
||||
}
|
||||
|
||||
void generator::visit_constant_int(ir::constant_int *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty());
|
||||
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_constant_fp(ir::constant_fp *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty());
|
||||
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
|
||||
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty());
|
||||
Type *array_ty = llvm::ArrayType::get(element_ty, size);
|
||||
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
|
||||
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
vmap_[alloc] = builder_->CreateBitCast(array, element_ty->getPointerTo(4));
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_function(ir::function*) {
|
||||
|
||||
}
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shapes = layout->shapes;
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw.at(0);
|
||||
unsigned fpw_1 = layout->fpw.at(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt.at(0);
|
||||
unsigned wpt_1 = layout->wpt.at(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// hmma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->order;
|
||||
const auto& shapes = layout->shapes;
|
||||
size_t dim = shapes.size();
|
||||
std::vector<int> nts = layout->nts;
|
||||
std::vector<int> mts = layout->mts;
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts[k]);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts[k] * mts[k];
|
||||
unsigned per_thread = nts[k] * shapes[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts[k] * per_block + n % nts[k];
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t*) {
|
||||
|
||||
}
|
||||
|
||||
void generator::for_each(ir::value *x, const std::function<void(indices_t)>& fn) {
|
||||
if(!x->get_type()->is_tile_ty())
|
||||
return fn({});
|
||||
|
@@ -76,27 +76,6 @@ constant *constant_fp::get(type *ty, double v){
|
||||
return result;
|
||||
}
|
||||
|
||||
// metaparameter
|
||||
metaparameter::metaparameter(type *ty, const std::vector<unsigned> &space)
|
||||
: constant_int(ty, 0), space_(space), has_value_(false){ }
|
||||
|
||||
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
||||
context_impl *impl = ctx.p_impl.get();
|
||||
std::vector<unsigned> space;
|
||||
for(unsigned i = lo; i <= hi; i *= 2)
|
||||
space.push_back(i);
|
||||
metaparameter *result = new metaparameter(ty, space);
|
||||
impl->mp_constants_.push_back(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<unsigned> &space) {
|
||||
context_impl *impl = ctx.p_impl.get();
|
||||
metaparameter *result = new metaparameter(ty, space);
|
||||
impl->mp_constants_.push_back(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// undef value
|
||||
undef_value::undef_value(type *ty)
|
||||
|
Reference in New Issue
Block a user