basic split-k across warps working for GEMM
This commit is contained in:
@@ -24,8 +24,7 @@ bool is_hmma(ir::value *v){
|
||||
ir::type *b_ty = b->get_type();
|
||||
// inputs have to be FP16
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||
// reduction has to be multiple of 4
|
||||
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||
// reduction has to be multiple of 4: TODO
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -66,9 +65,10 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
if(i == axis)
|
||||
continue;
|
||||
// std::cout << arg->get_name() << " " << v->get_name() << std::endl;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
// add_constraint({reduce, 0}, {arg, 0});
|
||||
// add_constraint({reduce, 1}, {arg, 1});
|
||||
return;
|
||||
}
|
||||
else
|
||||
@@ -81,8 +81,10 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
bool is_one = shapes[i] == one;
|
||||
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
|
||||
if(is_one)
|
||||
if(is_one){
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed && is_same)
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else{
|
||||
@@ -114,9 +116,17 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *A = v->get_operand(0);
|
||||
ir::value *B = v->get_operand(1);
|
||||
ir::value *D = v->get_operand(2);
|
||||
add_constraint({v, 0}, {D, 0});
|
||||
add_constraint({v, 1}, {D, 1});
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v, i}, {D, i});
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
if(shapes[i] == one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
@@ -242,7 +252,7 @@ void tune::run(ir::module &mod) {
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
@@ -266,14 +276,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
@@ -365,6 +375,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids_){
|
||||
// std::cout << i->get_name() << std::endl;
|
||||
ir::type *ty = i->get_type();
|
||||
const auto &shapes = ty->get_tile_shapes();
|
||||
// for each dimension, the product of layout components
|
||||
@@ -396,11 +407,15 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||
}
|
||||
int num_threads = get_req_num_threads(i);
|
||||
if(num_threads % 64 != 0)
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
}
|
||||
// for(auto x: errors)
|
||||
// for(auto e: x.second)
|
||||
// std::cout << x.first->get_name() << ": " << e << std::endl;
|
||||
// exit(EXIT_SUCCESS);
|
||||
return errors.empty();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user