[code generation] uniformized shape and layout metaparameters

This commit is contained in:
Philippe Tillet
2019-03-09 12:31:21 -05:00
parent 5f29263044
commit b721202812
3 changed files with 15 additions and 14 deletions

View File

@@ -65,22 +65,20 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
int main() {
std::vector<unsigned> params = {
// shapes
16, 16, 8,
// a0
2, 8, 1,
2, 8, 1, 16,
// b0
4, 4, 1,
4, 4, 1, 16,
// c
2, 4, 8, 4, 1, 1,
// a1
2, 4, 1,
2, 4, 1, 8,
// b1
1, 8, 1
};
unsigned TM = params[0];
unsigned TN = params[1];
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
unsigned TM = params[6];
unsigned TN = params[10];
unsigned nthreads = params[1]*params[2]*params[15]*params[16];
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);

View File

@@ -85,6 +85,12 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
params_[x.first].insert({"p0" + suffix, mps[0]});
params_[x.first].insert({"p1" + suffix, mps[1]});
params_[x.first].insert({"p2" + suffix, mps[2]});
ir::type *ty = x.first->get_type();
if(ty->is_tile_ty()){
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
if(static_params_.find(x) != static_params_.end()){
mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x));

View File

@@ -62,12 +62,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
// tuning parameters
tune.run(module);
unsigned i = 0;
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
for(ir::metaparameter *x: tune.get_params(module)){
x->set_value(params[3 + i++]);
}
for(ir::metaparameter *x: tune.get_params(module))
x->set_value(params[i++]);
// constraints
std::map<ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);