[code generation] uniformized shape and layout metaparameters
This commit is contained in:
@@ -65,22 +65,20 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
|
|||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
std::vector<unsigned> params = {
|
std::vector<unsigned> params = {
|
||||||
// shapes
|
|
||||||
16, 16, 8,
|
|
||||||
// a0
|
// a0
|
||||||
2, 8, 1,
|
2, 8, 1, 16,
|
||||||
// b0
|
// b0
|
||||||
4, 4, 1,
|
4, 4, 1, 16,
|
||||||
// c
|
// c
|
||||||
2, 4, 8, 4, 1, 1,
|
2, 4, 8, 4, 1, 1,
|
||||||
// a1
|
// a1
|
||||||
2, 4, 1,
|
2, 4, 1, 8,
|
||||||
// b1
|
// b1
|
||||||
1, 8, 1
|
1, 8, 1
|
||||||
};
|
};
|
||||||
unsigned TM = params[0];
|
unsigned TM = params[6];
|
||||||
unsigned TN = params[1];
|
unsigned TN = params[10];
|
||||||
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
|
unsigned nthreads = params[1]*params[2]*params[15]*params[16];
|
||||||
|
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
|
@@ -85,6 +85,12 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
|||||||
params_[x.first].insert({"p0" + suffix, mps[0]});
|
params_[x.first].insert({"p0" + suffix, mps[0]});
|
||||||
params_[x.first].insert({"p1" + suffix, mps[1]});
|
params_[x.first].insert({"p1" + suffix, mps[1]});
|
||||||
params_[x.first].insert({"p2" + suffix, mps[2]});
|
params_[x.first].insert({"p2" + suffix, mps[2]});
|
||||||
|
ir::type *ty = x.first->get_type();
|
||||||
|
if(ty->is_tile_ty()){
|
||||||
|
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
|
||||||
|
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||||
|
params_[x.first].insert({"shape" + suffix, mp});
|
||||||
|
}
|
||||||
if(static_params_.find(x) != static_params_.end()){
|
if(static_params_.find(x) != static_params_.end()){
|
||||||
mps[0]->set_value(static_params_.at(x));
|
mps[0]->set_value(static_params_.at(x));
|
||||||
mps[1]->set_value(static_params_.at(x));
|
mps[1]->set_value(static_params_.at(x));
|
||||||
|
@@ -62,12 +62,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
|
|||||||
// tuning parameters
|
// tuning parameters
|
||||||
tune.run(module);
|
tune.run(module);
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
|
for(ir::metaparameter *x: tune.get_params(module))
|
||||||
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
|
x->set_value(params[i++]);
|
||||||
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
|
|
||||||
for(ir::metaparameter *x: tune.get_params(module)){
|
|
||||||
x->set_value(params[3 + i++]);
|
|
||||||
}
|
|
||||||
// constraints
|
// constraints
|
||||||
std::map<ir::value*, std::vector<std::string>> errors;
|
std::map<ir::value*, std::vector<std::string>> errors;
|
||||||
tune.check_constraints(module, errors);
|
tune.check_constraints(module, errors);
|
||||||
|
Reference in New Issue
Block a user