[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -26,7 +26,10 @@ void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after)
|
||||
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
|
||||
if(!curr_phi)
|
||||
break;
|
||||
curr_phi->replace_uses_of_with(before, after);
|
||||
// curr_phi->replace_uses_of_with(before, after);
|
||||
for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx)
|
||||
if (curr_phi->get_incoming_block(idx) == before)
|
||||
curr_phi->set_incoming_block(idx, after);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in
|
||||
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
|
||||
return insert(unmasked_store_inst::create(ptr, val, eviction));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
return insert(masked_store_inst::create(ptr, val, mask));
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
|
||||
return insert(masked_store_inst::create(ptr, val, mask, eviction));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -412,8 +412,8 @@ value *builder::create_log(value *arg){
|
||||
return insert(log_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
|
||||
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
|
||||
value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
|
||||
return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
|
||||
}
|
||||
|
||||
value *builder::create_trans(value *A, const std::vector<int>& perm) {
|
||||
|
@@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
|
||||
|
||||
// Add incoming
|
||||
void phi_node::add_incoming(value *v, basic_block *block){
|
||||
assert(v && "PHI node got a null value!!");
|
||||
resize_ops(get_num_operands() + 1);
|
||||
blocks_.resize(get_num_operands() + 1);
|
||||
set_incoming_value(get_num_operands() - 1, v);
|
||||
@@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next)
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
|
||||
{ }
|
||||
|
||||
// unmasked_store
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new unmasked_store_inst(ptr, val, name, next);
|
||||
return new unmasked_store_inst(ptr, val, eviction, name, next);
|
||||
}
|
||||
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, eviction, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
|
@@ -43,6 +43,15 @@ std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
return result;
|
||||
}
|
||||
|
||||
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::post_order(fn)){
|
||||
auto inst_list = block->get_inst_list();
|
||||
for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++)
|
||||
do_work(*it);
|
||||
}
|
||||
}
|
||||
|
||||
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
|
Reference in New Issue
Block a user