[auto-tuning] much smaller parameters space
This commit is contained in:
@@ -111,7 +111,7 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 64, 64, 128}
|
||||
{false, true, 128, 128, 128}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
|
@@ -13,6 +13,7 @@ namespace ir{
|
||||
class instruction;
|
||||
class function;
|
||||
class metaparameter;
|
||||
class constant_int;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -34,7 +35,9 @@ private:
|
||||
void init_c_graph(ir::instruction *v);
|
||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
|
||||
void create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
ir::function *fn);
|
||||
unsigned get_req_num_threads(ir::instruction *i);
|
||||
|
||||
|
||||
@@ -49,8 +52,6 @@ public:
|
||||
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
void init(ir::module &mod);
|
||||
unsigned get_num_global_range();
|
||||
unsigned get_global_range_size(unsigned axis);
|
||||
unsigned get_num_threads();
|
||||
|
||||
private:
|
||||
@@ -61,10 +62,9 @@ private:
|
||||
std::map<node_t, unsigned> static_params_;
|
||||
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
|
||||
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
|
||||
unsigned num_global_ranges_;
|
||||
unsigned num_threads_;
|
||||
std::vector<ir::instruction*> grids_;
|
||||
std::vector<ir::value*> grids_;
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
ir::metaparameter* num_warps_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -14,9 +14,9 @@ namespace ir {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class optimize_dce {
|
||||
class dce {
|
||||
public:
|
||||
optimize_dce() {}
|
||||
dce() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
|
@@ -28,7 +28,7 @@ namespace transform{
|
||||
|
||||
class reassociate {
|
||||
struct cst_info {
|
||||
ir::getelementptr_inst* dyn_ptr;
|
||||
ir::value* dyn_ptr;
|
||||
ir::getelementptr_inst* sta_ptr;
|
||||
};
|
||||
|
||||
@@ -38,12 +38,11 @@ private:
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
reassociate(analysis::tune *params, analysis::alignment_info *align);
|
||||
reassociate(analysis::tune *params);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
analysis::tune* params_;
|
||||
analysis::alignment_info* align_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -19,12 +19,17 @@ namespace ir {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class optimize_trans {
|
||||
class peephole {
|
||||
private:
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
private:
|
||||
ir::value *replace_phi(ir::value* value, ir::builder &builder, const std::vector<ir::constant_int *> &perm);
|
||||
|
||||
public:
|
||||
optimize_trans() {}
|
||||
peephole() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
|
@@ -66,29 +66,30 @@ public:
|
||||
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
|
||||
optimize_dot(&tune),
|
||||
dce(),
|
||||
optimize_trans(),
|
||||
peephole(),
|
||||
alignment_info(),
|
||||
reassociate(&tune, &alignment_info),
|
||||
reassociate(&tune),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
ir::print(module, std::cout);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
peephole.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
alignment_info.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::selection selection;
|
||||
@@ -100,8 +101,8 @@ public:
|
||||
codegen::transform::shmem_barriers shmem_barriers;
|
||||
codegen::transform::vectorize vectorize;
|
||||
codegen::transform::optimize_dot optimize_dot;
|
||||
codegen::transform::optimize_dce dce;
|
||||
codegen::transform::optimize_trans optimize_trans;
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
@@ -8,7 +8,6 @@ namespace triton{
|
||||
namespace runtime{
|
||||
|
||||
struct launch_information{
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
std::map<std::string, unsigned> globals;
|
||||
};
|
||||
|
@@ -14,7 +14,8 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tune::tune(): num_global_ranges_(0){ }
|
||||
tune::tune() {
|
||||
}
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
@@ -123,8 +124,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
if(shapes[i] == one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
// add_constraint({v, i}, {A, i});
|
||||
// add_constraint({v, i}, {B, i});
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
@@ -172,11 +173,6 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||
params_[x.first].insert({"shape" + suffix, mp});
|
||||
}
|
||||
// if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
|
||||
// unsigned ax = range->get_axis();
|
||||
// global_range_sizes_[ax] = params_[x.first].at("shape.d0");
|
||||
// num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
// }
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(static_params_.at(x));
|
||||
@@ -189,21 +185,13 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
std::vector<ir::metaparameter*> result;
|
||||
std::set<ir::metaparameter*> seen;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && !x.second->has_value()){
|
||||
result.push_back(x.second);
|
||||
}
|
||||
|
||||
for(auto x: mod.globals()){
|
||||
for(auto x: mod.globals()) {
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
if(seen.insert(mp).second && !mp->has_value())
|
||||
result.push_back(mp);
|
||||
}
|
||||
|
||||
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
|
||||
result.push_back(num_warps_);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -212,7 +200,6 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
|
||||
}
|
||||
|
||||
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
// std::cout << "group? " << value->get_name() << " " << ax << std::endl;
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
@@ -229,139 +216,164 @@ void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
// Create metaparameters
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op()){
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_graph(i);
|
||||
}
|
||||
|
||||
// Build phi constraints
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_phi(i);
|
||||
|
||||
// Layout parameters
|
||||
unsigned group_id = 0;
|
||||
for(auto x: nodes_){
|
||||
for(auto x: nodes_)
|
||||
fragments_[x] = get_fragmentation_type(x, dependencies_);
|
||||
}
|
||||
while(!nodes_.empty()) {
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 8);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
if(node.second == 2)
|
||||
fpw->set_value(1);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
// std::unique_ptr<ir::metaparameter> mts_2(ir::metaparameter::create(ctx, ty, 1, 4));
|
||||
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
|
||||
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
// for(size_t k = 2; k < shapes.size(); k++)
|
||||
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
|
||||
// *params_.at(i).at("mts.d" + std::to_string(k)) = *x;
|
||||
// else
|
||||
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(shapes[k]->get_value());
|
||||
}
|
||||
else{
|
||||
// for(size_t k = 2; k < shapes.size(); k++)
|
||||
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
|
||||
// *params_.at(i).at("wpt.d" + std::to_string(k)) = *x;
|
||||
// else
|
||||
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(shapes[k]->get_value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
std::map<unsigned, ir::value*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
num_threads_ = get_req_num_threads(grids_.front());
|
||||
}
|
||||
int num_threads = get_num_threads();
|
||||
int num_warps = num_warps_->get_value();
|
||||
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
|
||||
|
||||
unsigned tune::get_req_num_threads(ir::instruction *i){
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
|
||||
unsigned result = 1;
|
||||
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
result *= params_.at(i).at("mts" + suffix)->get_value();
|
||||
for(ir::value *i: grids_){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
int shape_0 = shapes[0]->get_value();
|
||||
int shape_1 = shapes[1]->get_value();
|
||||
int size = i->get_type()->get_tile_num_elements();
|
||||
/* HMMA parameters*/
|
||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<int> fpw = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
int num_fragments = std::min((shape_0/8)*(shape_1/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw;
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(int d = 0; d < shapes.size(); d++)
|
||||
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<int> wpt = {1, 1, 1};
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(int d = 0; d < shapes.size(); d++)
|
||||
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
|
||||
}
|
||||
assert(num_warps == effective_num_warps);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
else {
|
||||
unsigned result = 32;
|
||||
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
result *= params_.at(i).at("wpt" + suffix)->get_value();
|
||||
|
||||
/* Scan-line */
|
||||
else{
|
||||
int shape = shapes[0]->get_value();
|
||||
int current = num_threads;
|
||||
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
|
||||
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
|
||||
current = current / params_.at(i).at("mts.d0")->get_value();
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
shape = shapes[d]->get_value();
|
||||
params_.at(i).at("nts.d" + str_d)->set_value(1);
|
||||
params_.at(i).at("mts.d" + str_d)->set_value(clamp(current, 1, shape));
|
||||
current = current / params_.at(i).at("mts.d" + str_d)->get_value();
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
|
||||
}
|
||||
assert(num_threads == effective_num_threads);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<ir::metaparameter*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
|
||||
void tune::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
|
||||
for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape != one);
|
||||
for(ir::constant_int* shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape->get_value() > 1)?shape->get_value():0;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
// bind references
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
for(auto ¶m: params_.at(i)){
|
||||
if(param.second->get_value() == 1)
|
||||
std::set<ir::value*> seen;
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
// recurse
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
continue;
|
||||
ir::instruction *&r = references[param.second];
|
||||
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
|
||||
r = i;
|
||||
unsigned x = get_param_group(v, d);
|
||||
ir::value *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
bind_references(i);
|
||||
|
||||
// create grid
|
||||
for(auto &ref: references)
|
||||
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
|
||||
@@ -370,85 +382,11 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
|
||||
|
||||
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
|
||||
using std::to_string;
|
||||
|
||||
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
|
||||
std::string strk = to_string(axis);
|
||||
if(fragments_.at({i, axis}) == STRIDED_SCAN){
|
||||
unsigned mts = params_[i]["mts.d" + strk]->get_value();
|
||||
unsigned nts = params_[i]["nts.d" + strk]->get_value();
|
||||
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
|
||||
return shape / (mts * nts);
|
||||
}
|
||||
else{
|
||||
return (unsigned)params_[i]["wpt.d" + strk]->get_value();
|
||||
}
|
||||
};
|
||||
|
||||
// number of warps
|
||||
ir::instruction *first = grids_.front();
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
|
||||
num_warps *= get_num_warps(first, k);
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids_){
|
||||
// std::cout << i->get_name() << std::endl;
|
||||
ir::type *ty = i->get_type();
|
||||
const auto &shapes = ty->get_tile_shapes();
|
||||
// for each dimension, the product of layout components
|
||||
// must device the shape
|
||||
for(size_t k = 0; k < shapes.size(); k++) {
|
||||
std::string strk = to_string(k);
|
||||
unsigned multiple;
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
|
||||
ir::metaparameter *mts = params_[i]["mts.d" + strk];
|
||||
ir::metaparameter *nts = params_[i]["nts.d" + strk];
|
||||
multiple = mts->get_value()*nts->get_value();
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
|
||||
ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
|
||||
multiple = fpw->get_value()*wpt->get_value();
|
||||
if(k < 2)
|
||||
multiple *= 8;
|
||||
}
|
||||
if(shapes[k]->get_value() % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||
" is not a multiple of layout (" + to_string(multiple) + ")");
|
||||
}
|
||||
// the product of mma fragments per warp must be 4
|
||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
unsigned prod = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++){
|
||||
prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
|
||||
}
|
||||
if(prod > 4)
|
||||
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||
}
|
||||
int num_threads = get_req_num_threads(i);
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
}
|
||||
// for(auto x: errors)
|
||||
// for(auto e: x.second)
|
||||
// std::cout << x.first->get_name() << ": " << e << std::endl;
|
||||
// exit(EXIT_SUCCESS);
|
||||
return errors.empty();
|
||||
}
|
||||
|
||||
unsigned tune::get_num_global_range() {
|
||||
return num_global_ranges_;
|
||||
}
|
||||
|
||||
unsigned tune::get_global_range_size(unsigned axis) {
|
||||
return global_range_sizes_.at(axis)->get_value();
|
||||
}
|
||||
|
||||
unsigned tune::get_num_threads() {
|
||||
return num_threads_;
|
||||
return num_warps_->get_value()*32;
|
||||
}
|
||||
|
||||
|
||||
|
@@ -243,7 +243,7 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
|
||||
/* convert ir::constant to Constant */
|
||||
Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
|
||||
Type *dst_ty = llvm_type(cst->get_type(), ctx);
|
||||
Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx);
|
||||
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
|
||||
return ConstantInt::get(dst_ty, cc->get_value());
|
||||
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
|
||||
@@ -478,8 +478,9 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
|
||||
nw[i] = ceil(nthreads, nwarps*warp_size);
|
||||
nwarps *= nw[i];
|
||||
}
|
||||
for(size_t i = 0; i < bs.size(); ++i)
|
||||
for(size_t i = 0; i < bs.size(); ++i){
|
||||
ws[i] = bs[i] / nw[i];
|
||||
}
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
@@ -565,7 +566,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateURem(builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)), builder.getInt32(fpw_1));
|
||||
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
||||
@@ -1296,7 +1298,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
else {
|
||||
result->for_each([&](indices_t idx){
|
||||
auto value = [&](ir::value *x) {
|
||||
if(x->get_type()->is_tile_ty())
|
||||
if(auto *cst = dynamic_cast<ir::constant_int*>(x))
|
||||
return (Value*)llvm_constant(cst, ctx);
|
||||
else if(x->get_type()->is_tile_ty())
|
||||
return tmap_.at(x)->get_value(idx);
|
||||
else
|
||||
return llvm_value(x, builder);
|
||||
|
@@ -9,7 +9,7 @@ namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void optimize_dce::run(ir::module &mod) {
|
||||
void dce::run(ir::module &mod) {
|
||||
std::list<ir::instruction*> work_list;
|
||||
std::set<ir::instruction*> marked;
|
||||
|
||||
|
@@ -155,8 +155,8 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate(analysis::tune* params, analysis::alignment_info* align)
|
||||
: params_(params), align_(align)
|
||||
reassociate::reassociate(analysis::tune* params)
|
||||
: params_(params)
|
||||
{ }
|
||||
|
||||
|
||||
@@ -190,93 +190,108 @@ void reassociate::run(ir::module &mod) {
|
||||
|
||||
// reassociate
|
||||
std::map<ir::value*, cst_info> infos;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
params_->copy(dyn_ptr, pz);
|
||||
params_->copy(sta_ptr, pz);
|
||||
align_->copy(sta_ptr, pz);
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(ir::getelementptr_inst* gepy = dynamic_cast<ir::getelementptr_inst*>(py))
|
||||
if(infos.find(gepy) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[gepy].sta_ptr;
|
||||
ir::getelementptr_inst *dyn = infos[gepy].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *new_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *new_pz = builder.create_gep(new_dyn, {cst}, pz->get_name());
|
||||
params_->copy(new_dyn, pz);
|
||||
params_->copy(new_pz, pz);
|
||||
align_->copy(new_pz, pz);
|
||||
pz->replace_all_uses_with(new_pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
std::set<ir::value*> replaced;
|
||||
size_t n_replaced;
|
||||
do{
|
||||
n_replaced = replaced.size();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
if(replaced.find(pz) != replaced.end())
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it = infos.find(vpa);
|
||||
if(it == infos.end())
|
||||
continue;
|
||||
ir::getelementptr_inst *pa = (ir::getelementptr_inst*)vpa;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::getelementptr_inst *dyn_ptr = it->second.dyn_ptr;
|
||||
ir::getelementptr_inst *sta_ptr = it->second.sta_ptr;
|
||||
// we take static offset out of the phi function
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *new_phi = builder.create_phi(phi->get_type(), 2);
|
||||
// new pz for phi has the same offsets
|
||||
builder.set_insert_point(pz);
|
||||
std::vector<ir::value*> idxs(pz->idx_begin(), pz->idx_end());
|
||||
ir::value *new_phi_pz = builder.create_gep(new_phi, idxs);
|
||||
// fold the static offset into the new pz value
|
||||
ir::value *new_pz = builder.create_gep(new_phi_pz, {*sta_ptr->idx_begin()});
|
||||
// populate incoming values
|
||||
new_phi->add_incoming(dyn_ptr, phi->get_incoming_block(idx_a));
|
||||
new_phi->add_incoming(new_phi_pz, phi->get_incoming_block(idx_z));
|
||||
// replace phi uses
|
||||
phi->replace_all_uses_with(new_phi);
|
||||
// replace pz uses
|
||||
pz->replace_all_uses_with(new_pz);
|
||||
// copy params
|
||||
params_->copy(new_phi_pz, pz);
|
||||
params_->copy(new_phi, phi);
|
||||
params_->copy(new_pz, pz);
|
||||
align_->copy(new_pz, pz);
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
params_->copy(dyn_ptr, pz);
|
||||
params_->copy(sta_ptr, pz);
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(infos.find(py) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[py].sta_ptr;
|
||||
ir::value *dyn = infos[py].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
params_->copy(pz_dyn, pz);
|
||||
params_->copy(pz_sta, pz);
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it_a = infos.find(vpa);
|
||||
if(it_a == infos.end())
|
||||
continue;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::value *pa_dyn = it_a->second.dyn_ptr;
|
||||
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
|
||||
ir::value *pz = phi->get_incoming_value(idx_z);
|
||||
// extract offset
|
||||
ir::value *off = *pa_sta->idx_begin();
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
auto insts = x->get_parent()->get_inst_list();
|
||||
auto it = std::find(insts.begin(), insts.end(), x);
|
||||
it++;
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *neg_off = builder.create_neg(off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
// copy parameters
|
||||
params_->copy(pz_dyn, pz);
|
||||
params_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
|
||||
params_->copy(neg_off, off);
|
||||
params_->copy(phi_dyn, phi);
|
||||
params_->copy(phi_sta, phi);
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}while(replaced.size() != n_replaced);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -7,14 +7,42 @@ namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
ir::builder& builder,
|
||||
const std::vector<ir::constant_int*> &perm){
|
||||
inline bool is_trans(ir::value *v){
|
||||
auto *x = dynamic_cast<ir::trans_inst*>(v);
|
||||
if(!x)
|
||||
return false;
|
||||
std::vector<ir::constant_int*> perm = x->get_perm();
|
||||
std::vector<ir::constant_int*> ref;
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
|
||||
for(size_t i = 0; i < perm.size(); i++)
|
||||
ref.push_back(ir::constant_int::get(int32_ty, i));
|
||||
std::swap(ref[0], ref[1]);
|
||||
// true is perm == ref
|
||||
return std::equal(perm.begin(), perm.end(), ref.begin());
|
||||
}
|
||||
|
||||
inline bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
// inputs have to be FP16
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||
// reduction has to be multiple of 4
|
||||
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
||||
const std::vector<ir::constant_int*>& perm) {
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
|
||||
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
|
||||
@@ -31,43 +59,159 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
throw std::runtime_error("cannot transpose phi");
|
||||
}
|
||||
|
||||
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
auto trans = dynamic_cast<ir::trans_inst*>(value);
|
||||
if(!trans)
|
||||
return false;
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
return false;
|
||||
ir::value* op = *ops.begin();
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
||||
if(!phi)
|
||||
return false;
|
||||
ir::value* new_phi = rewrite_trans_phi_impl(op, builder, trans->get_perm());
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
return true;
|
||||
}
|
||||
|
||||
void optimize_trans::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// transposition
|
||||
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
continue;
|
||||
ir::value* op = *ops.begin();
|
||||
// todo: chains of transpositions
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
if(dynamic_cast<ir::phi_node*>(op)){
|
||||
ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(value)){
|
||||
builder.set_insert_point(value);
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
bool trans_a = is_trans(A);
|
||||
bool trans_b = is_trans(B);
|
||||
// NN
|
||||
if(!dot->is_a_trans() && !dot->is_b_trans()){
|
||||
if(is_hmma(dot)) {
|
||||
ir::value *AA = A;
|
||||
ir::value *BB = B;
|
||||
if(trans_a){
|
||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
}
|
||||
}
|
||||
if(trans_b){
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
}
|
||||
}
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
return true;
|
||||
}
|
||||
else{
|
||||
// dot(op(a), trans(b))
|
||||
if(trans_b){
|
||||
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
return true;
|
||||
}
|
||||
// dot(op(a), b)
|
||||
if(!trans_b){
|
||||
// create permutations
|
||||
size_t size = B->get_type()->get_tile_shapes().size();
|
||||
std::vector<ir::constant_int*> perm(size);
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
||||
for(size_t i = 0; i < size; i++)
|
||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
||||
std::swap(perm[0], perm[1]);
|
||||
// replace NN -> NT (trans)
|
||||
ir::value* BB = builder.create_trans(B, perm);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// reductions
|
||||
if(auto x = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(i->get_type()->get_context()), 1);
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
if(shapes[x->get_axis()] == one){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(value->get_type()->get_context()), 1);
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
if(shapes[x->get_axis()] == one){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
|
||||
if(!y)
|
||||
return false;
|
||||
auto idx = *y->idx_begin();
|
||||
auto z = dynamic_cast<ir::binary_operator*>(idx);
|
||||
if(!z)
|
||||
return false;
|
||||
bool is_sub = z->get_op() == ir::binary_operator::llop::Sub;
|
||||
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
|
||||
bool is_lhs_0 = lhs && (lhs->get_value()==0);
|
||||
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
|
||||
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
|
||||
x->replace_all_uses_with(y->get_pointer_operand());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void peephole::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// keep track of whether any modification was made
|
||||
bool was_modified = false;
|
||||
|
||||
// rewrite dots first
|
||||
do{
|
||||
was_modified = false;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list())
|
||||
rewrite_dot(i, builder);
|
||||
}while(was_modified);
|
||||
|
||||
// rewrite other ops
|
||||
do{
|
||||
was_modified = false;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
}
|
||||
}while(was_modified);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -363,7 +363,7 @@ void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
|
||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
||||
unsigned TM = info.globals["TM"], TN = info.globals["TN"];
|
||||
unsigned GZ = 1;
|
||||
set_arg(kernel, a, b, c, bias);
|
||||
std::array<size_t, 3> grid = {1};
|
||||
|
@@ -106,11 +106,9 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int TM = {8};
|
||||
const tunable int TN = {8};
|
||||
const tunable int TM = {128};
|
||||
const tunable int TN = {128};
|
||||
const tunable int TK = {32};
|
||||
const tunable int GZ = {1};
|
||||
|
||||
|
||||
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
@@ -127,18 +125,14 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
bool checka[)" + AS + R"(] = k > TK;
|
||||
bool checkb[)" + BS + R"(] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
|
@@ -48,20 +48,20 @@ void backend::platforms::init() {
|
||||
if(dispatch::cuinit()){
|
||||
cache_.push_back(new cu_platform());
|
||||
}
|
||||
//if OpenCL is here
|
||||
if(dispatch::clinit()){
|
||||
cl_uint num_platforms;
|
||||
dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||
std::vector<cl_platform_id> ids(num_platforms);
|
||||
dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
|
||||
for(cl_platform_id id: ids)
|
||||
cache_.push_back(new cl_platform(id));
|
||||
}
|
||||
//if host is here
|
||||
bool host_visible = true;
|
||||
if(host_visible){
|
||||
cache_.push_back(new host_platform());
|
||||
}
|
||||
// //if OpenCL is here
|
||||
// if(dispatch::clinit()){
|
||||
// cl_uint num_platforms;
|
||||
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||
// std::vector<cl_platform_id> ids(num_platforms);
|
||||
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
|
||||
// for(cl_platform_id id: ids)
|
||||
// cache_.push_back(new cl_platform(id));
|
||||
// }
|
||||
// //if host is here
|
||||
// bool host_visible = true;
|
||||
// if(host_visible){
|
||||
// cache_.push_back(new host_platform());
|
||||
// }
|
||||
if(cache_.empty())
|
||||
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ namespace ir{
|
||||
|
||||
constant *constant::get_null_value(type *ty) {
|
||||
context &ctx = ty->get_context();
|
||||
switch (ty->get_type_id()) {
|
||||
switch (ty->get_scalar_ty()->get_type_id()) {
|
||||
case type::IntegerTyID:
|
||||
return constant_int::get(ty, 0);
|
||||
case type::HalfTyID:
|
||||
|
@@ -147,13 +147,13 @@ binary_operator *binary_operator::create(op_t op, value *lhs, value *rhs, const
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->is_floating_point_ty());
|
||||
assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
||||
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
return binary_operator::create(llvm::Instruction::FSub, zero, arg, name, next);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->is_integer_ty());
|
||||
assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
||||
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
return binary_operator::create(llvm::Instruction::Sub, zero, arg, name, next);
|
||||
}
|
||||
|
@@ -73,6 +73,14 @@ const type::tile_shapes_t &type::get_tile_shapes() const {
|
||||
return ((tile_type*)this)->get_shapes();
|
||||
}
|
||||
|
||||
unsigned type::get_tile_num_elements() const {
|
||||
const tile_shapes_t& shapes = get_tile_shapes();
|
||||
unsigned result = 1;
|
||||
for(ir::constant_int *x: shapes)
|
||||
result *= x->get_value();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// composite predicates
|
||||
bool type::is_int_or_tileint_ty()
|
||||
|
@@ -57,10 +57,8 @@ unsigned user::get_num_hidden() const {
|
||||
}
|
||||
|
||||
void user::replace_all_uses_with(value *target) {
|
||||
for(auto it = users_.begin(); it != users_.end();){
|
||||
for(auto it = users_.begin(); it != users_.end(); it++){
|
||||
(*it)->replace_uses_of_with(this, target);
|
||||
target->add_use(*it);
|
||||
erase_use(*it++);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +66,8 @@ void user::replace_uses_of_with(value *before, value *after) {
|
||||
for(size_t i = 0; i < ops_.size(); i++)
|
||||
if(ops_[i] == before)
|
||||
ops_[i] = after;
|
||||
after->add_use(this);
|
||||
erase_use(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -82,10 +82,6 @@ void parallel_for_each(std::vector<std::vector<unsigned>> const & iterates, std:
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
info.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
// add globals
|
||||
for(auto x: module.globals())
|
||||
info.globals[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
|
Reference in New Issue
Block a user