basic split-k across warps working for GEMM

This commit is contained in:
Philippe Tillet
2019-08-05 19:33:28 -07:00
parent 899b2b72e1
commit d62e581ab3
12 changed files with 99 additions and 63 deletions

View File

@@ -24,8 +24,7 @@ bool is_hmma(ir::value *v){
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
// reduction has to be multiple of 4: TODO
}
return result;
}
@@ -66,9 +65,10 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
// std::cout << arg->get_name() << " " << v->get_name() << std::endl;
add_constraint({reduce, current++}, {arg, i});
}
// add_constraint({reduce, 0}, {arg, 0});
// add_constraint({reduce, 1}, {arg, 1});
return;
}
else
@@ -81,8 +81,10 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 0; i < shapes.size(); i ++){
bool is_one = shapes[i] == one;
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
if(is_one)
if(is_one){
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {v, i});
}
else if(!is_skewed && is_same)
add_constraint({v, i}, {op, current++});
else{
@@ -114,9 +116,17 @@ void tune::init_c_graph(ir::instruction *v) {
}
// Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
add_constraint({v, 0}, {D, 0});
add_constraint({v, 1}, {D, 1});
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == one)
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)) {
@@ -242,7 +252,7 @@ void tune::run(ir::module &mod) {
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
@@ -266,14 +276,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}
@@ -365,6 +375,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
// check constraints
for(ir::instruction *i: grids_){
// std::cout << i->get_name() << std::endl;
ir::type *ty = i->get_type();
const auto &shapes = ty->get_tile_shapes();
// for each dimension, the product of layout components
@@ -396,11 +407,15 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
errors[i].push_back("HMMA must have only 4 fragments per warp");
}
int num_threads = get_req_num_threads(i);
if(num_threads % 64 != 0)
if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
}
// for(auto x: errors)
// for(auto e: x.second)
// std::cout << x.first->get_name() << ": " << e << std::endl;
// exit(EXIT_SUCCESS);
return errors.empty();
}