[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
}
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
return insert(unmasked_store_inst::create(ptr, val, eviction));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask){
return insert(masked_store_inst::create(ptr, val, mask));
value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
return insert(masked_store_inst::create(ptr, val, mask, eviction));
}
//===----------------------------------------------------------------------===//
@@ -412,8 +412,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg));
}
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
}
value *builder::create_trans(value *A, const std::vector<int>& perm) {