[auto-tuning] much smaller parameters space

This commit is contained in:
Philippe Tillet
2019-08-09 16:57:18 -07:00
parent fd49cdc92b
commit 1400d960a6
20 changed files with 470 additions and 367 deletions

View File

@@ -111,7 +111,7 @@ int main() {
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, true, 64, 64, 128}
{false, true, 128, 128, 128}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},

View File

@@ -13,6 +13,7 @@ namespace ir{
class instruction;
class function;
class metaparameter;
class constant_int;
}
namespace codegen{
@@ -34,7 +35,9 @@ private:
void init_c_graph(ir::instruction *v);
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
void create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn);
unsigned get_req_num_threads(ir::instruction *i);
@@ -49,8 +52,6 @@ public:
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
void init(ir::module &mod);
unsigned get_num_global_range();
unsigned get_global_range_size(unsigned axis);
unsigned get_num_threads();
private:
@@ -61,10 +62,9 @@ private:
std::map<node_t, unsigned> static_params_;
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
unsigned num_global_ranges_;
unsigned num_threads_;
std::vector<ir::instruction*> grids_;
std::vector<ir::value*> grids_;
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
ir::metaparameter* num_warps_;
};

View File

@@ -14,9 +14,9 @@ namespace ir {
namespace codegen{
namespace transform{
class optimize_dce {
class dce {
public:
optimize_dce() {}
dce() {}
void run(ir::module &mod);
};

View File

@@ -28,7 +28,7 @@ namespace transform{
class reassociate {
struct cst_info {
ir::getelementptr_inst* dyn_ptr;
ir::value* dyn_ptr;
ir::getelementptr_inst* sta_ptr;
};
@@ -38,12 +38,11 @@ private:
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
reassociate(analysis::tune *params, analysis::alignment_info *align);
reassociate(analysis::tune *params);
void run(ir::module& module);
private:
analysis::tune* params_;
analysis::alignment_info* align_;
};
}

View File

@@ -19,12 +19,17 @@ namespace ir {
namespace codegen{
namespace transform{
class optimize_trans {
class peephole {
private:
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
private:
ir::value *replace_phi(ir::value* value, ir::builder &builder, const std::vector<ir::constant_int *> &perm);
public:
optimize_trans() {}
peephole() {}
void run(ir::module &mod);
};

View File

@@ -66,29 +66,30 @@ public:
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
optimize_dot(&tune),
dce(),
optimize_trans(),
peephole(),
alignment_info(),
reassociate(&tune, &alignment_info),
reassociate(&tune),
target_(target) { }
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
ir::print(module, std::cout);
peephole.run(module);
dce.run(module);
}
void target_dependent(ir::module &module) {
alignment_info.run(module);
reassociate.run(module);
peephole.run(module);
if(target_->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
shmem_barriers.run(module);
}
alignment_info.run(module);
vectorize.run(module);
dce.run(module);
// ir::print(module, std::cout);
ir::print(module, std::cout);
}
codegen::selection selection;
@@ -100,8 +101,8 @@ public:
codegen::transform::shmem_barriers shmem_barriers;
codegen::transform::vectorize vectorize;
codegen::transform::optimize_dot optimize_dot;
codegen::transform::optimize_dce dce;
codegen::transform::optimize_trans optimize_trans;
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate;
codegen::target* target_;
};

View File

@@ -8,7 +8,6 @@ namespace triton{
namespace runtime{
struct launch_information{
std::vector<unsigned> global_range_size;
unsigned num_threads;
std::map<std::string, unsigned> globals;
};

View File

@@ -14,7 +14,8 @@ namespace triton{
namespace codegen{
namespace analysis{
tune::tune(): num_global_ranges_(0){ }
tune::tune() {
}
bool is_hmma(ir::value *v){
bool result = false;
@@ -123,8 +124,8 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == one)
static_params_.insert({{v, i}, 1});
// add_constraint({v, i}, {A, i});
// add_constraint({v, i}, {B, i});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
@@ -172,11 +173,6 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
// if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
// unsigned ax = range->get_axis();
// global_range_sizes_[ax] = params_[x.first].at("shape.d0");
// num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
// }
if(static_params_.find(x) != static_params_.end()){
for(ir::metaparameter *mp: mps)
mp->set_value(static_params_.at(x));
@@ -189,21 +185,13 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
std::vector<ir::metaparameter*> result;
std::set<ir::metaparameter*> seen;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && !x.second->has_value()){
result.push_back(x.second);
}
for(auto x: mod.globals()){
for(auto x: mod.globals()) {
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
if(seen.insert(mp).second && !mp->has_value())
result.push_back(mp);
}
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
result.push_back(num_warps_);
return result;
}
@@ -212,7 +200,6 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
// std::cout << "group? " << value->get_name() << " " << ax << std::endl;
unsigned result = groups_.at(value).at(ax);
return result;
}
@@ -229,139 +216,164 @@ void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
// Create metaparameters
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op()){
if(i->has_tile_result_or_op())
init_c_graph(i);
}
// Build phi constraints
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op())
init_c_phi(i);
// Layout parameters
unsigned group_id = 0;
for(auto x: nodes_){
for(auto x: nodes_)
fragments_[x] = get_fragmentation_type(x, dependencies_);
}
while(!nodes_.empty()) {
ir::type *ty = mod.get_builder().get_int32_ty();
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 8);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
if(node.second == 2)
fpw->set_value(1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
}
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
auto shapes = i->get_type()->get_tile_shapes();
if(auto *x = dynamic_cast<ir::load_inst*>(i))
if(fragments_.at({i, 0}) == STRIDED_SCAN){
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
// std::unique_ptr<ir::metaparameter> mts_2(ir::metaparameter::create(ctx, ty, 1, 4));
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
if(fragments_.at({i, 0}) == STRIDED_SCAN){
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
// for(size_t k = 2; k < shapes.size(); k++)
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
// *params_.at(i).at("mts.d" + std::to_string(k)) = *x;
// else
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(shapes[k]->get_value());
}
else{
// for(size_t k = 2; k < shapes.size(); k++)
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
// *params_.at(i).at("wpt.d" + std::to_string(k)) = *x;
// else
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(shapes[k]->get_value());
}
}
}
}
void tune::init(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
std::map<ir::metaparameter*, ir::instruction*> references;
std::map<unsigned, ir::value*> references;
create_grids(grids_, references, fn);
}
num_threads_ = get_req_num_threads(grids_.front());
}
int num_threads = get_num_threads();
int num_warps = num_warps_->get_value();
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
unsigned tune::get_req_num_threads(ir::instruction *i){
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
unsigned result = 1;
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
result *= params_.at(i).at("mts" + suffix)->get_value();
for(ir::value *i: grids_){
if(!i->get_type()->is_tile_ty())
continue;
auto shapes = i->get_type()->get_tile_shapes();
int shape_0 = shapes[0]->get_value();
int shape_1 = shapes[1]->get_value();
int size = i->get_type()->get_tile_num_elements();
/* HMMA parameters*/
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<int> fpw = {1, 1, 1};
std::vector<int> fpw_nm1;
int num_fragments = std::min((shape_0/8)*(shape_1/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
if(fpw[0]*fpw[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
// store parameters
for(int d = 0; d < shapes.size(); d++)
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<int> wpt = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
for(int d = 0; d < shapes.size(); d++)
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
assert(num_warps == effective_num_warps);
}
return result;
}
else {
unsigned result = 32;
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
result *= params_.at(i).at("wpt" + suffix)->get_value();
/* Scan-line */
else{
int shape = shapes[0]->get_value();
int current = num_threads;
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
current = current / params_.at(i).at("mts.d0")->get_value();
for(size_t d = 1; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
shape = shapes[d]->get_value();
params_.at(i).at("nts.d" + str_d)->set_value(1);
params_.at(i).at("mts.d" + str_d)->set_value(clamp(current, 1, shape));
current = current / params_.at(i).at("mts.d" + str_d)->get_value();
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
}
assert(num_threads == effective_num_threads);
}
return result;
}
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
std::map<ir::metaparameter*, ir::instruction*> &references,
ir::function *fn) {
void tune::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0;
auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
result += (shape != one);
for(ir::constant_int* shape: v->get_type()->get_tile_shapes()) {
result += (shape->get_value() > 1)?shape->get_value():0;
}
return result;
};
// bind references
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
for(auto &param: params_.at(i)){
if(param.second->get_value() == 1)
std::set<ir::value*> seen;
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
// recurse
if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops())
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
continue;
ir::instruction *&r = references[param.second];
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
r = i;
unsigned x = get_param_group(v, d);
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
}
}
};
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
bind_references(i);
// create grid
for(auto &ref: references)
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
@@ -370,85 +382,11 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
using std::to_string;
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
std::string strk = to_string(axis);
if(fragments_.at({i, axis}) == STRIDED_SCAN){
unsigned mts = params_[i]["mts.d" + strk]->get_value();
unsigned nts = params_[i]["nts.d" + strk]->get_value();
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
return shape / (mts * nts);
}
else{
return (unsigned)params_[i]["wpt.d" + strk]->get_value();
}
};
// number of warps
ir::instruction *first = grids_.front();
int num_warps = 1;
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
num_warps *= get_num_warps(first, k);
// check constraints
for(ir::instruction *i: grids_){
// std::cout << i->get_name() << std::endl;
ir::type *ty = i->get_type();
const auto &shapes = ty->get_tile_shapes();
// for each dimension, the product of layout components
// must device the shape
for(size_t k = 0; k < shapes.size(); k++) {
std::string strk = to_string(k);
unsigned multiple;
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
ir::metaparameter *mts = params_[i]["mts.d" + strk];
ir::metaparameter *nts = params_[i]["nts.d" + strk];
multiple = mts->get_value()*nts->get_value();
}
else {
ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
multiple = fpw->get_value()*wpt->get_value();
if(k < 2)
multiple *= 8;
}
if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
" is not a multiple of layout (" + to_string(multiple) + ")");
}
// the product of mma fragments per warp must be 4
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
unsigned prod = 1;
for(size_t k = 0; k < shapes.size(); k++){
prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
}
if(prod > 4)
errors[i].push_back("HMMA must have only 4 fragments per warp");
}
int num_threads = get_req_num_threads(i);
if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
}
// for(auto x: errors)
// for(auto e: x.second)
// std::cout << x.first->get_name() << ": " << e << std::endl;
// exit(EXIT_SUCCESS);
return errors.empty();
}
unsigned tune::get_num_global_range() {
return num_global_ranges_;
}
unsigned tune::get_global_range_size(unsigned axis) {
return global_range_sizes_.at(axis)->get_value();
}
unsigned tune::get_num_threads() {
return num_threads_;
return num_warps_->get_value()*32;
}

View File

@@ -243,7 +243,7 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
/* convert ir::constant to Constant */
Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
Type *dst_ty = llvm_type(cst->get_type(), ctx);
Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx);
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
return ConstantInt::get(dst_ty, cc->get_value());
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
@@ -478,8 +478,9 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
nw[i] = ceil(nthreads, nwarps*warp_size);
nwarps *= nw[i];
}
for(size_t i = 0; i < bs.size(); ++i)
for(size_t i = 0; i < bs.size(); ++i){
ws[i] = bs[i] / nw[i];
}
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
@@ -565,7 +566,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateURem(builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)), builder.getInt32(fpw_1));
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
@@ -1296,7 +1298,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else {
result->for_each([&](indices_t idx){
auto value = [&](ir::value *x) {
if(x->get_type()->is_tile_ty())
if(auto *cst = dynamic_cast<ir::constant_int*>(x))
return (Value*)llvm_constant(cst, ctx);
else if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
else
return llvm_value(x, builder);

View File

@@ -9,7 +9,7 @@ namespace codegen{
namespace transform{
void optimize_dce::run(ir::module &mod) {
void dce::run(ir::module &mod) {
std::list<ir::instruction*> work_list;
std::set<ir::instruction*> marked;

View File

@@ -155,8 +155,8 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
return new_value;
}
reassociate::reassociate(analysis::tune* params, analysis::alignment_info* align)
: params_(params), align_(align)
reassociate::reassociate(analysis::tune* params)
: params_(params)
{ }
@@ -190,93 +190,108 @@ void reassociate::run(ir::module &mod) {
// reassociate
std::map<ir::value*, cst_info> infos;
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list()){
// getelementptr instruction
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
// unpack GEP instruction
ir::value* py = pz->get_pointer_operand();
ir::value* offset = *pz->idx_begin();
// reassociate index
ir::value *sta = nullptr;
ir::value *dyn = offset;
reassociate_idx(offset, builder, dyn, sta);
if(sta){
builder.set_insert_point(pz);
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
params_->copy(dyn_ptr, pz);
params_->copy(sta_ptr, pz);
align_->copy(sta_ptr, pz);
pz->replace_all_uses_with(sta_ptr);
infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
}
// reassociate pointer argument
if(ir::getelementptr_inst* gepy = dynamic_cast<ir::getelementptr_inst*>(py))
if(infos.find(gepy) != infos.end()){
builder.set_insert_point(pz);
ir::getelementptr_inst *sta = infos[gepy].sta_ptr;
ir::getelementptr_inst *dyn = infos[gepy].dyn_ptr;
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *new_dyn = builder.create_gep(dyn, {off});
ir::value *new_pz = builder.create_gep(new_dyn, {cst}, pz->get_name());
params_->copy(new_dyn, pz);
params_->copy(new_pz, pz);
align_->copy(new_pz, pz);
pz->replace_all_uses_with(new_pz);
}
// reassociate phi-node pointer
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
// only optimize the case where py = phi pa, pz for now
std::vector<ir::value*> ops = phi->ops();
if(ops.size() != 2)
std::set<ir::value*> replaced;
size_t n_replaced;
do{
n_replaced = replaced.size();
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list()){
// getelementptr instruction
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
if(replaced.find(pz) != replaced.end())
continue;
if(ops[0] != pz && ops[1] != pz)
continue;
// grab incoming
size_t idx_z = (ops[0] == pz) ? 0 : 1;
size_t idx_a = (ops[0] == pz) ? 1 : 0;
// check if pa is known to have constant offset
ir::value *vpa = phi->get_incoming_value(idx_a);
auto it = infos.find(vpa);
if(it == infos.end())
continue;
ir::getelementptr_inst *pa = (ir::getelementptr_inst*)vpa;
// unpack dynamically/statically offset pointer
ir::getelementptr_inst *dyn_ptr = it->second.dyn_ptr;
ir::getelementptr_inst *sta_ptr = it->second.sta_ptr;
// we take static offset out of the phi function
builder.set_insert_point(phi);
ir::phi_node *new_phi = builder.create_phi(phi->get_type(), 2);
// new pz for phi has the same offsets
builder.set_insert_point(pz);
std::vector<ir::value*> idxs(pz->idx_begin(), pz->idx_end());
ir::value *new_phi_pz = builder.create_gep(new_phi, idxs);
// fold the static offset into the new pz value
ir::value *new_pz = builder.create_gep(new_phi_pz, {*sta_ptr->idx_begin()});
// populate incoming values
new_phi->add_incoming(dyn_ptr, phi->get_incoming_block(idx_a));
new_phi->add_incoming(new_phi_pz, phi->get_incoming_block(idx_z));
// replace phi uses
phi->replace_all_uses_with(new_phi);
// replace pz uses
pz->replace_all_uses_with(new_pz);
// copy params
params_->copy(new_phi_pz, pz);
params_->copy(new_phi, phi);
params_->copy(new_pz, pz);
align_->copy(new_pz, pz);
// unpack GEP instruction
ir::value* py = pz->get_pointer_operand();
ir::value* offset = *pz->idx_begin();
// reassociate index
ir::value *sta = nullptr;
ir::value *dyn = offset;
reassociate_idx(offset, builder, dyn, sta);
if(sta){
builder.set_insert_point(pz);
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
params_->copy(dyn_ptr, pz);
params_->copy(sta_ptr, pz);
pz->replace_all_uses_with(sta_ptr);
infos[sta_ptr].dyn_ptr = dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
replaced.insert(pz);
}
// reassociate pointer argument
if(infos.find(py) != infos.end()){
builder.set_insert_point(pz);
ir::getelementptr_inst *sta = infos[py].sta_ptr;
ir::value *dyn = infos[py].dyn_ptr;
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *pz_dyn = builder.create_gep(dyn, {off});
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
params_->copy(pz_dyn, pz);
params_->copy(pz_sta, pz);
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
replaced.insert(pz);
}
// reassociate phi-node pointer
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
// only optimize the case where py = phi pa, pz for now
std::vector<ir::value*> ops = phi->ops();
if(ops.size() != 2)
continue;
if(ops[0] != pz && ops[1] != pz)
continue;
// grab incoming
size_t idx_z = (ops[0] == pz) ? 0 : 1;
size_t idx_a = (ops[0] == pz) ? 1 : 0;
// check if pa is known to have constant offset
ir::value *vpa = phi->get_incoming_value(idx_a);
auto it_a = infos.find(vpa);
if(it_a == infos.end())
continue;
// unpack dynamically/statically offset pointer
ir::value *pa_dyn = it_a->second.dyn_ptr;
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
ir::value *pz = phi->get_incoming_value(idx_z);
// extract offset
ir::value *off = *pa_sta->idx_begin();
builder.set_insert_point(phi);
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
// re-add the offset
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
phi->replace_all_uses_with(phi_sta);
// remove offset from pz
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
auto insts = x->get_parent()->get_inst_list();
auto it = std::find(insts.begin(), insts.end(), x);
it++;
builder.set_insert_point(*it);
}
ir::value *neg_off = builder.create_neg(off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
// copy parameters
params_->copy(pz_dyn, pz);
params_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
params_->copy(neg_off, off);
params_->copy(phi_dyn, phi);
params_->copy(phi_sta, phi);
infos[phi_sta].dyn_ptr = phi_dyn;
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
replaced.insert(phi);
}
}
}
}
}
}
}
}
}while(replaced.size() != n_replaced);
}
}

View File

@@ -7,14 +7,42 @@ namespace codegen{
namespace transform{
ir::value* optimize_trans::replace_phi(ir::value* value,
ir::builder& builder,
const std::vector<ir::constant_int*> &perm){
inline bool is_trans(ir::value *v){
auto *x = dynamic_cast<ir::trans_inst*>(v);
if(!x)
return false;
std::vector<ir::constant_int*> perm = x->get_perm();
std::vector<ir::constant_int*> ref;
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
for(size_t i = 0; i < perm.size(); i++)
ref.push_back(ir::constant_int::get(int32_ty, i));
std::swap(ref[0], ref[1]);
// true is perm == ref
return std::equal(perm.begin(), perm.end(), ref.begin());
}
inline bool is_hmma(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
}
return result;
}
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
const std::vector<ir::constant_int*>& perm) {
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands
std::vector<ir::value*> incs;
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
@@ -31,43 +59,159 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
trans->set_operand(0, i);
return trans;
}
throw std::runtime_error("cannot transpose phi");
}
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
auto trans = dynamic_cast<ir::trans_inst*>(value);
if(!trans)
return false;
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
return false;
ir::value* op = *ops.begin();
auto* phi = dynamic_cast<ir::phi_node*>(op);
if(!phi)
return false;
ir::value* new_phi = rewrite_trans_phi_impl(op, builder, trans->get_perm());
trans->replace_all_uses_with(new_phi);
return true;
}
void optimize_trans::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// iterate
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// transposition
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
continue;
ir::value* op = *ops.begin();
// todo: chains of transpositions
// trans(phi) -> phi(trans(), trans()...)
if(dynamic_cast<ir::phi_node*>(op)){
ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
trans->replace_all_uses_with(new_phi);
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
if(auto dot = dynamic_cast<ir::dot_inst*>(value)){
builder.set_insert_point(value);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
bool trans_a = is_trans(A);
bool trans_b = is_trans(B);
// NN
if(!dot->is_a_trans() && !dot->is_b_trans()){
if(is_hmma(dot)) {
ir::value *AA = A;
ir::value *BB = B;
if(trans_a){
AA = ((ir::trans_inst*)A)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
}
if(trans_b){
BB = ((ir::trans_inst*)B)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
}
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt);
return true;
}
else{
// dot(op(a), trans(b))
if(trans_b){
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
return true;
}
// dot(op(a), b)
if(!trans_b){
// create permutations
size_t size = B->get_type()->get_tile_shapes().size();
std::vector<ir::constant_int*> perm(size);
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
for(size_t i = 0; i < size; i++)
perm[i] = ir::constant_int::get(int32_ty, i);
std::swap(perm[0], perm[1]);
// replace NN -> NT (trans)
ir::value* BB = builder.create_trans(B, perm);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
return true;
}
}
}
// reductions
if(auto x = dynamic_cast<ir::reduce_inst*>(i)) {
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(i->get_type()->get_context()), 1);
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
if(shapes[x->get_axis()] == one){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
x->replace_all_uses_with(new_red);
}
}
}
return false;
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x)
return false;
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(value->get_type()->get_context()), 1);
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
if(shapes[x->get_axis()] == one){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
x->replace_all_uses_with(new_red);
return true;
}
return false;
}
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
if(!x)
return false;
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
if(!y)
return false;
auto idx = *y->idx_begin();
auto z = dynamic_cast<ir::binary_operator*>(idx);
if(!z)
return false;
bool is_sub = z->get_op() == ir::binary_operator::llop::Sub;
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
bool is_lhs_0 = lhs && (lhs->get_value()==0);
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
x->replace_all_uses_with(y->get_pointer_operand());
return true;
}
return false;
}
void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made
bool was_modified = false;
// rewrite dots first
do{
was_modified = false;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list())
rewrite_dot(i, builder);
}while(was_modified);
// rewrite other ops
do{
was_modified = false;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
}
}while(was_modified);
}
}

View File

@@ -363,7 +363,7 @@ void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
runtime::launch_information info) {
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
unsigned TM = info.globals["TM"], TN = info.globals["TN"];
unsigned GZ = 1;
set_arg(kernel, a, b, c, bias);
std::array<size_t, 3> grid = {1};

View File

@@ -106,11 +106,9 @@ void dot::triton_c_src(std::ostream &os) const {
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int TM = {8};
const tunable int TN = {8};
const tunable int TM = {128};
const tunable int TN = {128};
const tunable int TK = {32};
const tunable int GZ = {1};
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
restrict read_only align(16) )" + b_ty_ + R"( *B,
@@ -127,18 +125,14 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
float xc[)" + XCS + R"(] = 0;
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
)" + a_ty_ + R"( a[)" + AS + R"(] = *pa;
)" + b_ty_ + R"( b[)" + BS + R"(] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = dot()" + usea + ", " + useb + R"(, xc);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
bool checka[)" + AS + R"(] = k > TK;
bool checkb[)" + BS + R"(] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
a = *pa;
b = *pb;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);

View File

@@ -48,20 +48,20 @@ void backend::platforms::init() {
if(dispatch::cuinit()){
cache_.push_back(new cu_platform());
}
//if OpenCL is here
if(dispatch::clinit()){
cl_uint num_platforms;
dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
std::vector<cl_platform_id> ids(num_platforms);
dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
for(cl_platform_id id: ids)
cache_.push_back(new cl_platform(id));
}
//if host is here
bool host_visible = true;
if(host_visible){
cache_.push_back(new host_platform());
}
// //if OpenCL is here
// if(dispatch::clinit()){
// cl_uint num_platforms;
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
// std::vector<cl_platform_id> ids(num_platforms);
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
// for(cl_platform_id id: ids)
// cache_.push_back(new cl_platform(id));
// }
// //if host is here
// bool host_visible = true;
// if(host_visible){
// cache_.push_back(new host_platform());
// }
if(cache_.empty())
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
}

View File

@@ -12,7 +12,7 @@ namespace ir{
constant *constant::get_null_value(type *ty) {
context &ctx = ty->get_context();
switch (ty->get_type_id()) {
switch (ty->get_scalar_ty()->get_type_id()) {
case type::IntegerTyID:
return constant_int::get(ty, 0);
case type::HalfTyID:

View File

@@ -147,13 +147,13 @@ binary_operator *binary_operator::create(op_t op, value *lhs, value *rhs, const
}
binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->is_floating_point_ty());
assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
return binary_operator::create(llvm::Instruction::FSub, zero, arg, name, next);
}
binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->is_integer_ty());
assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
return binary_operator::create(llvm::Instruction::Sub, zero, arg, name, next);
}

View File

@@ -73,6 +73,14 @@ const type::tile_shapes_t &type::get_tile_shapes() const {
return ((tile_type*)this)->get_shapes();
}
unsigned type::get_tile_num_elements() const {
const tile_shapes_t& shapes = get_tile_shapes();
unsigned result = 1;
for(ir::constant_int *x: shapes)
result *= x->get_value();
return result;
}
// composite predicates
bool type::is_int_or_tileint_ty()

View File

@@ -57,10 +57,8 @@ unsigned user::get_num_hidden() const {
}
void user::replace_all_uses_with(value *target) {
for(auto it = users_.begin(); it != users_.end();){
for(auto it = users_.begin(); it != users_.end(); it++){
(*it)->replace_uses_of_with(this, target);
target->add_use(*it);
erase_use(*it++);
}
}
@@ -68,6 +66,8 @@ void user::replace_uses_of_with(value *before, value *after) {
for(size_t i = 0; i < ops_.size(); i++)
if(ops_[i] == before)
ops_[i] = after;
after->add_use(this);
erase_use(this);
}
}

View File

@@ -82,10 +82,6 @@ void parallel_for_each(std::vector<std::vector<unsigned>> const & iterates, std:
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
passes.selection.run(module, *result);
// launch information
info.global_range_size.clear();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
// add globals
for(auto x: module.globals())
info.globals[x.first] = ((ir::metaparameter*)x.second)->get_value();