[codegen/selection] tensor cores now used for transposed layotus

This commit is contained in:
Philippe Tillet
2019-07-17 17:20:38 -07:00
parent bfa39b8992
commit 2f0817b2cd
5 changed files with 57 additions and 14 deletions

View File

@@ -9,7 +9,7 @@
int main() { int main() {
bool AT = false; bool AT = false;
bool BT = false; bool BT = true;
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters // matrix multiplication parameters

View File

@@ -25,6 +25,7 @@ def run_dot():
# Test # Test
hresult = np.dot(ha.T, hb.T).T hresult = np.dot(ha.T, hb.T).T
dif = np.abs(result - hresult) dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult) print(hresult)
print(result) print(result)
print("dif: %f" % np.max(dif)) print("dif: %f" % np.max(dif))

View File

@@ -538,8 +538,11 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
/* intra warp offset */ /* intra warp offset */
// offset of quad in pair // offset of quad in pair
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_0 * pack_size_0_)); Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_1 * pack_size_1_)); builder.getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
builder.getInt32(fpw_1 * pack_size_1_));
// Quad pair id // Quad pair id
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
@@ -559,15 +562,17 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
// a offset // a offset
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a)); offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder.CreateAnd(u_thread_id, _3); offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
// // b offsets // b offsets
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b)); offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder.CreateAnd(u_thread_id, _3); offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
// c offsets // c offsets
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_); Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2), Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAdd(warp_offset_j, pair_b_off)); builder.CreateAdd(warp_offset_j, pair_b_off));
/* indices */ /* indices */
// i indices // i indices
std::vector<Value*> idx_i; std::vector<Value*> idx_i;
@@ -1026,7 +1031,25 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 " Value *offset_a_i = offset_a_i_;
Value *offset_a_k = offset_a_k_;
Value *offset_b_j = offset_b_j_;
Value *offset_b_k = offset_b_k_;
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
if(dot->is_a_trans()){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_a_k = builder.getInt32(0);
}
if(!dot->is_b_trans()){
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_b_k = builder.getInt32(0);
}
std::string op_a = dot->is_a_trans() ? "row" : "col";
std::string op_b = dot->is_b_trans() ? "row" : "col";
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, " "{$0, $1, $2, $3, $4, $5, $6, $7}, "
"{$8, $9}, " "{$8, $9}, "
"{$10, $11}, " "{$10, $11}, "
@@ -1046,10 +1069,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){ for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned K = 0; K < NK; K += 4){ for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K); Value *_K = builder.getInt32(K);
Value *current_offset_a_i = builder.CreateAdd(offset_a_i_, builder.getInt32(pack_i*stride_rep_i*pack_size_0_)); Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j_, builder.getInt32(pack_j*stride_rep_j*pack_size_1_)); Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
Value *ha = TA->get_value({current_offset_a_i, builder.CreateAdd(offset_a_k_, _K)}); indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
Value *hb = TB->get_value({current_offset_b_i, builder.CreateAdd(offset_b_k_, _K)}); indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
if(dot->is_a_trans())
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++) for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){ for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)); Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));

View File

@@ -15,9 +15,22 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
if(dynamic_cast<ir::trans_inst*>(x)) if(dynamic_cast<ir::trans_inst*>(x))
return 4; return 4;
for(ir::user* user: x->get_users()) for(ir::user* user: x->get_users())
if(dynamic_cast<ir::dot_inst*>(user)) if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){ bool is_hmma = params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C;
return 16; bool is_op_0 = x == dot->get_operand(0);
bool is_op_1 = x == dot->get_operand(1);
if(is_hmma && is_op_0){
if(dot->is_a_trans())
return 20;
else
return 16;
}
if(is_hmma && is_op_1){
if(!dot->is_b_trans())
return 20;
else
return 16;
}
} }
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) { if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
unsigned result = 0; unsigned result = 0;

View File

@@ -221,7 +221,7 @@ void tune::run(ir::module &mod) {
} }
else { else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4); ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
} }
} }
@@ -235,7 +235,7 @@ void tune::run(ir::module &mod) {
continue; continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
*params_.at(i).at("nts.d0") = *tmp; *params_.at(i).at("nts.d0") = *tmp;
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){