[codegen] more progress

This commit is contained in:
Philippe Tillet
2019-10-03 14:11:50 -04:00
parent 1bf0c8adeb
commit a1e0512703
6 changed files with 63 additions and 56 deletions

View File

@@ -91,46 +91,36 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
bool liveness::do_pad(ir::value *x) {
// alignment for matrix product
// if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
// auto order = tiles_->order(x);
// // a
// ir::value *a = dot->get_operand(0);\
// size_t previous_a = pad_[a];
// bool a_trans = dynamic_cast<ir::trans_inst*>(a);
// bool a_row = order[0] == 0;
// if(tiles_->hmma(x) == HMMA_A_ROW)
// pad_[a] = 16;
// else if(tiles_->hmma(x) == HMMA_A_COL)
// pad_[a] = 8;
// else if(a_trans ^ a_row)
// pad_[a] = 4;
// else
// pad_[a] = 0;
// // b
// ir::value *b = dot->get_operand(1);
// size_t previous_b = pad_[b];
// bool b_trans = dynamic_cast<ir::trans_inst*>(b);
// bool b_col = order[0] == 0;
// if(tiles_->hmma(x) == HMMA_B_COL)
// pad_[b] = 16;
// if(tiles_->hmma(x) == HMMA_B_ROW)
// pad_[b] = 8;
// if(b_trans ^ b_col)
// pad_[b] = 4;
// else
// pad_[b] = 0;
// return previous_a != pad_[a] || previous_b != pad_[b];
// }
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
// a
ir::value *a = dot->get_operand(0);\
size_t previous_a = pad_[a];
if(tiles_->hmma(a) == HMMA_A_ROW)
pad_[a] = 16;
else if(tiles_->hmma(a) == HMMA_A_COL)
pad_[a] = 8;
else
pad_[a] = 0;
// b
ir::value *b = dot->get_operand(1);
size_t previous_b = pad_[b];
if(tiles_->hmma(b) == HMMA_B_COL)
pad_[b] = 16;
if(tiles_->hmma(b) == HMMA_B_ROW)
pad_[b] = 8;
else
pad_[b] = 0;
return previous_a != pad_[a] || previous_b != pad_[b];
}
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts);
ir::value *arg = cts->get_operand(0);
auto arg_order = tiles_->order(arg);
size_t previous = pad_[cts];
if(cts_order != arg_order)
pad_[cts] = 4;
return pad_[cts] != previous;
}
// if(auto* tr = dynamic_cast<ir::trans_inst*>(x)) {
// pad_[tr] = 4;
// }
// padding for phi-nodes
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool has_changed = false;
@@ -142,7 +132,7 @@ bool liveness::do_pad(ir::value *x) {
}
return has_changed;
}
// default -- no pading
// default -- no padding
size_t previous = pad_[x];
pad_[x] = std::max<int>(previous, 0);
return pad_[x] != previous;

View File

@@ -254,20 +254,20 @@ void tiles::run(ir::module &) {
}
order_[i] = order;
}
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order_[layout_->id(a)] = is_trans(a) ? row : col;
order_[layout_->id(b)] = is_trans(b) ? col : row;
}
}
// for(size_t i = 0; i < num_groups; i++){
// std::vector<ir::dot_inst*> dots;
// for(ir::value* v: layout_->values(i))
// if(auto *x = dynamic_cast<ir::dot_inst*>(v))
// dots.push_back(x);
// for(ir::dot_inst* dot: dots){
// ir::value* a = dot->get_operand(0);
// ir::value* b = dot->get_operand(1);
// std::vector<int> col = {0, 1};
// std::vector<int> row = {1, 0};
// order_[layout_->id(a)] = is_trans(a) ? row : col;
// order_[layout_->id(b)] = is_trans(b) ? col : row;
// }
// }
// tiling parameters
for(auto x: largest_){
ir::value *i = x.second;

View File

@@ -1049,6 +1049,19 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I
tmap_[x] = out;
}
bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
@@ -1082,8 +1095,11 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1));
bool is_a_row = ord_a[ord_a.size() - 2] == 1;
bool is_b_row = ord_b[ord_b.size() - 2] == 1;
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
if(is_a_row){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
@@ -1124,7 +1140,7 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
indices_t idx_b = {builder.CreateAdd(offset_b_k, _K), current_offset_b_i};
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a);

View File

@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -32,7 +32,7 @@ int main() {
for(const auto& c: configs){
std::tie(AT, BT, M, N, K) = c;
std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush;
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K))
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}

View File

@@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
opt.num_warps = {nwarp};
}
if(mode == BENCH) {
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {2, 4, 8};
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"16"}});
opt.num_warps = {4};
}
// kernels