[codegen] now matrix-multiplication is bank-conflict free for all

layouts
This commit is contained in:
Philippe Tillet
2019-10-01 16:57:59 -04:00
parent ed1b2bc563
commit 86a3e5d897
6 changed files with 71 additions and 41 deletions

View File

@@ -39,6 +39,7 @@ class tiles {
private:
void init_hmma_tile(ir::value *i);
void init_scanline_tile(ir::value *i);
bool is_trans(ir::value *i);
public:
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);

View File

@@ -91,36 +91,46 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
bool liveness::do_pad(ir::value *x) {
// alignment for matrix product
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
auto order = tiles_->order(x);
// a
ir::value *a = dot->get_operand(0);\
size_t previous_a = pad_[a];
bool a_trans = dynamic_cast<ir::trans_inst*>(a);
bool a_row = order[0] == 1;
if(tiles_->hmma(x) == HMMA_A_ROW)
pad_[a] = 16;
else if(tiles_->hmma(x) == HMMA_A_COL)
pad_[a] = 8;
else if(a_trans ^ a_row)
pad_[a] = 4;
else
pad_[a] = 0;
// b
ir::value *b = dot->get_operand(1);
size_t previous_b = pad_[b];
bool b_trans = dynamic_cast<ir::trans_inst*>(a);
bool b_col = order[0] == 0;
if(tiles_->hmma(x) == HMMA_B_COL)
pad_[b] = 16;
if(tiles_->hmma(x) == HMMA_B_ROW)
pad_[b] = 8;
if(b_trans ^ b_col)
pad_[b] = 4;
else
pad_[b] = 0;
return previous_a != pad_[a] || previous_b != pad_[b];
// if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
// auto order = tiles_->order(x);
// // a
// ir::value *a = dot->get_operand(0);\
// size_t previous_a = pad_[a];
// bool a_trans = dynamic_cast<ir::trans_inst*>(a);
// bool a_row = order[0] == 0;
// if(tiles_->hmma(x) == HMMA_A_ROW)
// pad_[a] = 16;
// else if(tiles_->hmma(x) == HMMA_A_COL)
// pad_[a] = 8;
// else if(a_trans ^ a_row)
// pad_[a] = 4;
// else
// pad_[a] = 0;
// // b
// ir::value *b = dot->get_operand(1);
// size_t previous_b = pad_[b];
// bool b_trans = dynamic_cast<ir::trans_inst*>(b);
// bool b_col = order[0] == 0;
// if(tiles_->hmma(x) == HMMA_B_COL)
// pad_[b] = 16;
// if(tiles_->hmma(x) == HMMA_B_ROW)
// pad_[b] = 8;
// if(b_trans ^ b_col)
// pad_[b] = 4;
// else
// pad_[b] = 0;
// return previous_a != pad_[a] || previous_b != pad_[b];
// }
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts);
ir::value *arg = cts->get_operand(0);
auto arg_order = tiles_->order(arg);
if(cts_order != arg_order)
pad_[cts] = 4;
}
// if(auto* tr = dynamic_cast<ir::trans_inst*>(x)) {
// pad_[tr] = 4;
// }
// padding for phi-nodes
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool has_changed = false;
@@ -157,7 +167,6 @@ unsigned liveness::num_bytes(ir::value *x) {
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = pad_.at(x);
std::cout << x->get_name() << " " << pad << std::endl;
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld;

View File

@@ -184,6 +184,20 @@ void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
}
bool tiles::is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void tiles::run(ir::module &) {
hmma_.clear();
largest_.clear();
@@ -220,6 +234,7 @@ void tiles::run(ir::module &) {
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
// find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
@@ -240,13 +255,18 @@ void tiles::run(ir::module &) {
order_[i] = order;
}
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::copy_to_shared_inst*> cts;
std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values(i))
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
cts.push_back(x);
if(cts.empty())
continue;
order_[i] = order(cts[0]->get_operand(0));
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order_[layout_->id(a)] = is_trans(a) ? row : col;
order_[layout_->id(b)] = is_trans(b) ? col : row;
}
}
// tiling parameters
for(auto x: largest_){

View File

@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl;
cu_context::context_switcher ctx(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -45,8 +45,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"64"}});
opt.defines.push_back({"TN", {"64"}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4};
// create function
@@ -79,7 +79,8 @@ int main() {
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
{true, false}, {true, true}}){
std::vector<config_t> tmp = {
config_t{x[0], x[1], 2048, 2048, 2048}
// config_t{x[0], x[1], 16, 2048, 2048},

View File

@@ -54,7 +54,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > TK; k-= TK){
for(int k = K; k > 0; k-= TK){
c += USEA @ USEB;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;