[codegen/selection] tensor cores now used for transposed layotus
This commit is contained in:
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
bool AT = false;
|
bool AT = false;
|
||||||
bool BT = false;
|
bool BT = true;
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
|
@@ -25,6 +25,7 @@ def run_dot():
|
|||||||
# Test
|
# Test
|
||||||
hresult = np.dot(ha.T, hb.T).T
|
hresult = np.dot(ha.T, hb.T).T
|
||||||
dif = np.abs(result - hresult)
|
dif = np.abs(result - hresult)
|
||||||
|
np.savetxt('dif.dat', dif, '%2.4f')
|
||||||
print(hresult)
|
print(hresult)
|
||||||
print(result)
|
print(result)
|
||||||
print("dif: %f" % np.max(dif))
|
print("dif: %f" % np.max(dif))
|
||||||
|
@@ -538,8 +538,11 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
|
|
||||||
/* intra warp offset */
|
/* intra warp offset */
|
||||||
// offset of quad in pair
|
// offset of quad in pair
|
||||||
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_0 * pack_size_0_));
|
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||||
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_1 * pack_size_1_));
|
builder.getInt32(fpw_0 * pack_size_0_));
|
||||||
|
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
||||||
|
builder.getInt32(fpw_1 * pack_size_1_));
|
||||||
|
|
||||||
// Quad pair id
|
// Quad pair id
|
||||||
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||||
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||||
@@ -559,15 +562,17 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
// a offset
|
// a offset
|
||||||
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
||||||
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||||
// // b offsets
|
// b offsets
|
||||||
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
||||||
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||||
|
|
||||||
|
|
||||||
// c offsets
|
// c offsets
|
||||||
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||||
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||||
builder.CreateAdd(warp_offset_j, pair_b_off));
|
builder.CreateAdd(warp_offset_j, pair_b_off));
|
||||||
|
|
||||||
|
|
||||||
/* indices */
|
/* indices */
|
||||||
// i indices
|
// i indices
|
||||||
std::vector<Value*> idx_i;
|
std::vector<Value*> idx_i;
|
||||||
@@ -1026,7 +1031,25 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
|
|
||||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
|
Value *offset_a_i = offset_a_i_;
|
||||||
|
Value *offset_a_k = offset_a_k_;
|
||||||
|
Value *offset_b_j = offset_b_j_;
|
||||||
|
Value *offset_b_k = offset_b_k_;
|
||||||
|
|
||||||
|
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
|
||||||
|
if(dot->is_a_trans()){
|
||||||
|
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||||
|
offset_a_k = builder.getInt32(0);
|
||||||
|
}
|
||||||
|
if(!dot->is_b_trans()){
|
||||||
|
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||||
|
offset_b_k = builder.getInt32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string op_a = dot->is_a_trans() ? "row" : "col";
|
||||||
|
std::string op_b = dot->is_b_trans() ? "row" : "col";
|
||||||
|
|
||||||
|
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
|
||||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||||
"{$8, $9}, "
|
"{$8, $9}, "
|
||||||
"{$10, $11}, "
|
"{$10, $11}, "
|
||||||
@@ -1046,10 +1069,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
||||||
for(unsigned K = 0; K < NK; K += 4){
|
for(unsigned K = 0; K < NK; K += 4){
|
||||||
Value *_K = builder.getInt32(K);
|
Value *_K = builder.getInt32(K);
|
||||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i_, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j_, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||||
Value *ha = TA->get_value({current_offset_a_i, builder.CreateAdd(offset_a_k_, _K)});
|
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||||
Value *hb = TB->get_value({current_offset_b_i, builder.CreateAdd(offset_b_k_, _K)});
|
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||||
|
if(dot->is_a_trans())
|
||||||
|
std::swap(idx_a[0], idx_a[1]);
|
||||||
|
if(!dot->is_b_trans())
|
||||||
|
std::swap(idx_b[0], idx_b[1]);
|
||||||
|
Value *ha = TA->get_value(idx_a);
|
||||||
|
Value *hb = TB->get_value(idx_b);
|
||||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||||
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));
|
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));
|
||||||
|
@@ -15,9 +15,22 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
|||||||
if(dynamic_cast<ir::trans_inst*>(x))
|
if(dynamic_cast<ir::trans_inst*>(x))
|
||||||
return 4;
|
return 4;
|
||||||
for(ir::user* user: x->get_users())
|
for(ir::user* user: x->get_users())
|
||||||
if(dynamic_cast<ir::dot_inst*>(user))
|
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
||||||
if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){
|
bool is_hmma = params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C;
|
||||||
return 16;
|
bool is_op_0 = x == dot->get_operand(0);
|
||||||
|
bool is_op_1 = x == dot->get_operand(1);
|
||||||
|
if(is_hmma && is_op_0){
|
||||||
|
if(dot->is_a_trans())
|
||||||
|
return 20;
|
||||||
|
else
|
||||||
|
return 16;
|
||||||
|
}
|
||||||
|
if(is_hmma && is_op_1){
|
||||||
|
if(!dot->is_b_trans())
|
||||||
|
return 20;
|
||||||
|
else
|
||||||
|
return 16;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||||
unsigned result = 0;
|
unsigned result = 0;
|
||||||
|
@@ -221,7 +221,7 @@ void tune::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -235,7 +235,7 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
|
Reference in New Issue
Block a user