[triton/codegen] added shared memory padding for HMMA arguments and vectorized loads
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
@@ -10,7 +11,7 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
bool shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
bool result = false;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||
@@ -18,15 +19,24 @@ bool shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
return result;
|
||||
}
|
||||
if(dynamic_cast<ir::trans_inst*>(x))
|
||||
return true;
|
||||
return false;
|
||||
return 4;
|
||||
for(ir::user* user: x->get_users())
|
||||
if(dynamic_cast<ir::dot_inst*>(user))
|
||||
if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){
|
||||
if(x == user->get_operand(0))
|
||||
return 16;
|
||||
else
|
||||
return 16;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
if(is_ld_padded(x)){
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
|
||||
result += 4 * result / ld;
|
||||
result += pad * result / ld;
|
||||
}
|
||||
if(buffer_info_->is_double(x))
|
||||
result *= 2;
|
||||
|
Reference in New Issue
Block a user