[codegen] now matrix-multiplication is bank-conflict free for all
layouts
This commit is contained in:
@@ -39,6 +39,7 @@ class tiles {
|
||||
private:
|
||||
void init_hmma_tile(ir::value *i);
|
||||
void init_scanline_tile(ir::value *i);
|
||||
bool is_trans(ir::value *i);
|
||||
|
||||
public:
|
||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||
|
@@ -91,36 +91,46 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
|
||||
|
||||
bool liveness::do_pad(ir::value *x) {
|
||||
// alignment for matrix product
|
||||
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||
auto order = tiles_->order(x);
|
||||
// a
|
||||
ir::value *a = dot->get_operand(0);\
|
||||
size_t previous_a = pad_[a];
|
||||
bool a_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||
bool a_row = order[0] == 1;
|
||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||
pad_[a] = 16;
|
||||
else if(tiles_->hmma(x) == HMMA_A_COL)
|
||||
pad_[a] = 8;
|
||||
else if(a_trans ^ a_row)
|
||||
pad_[a] = 4;
|
||||
else
|
||||
pad_[a] = 0;
|
||||
// b
|
||||
ir::value *b = dot->get_operand(1);
|
||||
size_t previous_b = pad_[b];
|
||||
bool b_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||
bool b_col = order[0] == 0;
|
||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||
pad_[b] = 16;
|
||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||
pad_[b] = 8;
|
||||
if(b_trans ^ b_col)
|
||||
pad_[b] = 4;
|
||||
else
|
||||
pad_[b] = 0;
|
||||
return previous_a != pad_[a] || previous_b != pad_[b];
|
||||
// if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||
// auto order = tiles_->order(x);
|
||||
// // a
|
||||
// ir::value *a = dot->get_operand(0);\
|
||||
// size_t previous_a = pad_[a];
|
||||
// bool a_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||
// bool a_row = order[0] == 0;
|
||||
// if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||
// pad_[a] = 16;
|
||||
// else if(tiles_->hmma(x) == HMMA_A_COL)
|
||||
// pad_[a] = 8;
|
||||
// else if(a_trans ^ a_row)
|
||||
// pad_[a] = 4;
|
||||
// else
|
||||
// pad_[a] = 0;
|
||||
// // b
|
||||
// ir::value *b = dot->get_operand(1);
|
||||
// size_t previous_b = pad_[b];
|
||||
// bool b_trans = dynamic_cast<ir::trans_inst*>(b);
|
||||
// bool b_col = order[0] == 0;
|
||||
// if(tiles_->hmma(x) == HMMA_B_COL)
|
||||
// pad_[b] = 16;
|
||||
// if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||
// pad_[b] = 8;
|
||||
// if(b_trans ^ b_col)
|
||||
// pad_[b] = 4;
|
||||
// else
|
||||
// pad_[b] = 0;
|
||||
// return previous_a != pad_[a] || previous_b != pad_[b];
|
||||
// }
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
|
||||
auto cts_order = tiles_->order(cts);
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
auto arg_order = tiles_->order(arg);
|
||||
if(cts_order != arg_order)
|
||||
pad_[cts] = 4;
|
||||
}
|
||||
// if(auto* tr = dynamic_cast<ir::trans_inst*>(x)) {
|
||||
// pad_[tr] = 4;
|
||||
// }
|
||||
// padding for phi-nodes
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
bool has_changed = false;
|
||||
@@ -157,7 +167,6 @@ unsigned liveness::num_bytes(ir::value *x) {
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = pad_.at(x);
|
||||
std::cout << x->get_name() << " " << pad << std::endl;
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
|
@@ -184,6 +184,20 @@ void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
}
|
||||
|
||||
|
||||
bool tiles::is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void tiles::run(ir::module &) {
|
||||
hmma_.clear();
|
||||
largest_.clear();
|
||||
@@ -220,6 +234,7 @@ void tiles::run(ir::module &) {
|
||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||
}
|
||||
|
||||
|
||||
// find out the layout ordering of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
@@ -240,13 +255,18 @@ void tiles::run(ir::module &) {
|
||||
order_[i] = order;
|
||||
}
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::copy_to_shared_inst*> cts;
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
cts.push_back(x);
|
||||
if(cts.empty())
|
||||
continue;
|
||||
order_[i] = order(cts[0]->get_operand(0));
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||
order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
|
@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -45,8 +45,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"64"}});
|
||||
opt.defines.push_back({"TN", {"64"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {4};
|
||||
// create function
|
||||
@@ -79,7 +79,8 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||
{true, false}, {true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
|
@@ -54,7 +54,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
// reduction loop
|
||||
for(int k = K; k > TK; k-= TK){
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USEA @ USEB;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
|
Reference in New Issue
Block a user