[runtime] overall of the run-time API
This commit is contained in:
@@ -14,7 +14,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tune::tune() {
|
||||
tune::tune(size_t num_warps): num_warps_(num_warps){
|
||||
}
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
@@ -183,20 +183,17 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
}
|
||||
|
||||
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
std::vector<ir::metaparameter*> result;
|
||||
std::set<ir::metaparameter*> seen;
|
||||
for(auto x: mod.globals()) {
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
if(seen.insert(mp).second && !mp->has_value())
|
||||
result.push_back(mp);
|
||||
}
|
||||
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
|
||||
result.push_back(num_warps_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i) {
|
||||
return params_.at(i);
|
||||
throw std::runtime_error("remove me");
|
||||
// std::vector<ir::metaparameter*> result;
|
||||
// std::set<ir::metaparameter*> seen;
|
||||
// for(auto x: mod.globals()) {
|
||||
// if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
// if(seen.insert(mp).second && !mp->has_value())
|
||||
// result.push_back(mp);
|
||||
// }
|
||||
// num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
|
||||
// result.push_back(num_warps_);
|
||||
// return result;
|
||||
}
|
||||
|
||||
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
@@ -257,7 +254,6 @@ void tune::init(ir::module &mod) {
|
||||
}
|
||||
|
||||
int num_threads = get_num_threads();
|
||||
int num_warps = num_warps_->get_value();
|
||||
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
|
||||
|
||||
for(ir::value *i: grids_){
|
||||
@@ -292,9 +288,9 @@ void tune::init(ir::module &mod) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
@@ -307,7 +303,7 @@ void tune::init(ir::module &mod) {
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
|
||||
}
|
||||
assert(num_warps == effective_num_warps);
|
||||
assert(num_warps_ == effective_num_warps);
|
||||
}
|
||||
|
||||
/* Scan-line */
|
||||
@@ -386,7 +382,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
}
|
||||
|
||||
unsigned tune::get_num_threads() {
|
||||
return num_warps_->get_value()*32;
|
||||
return num_warps_*32;
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user