[code generation] reparameterization
This commit is contained in:
@@ -86,7 +86,7 @@ int main() {
|
|||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
size_t M = 128, N = 128, K = 128;
|
size_t M = 512, N = 512, K = 512;
|
||||||
size_t bound = 8;
|
size_t bound = 8;
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
|
@@ -412,7 +412,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||||
assert(expr_ == nullptr);
|
assert(expr_ == nullptr);
|
||||||
//TODO: implement ranges
|
//TODO: implement ranges
|
||||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, 128);
|
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
|
||||||
}
|
}
|
||||||
if(expr_){
|
if(expr_){
|
||||||
value = expr_->codegen(mod);
|
value = expr_->codegen(mod);
|
||||||
|
@@ -371,18 +371,37 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int32_t ceil(int32_t num, int32_t div){
|
||||||
|
return (num + div - 1)/div;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
||||||
|
static const size_t warp_size = 32;
|
||||||
|
size_t nthreads = 1, nwarps = 1;
|
||||||
|
nw.resize(bs.size());
|
||||||
|
ws.resize(bs.size());
|
||||||
|
for(size_t i = 0; i < bs.size(); ++i){
|
||||||
|
nthreads *= bs[i];
|
||||||
|
nw[i] = ceil(nthreads, nwarps*warp_size);
|
||||||
|
nwarps *= nw[i];
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < bs.size(); ++i)
|
||||||
|
ws[i] = bs[i] / nw[i];
|
||||||
|
}
|
||||||
|
|
||||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||||
size_t dim = shapes.size();
|
size_t dim = shapes.size();
|
||||||
std::vector<unsigned> contiguous(dim);
|
std::vector<unsigned> contiguous(dim);
|
||||||
|
std::vector<unsigned> block_size(dim);
|
||||||
std::vector<unsigned> warp_size(dim);
|
std::vector<unsigned> warp_size(dim);
|
||||||
std::vector<unsigned> n_warps(dim);
|
std::vector<unsigned> n_warps(dim);
|
||||||
for(unsigned i = 0; i < shapes.size(); i++){
|
for(unsigned i = 0; i < shapes.size(); i++){
|
||||||
std::string str_i = std::to_string(i);
|
std::string str_i = std::to_string(i);
|
||||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||||
warp_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||||
n_warps[i] = shapes[i]->get_value() / (contiguous[i] * warp_size[i]);
|
|
||||||
}
|
}
|
||||||
|
to_warps(block_size, n_warps, warp_size);
|
||||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||||
// Create axes
|
// Create axes
|
||||||
|
@@ -144,8 +144,8 @@ void tune::run(ir::module &mod) {
|
|||||||
// Layout parameters
|
// Layout parameters
|
||||||
while(!nodes_.empty()){
|
while(!nodes_.empty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 8);
|
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -157,21 +157,12 @@ void tune::init(ir::module &mod) {
|
|||||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||||
create_grids(grids_, references, fn);
|
create_grids(grids_, references, fn);
|
||||||
}
|
}
|
||||||
// number of warps
|
|
||||||
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
|
|
||||||
std::string strk = std::to_string(axis);
|
|
||||||
unsigned mts = params_[i]["mts.d" + strk]->get_value();
|
|
||||||
unsigned nts = params_[i]["nts.d" + strk]->get_value();
|
|
||||||
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
|
|
||||||
return shape / (mts * nts);
|
|
||||||
};
|
|
||||||
// number of threads
|
// number of threads
|
||||||
num_threads_ = 1;
|
num_threads_ = 1;
|
||||||
ir::instruction *first = grids_.front();
|
ir::instruction *first = grids_.front();
|
||||||
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
|
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
|
||||||
std::string suffix = ".d" + std::to_string(k);
|
std::string suffix = ".d" + std::to_string(k);
|
||||||
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
|
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
|
||||||
num_threads_ *= get_num_warps(first, k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,15 +234,10 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
|||||||
int num_threads = 1;
|
int num_threads = 1;
|
||||||
for(size_t k = 0; k < shapes.size(); k++)
|
for(size_t k = 0; k < shapes.size(); k++)
|
||||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||||
if(num_threads != 32)
|
if(num_threads % 32 != 0)
|
||||||
errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32");
|
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
||||||
// The number of warps required by the layout is the same
|
if(num_threads != num_threads_)
|
||||||
// for all tiles in the function
|
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||||
int required_num_warps = 1;
|
|
||||||
for(size_t k = 0; k < shapes.size(); k++)
|
|
||||||
required_num_warps *= get_num_warps(i, k);
|
|
||||||
if(required_num_warps != num_warps)
|
|
||||||
errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps));
|
|
||||||
}
|
}
|
||||||
return errors.empty();
|
return errors.empty();
|
||||||
}
|
}
|
||||||
|
@@ -142,7 +142,6 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
|||||||
launch_information info = launch_info_map_.at("matmul");
|
launch_information info = launch_info_map_.at("matmul");
|
||||||
for(unsigned p: params)
|
for(unsigned p: params)
|
||||||
std::cout << p << " " << std::flush;
|
std::cout << p << " " << std::flush;
|
||||||
std::cout << std::endl;
|
|
||||||
benchmark(kernel, info);
|
benchmark(kernel, info);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user