[triton/codegen] added shared memory padding for HMMA arguments and vectorized loads

This commit is contained in:
Philippe Tillet
2019-06-11 13:27:54 -07:00
parent cbd916994d
commit 1b5a742a88
9 changed files with 61 additions and 39 deletions

View File

@@ -1,6 +1,7 @@
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/tune.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/type.h"
#include "triton/ir/value.h"
@@ -10,7 +11,7 @@
namespace triton{
namespace codegen{
bool shmem_allocation::is_ld_padded(ir::value *x) {
unsigned shmem_allocation::is_ld_padded(ir::value *x) {
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool result = false;
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
@@ -18,15 +19,24 @@ bool shmem_allocation::is_ld_padded(ir::value *x) {
return result;
}
if(dynamic_cast<ir::trans_inst*>(x))
return true;
return false;
return 4;
for(ir::user* user: x->get_users())
if(dynamic_cast<ir::dot_inst*>(user))
if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){
if(x == user->get_operand(0))
return 16;
else
return 16;
}
return 0;
}
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
if(is_ld_padded(x)){
unsigned pad = is_ld_padded(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
result += 4 * result / ld;
result += pad * result / ld;
}
if(buffer_info_->is_double(x))
result *= 2;