[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -12,8 +12,8 @@ namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
// simplify layout conversions using the following simple rules:
@@ -64,15 +64,18 @@ void coalesce::run(ir::module &mod) {
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_rank() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma()){
if(layout_->get(op)->to_mma())
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
i->replace_uses_of_with(op, new_op);
}
// coalesce before copy_to_shared
// It's dirty, but the backend is being rewritten from scratch. :)
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
// only necessary for sm < 80 as Ampere+ can handle reduction
// on MMA layout
if(!has_sm80_)
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
if(ir::value* op = i->get_operand(0))
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_rank() == 2)
@@ -89,7 +92,8 @@ void coalesce::run(ir::module &mod) {
if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_rank()==2)
if(layout_->get(x)->to_mma()){
if(layout_->get(x)->to_mma())
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
builder.set_insert_point_after(x);
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x);

View File

@@ -1,8 +1,10 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
@@ -10,9 +12,9 @@ namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
bool cts::is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
return op == 0 || op == 1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
return false;
}
inline bool is_shmem_res(ir::value* v){
bool cts::is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
// run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
return;
}
// already in shared memory
@@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
copies.insert({x, copy});
parent->replace_uses_of_with(x, copies.at(x));
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k)){
add_copy(i, i->get_operand(k), builder, true);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
// Precompute where copies should be added
std::set<ir::value*> shmem_ops;
std::set<ir::value*> shmem_res;
ir::for_each_instruction(mod, [&](ir::instruction* i) {
if(i->get_id() == ir::INST_DOT){
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
ir::value* lhs = i->get_operand(0);
ir::type* ty = lhs->get_type()->get_scalar_ty();
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
// TODO: V100
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
if(is_lhs_shmem)
shmem_ops.insert(lhs);
shmem_ops.insert(i->get_operand(1));
}
}
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS ||
i->get_id() == ir::INST_COPY_TO_SHARED ||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
shmem_res.insert(i);
});
// Add shared copies
std::map<ir::value*, ir::value*> copies;
ir::builder &builder = mod.get_builder();
ir::for_each_instruction(mod, [&](ir::instruction* i) {
size_t num_op = i->get_num_operands();
for(size_t k = 0; k < num_op; k++){
ir::value* op = i->get_operand(k);
// copy to shared operands
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
if(is_shmem_op)
add_copy(i, op, builder, true, copies);
}
});
}

View File

@@ -87,7 +87,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}