[codegen] separated lower_dot_inst into lower_outer_dot ||
lower_hmma_dot || lower_scanline_dot
This commit is contained in:
@@ -999,194 +999,148 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I
|
||||
});
|
||||
}
|
||||
|
||||
void selection::lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
|
||||
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
|
||||
}
|
||||
|
||||
void selection::lower_scalar_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
|
||||
}
|
||||
|
||||
void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
const auto& shapes = dot->get_type()->get_tile_shapes();
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(dot);
|
||||
Module *module = fn->getParent();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *C = dot->get_operand(2);
|
||||
bool AT = dot->is_a_trans();
|
||||
bool BT = dot->is_b_trans();
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
||||
unsigned NK = A_shapes[red_axis]->get_value();
|
||||
|
||||
// std::cout << red_axis << " " << NK << std::endl;
|
||||
if(NK != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN) {
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
// input indices
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
// add batching dimension
|
||||
for(size_t i = 2; i < idx.size(); i++){
|
||||
a_idx.insert(a_idx.end(), idx[i]);
|
||||
b_idx.insert(b_idx.end(), idx[i]);
|
||||
}
|
||||
// load value
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
if(a->getType() != c_ty)
|
||||
a = builder.CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder.CreateFPCast(b, c_ty);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
TA->set_vector_size(4*pack_size_0_);
|
||||
TB->set_vector_size(4*pack_size_1_);
|
||||
TA->set_return_mode(true);
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
TC->for_each([&](indices_t idx){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
fcs[key].push_back(TD->get_value(idx));
|
||||
});
|
||||
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
|
||||
Value *offset_a_i = offset_a_i_;
|
||||
Value *offset_a_k = offset_a_k_;
|
||||
Value *offset_b_j = offset_b_j_;
|
||||
Value *offset_b_k = offset_b_k_;
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
|
||||
if(dot->is_a_trans()){
|
||||
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
offset_a_k = builder.getInt32(0);
|
||||
}
|
||||
if(!dot->is_b_trans()){
|
||||
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
offset_b_k = builder.getInt32(0);
|
||||
}
|
||||
|
||||
std::string op_a = dot->is_a_trans() ? "row" : "col";
|
||||
std::string op_b = dot->is_b_trans() ? "row" : "col";
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
|
||||
unsigned ld_fc = num_rep_i * 2;
|
||||
|
||||
|
||||
for(auto& x: fcs){
|
||||
std::vector<Value *>& fc = x.second;
|
||||
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
|
||||
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
||||
for(unsigned K = 0; K < NK; K += 4){
|
||||
Value *_K = builder.getInt32(K);
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||
if(dot->is_a_trans())
|
||||
std::swap(idx_a[0], idx_a[1]);
|
||||
if(!dot->is_b_trans())
|
||||
std::swap(idx_b[0], idx_b[1]);
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
|
||||
}
|
||||
}
|
||||
else {
|
||||
TA->set_vector_size(4*pack_size_0_);
|
||||
TB->set_vector_size(4*pack_size_1_);
|
||||
TA->set_return_mode(true);
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
fcs[key].push_back(TC->get_value(idx));
|
||||
});
|
||||
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
|
||||
Value *offset_a_i = offset_a_i_;
|
||||
Value *offset_a_k = offset_a_k_;
|
||||
Value *offset_b_j = offset_b_j_;
|
||||
Value *offset_b_k = offset_b_k_;
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
|
||||
if(dot->is_a_trans()){
|
||||
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
offset_a_k = builder.getInt32(0);
|
||||
}
|
||||
if(!dot->is_b_trans()){
|
||||
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
offset_b_k = builder.getInt32(0);
|
||||
}
|
||||
|
||||
std::string op_a = dot->is_a_trans() ? "row" : "col";
|
||||
std::string op_b = dot->is_b_trans() ? "row" : "col";
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
|
||||
unsigned ld_fc = num_rep_i * 2;
|
||||
|
||||
|
||||
for(auto& x: fcs){
|
||||
std::vector<Value *>& fc = x.second;
|
||||
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
|
||||
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
||||
for(unsigned K = 0; K < NK; K += 4){
|
||||
Value *_K = builder.getInt32(K);
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||
if(dot->is_a_trans())
|
||||
std::swap(idx_a[0], idx_a[1]);
|
||||
if(!dot->is_b_trans())
|
||||
std::swap(idx_b[0], idx_b[1]);
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
result->for_each([&](indices_t idx){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
if(i >= fcs.at(key).size())
|
||||
i = 0;
|
||||
result->set_value(idx, fcs.at(key)[i++]);
|
||||
});
|
||||
|
||||
TA->set_return_mode(false);
|
||||
TB->set_return_mode(false);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
indices_t a_idx = {idx[0], builder.getInt32(0)};
|
||||
indices_t b_idx = {builder.getInt32(0), idx[1]};
|
||||
if(AT)
|
||||
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
TC->for_each([&](indices_t idx){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
if(i >= fcs.at(key).size())
|
||||
i = 0;
|
||||
TC->set_value(idx, fcs.at(key)[i++]);
|
||||
});
|
||||
|
||||
TA->set_return_mode(false);
|
||||
TB->set_return_mode(false);
|
||||
}
|
||||
|
||||
void selection::lower_scanline_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
|
||||
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add) {
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
TC->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
// input indices
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(dot->is_a_trans())
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
if(dot->is_b_trans())
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
// add batching dimension
|
||||
for(size_t i = 2; i < idx.size(); i++){
|
||||
a_idx.insert(a_idx.end(), idx[i]);
|
||||
b_idx.insert(b_idx.end(), idx[i]);
|
||||
}
|
||||
// load value
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
if(a->getType() != c_ty)
|
||||
@@ -1194,8 +1148,59 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
||||
if(b->getType() != c_ty)
|
||||
b = builder.CreateFPCast(b, c_ty);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
TC->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
void selection::lower_outer_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
|
||||
distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD,
|
||||
Type *c_ty, Function *f_mul_add) {
|
||||
TC->for_each([&](indices_t idx){
|
||||
Value *res = TD->get_value(idx);
|
||||
indices_t a_idx = {idx[0], builder.getInt32(0)};
|
||||
indices_t b_idx = {builder.getInt32(0), idx[1]};
|
||||
if(dot->is_a_trans())
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(dot->is_b_trans())
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
if(a->getType() != c_ty)
|
||||
a = builder.CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder.CreateFPCast(b, c_ty);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
TC->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
distributed_tile* TC = (distributed_tile*)tmap_.at(dot);
|
||||
Module *module = fn->getParent();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
|
||||
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
|
||||
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
||||
unsigned NK = A_shapes[red_axis]->get_value();
|
||||
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN)
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
else
|
||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||
}
|
||||
else {
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
|
||||
lower_outer_dot(dot, ctx, fn, builder, TC, TA, TB, TD, c_ty, f_mul_add);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user