[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -258,7 +258,8 @@ public:
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt);
|
||||
analysis::align* align, target *tgt,
|
||||
bool is_tmp = false);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
@@ -276,6 +277,7 @@ public:
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
bool is_tmp() const { return is_tmp_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
@@ -290,6 +292,7 @@ private:
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
bool is_tmp_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -32,11 +32,12 @@ private:
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool has_sm80_;
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
@@ -15,18 +15,26 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
bool is_shmem_op(ir::instruction* i, int op);
|
||||
bool is_shmem_res(ir::value* i);
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
bool has_sm80_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -142,9 +142,9 @@ public:
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
|
||||
// Struct instructions
|
||||
value *create_insert_value(value* val, value *elt, size_t idx);
|
||||
value *create_extract_value(value* val, size_t idx);
|
||||
@@ -176,7 +176,7 @@ public:
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
|
||||
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
|
@@ -112,7 +112,7 @@ public:
|
||||
static function *create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod);
|
||||
// blocks
|
||||
const blocks_t &blocks() { return blocks_; }
|
||||
blocks_t &blocks() { return blocks_; }
|
||||
const blocks_t &blocks() const { return blocks_; }
|
||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||
|
||||
|
@@ -435,13 +435,31 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class io_inst: public instruction {
|
||||
public:
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
protected:
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
return "";
|
||||
}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
|
||||
protected:
|
||||
EVICTION_POLICY eviction_;
|
||||
};
|
||||
|
||||
// load
|
||||
@@ -453,14 +471,8 @@ public:
|
||||
CG,
|
||||
};
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
@@ -472,12 +484,6 @@ protected:
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
return "";
|
||||
}
|
||||
EVICTION_POLICY eviction_;
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
@@ -553,7 +559,7 @@ public:
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
@@ -564,11 +570,11 @@ public:
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
@@ -578,14 +584,14 @@ public:
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
@@ -755,6 +761,8 @@ private:
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
|
||||
io_inst(ty, id, num_ops, NORMAL, name, next) {}
|
||||
};
|
||||
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
@@ -856,6 +864,8 @@ public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
bool allow_tf32() const { return allow_tf32_; }
|
||||
bool is_trans_a() const { return AT_ == Trans; }
|
||||
bool is_trans_b() const { return BT_ == Trans; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -872,6 +882,8 @@ private:
|
||||
DataType C_type_ = DataType::FP32;
|
||||
DataType A_type_ = DataType::FP16;
|
||||
DataType B_type_ = DataType::FP16;
|
||||
TransT AT_;
|
||||
TransT BT_;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
|
@@ -22,6 +22,7 @@ public:
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user