[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
|
||||
|
||||
// Add incoming
|
||||
void phi_node::add_incoming(value *v, basic_block *block){
|
||||
assert(v && "PHI node got a null value!!");
|
||||
resize_ops(get_num_operands() + 1);
|
||||
blocks_.resize(get_num_operands() + 1);
|
||||
set_incoming_value(get_num_operands() - 1, v);
|
||||
@@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next)
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
|
||||
{ }
|
||||
|
||||
// unmasked_store
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new unmasked_store_inst(ptr, val, name, next);
|
||||
return new unmasked_store_inst(ptr, val, eviction, name, next);
|
||||
}
|
||||
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, eviction, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
|
Reference in New Issue
Block a user