[codegen] more progress
This commit is contained in:
@@ -1049,6 +1049,19 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I
|
||||
tmap_[x] = out;
|
||||
}
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
|
||||
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
|
||||
@@ -1082,8 +1095,11 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
auto ord_a = tiles_->order(dot->get_operand(0));
|
||||
auto ord_b = tiles_->order(dot->get_operand(1));
|
||||
|
||||
bool is_a_row = ord_a[ord_a.size() - 2] == 1;
|
||||
bool is_b_row = ord_b[ord_b.size() - 2] == 1;
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
|
||||
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
|
||||
|
||||
|
||||
if(is_a_row){
|
||||
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
@@ -1124,7 +1140,7 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||
indices_t idx_b = {builder.CreateAdd(offset_b_k, _K), current_offset_b_i};
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
|
Reference in New Issue
Block a user