[codegen] more progress

This commit is contained in:
Philippe Tillet
2019-10-03 14:11:50 -04:00
parent 1bf0c8adeb
commit a1e0512703
6 changed files with 63 additions and 56 deletions

View File

@@ -1049,6 +1049,19 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I
tmap_[x] = out;
}
bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
@@ -1082,8 +1095,11 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1));
bool is_a_row = ord_a[ord_a.size() - 2] == 1;
bool is_b_row = ord_b[ord_b.size() - 2] == 1;
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
if(is_a_row){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
@@ -1124,7 +1140,7 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
indices_t idx_b = {builder.CreateAdd(offset_b_k, _K), current_offset_b_i};
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a);