[codegen/selection] performance fix-up when A is transposed for hmma
This commit is contained in:
@@ -8,12 +8,12 @@
|
|||||||
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
bool AT = false;
|
bool AT = true;
|
||||||
bool BT = true;
|
bool BT = false;
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
int32_t M = 8192, N = 8192, K = 8192;
|
int32_t M = 2048, N = 2048, K = 2048;
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
std::vector<float> ha(M*K);
|
std::vector<float> ha(M*K);
|
||||||
|
@@ -984,7 +984,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||||
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
|
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
|
||||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
||||||
|
unsigned NK = A->get_type()->get_tile_shapes()[red_axis]->get_value();
|
||||||
if(NK != 1)
|
if(NK != 1)
|
||||||
{
|
{
|
||||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||||
@@ -1147,6 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||||
|
// vector_size = result->axis(0).contiguous;
|
||||||
std::map<unsigned, Value*> packets;
|
std::map<unsigned, Value*> packets;
|
||||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
|
@@ -21,13 +21,13 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
|||||||
bool is_op_1 = x == dot->get_operand(1);
|
bool is_op_1 = x == dot->get_operand(1);
|
||||||
if(is_hmma && is_op_0){
|
if(is_hmma && is_op_0){
|
||||||
if(dot->is_a_trans())
|
if(dot->is_a_trans())
|
||||||
return 20;
|
return 4;
|
||||||
else
|
else
|
||||||
return 16;
|
return 16;
|
||||||
}
|
}
|
||||||
if(is_hmma && is_op_1){
|
if(is_hmma && is_op_1){
|
||||||
if(!dot->is_b_trans())
|
if(!dot->is_b_trans())
|
||||||
return 20;
|
return 4;
|
||||||
else
|
else
|
||||||
return 16;
|
return 16;
|
||||||
}
|
}
|
||||||
|
@@ -235,7 +235,7 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
|
Reference in New Issue
Block a user