[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -258,7 +258,8 @@ public:
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align, target *tgt);
analysis::align* align, target *tgt,
bool is_tmp = false);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
@@ -276,6 +277,7 @@ public:
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
bool is_tmp() const { return is_tmp_; }
private:
size_t size_;
@@ -290,6 +292,7 @@ private:
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
bool is_tmp_;
};

View File

@@ -32,11 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
bool has_sm80_;
analysis::align* align_;
analysis::layouts* layout_;
};

View File

@@ -15,18 +15,26 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
bool is_shmem_op(ir::instruction* i, int op);
bool is_shmem_res(ir::value* i);
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
public:
cts(bool use_async = false): use_async_(use_async) {}
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
void run(ir::module &mod);
private:
bool use_async_;
bool has_sm80_;
analysis::layouts* layouts_;
};
}

View File

@@ -142,9 +142,9 @@ public:
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val);
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask);
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
// Struct instructions
value *create_insert_value(value* val, value *elt, size_t idx);
value *create_extract_value(value* val, size_t idx);
@@ -176,7 +176,7 @@ public:
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);

View File

@@ -112,7 +112,7 @@ public:
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);

View File

@@ -435,13 +435,31 @@ private:
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
public:
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
protected:
EVICTION_POLICY eviction_;
};
// load
@@ -453,14 +471,8 @@ public:
CG,
};
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
bool get_is_volatile() const { return is_volatile_; }
protected:
@@ -472,12 +484,6 @@ protected:
if (cache_ == CG) return ".cg";
return "";
}
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
EVICTION_POLICY eviction_;
CACHE_MODIFIER cache_;
std::string get_volatile_repr() {
@@ -553,7 +559,7 @@ public:
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
public:
@@ -564,11 +570,11 @@ public:
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
@@ -578,14 +584,14 @@ public:
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
@@ -755,6 +761,8 @@ private:
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
io_inst(ty, id, num_ops, NORMAL, name, next) {}
};
class atomic_rmw_inst: public atomic_inst {
@@ -856,6 +864,8 @@ public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
bool is_trans_a() const { return AT_ == Trans; }
bool is_trans_b() const { return BT_ == Trans; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
@@ -872,6 +882,8 @@ private:
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
TransT AT_;
TransT BT_;
};
//class outer_inst: public builtin_inst {

View File

@@ -22,6 +22,7 @@ public:
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}