[runtime] overall of the run-time API

This commit is contained in:
Philippe Tillet
2019-08-14 15:43:50 -07:00
parent b8cd63e0da
commit 38a8b0ab19
13 changed files with 633 additions and 86 deletions

View File

@@ -14,7 +14,7 @@ namespace triton{
namespace codegen{
namespace analysis{
tune::tune() {
tune::tune(size_t num_warps): num_warps_(num_warps){
}
bool is_hmma(ir::value *v){
@@ -183,20 +183,17 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
}
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
std::vector<ir::metaparameter*> result;
std::set<ir::metaparameter*> seen;
for(auto x: mod.globals()) {
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
if(seen.insert(mp).second && !mp->has_value())
result.push_back(mp);
}
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
result.push_back(num_warps_);
return result;
}
std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i) {
return params_.at(i);
throw std::runtime_error("remove me");
// std::vector<ir::metaparameter*> result;
// std::set<ir::metaparameter*> seen;
// for(auto x: mod.globals()) {
// if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
// if(seen.insert(mp).second && !mp->has_value())
// result.push_back(mp);
// }
// num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
// result.push_back(num_warps_);
// return result;
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
@@ -257,7 +254,6 @@ void tune::init(ir::module &mod) {
}
int num_threads = get_num_threads();
int num_warps = num_warps_->get_value();
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
for(ir::value *i: grids_){
@@ -292,9 +288,9 @@ void tune::init(ir::module &mod) {
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
@@ -307,7 +303,7 @@ void tune::init(ir::module &mod) {
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
assert(num_warps == effective_num_warps);
assert(num_warps_ == effective_num_warps);
}
/* Scan-line */
@@ -386,7 +382,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
}
unsigned tune::get_num_threads() {
return num_warps_->get_value()*32;
return num_warps_*32;
}