[feature] added basic tensor core support
This commit is contained in:
@@ -152,11 +152,16 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
indices_t non_cst_idx, cst_idx;
|
||||
extract_constant(idx, non_cst_idx, cst_idx);
|
||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||
unsigned vector_size = vector_size_;
|
||||
Type *ty = ty_;
|
||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
||||
ty = IntegerType::get(ty->getContext(), 32);
|
||||
vector_size = vector_size / 2;
|
||||
}
|
||||
if(base_ptr == nullptr){
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
|
||||
// base_ptr = builder_.CreateBitCast(base_ptr, load_ptr_->getType());
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vector_size_);
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
||||
}
|
||||
@@ -477,26 +482,64 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
// offset_i = tid & 2 + tid & 8
|
||||
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAnd(u_thread_id, _8));
|
||||
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
|
||||
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
|
||||
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
||||
|
||||
// warp tile size
|
||||
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
Value *warp_tile_size_0 = builder.getInt32(wts_0);
|
||||
Value *warp_tile_size_1 = builder.getInt32(wts_1);
|
||||
|
||||
/* intra warp offset */
|
||||
Value *qpa_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); // quad pair id
|
||||
Value *qpb_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), builder.CreateUDiv(_16, builder.getInt32(fpw_1))); // quad pair id
|
||||
// B ofsets
|
||||
Value *qpb_off = builder.CreateURem(builder.CreateMul(qpb_id, _8), warp_tile_size_1); // offset of quad pair in warp
|
||||
// A offsets
|
||||
Value *qa_off = builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4);// offset of quad in pair
|
||||
Value *qpa_off = builder.CreateURem(builder.CreateMul(qpa_id, _8), warp_tile_size_0); // offset of LHS quad pair in warp
|
||||
|
||||
/* inter warp offset */
|
||||
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
|
||||
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_offset_i = builder.CreateMul(warp_id_0, warp_tile_size_0);
|
||||
Value *warp_offset_j = builder.CreateMul(warp_id_1, warp_tile_size_1);
|
||||
|
||||
// offset_i = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
|
||||
Value *offset_i = builder.CreateAdd(warp_offset_i,
|
||||
builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
builder.CreateAdd(qpa_off, qa_off)));
|
||||
|
||||
// repetitions
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
|
||||
// idx_i
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4 + 1)));
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned base_i = 0; base_i < shapes[0]->get_value(); base_i += stride_rep_i)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(base_i + i*2)));
|
||||
}
|
||||
|
||||
// offset_j = tid & 2 + tid & 8
|
||||
Value *offset_j = builder.CreateAdd(warp_offset_j,
|
||||
builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
qpb_off));
|
||||
|
||||
|
||||
// idx_j
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*2)));
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned base_j = 0; base_j < shapes[1]->get_value(); base_j += stride_rep_j)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(base_j + j*4)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(base_j + j*4 + 1)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
||||
}
|
||||
@@ -797,6 +840,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0)
|
||||
packets[id] = result->get_value(idx);
|
||||
});
|
||||
in->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
packets[id] = builder.CreateInsertElement(packets.at(id), in->get_value(idx), linear % vector_size);
|
||||
});
|
||||
result->for_each([&](indices_t idx){
|
||||
@@ -834,7 +881,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
std::cout << NK << std::endl;
|
||||
if(NK != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
@@ -862,8 +908,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
}
|
||||
else
|
||||
{
|
||||
TA->set_vector_size(2);
|
||||
TB->set_vector_size(2);
|
||||
TA->set_vector_size(4);
|
||||
TB->set_vector_size(4);
|
||||
TA->set_return_mode(true);
|
||||
TB->set_return_mode(true);
|
||||
Value *_0 = builder.getInt32(0);
|
||||
@@ -873,22 +919,47 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
Value *warp_tile_size_0 = builder.getInt32(wts_0);
|
||||
Value *warp_tile_size_1 = builder.getInt32(wts_1);
|
||||
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
Value *u_thread_id = builder.CreateURem(tid, builder.getInt32(32));
|
||||
Value *u_warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
|
||||
|
||||
/* intra-warp offset */
|
||||
Value *qpa_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); // quad pair id
|
||||
Value *qpb_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), builder.CreateUDiv(_16, builder.getInt32(fpw_1))); // quad pair id
|
||||
Value *qpa_off = builder.CreateURem(builder.CreateMul(qpa_id, _8), warp_tile_size_0); // offset of LHS quad pair in warp
|
||||
Value *qpb_off = builder.CreateURem(builder.CreateMul(qpb_id, _8), warp_tile_size_1); // offset of quad pair in warp
|
||||
Value *q_off = builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4);// offset of quad in pair
|
||||
|
||||
/* inter-warp offset */
|
||||
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
|
||||
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_offset_i = builder.CreateMul(warp_id_0, warp_tile_size_0);
|
||||
Value *warp_offset_j = builder.CreateMul(warp_id_1, warp_tile_size_1);
|
||||
|
||||
/* repetitions */
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
|
||||
// offset_a_i = (tid & 4)*2 + (tid & 16)/4;
|
||||
// offset_a_k = (tid & 3)
|
||||
Value *offset_a_i = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4), _2),
|
||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||
_4));
|
||||
Value *offset_a_k = builder.CreateAnd(tid, _3);
|
||||
Value *offset_a_i = builder.CreateAdd(warp_offset_i, builder.CreateAdd(qpa_off, q_off));
|
||||
Value *offset_a_k = builder.CreateAnd(u_thread_id, _3);
|
||||
|
||||
// offset_b_i = (tid & 4)*1 + (tid & 16)/4
|
||||
// offset_b_i = (tid & 8)*1 + (tid & 16)/4
|
||||
// offset_b_k = (tid & 3)
|
||||
Value *offset_b_i = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||
_4));
|
||||
Value *offset_b_k = builder.CreateAnd(tid, _3);
|
||||
Value *offset_b_i = builder.CreateAdd(warp_offset_j, builder.CreateAdd(qpb_off, q_off));
|
||||
Value *offset_b_k = builder.CreateAnd(u_thread_id, _3);
|
||||
|
||||
|
||||
std::vector<Value *> fc;
|
||||
@@ -902,26 +973,45 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
|
||||
unsigned num_rep_j = shapes[1]->get_value() / stride_rep_j;
|
||||
unsigned ld_fc = num_rep_i * 2;
|
||||
for(unsigned ii = 0; ii < num_rep_i; ii++)
|
||||
for(unsigned jj = 0; jj < num_rep_j; jj++)
|
||||
for(unsigned K = 0; K < NK; K += 4){
|
||||
Value *_K = builder.getInt32(K);
|
||||
Value *ha0 = TA->get_value({offset_a_i, builder.CreateAdd(offset_a_k, _K)});
|
||||
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _2), builder.CreateAdd(offset_a_k, _K)});
|
||||
Value *hb0 = TB->get_value({offset_b_i, builder.CreateAdd(offset_b_k, _K)});
|
||||
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _2), builder.CreateAdd(offset_b_k, _K)});
|
||||
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[0], fc[2], fc[1], fc[3], fc[4], fc[6], fc[5], fc[7]});
|
||||
fc[0] = builder.CreateExtractValue(nc, {0});
|
||||
fc[2] = builder.CreateExtractValue(nc, {1});
|
||||
fc[1] = builder.CreateExtractValue(nc, {2});
|
||||
fc[3] = builder.CreateExtractValue(nc, {3});
|
||||
fc[4] = builder.CreateExtractValue(nc, {4});
|
||||
fc[6] = builder.CreateExtractValue(nc, {5});
|
||||
fc[5] = builder.CreateExtractValue(nc, {6});
|
||||
fc[7] = builder.CreateExtractValue(nc, {7});
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(ii * stride_rep_i));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_i, builder.getInt32(jj * stride_rep_j));
|
||||
Value *ha = TA->get_value({current_offset_a_i, builder.CreateAdd(offset_a_k, _K)});
|
||||
Value *hb = TB->get_value({current_offset_b_i, builder.CreateAdd(offset_b_k, _K)});
|
||||
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(0));
|
||||
Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(1));
|
||||
Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(0));
|
||||
Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(1));
|
||||
std::vector<size_t> idx = {
|
||||
(ii*2 + 0) + (jj*4 + 0)*ld_fc,
|
||||
(ii*2 + 0) + (jj*4 + 1)*ld_fc,
|
||||
(ii*2 + 1) + (jj*4 + 0)*ld_fc,
|
||||
(ii*2 + 1) + (jj*4 + 1)*ld_fc,
|
||||
(ii*2 + 0) + (jj*4 + 2)*ld_fc,
|
||||
(ii*2 + 0) + (jj*4 + 3)*ld_fc,
|
||||
(ii*2 + 1) + (jj*4 + 2)*ld_fc,
|
||||
(ii*2 + 1) + (jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
|
||||
}
|
||||
|
||||
// write back
|
||||
|
Reference in New Issue
Block a user