[code generation] reparameterization
This commit is contained in:
@@ -86,7 +86,7 @@ int main() {
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
|
||||
// matrix multiplication parameters
|
||||
size_t M = 128, N = 128, K = 128;
|
||||
size_t M = 512, N = 512, K = 512;
|
||||
size_t bound = 8;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
|
@@ -412,7 +412,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||
assert(expr_ == nullptr);
|
||||
//TODO: implement ranges
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, 128);
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
|
||||
}
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
|
@@ -371,18 +371,37 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
||||
static const size_t warp_size = 32;
|
||||
size_t nthreads = 1, nwarps = 1;
|
||||
nw.resize(bs.size());
|
||||
ws.resize(bs.size());
|
||||
for(size_t i = 0; i < bs.size(); ++i){
|
||||
nthreads *= bs[i];
|
||||
nw[i] = ceil(nthreads, nwarps*warp_size);
|
||||
nwarps *= nw[i];
|
||||
}
|
||||
for(size_t i = 0; i < bs.size(); ++i)
|
||||
ws[i] = bs[i] / nw[i];
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||
warp_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||
n_warps[i] = shapes[i]->get_value() / (contiguous[i] * warp_size[i]);
|
||||
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||
}
|
||||
to_warps(block_size, n_warps, warp_size);
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||
// Create axes
|
||||
|
@@ -144,8 +144,8 @@ void tune::run(ir::module &mod) {
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 8);
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
@@ -157,21 +157,12 @@ void tune::init(ir::module &mod) {
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
// number of warps
|
||||
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
|
||||
std::string strk = std::to_string(axis);
|
||||
unsigned mts = params_[i]["mts.d" + strk]->get_value();
|
||||
unsigned nts = params_[i]["nts.d" + strk]->get_value();
|
||||
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
|
||||
return shape / (mts * nts);
|
||||
};
|
||||
// number of threads
|
||||
num_threads_ = 1;
|
||||
ir::instruction *first = grids_.front();
|
||||
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
|
||||
num_threads_ *= get_num_warps(first, k);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,15 +234,10 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||
if(num_threads != 32)
|
||||
errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32");
|
||||
// The number of warps required by the layout is the same
|
||||
// for all tiles in the function
|
||||
int required_num_warps = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
required_num_warps *= get_num_warps(i, k);
|
||||
if(required_num_warps != num_warps)
|
||||
errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps));
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
}
|
||||
return errors.empty();
|
||||
}
|
||||
|
@@ -142,7 +142,6 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
std::cout << std::endl;
|
||||
benchmark(kernel, info);
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user