[code generation] uniformized shape and layout metaparameters
This commit is contained in:
@@ -65,22 +65,20 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
|
||||
|
||||
int main() {
|
||||
std::vector<unsigned> params = {
|
||||
// shapes
|
||||
16, 16, 8,
|
||||
// a0
|
||||
2, 8, 1,
|
||||
2, 8, 1, 16,
|
||||
// b0
|
||||
4, 4, 1,
|
||||
4, 4, 1, 16,
|
||||
// c
|
||||
2, 4, 8, 4, 1, 1,
|
||||
// a1
|
||||
2, 4, 1,
|
||||
2, 4, 1, 8,
|
||||
// b1
|
||||
1, 8, 1
|
||||
};
|
||||
unsigned TM = params[0];
|
||||
unsigned TN = params[1];
|
||||
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
|
||||
unsigned TM = params[6];
|
||||
unsigned TN = params[10];
|
||||
unsigned nthreads = params[1]*params[2]*params[15]*params[16];
|
||||
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
|
@@ -85,6 +85,12 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
params_[x.first].insert({"p0" + suffix, mps[0]});
|
||||
params_[x.first].insert({"p1" + suffix, mps[1]});
|
||||
params_[x.first].insert({"p2" + suffix, mps[2]});
|
||||
ir::type *ty = x.first->get_type();
|
||||
if(ty->is_tile_ty()){
|
||||
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||
params_[x.first].insert({"shape" + suffix, mp});
|
||||
}
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
mps[0]->set_value(static_params_.at(x));
|
||||
mps[1]->set_value(static_params_.at(x));
|
||||
|
@@ -62,12 +62,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
|
||||
// tuning parameters
|
||||
tune.run(module);
|
||||
unsigned i = 0;
|
||||
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
|
||||
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
|
||||
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
|
||||
for(ir::metaparameter *x: tune.get_params(module)){
|
||||
x->set_value(params[3 + i++]);
|
||||
}
|
||||
for(ir::metaparameter *x: tune.get_params(module))
|
||||
x->set_value(params[i++]);
|
||||
|
||||
// constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
tune.check_constraints(module, errors);
|
||||
|
Reference in New Issue
Block a user