[codegen] more progress
This commit is contained in:
@@ -91,46 +91,36 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
|
||||
|
||||
bool liveness::do_pad(ir::value *x) {
|
||||
// alignment for matrix product
|
||||
// if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||
// auto order = tiles_->order(x);
|
||||
// // a
|
||||
// ir::value *a = dot->get_operand(0);\
|
||||
// size_t previous_a = pad_[a];
|
||||
// bool a_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||
// bool a_row = order[0] == 0;
|
||||
// if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||
// pad_[a] = 16;
|
||||
// else if(tiles_->hmma(x) == HMMA_A_COL)
|
||||
// pad_[a] = 8;
|
||||
// else if(a_trans ^ a_row)
|
||||
// pad_[a] = 4;
|
||||
// else
|
||||
// pad_[a] = 0;
|
||||
// // b
|
||||
// ir::value *b = dot->get_operand(1);
|
||||
// size_t previous_b = pad_[b];
|
||||
// bool b_trans = dynamic_cast<ir::trans_inst*>(b);
|
||||
// bool b_col = order[0] == 0;
|
||||
// if(tiles_->hmma(x) == HMMA_B_COL)
|
||||
// pad_[b] = 16;
|
||||
// if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||
// pad_[b] = 8;
|
||||
// if(b_trans ^ b_col)
|
||||
// pad_[b] = 4;
|
||||
// else
|
||||
// pad_[b] = 0;
|
||||
// return previous_a != pad_[a] || previous_b != pad_[b];
|
||||
// }
|
||||
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||
// a
|
||||
ir::value *a = dot->get_operand(0);\
|
||||
size_t previous_a = pad_[a];
|
||||
if(tiles_->hmma(a) == HMMA_A_ROW)
|
||||
pad_[a] = 16;
|
||||
else if(tiles_->hmma(a) == HMMA_A_COL)
|
||||
pad_[a] = 8;
|
||||
else
|
||||
pad_[a] = 0;
|
||||
// b
|
||||
ir::value *b = dot->get_operand(1);
|
||||
size_t previous_b = pad_[b];
|
||||
if(tiles_->hmma(b) == HMMA_B_COL)
|
||||
pad_[b] = 16;
|
||||
if(tiles_->hmma(b) == HMMA_B_ROW)
|
||||
pad_[b] = 8;
|
||||
else
|
||||
pad_[b] = 0;
|
||||
return previous_a != pad_[a] || previous_b != pad_[b];
|
||||
}
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
|
||||
auto cts_order = tiles_->order(cts);
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
auto arg_order = tiles_->order(arg);
|
||||
size_t previous = pad_[cts];
|
||||
if(cts_order != arg_order)
|
||||
pad_[cts] = 4;
|
||||
return pad_[cts] != previous;
|
||||
}
|
||||
// if(auto* tr = dynamic_cast<ir::trans_inst*>(x)) {
|
||||
// pad_[tr] = 4;
|
||||
// }
|
||||
// padding for phi-nodes
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
bool has_changed = false;
|
||||
@@ -142,7 +132,7 @@ bool liveness::do_pad(ir::value *x) {
|
||||
}
|
||||
return has_changed;
|
||||
}
|
||||
// default -- no pading
|
||||
// default -- no padding
|
||||
size_t previous = pad_[x];
|
||||
pad_[x] = std::max<int>(previous, 0);
|
||||
return pad_[x] != previous;
|
||||
|
@@ -254,20 +254,20 @@ void tiles::run(ir::module &) {
|
||||
}
|
||||
order_[i] = order;
|
||||
}
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||
order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
// for(size_t i = 0; i < num_groups; i++){
|
||||
// std::vector<ir::dot_inst*> dots;
|
||||
// for(ir::value* v: layout_->values(i))
|
||||
// if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
// dots.push_back(x);
|
||||
// for(ir::dot_inst* dot: dots){
|
||||
// ir::value* a = dot->get_operand(0);
|
||||
// ir::value* b = dot->get_operand(1);
|
||||
// std::vector<int> col = {0, 1};
|
||||
// std::vector<int> row = {1, 0};
|
||||
// order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||
// order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||
// }
|
||||
// }
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
ir::value *i = x.second;
|
||||
|
@@ -1049,6 +1049,19 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I
|
||||
tmap_[x] = out;
|
||||
}
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
|
||||
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
|
||||
@@ -1082,8 +1095,11 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
auto ord_a = tiles_->order(dot->get_operand(0));
|
||||
auto ord_b = tiles_->order(dot->get_operand(1));
|
||||
|
||||
bool is_a_row = ord_a[ord_a.size() - 2] == 1;
|
||||
bool is_b_row = ord_b[ord_b.size() - 2] == 1;
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
|
||||
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
|
||||
|
||||
|
||||
if(is_a_row){
|
||||
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
@@ -1124,7 +1140,7 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||
indices_t idx_b = {builder.CreateAdd(offset_b_k, _K), current_offset_b_i};
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
|
@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -32,7 +32,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush;
|
||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K))
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
@@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.num_warps = {4};
|
||||
}
|
||||
|
||||
// kernels
|
||||
|
Reference in New Issue
Block a user