[codegen] now matrix-multiplication is bank-conflict free for all
layouts
This commit is contained in:
@@ -39,6 +39,7 @@ class tiles {
|
|||||||
private:
|
private:
|
||||||
void init_hmma_tile(ir::value *i);
|
void init_hmma_tile(ir::value *i);
|
||||||
void init_scanline_tile(ir::value *i);
|
void init_scanline_tile(ir::value *i);
|
||||||
|
bool is_trans(ir::value *i);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||||
|
@@ -91,36 +91,46 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
|
|||||||
|
|
||||||
bool liveness::do_pad(ir::value *x) {
|
bool liveness::do_pad(ir::value *x) {
|
||||||
// alignment for matrix product
|
// alignment for matrix product
|
||||||
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
// if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||||
auto order = tiles_->order(x);
|
// auto order = tiles_->order(x);
|
||||||
// a
|
// // a
|
||||||
ir::value *a = dot->get_operand(0);\
|
// ir::value *a = dot->get_operand(0);\
|
||||||
size_t previous_a = pad_[a];
|
// size_t previous_a = pad_[a];
|
||||||
bool a_trans = dynamic_cast<ir::trans_inst*>(a);
|
// bool a_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||||
bool a_row = order[0] == 1;
|
// bool a_row = order[0] == 0;
|
||||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
// if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||||
pad_[a] = 16;
|
// pad_[a] = 16;
|
||||||
else if(tiles_->hmma(x) == HMMA_A_COL)
|
// else if(tiles_->hmma(x) == HMMA_A_COL)
|
||||||
pad_[a] = 8;
|
// pad_[a] = 8;
|
||||||
else if(a_trans ^ a_row)
|
// else if(a_trans ^ a_row)
|
||||||
pad_[a] = 4;
|
// pad_[a] = 4;
|
||||||
else
|
// else
|
||||||
pad_[a] = 0;
|
// pad_[a] = 0;
|
||||||
// b
|
// // b
|
||||||
ir::value *b = dot->get_operand(1);
|
// ir::value *b = dot->get_operand(1);
|
||||||
size_t previous_b = pad_[b];
|
// size_t previous_b = pad_[b];
|
||||||
bool b_trans = dynamic_cast<ir::trans_inst*>(a);
|
// bool b_trans = dynamic_cast<ir::trans_inst*>(b);
|
||||||
bool b_col = order[0] == 0;
|
// bool b_col = order[0] == 0;
|
||||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
// if(tiles_->hmma(x) == HMMA_B_COL)
|
||||||
pad_[b] = 16;
|
// pad_[b] = 16;
|
||||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
// if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||||
pad_[b] = 8;
|
// pad_[b] = 8;
|
||||||
if(b_trans ^ b_col)
|
// if(b_trans ^ b_col)
|
||||||
pad_[b] = 4;
|
// pad_[b] = 4;
|
||||||
else
|
// else
|
||||||
pad_[b] = 0;
|
// pad_[b] = 0;
|
||||||
return previous_a != pad_[a] || previous_b != pad_[b];
|
// return previous_a != pad_[a] || previous_b != pad_[b];
|
||||||
|
// }
|
||||||
|
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
|
||||||
|
auto cts_order = tiles_->order(cts);
|
||||||
|
ir::value *arg = cts->get_operand(0);
|
||||||
|
auto arg_order = tiles_->order(arg);
|
||||||
|
if(cts_order != arg_order)
|
||||||
|
pad_[cts] = 4;
|
||||||
}
|
}
|
||||||
|
// if(auto* tr = dynamic_cast<ir::trans_inst*>(x)) {
|
||||||
|
// pad_[tr] = 4;
|
||||||
|
// }
|
||||||
// padding for phi-nodes
|
// padding for phi-nodes
|
||||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||||
bool has_changed = false;
|
bool has_changed = false;
|
||||||
@@ -157,7 +167,6 @@ unsigned liveness::num_bytes(ir::value *x) {
|
|||||||
}
|
}
|
||||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||||
unsigned pad = pad_.at(x);
|
unsigned pad = pad_.at(x);
|
||||||
std::cout << x->get_name() << " " << pad << std::endl;
|
|
||||||
if(pad > 0){
|
if(pad > 0){
|
||||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||||
num_bytes += pad * num_bytes / ld;
|
num_bytes += pad * num_bytes / ld;
|
||||||
|
@@ -184,6 +184,20 @@ void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool tiles::is_trans(ir::value *v) {
|
||||||
|
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||||
|
bool result = true;
|
||||||
|
for(ir::value *op: phi->ops())
|
||||||
|
result = result && is_trans(op);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void tiles::run(ir::module &) {
|
void tiles::run(ir::module &) {
|
||||||
hmma_.clear();
|
hmma_.clear();
|
||||||
largest_.clear();
|
largest_.clear();
|
||||||
@@ -220,6 +234,7 @@ void tiles::run(ir::module &) {
|
|||||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// find out the layout ordering of a group
|
// find out the layout ordering of a group
|
||||||
for(size_t i = 0; i < num_groups; i++){
|
for(size_t i = 0; i < num_groups; i++){
|
||||||
std::set<ir::io_inst*> io;
|
std::set<ir::io_inst*> io;
|
||||||
@@ -240,13 +255,18 @@ void tiles::run(ir::module &) {
|
|||||||
order_[i] = order;
|
order_[i] = order;
|
||||||
}
|
}
|
||||||
for(size_t i = 0; i < num_groups; i++){
|
for(size_t i = 0; i < num_groups; i++){
|
||||||
std::vector<ir::copy_to_shared_inst*> cts;
|
std::vector<ir::dot_inst*> dots;
|
||||||
for(ir::value* v: layout_->values(i))
|
for(ir::value* v: layout_->values(i))
|
||||||
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||||
cts.push_back(x);
|
dots.push_back(x);
|
||||||
if(cts.empty())
|
for(ir::dot_inst* dot: dots){
|
||||||
continue;
|
ir::value* a = dot->get_operand(0);
|
||||||
order_[i] = order(cts[0]->get_operand(0));
|
ir::value* b = dot->get_operand(1);
|
||||||
|
std::vector<int> col = {0, 1};
|
||||||
|
std::vector<int> row = {1, 0};
|
||||||
|
order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||||
|
order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// tiling parameters
|
// tiling parameters
|
||||||
for(auto x: largest_){
|
for(auto x: largest_){
|
||||||
|
@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
std::cout << source << std::endl;
|
|
||||||
cu_context::context_switcher ctx(*context);
|
cu_context::context_switcher ctx(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -45,8 +45,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||||
opt.defines.push_back({"TM", {"64"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"64"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {4};
|
||||||
// create function
|
// create function
|
||||||
@@ -79,7 +79,8 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||||
|
{true, false}, {true, true}}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||||
|
@@ -54,7 +54,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
TYPE a[SHAPE_A] = *pa;
|
TYPE a[SHAPE_A] = *pa;
|
||||||
TYPE b[SHAPE_B] = *pb;
|
TYPE b[SHAPE_B] = *pb;
|
||||||
// reduction loop
|
// reduction loop
|
||||||
for(int k = K; k > TK; k-= TK){
|
for(int k = K; k > 0; k-= TK){
|
||||||
c += USEA @ USEB;
|
c += USEA @ USEB;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa = pa + TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb = pb + TK * STRIDE_BK;
|
||||||
|
Reference in New Issue
Block a user