[GENERAL] Merged einsum feature branch. Various feature, performance
improvements and bugfixes: * Added preliminary support for extended Einstein summation in PyTriton * Significant performance improvement on FP32 kernels containing matrix multiplication * Added re-coalescing pass for FP16 kernels containing matrix multiplication * Various bugfixes
This commit is contained in:
@@ -1,21 +1,37 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
inline bool is_shared(ir::value *v) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
return storage_info.at(i->get_id()).first == codegen::SHARED;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_REDUCE)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
@@ -36,9 +52,8 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
ir::value_id_t id = i->get_id();
|
||||
// already in shared memory
|
||||
if(to_shared && storage_info.at(id).first == SHARED)
|
||||
if(to_shared && is_shmem_res(i))
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
@@ -53,18 +68,19 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
auto storage = storage_info.at(i->get_id());
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < storage.second.size(); k++)
|
||||
if(storage.second[k] == SHARED)
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k))
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < storage.second.size(); k++)
|
||||
if(storage.second[k] == DISTRIBUTED &&
|
||||
is_shared(i->get_operand(k))){
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user