[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -26,7 +26,10 @@ void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after)
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
if(!curr_phi)
break;
curr_phi->replace_uses_of_with(before, after);
// curr_phi->replace_uses_of_with(before, after);
for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx)
if (curr_phi->get_incoming_block(idx) == before)
curr_phi->set_incoming_block(idx, after);
}
}

View File

@@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
}
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
return insert(unmasked_store_inst::create(ptr, val, eviction));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask){
return insert(masked_store_inst::create(ptr, val, mask));
value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
return insert(masked_store_inst::create(ptr, val, mask, eviction));
}
//===----------------------------------------------------------------------===//
@@ -412,8 +412,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg));
}
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
}
value *builder::create_trans(value *A, const std::vector<int>& perm) {

View File

@@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
// Add incoming
void phi_node::add_incoming(value *v, basic_block *block){
assert(v && "PHI node got a null value!!");
resize_ops(get_num_operands() + 1);
blocks_.resize(get_num_operands() + 1);
set_incoming_value(get_num_operands() - 1, v);
@@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
//===----------------------------------------------------------------------===//
// io_inst
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next)
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
{ }
// load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
{ }
// load
@@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
// store
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
{ }
// unmasked_store
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new unmasked_store_inst(ptr, val, name, next);
return new unmasked_store_inst(ptr, val, eviction, name, next);
}
// masked store
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, mask);
}
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, name, next);
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, eviction, name, next);
}
//===----------------------------------------------------------------------===//
@@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);

View File

@@ -43,6 +43,15 @@ std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
return result;
}
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::post_order(fn)){
auto inst_list = block->get_inst_list();
for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++)
do_work(*it);
}
}
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::reverse_post_order(fn))