[codegen] now matrix-multiplication is bank-conflict free for all

layouts
This commit is contained in:
Philippe Tillet
2019-10-01 16:57:59 -04:00
parent ed1b2bc563
commit 86a3e5d897
6 changed files with 71 additions and 41 deletions

View File

@@ -39,6 +39,7 @@ class tiles {
private: private:
void init_hmma_tile(ir::value *i); void init_hmma_tile(ir::value *i);
void init_scanline_tile(ir::value *i); void init_scanline_tile(ir::value *i);
bool is_trans(ir::value *i);
public: public:
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout); tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);

View File

@@ -91,36 +91,46 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
bool liveness::do_pad(ir::value *x) { bool liveness::do_pad(ir::value *x) {
// alignment for matrix product // alignment for matrix product
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) { // if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
auto order = tiles_->order(x); // auto order = tiles_->order(x);
// a // // a
ir::value *a = dot->get_operand(0);\ // ir::value *a = dot->get_operand(0);\
size_t previous_a = pad_[a]; // size_t previous_a = pad_[a];
bool a_trans = dynamic_cast<ir::trans_inst*>(a); // bool a_trans = dynamic_cast<ir::trans_inst*>(a);
bool a_row = order[0] == 1; // bool a_row = order[0] == 0;
if(tiles_->hmma(x) == HMMA_A_ROW) // if(tiles_->hmma(x) == HMMA_A_ROW)
pad_[a] = 16; // pad_[a] = 16;
else if(tiles_->hmma(x) == HMMA_A_COL) // else if(tiles_->hmma(x) == HMMA_A_COL)
pad_[a] = 8; // pad_[a] = 8;
else if(a_trans ^ a_row) // else if(a_trans ^ a_row)
pad_[a] = 4; // pad_[a] = 4;
else // else
pad_[a] = 0; // pad_[a] = 0;
// b // // b
ir::value *b = dot->get_operand(1); // ir::value *b = dot->get_operand(1);
size_t previous_b = pad_[b]; // size_t previous_b = pad_[b];
bool b_trans = dynamic_cast<ir::trans_inst*>(a); // bool b_trans = dynamic_cast<ir::trans_inst*>(b);
bool b_col = order[0] == 0; // bool b_col = order[0] == 0;
if(tiles_->hmma(x) == HMMA_B_COL) // if(tiles_->hmma(x) == HMMA_B_COL)
pad_[b] = 16; // pad_[b] = 16;
if(tiles_->hmma(x) == HMMA_B_ROW) // if(tiles_->hmma(x) == HMMA_B_ROW)
pad_[b] = 8; // pad_[b] = 8;
if(b_trans ^ b_col) // if(b_trans ^ b_col)
pad_[b] = 4; // pad_[b] = 4;
else // else
pad_[b] = 0; // pad_[b] = 0;
return previous_a != pad_[a] || previous_b != pad_[b]; // return previous_a != pad_[a] || previous_b != pad_[b];
// }
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts);
ir::value *arg = cts->get_operand(0);
auto arg_order = tiles_->order(arg);
if(cts_order != arg_order)
pad_[cts] = 4;
} }
// if(auto* tr = dynamic_cast<ir::trans_inst*>(x)) {
// pad_[tr] = 4;
// }
// padding for phi-nodes // padding for phi-nodes
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) { if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool has_changed = false; bool has_changed = false;
@@ -157,7 +167,6 @@ unsigned liveness::num_bytes(ir::value *x) {
} }
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = pad_.at(x); unsigned pad = pad_.at(x);
std::cout << x->get_name() << " " << pad << std::endl;
if(pad > 0){ if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld; num_bytes += pad * num_bytes / ld;

View File

@@ -184,6 +184,20 @@ void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
} }
bool tiles::is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void tiles::run(ir::module &) { void tiles::run(ir::module &) {
hmma_.clear(); hmma_.clear();
largest_.clear(); largest_.clear();
@@ -220,6 +234,7 @@ void tiles::run(ir::module &) {
largest_[i] = *std::max_element(values.begin(), values.end(), cmp); largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
} }
// find out the layout ordering of a group // find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io; std::set<ir::io_inst*> io;
@@ -240,13 +255,18 @@ void tiles::run(ir::module &) {
order_[i] = order; order_[i] = order;
} }
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
std::vector<ir::copy_to_shared_inst*> cts; std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values(i)) for(ir::value* v: layout_->values(i))
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v)) if(auto *x = dynamic_cast<ir::dot_inst*>(v))
cts.push_back(x); dots.push_back(x);
if(cts.empty()) for(ir::dot_inst* dot: dots){
continue; ir::value* a = dot->get_operand(0);
order_[i] = order(cts[0]->get_operand(0)); ir::value* b = dot->get_operand(1);
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order_[layout_->id(a)] = is_trans(a) ? row : col;
order_[layout_->id(b)] = is_trans(b) ? col : row;
}
} }
// tiling parameters // tiling parameters
for(auto x: largest_){ for(auto x: largest_){

View File

@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl;
cu_context::context_switcher ctx(*context); cu_context::context_switcher ctx(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -45,8 +45,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}}); opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"64"}}); opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"64"}}); opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"8"}}); opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4}; opt.num_warps = {4};
// create function // create function
@@ -79,7 +79,8 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t; typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){ for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
{true, false}, {true, true}}){
std::vector<config_t> tmp = { std::vector<config_t> tmp = {
config_t{x[0], x[1], 2048, 2048, 2048} config_t{x[0], x[1], 2048, 2048, 2048}
// config_t{x[0], x[1], 16, 2048, 2048}, // config_t{x[0], x[1], 16, 2048, 2048},

View File

@@ -54,7 +54,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
TYPE a[SHAPE_A] = *pa; TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb; TYPE b[SHAPE_B] = *pb;
// reduction loop // reduction loop
for(int k = K; k > TK; k-= TK){ for(int k = K; k > 0; k-= TK){
c += USEA @ USEB; c += USEA @ USEB;
pa = pa + TK * STRIDE_AK; pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK; pb = pb + TK * STRIDE_BK;