[code generation] uniformized shape and layout metaparameters

This commit is contained in:
Philippe Tillet
2019-03-09 12:31:21 -05:00
parent 5f29263044
commit b721202812
3 changed files with 15 additions and 14 deletions

View File

@@ -65,22 +65,20 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
int main() { int main() {
std::vector<unsigned> params = { std::vector<unsigned> params = {
// shapes
16, 16, 8,
// a0 // a0
2, 8, 1, 2, 8, 1, 16,
// b0 // b0
4, 4, 1, 4, 4, 1, 16,
// c // c
2, 4, 8, 4, 1, 1, 2, 4, 8, 4, 1, 1,
// a1 // a1
2, 4, 1, 2, 4, 1, 8,
// b1 // b1
1, 8, 1 1, 8, 1
}; };
unsigned TM = params[0]; unsigned TM = params[6];
unsigned TN = params[1]; unsigned TN = params[10];
unsigned nthreads = params[10]*params[13]*params[11]*params[14]; unsigned nthreads = params[1]*params[2]*params[15]*params[16];
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context); triton::jit jit(context);

View File

@@ -85,6 +85,12 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
params_[x.first].insert({"p0" + suffix, mps[0]}); params_[x.first].insert({"p0" + suffix, mps[0]});
params_[x.first].insert({"p1" + suffix, mps[1]}); params_[x.first].insert({"p1" + suffix, mps[1]});
params_[x.first].insert({"p2" + suffix, mps[2]}); params_[x.first].insert({"p2" + suffix, mps[2]});
ir::type *ty = x.first->get_type();
if(ty->is_tile_ty()){
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
if(static_params_.find(x) != static_params_.end()){ if(static_params_.find(x) != static_params_.end()){
mps[0]->set_value(static_params_.at(x)); mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x)); mps[1]->set_value(static_params_.at(x));

View File

@@ -62,12 +62,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
// tuning parameters // tuning parameters
tune.run(module); tune.run(module);
unsigned i = 0; unsigned i = 0;
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]); for(ir::metaparameter *x: tune.get_params(module))
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]); x->set_value(params[i++]);
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
for(ir::metaparameter *x: tune.get_params(module)){
x->set_value(params[3 + i++]);
}
// constraints // constraints
std::map<ir::value*, std::vector<std::string>> errors; std::map<ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors); tune.check_constraints(module, errors);