[codegen/tune]: added fragmentation types

This commit is contained in:
Philippe Tillet
2019-06-06 16:48:32 -07:00
parent f58c9a4d2b
commit cdf5a0d011
7 changed files with 101 additions and 35 deletions

View File

@@ -20,8 +20,8 @@ using GPUDevice = Eigen::GpuDevice;
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TM = {16};
const tunable int32 TN = {16};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
@@ -126,7 +126,7 @@ class BlockSparseGemmOp : public OpKernel {
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
stream->synchronize();
// just-in-time compile source-code
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
jit.add_module("matmul", src, {8, 2, 16, 8, 2, 16, 8, 8, 2, 2, 8, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
// launch info

View File

@@ -20,6 +20,6 @@ hresult = np.dot(hb.T, ha)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
a: ha,
b: hb})
a: ha,
b: hb})
print(result - hresult)

View File

@@ -70,6 +70,7 @@ private:
unsigned vector_size_;
};
// Distribtued tile
class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
typedef std::vector<indices_t> ordered_indices_vec_t;
@@ -98,6 +99,15 @@ private:
};
// Fragmented tile
class fragmented_tile: public tile{
public:
private:
};
// Selection pass
class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
@@ -118,9 +128,9 @@ private:
// grid construction
void create_grids(std::vector<ir::value *> &grids,
std::map<ir::metaparameter *, ir::value *> &references,
std::map<unsigned, ir::value *> &references,
ir::function *fn);
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<ir::metaparameter *, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<unsigned, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr);
@@ -143,7 +153,7 @@ private:
tune *params_;
target *tgt_;
shmem_info *buffer_info_;
std::map<ir::metaparameter*, distributed_axis> axes_;
std::map<unsigned, distributed_axis> axes_;
llvm::Value *sh_mem_ptr_;
};

View File

@@ -21,11 +21,17 @@ class tune {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
enum fragment_t{
STRIDED_SCAN,
HMMA_FRAGMENT_C
};
private:
void add_constraint(node_t x, node_t y);
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph);
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
@@ -34,7 +40,8 @@ public:
std::vector<ir::metaparameter *> get_params(ir::module& mod);
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
unsigned get_param_group(ir::value *value, unsigned ax);
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; groups_[dst] = groups_[src]; }
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
void init(ir::module &mod);
@@ -46,12 +53,14 @@ private:
std::vector<unsigned*> pool_;
graph_t dependencies_;
std::set<node_t> nodes_;
std::map<node_t, fragment_t> fragments_;
std::map<node_t, unsigned> static_params_;
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
unsigned num_global_ranges_;
unsigned num_threads_;
std::vector<ir::instruction*> grids_;
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
};

View File

@@ -459,12 +459,12 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->get_param(v, "nts.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
}
}
void selection::create_grids(std::vector<ir::value*> &grids,
std::map<ir::metaparameter*, ir::value*> &references,
std::map<unsigned, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
@@ -479,7 +479,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second || dynamic_cast<ir::mask_inst*>(v))
return;
// recurse
if(auto *user = dynamic_cast<ir::user*>(v))
@@ -492,7 +492,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
continue;
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
unsigned x = params_->get_param_group(v, d);
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
@@ -517,7 +517,7 @@ bool static inline has_phi_user(ir::value *v) {
return false;
}
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
const std::map<ir::metaparameter*, ir::value*>& references,
const std::map<unsigned, ir::value*>& references,
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
@@ -576,7 +576,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::vector<distributed_axis> axes(cshapes.size());
for(size_t d = 0; d < cshapes.size(); d++){
if(cshapes[d]->get_value() > 1){
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
unsigned x = params_->get_param_group(v, d);
axes[d] = axes_.at(x);
}
else{
@@ -607,7 +607,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
std::vector<ir::value*> grids;
std::map<ir::metaparameter*, ir::value*> references;
std::map<unsigned, ir::value*> references;
create_grids(grids, references, fn);
for(ir::value* i: grids){
if(auto *instr = dynamic_cast<ir::instruction*>(i))
@@ -812,7 +812,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
// res = builder.CreateCall(f_mul_add, {ConstantFP::get(a->getType(), 1), ConstantFP::get(b->getType(), 1), res});
res = builder.CreateCall(f_mul_add, {a, b, res});
}

View File

@@ -15,6 +15,19 @@ namespace codegen{
tune::tune(): num_global_ranges_(0){ }
bool is_hmma(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = !x->is_a_trans() && x->is_b_trans();
result = result && a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
void tune::add_constraint(node_t x, node_t y) {
dependencies_[x].insert(y);
dependencies_[y].insert(x);
@@ -34,6 +47,7 @@ void tune::init_c_phi(ir::instruction *v) {
}
void tune::init_c_graph(ir::instruction *v) {
// Reference shape
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context());
ir::type::tile_shapes_t shapes;
@@ -83,20 +97,41 @@ void tune::init_c_graph(ir::instruction *v) {
add_constraint({v, 1}, {D, 1});
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)){
else if(dynamic_cast<ir::user*>(v)) {
for(unsigned k = 0; k < v->get_num_results(); k++)
for(unsigned i = 0; i < shapes.size(); i ++)
for(ir::value* op: v->ops())
for(unsigned i = 0; i < shapes.size(); i ++){
for(ir::value* op: v->ops()){
add_constraint({v->get_result(k), i}, {op, i});
}
}
}
}
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph) {
tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
std::list<node_t> work;
std::set<node_t> seen;
work.push_back(x);
while(!work.empty()){
node_t current = work.back();
if(is_hmma(current.first))
return HMMA_FRAGMENT_C;
work.pop_back();
seen.insert(current);
for(node_t y: graph[current]){
if(seen.find(y) == seen.end())
work.push_back(y);
}
}
return STRIDED_SCAN;
}
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x.first][x.second] = group_id;
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
params_[x.first].insert({"nts" + suffix, mps[0]});
params_[x.first].insert({"mts" + suffix, mps[1]});
for(int i = 0; i < mps.size(); i++)
params_[x.first].insert({prefixes[i] + suffix, mps[i]});
ir::type *ty = x.first->get_type();
if(ty->is_tile_ty()){
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
@@ -109,11 +144,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
}
if(static_params_.find(x) != static_params_.end()){
mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x));
for(ir::metaparameter *mp: mps)
mp->set_value(static_params_.at(x));
}
for(const node_t &y: graph[x])
connected_components(y, mps, nodes, graph);
connected_components(y, mps, prefixes, nodes, graph, group_id);
}
}
@@ -142,6 +177,10 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
return params_.at(i);
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
unsigned result = groups_.at(value).at(ax);
return result;
}
void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
@@ -159,12 +198,21 @@ void tune::run(ir::module &mod) {
if(i->has_tile_result_or_op())
init_c_phi(i);
// Layout parameters
while(!nodes_.empty()){
unsigned group_id = 0;
while(!nodes_.empty()) {
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
nts->set_value(1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
node_t node = *nodes_.begin();
fragment_t fragment = get_fragmentation_type(node, dependencies_);
if(fragment == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw}, {"fpw"}, nodes_, dependencies_, group_id++);
}
}
}
@@ -269,7 +317,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
if(num_threads % 64 != 0)
if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");

View File

@@ -62,8 +62,8 @@ std::string gemm::src(bool AT, bool BT) {
}
std::string res =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TM = {16};
const tunable int32 TN = {16};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};