[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -258,7 +258,8 @@ public:
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt);
|
||||
analysis::align* align, target *tgt,
|
||||
bool is_tmp = false);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
@@ -276,6 +277,7 @@ public:
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
bool is_tmp() const { return is_tmp_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
@@ -290,6 +292,7 @@ private:
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
bool is_tmp_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -32,11 +32,12 @@ private:
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool has_sm80_;
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
@@ -15,18 +15,26 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
bool is_shmem_op(ir::instruction* i, int op);
|
||||
bool is_shmem_res(ir::value* i);
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
bool has_sm80_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -142,9 +142,9 @@ public:
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
|
||||
// Struct instructions
|
||||
value *create_insert_value(value* val, value *elt, size_t idx);
|
||||
value *create_extract_value(value* val, size_t idx);
|
||||
@@ -176,7 +176,7 @@ public:
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
|
||||
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
|
@@ -112,7 +112,7 @@ public:
|
||||
static function *create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod);
|
||||
// blocks
|
||||
const blocks_t &blocks() { return blocks_; }
|
||||
blocks_t &blocks() { return blocks_; }
|
||||
const blocks_t &blocks() const { return blocks_; }
|
||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||
|
||||
|
@@ -435,13 +435,31 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class io_inst: public instruction {
|
||||
public:
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
protected:
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
return "";
|
||||
}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
|
||||
protected:
|
||||
EVICTION_POLICY eviction_;
|
||||
};
|
||||
|
||||
// load
|
||||
@@ -453,14 +471,8 @@ public:
|
||||
CG,
|
||||
};
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
@@ -472,12 +484,6 @@ protected:
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
return "";
|
||||
}
|
||||
EVICTION_POLICY eviction_;
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
@@ -553,7 +559,7 @@ public:
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
@@ -564,11 +570,11 @@ public:
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
@@ -578,14 +584,14 @@ public:
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
@@ -755,6 +761,8 @@ private:
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
|
||||
io_inst(ty, id, num_ops, NORMAL, name, next) {}
|
||||
};
|
||||
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
@@ -856,6 +864,8 @@ public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
bool allow_tf32() const { return allow_tf32_; }
|
||||
bool is_trans_a() const { return AT_ == Trans; }
|
||||
bool is_trans_b() const { return BT_ == Trans; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -872,6 +882,8 @@ private:
|
||||
DataType C_type_ = DataType::FP32;
|
||||
DataType A_type_ = DataType::FP16;
|
||||
DataType B_type_ = DataType::FP16;
|
||||
TransT AT_;
|
||||
TransT BT_;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
|
@@ -22,6 +22,7 @@ public:
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
|
@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* x: V){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -212,11 +212,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
order_ = {0, 1};
|
||||
}
|
||||
else{
|
||||
// fpw_ = {1, 1, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 2};
|
||||
order_ = {1, 0};
|
||||
// rep_ = {2, 2, 1};
|
||||
}
|
||||
|
||||
/* warps per tile */
|
||||
@@ -233,6 +231,24 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
// try to have a warp own entire rows of the output
|
||||
// this makes it easier to fuse multiple mmas by fusing
|
||||
// registers
|
||||
bool one_warp_per_row = false;
|
||||
for(ir::value* v: values)
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto* dot = dynamic_cast<ir::dot_inst*>(u);
|
||||
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
|
||||
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
|
||||
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
|
||||
}
|
||||
// std::cout << one_warp_per_row << std::endl;
|
||||
|
||||
if(one_warp_per_row){
|
||||
wpt_[1] = 1;
|
||||
wpt_[0] = num_warps;
|
||||
}
|
||||
else{
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
@@ -250,6 +266,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
}
|
||||
} while(changed);
|
||||
}
|
||||
}
|
||||
|
||||
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
@@ -430,8 +449,8 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
|
||||
analysis::align* align, target *tgt, bool is_tmp)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
@@ -619,7 +638,7 @@ void layouts::create_tmp_layout(size_t id, data_layout *arg,
|
||||
ir::instruction *i, bool is_index) {
|
||||
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
|
||||
: i->get_type()->get_scalar_ty();
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
|
||||
if (is_index) {
|
||||
tmp_index_[i] = id;
|
||||
} else {
|
||||
|
@@ -14,40 +14,105 @@ namespace analysis{
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || layout->is_tmp())
|
||||
continue;
|
||||
for(ir::value* v:layout->get_values()){
|
||||
layouts_map[v].insert(layout);
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
|
||||
|
||||
std::map<ir::user*, std::set<shared_layout*>> live_in;
|
||||
while(true){
|
||||
bool changed = false;
|
||||
ir::instruction* last_inst = nullptr;
|
||||
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
|
||||
// gen
|
||||
std::set<shared_layout*> gen;
|
||||
for(ir::value* v: i->ops())
|
||||
for(shared_layout* layout: layouts_map[v])
|
||||
gen.insert(layout);
|
||||
// kill
|
||||
std::set<shared_layout*> kill;
|
||||
for(shared_layout* layout: layouts_map[i])
|
||||
kill.insert(layout);
|
||||
// temporaries are handled separately
|
||||
if(layouts_->has_tmp(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
}
|
||||
if(layouts_->has_tmp_index(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
}
|
||||
// live-out
|
||||
std::set<shared_layout*> live_out;
|
||||
std::vector<ir::instruction*> succs = {last_inst};
|
||||
if(i == i->get_parent()->get_inst_list().back())
|
||||
for(ir::basic_block* succ: i->get_parent()->get_successors())
|
||||
succs.push_back(succ->get_inst_list().front());
|
||||
for(ir::instruction* succ: succs)
|
||||
for(shared_layout* layout: live_in[succ])
|
||||
if(!layout->is_tmp())
|
||||
live_out.insert(layout);
|
||||
|
||||
// new sets
|
||||
std::set<shared_layout*> live_out_minus_kill;
|
||||
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
|
||||
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
|
||||
std::set<shared_layout*> new_live_in;
|
||||
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
|
||||
std::inserter(new_live_in, new_live_in.end()));
|
||||
|
||||
changed = changed || (new_live_in != live_in[i]);
|
||||
live_in[i] = new_live_in;
|
||||
last_inst = i;
|
||||
});
|
||||
if(!changed)
|
||||
break;
|
||||
}
|
||||
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction* i){
|
||||
// i->print(std::cout);
|
||||
// std::cout << " live_in: " << live_in[i].size() << std::endl;
|
||||
// });
|
||||
|
||||
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
slot_index index = 0;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* instr){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
});
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(layout)
|
||||
intervals_[layout] = segment{INT32_MAX, 0};
|
||||
}
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second)
|
||||
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second){
|
||||
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
|
||||
}
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
@@ -28,12 +28,15 @@ void swizzle::run(ir::module &) {
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int per_phase = 1;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(in_layout)
|
||||
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
else
|
||||
per_phase = 1;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
@@ -46,7 +49,7 @@ void swizzle::run(ir::module &) {
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
|
@@ -31,27 +31,28 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::transform::inliner inliner;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||
codegen::transform::pipeline pipeline(has_sm80, num_stages);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
|
||||
codegen::transform::cts cts(&layouts, has_sm80);
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
// run passes
|
||||
inliner.run(ir);
|
||||
dce.run(ir);
|
||||
// ir.print(std::cout);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
@@ -84,10 +85,15 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
// std::cout << "---" << std::endl;
|
||||
// ir.print(std::cout);
|
||||
// std::cout << "---" << std::endl;
|
||||
// ir.print(std::cout);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
// exit(1);
|
||||
// ir.print(std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
shared_static = allocation.allocated_size();
|
||||
|
@@ -744,11 +744,13 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *lane = urem(tid, i32(32));
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
bool is_mma_first_row = false;
|
||||
if(op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(op);
|
||||
size_t aln = alignment_->get(op, ord[0]);
|
||||
@@ -757,11 +759,15 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
max_eq = std::max<size_t>(max_eq, 1);
|
||||
aln = std::min(aln, max_eq);
|
||||
}
|
||||
auto layout = layouts_->get(x)->to_scanline();
|
||||
if(layout){
|
||||
size_t nts = layout->nts(ord[0]);
|
||||
vec = std::min(nts, aln);
|
||||
}
|
||||
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x));
|
||||
assert(layout);
|
||||
|
||||
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
|
||||
// TODO: generalize
|
||||
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
||||
(a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
|
||||
if(is_mma_first_row)
|
||||
vec = std::min<size_t>(2, aln);
|
||||
}
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
@@ -795,8 +801,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
int tot_width = nbits*vec;
|
||||
int width = std::min(tot_width, max_word_width);
|
||||
int n_words = std::max(1, tot_width / width);
|
||||
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||
has_evict_policy = false; // currently disable until supported in `store`
|
||||
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||
// has_evict_policy = false; // currently disable until supported in `store`
|
||||
// -----
|
||||
// create inline asm string
|
||||
// -----
|
||||
@@ -810,7 +816,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||
if (has_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
@@ -822,7 +828,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
asm_oss << "}";
|
||||
asm_oss << ", [ $" << n_words + 1; // load
|
||||
asm_oss << " + " << in_off << "]"; // constant offset
|
||||
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
asm_oss << ";";
|
||||
bool has_other = other && (other != UndefValue::get(other->getType()));
|
||||
std::vector<Value *> others;
|
||||
@@ -844,7 +850,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
||||
asm_oss << "0x" << std::hex << cst->getSExtValue();
|
||||
else{
|
||||
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
|
||||
asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
||||
others.push_back(v);
|
||||
}
|
||||
asm_oss.flags(flags);
|
||||
@@ -859,7 +865,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(Value *v: others)
|
||||
arg_tys.push_back(v->getType());
|
||||
if (has_evict_policy)
|
||||
if (has_l2_evict_policy)
|
||||
arg_tys.push_back(i64_ty);
|
||||
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
||||
// ---
|
||||
@@ -875,7 +881,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
}
|
||||
if (has_evict_policy)
|
||||
if (has_l2_evict_policy)
|
||||
asm_cstrt += ",l";
|
||||
// ---
|
||||
// finally call inline ASM
|
||||
@@ -884,7 +890,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(Value *v: others)
|
||||
args.push_back(v);
|
||||
if (has_evict_policy)
|
||||
if (has_l2_evict_policy)
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
|
||||
|
||||
@@ -935,6 +941,9 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
// operands
|
||||
ir::value *ptr_op = x->get_pointer_operand();
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
ir::value *msk_op = nullptr;
|
||||
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
|
||||
msk_op = msk_st->get_mask_operand();
|
||||
// vector size
|
||||
size_t vec = 1;
|
||||
if(val_op->get_type()->is_block_ty()){
|
||||
@@ -946,36 +955,107 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
max_eq = std::max<size_t>(max_eq, 1);
|
||||
aln = std::min(aln, max_eq);
|
||||
}
|
||||
vec = std::min(nts, aln);
|
||||
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op));
|
||||
assert(layout);
|
||||
// vec = std::min(nts, aln);
|
||||
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
|
||||
// TODO: generalize
|
||||
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
||||
(a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
|
||||
if(is_mma_first_row)
|
||||
vec = std::min<size_t>(2, aln);
|
||||
}
|
||||
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
if(ty->isIntegerTy(1))
|
||||
ty = builder_->getInt8Ty();
|
||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
indices_t idx = idxs[i];
|
||||
// pointers
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
// vectorize
|
||||
Type *v_ty = vec_ty(ty, vec);
|
||||
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
|
||||
// value
|
||||
Value* val = UndefValue::get(v_ty);
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
|
||||
if(mx){
|
||||
Value *msk = vals_[mx->get_mask_operand()][idx];
|
||||
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
||||
builder_->SetInsertPoint(no_op->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
store(val, ptr);
|
||||
builder_->SetInsertPoint(no_op);
|
||||
size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8);
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
|
||||
size_t in_off;
|
||||
if(in_gep){
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||
ptr = cst ? in_gep->getPointerOperand() : in_gep;
|
||||
}
|
||||
else
|
||||
store(val, ptr);
|
||||
else{
|
||||
in_off = 0;
|
||||
}
|
||||
// mask
|
||||
Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue();
|
||||
size_t nbits = dtsize*8;
|
||||
// pack sub-words (< 32/64bits) into words
|
||||
// each load has width min(nbits*vec, 32/64)
|
||||
// and there are (nbits * vec)/width of them
|
||||
int max_word_width = std::max<int>(32, nbits);
|
||||
int tot_width = nbits*vec;
|
||||
int width = std::min(tot_width, max_word_width);
|
||||
int n_words = std::max(1, tot_width / width);
|
||||
// -----
|
||||
// create inline asm string
|
||||
// -----
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$0"; // predicate
|
||||
asm_oss << " st.global";
|
||||
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
asm_oss << " [ $1 + " << in_off << "]";
|
||||
asm_oss << " , {";
|
||||
for(int i = 0; i < n_words; i++){ // return values
|
||||
if(i > 0) asm_oss << ",";
|
||||
asm_oss << "$" << 2 + i;
|
||||
}
|
||||
asm_oss << "}";
|
||||
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
asm_oss << ";";
|
||||
// ----
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
Type* val_arg_ty = IntegerType::get(*ctx_, width);
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(int ii = 0; ii < n_words; ii++)
|
||||
arg_tys.push_back(val_arg_ty);
|
||||
if (has_l2_evict_policy)
|
||||
arg_tys.push_back(i64_ty);
|
||||
FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
|
||||
// ---
|
||||
// create inline ASM constraints
|
||||
// ---
|
||||
std::string asm_cstrt = "b,l";
|
||||
for(int ii = 0; ii < n_words; ii++){
|
||||
asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
}
|
||||
if (has_l2_evict_policy)
|
||||
asm_cstrt += ",l";
|
||||
// ---
|
||||
// finally call inline ASM
|
||||
// ---
|
||||
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(unsigned int ii = 0; ii < n_words; ii++){
|
||||
size_t n_subw = width / nbits;
|
||||
Value* curr = UndefValue::get(vec_ty(ty, n_subw));
|
||||
for(unsigned int jj = 0; jj < n_subw; jj++){
|
||||
Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]];
|
||||
if(new_elt->getType()->isIntegerTy(1))
|
||||
new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty());
|
||||
new_elt = bit_cast(new_elt, ty);
|
||||
curr = builder_->CreateInsertElement(curr, new_elt, jj);
|
||||
}
|
||||
args.push_back(bit_cast(curr, val_arg_ty));
|
||||
}
|
||||
if (has_l2_evict_policy)
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
call(_asm, args);
|
||||
}
|
||||
}
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
||||
@@ -1098,6 +1178,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
// Value *ex2arg = vals_[x->get_operand(0)][idx];
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
}
|
||||
}
|
||||
@@ -1291,6 +1372,18 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
// order
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
bool is_a_trans = C->is_trans_a();
|
||||
// is_a_trans = false;
|
||||
if(C->is_trans_a()){
|
||||
std::swap(ord_a[0], ord_a[1]);
|
||||
std::swap(shape_a[0], shape_a[1]);
|
||||
std::swap(offset_a_m_, offset_a_k_);
|
||||
}
|
||||
// std::cout << "visiting" << std::endl;
|
||||
// if(C->is_trans_b()){
|
||||
// std::swap(ord_b[0], ord_b[1]);
|
||||
// std::swap(shape_b[0], shape_b[1]);
|
||||
// }
|
||||
// layouts
|
||||
analysis::mma_layout* layout_c = layouts_->get(C)->to_mma();
|
||||
analysis::shared_layout* layout_a = layouts_->get(A)->to_shared();
|
||||
@@ -1322,6 +1415,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
int step_b0 = is_b_row ? stride_rep_n : stride_rep_k;
|
||||
int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1);
|
||||
|
||||
|
||||
// max_phase_a = 4;
|
||||
// vec_a = 8;
|
||||
// std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl;
|
||||
// std::cout << vec_a << " " << vec_b << std::endl;
|
||||
|
||||
/* --------------------------------- */
|
||||
/* --- pre-compute pointer lanes --- */
|
||||
/* --------------------------------- */
|
||||
@@ -1916,12 +2015,17 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
if(C->is_trans_a()){
|
||||
std::swap(ord_a[0], ord_a[1]);
|
||||
std::swap(shape_a[0], shape_a[1]);
|
||||
}
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
if(C->is_trans_b()){
|
||||
std::swap(ord_b[0], ord_b[1]);
|
||||
std::swap(shape_b[0], shape_b[1]);
|
||||
}
|
||||
NK = shape_a[1];
|
||||
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
|
||||
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
|
||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
||||
bool is_a_row = ord_a[0] == 1;
|
||||
bool is_b_row = ord_b[0] == 1;
|
||||
|
||||
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
|
||||
const int mma_instr_m = mma_instr_shape[0];
|
||||
@@ -1933,10 +2037,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
const int mat_shape_n = mat_shape[1];
|
||||
const int mat_shape_k = mat_shape[2];
|
||||
|
||||
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||
const int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||
const int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||
|
||||
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
|
||||
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
||||
@@ -2001,7 +2101,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
if(FirstBB != CurrBB)
|
||||
|
||||
// if true, this will move pointer declarations to the entry basic block
|
||||
// not prefetched cases tend to be more limited in resource usage
|
||||
// so we don't pre-compute ptrs to save registers
|
||||
bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB);
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
@@ -2015,47 +2120,137 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
auto register_lds2 =
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
|
||||
if (k < 2 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
|
||||
} else
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
|
||||
// v (s0_0(0), s1_0(2), | *num_rep_k
|
||||
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
|
||||
// -----------
|
||||
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
|
||||
std::function<void(int,int,int,bool)> load_a;
|
||||
analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared();
|
||||
bool is_a_shared = layout_a != nullptr;
|
||||
if(is_a_shared) {
|
||||
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
||||
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
||||
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
|
||||
int num_ptr_a = a_loader.get_num_ptr();
|
||||
// pointers
|
||||
std::vector<Value*> ptrs_a(num_ptr_a);
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
for(int i = 0; i < num_ptr_a; i++)
|
||||
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
// loading function
|
||||
load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
|
||||
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
||||
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
||||
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
||||
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
|
||||
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
|
||||
};
|
||||
}
|
||||
else {
|
||||
load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||
distributed_axis ax_n = axes_.at(a_axes_->get(A, 1));
|
||||
int ldm = ax_n.values.size();
|
||||
if(ldm != num_rep_k*4)
|
||||
throw std::runtime_error("Internal compiler error when trying to fuse matmuls!");
|
||||
// std::cout << m << " " << k << std::endl;
|
||||
// std::cout << idxs_[A].size() << std::endl;
|
||||
// std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
|
||||
// int ldm = num_rep_k*4;
|
||||
Value* ha0 = UndefValue::get(fp16x2_ty);
|
||||
Value* ha1 = UndefValue::get(fp16x2_ty);
|
||||
Value* ha2 = UndefValue::get(fp16x2_ty);
|
||||
Value* ha3 = UndefValue::get(fp16x2_ty);
|
||||
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
|
||||
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
|
||||
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));
|
||||
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1));
|
||||
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0));
|
||||
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1));
|
||||
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0));
|
||||
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1));
|
||||
ha[{m, k}] = ha0;
|
||||
ha[{m+1, k}] = ha1;
|
||||
ha[{m, k+1}] = ha2;
|
||||
ha[{m+1, k+1}] = ha3;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
// | -> n (col-major)
|
||||
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
|
||||
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
|
||||
// -----------
|
||||
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
|
||||
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
|
||||
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
|
||||
analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared();
|
||||
const int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||
const int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||
std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n};
|
||||
std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n};
|
||||
int k_order_b = 0;
|
||||
// if(C->is_trans_b()){
|
||||
// std::swap(mma_instr_b[0], mma_instr_b[1]);
|
||||
// std::swap(mat_shape_b[0], mat_shape_b[1]);
|
||||
// k_order_b = k_order_b ^ 1;
|
||||
// std::swap(ord_b[0], ord_b[1]);
|
||||
// std::swap(shape_b[0], shape_b[1]);
|
||||
// }
|
||||
|
||||
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b,
|
||||
mma_instr_b, mat_shape_b,
|
||||
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
|
||||
int num_ptr_b = b_loader.get_num_ptr();
|
||||
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
// A pointer
|
||||
std::vector<Value*> ptrs_a(num_ptr_a);
|
||||
for(int i = 0; i < num_ptr_a; i++)
|
||||
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
|
||||
// B pointer
|
||||
// pointers
|
||||
int num_ptr_b = b_loader.get_num_ptr();
|
||||
std::vector<Value*> ptrs_b(num_ptr_b);
|
||||
for(int i = 0; i < num_ptr_b; i++)
|
||||
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
|
||||
|
||||
|
||||
// loading function
|
||||
std::function<void(int,int,int,bool)> load_b;
|
||||
load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
||||
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
||||
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
|
||||
};
|
||||
|
||||
|
||||
|
||||
// create mma & unpack result, m, n, k are offsets in mat
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
|
||||
" {$0, $1, $2, $3},"
|
||||
" {$4, $5, $6, $7},"
|
||||
" {$8, $9},"
|
||||
" {$10, $11, $12, $13};",
|
||||
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
// create mma & unpack result, m, n, k are offsets in mat
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned cols_per_thread = num_rep_n * 2;
|
||||
std::vector<size_t> idx = {
|
||||
(m + 0)*cols_per_thread + (n*2 + 0),
|
||||
@@ -2072,39 +2267,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
|
||||
};
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
auto register_lds2 =
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
|
||||
if (k < 2 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
|
||||
} else
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
||||
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
||||
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
||||
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
|
||||
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
||||
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
||||
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
|
||||
};
|
||||
|
||||
if (C->is_prefetched()) {
|
||||
// create phis
|
||||
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
||||
@@ -2163,6 +2325,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
i = 0;
|
||||
vals_[C][idx] = fcs.at(key)[i++];
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2384,7 +2547,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
} else if (layout->to_mma()) {
|
||||
shuffle_width = 4;
|
||||
warps_per_inner = layout->to_mma()->wpt(1);
|
||||
col_per_thread = 16;
|
||||
col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
|
||||
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
|
||||
}
|
||||
assert(warp_j != nullptr);
|
||||
@@ -2403,6 +2566,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
Value* is_warp0 = icmp_eq(warp, i32(0));
|
||||
Value* is_thread0 = icmp_eq(thread, i32(0));
|
||||
Value* lane_j = urem(lane, i32(shuffle_width));
|
||||
if(warps_per_inner > 1)
|
||||
add_barrier();
|
||||
// compute partial sum for each warp, and store to shared memory
|
||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||
@@ -2425,6 +2589,11 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
// store partial result to shared memory
|
||||
auto x_idxs = idxs_[x][i];
|
||||
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||
// single warp on the reduce dimension -- no need to use shmem
|
||||
if(warps_per_inner==1){
|
||||
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
|
||||
}
|
||||
else{
|
||||
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
|
||||
if (with_index) {
|
||||
@@ -2432,6 +2601,9 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
|
||||
}
|
||||
}
|
||||
}
|
||||
if(warps_per_inner==1)
|
||||
return;
|
||||
add_barrier();
|
||||
// at this point, partial accumulator synchronized in shared memory
|
||||
// Just need to reduce `warp_per_inner` numbers in shared memory
|
||||
@@ -2559,6 +2731,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
case ir::reduce_inst::FMAX: return max_num(x, y);
|
||||
case ir::reduce_inst::FMIN: return min_num(x, y);
|
||||
case ir::reduce_inst::XOR: return xor_(x, y);
|
||||
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
};
|
||||
@@ -2639,7 +2812,9 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||
Value *base;
|
||||
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
|
||||
int off = alloc_->offset(layouts_->get(layouts_->tmp(out)));
|
||||
// std::cout << off << std::endl;
|
||||
base = gep(shmem_, i32(off));
|
||||
base = bit_cast(base, ptr_ty(ty, 3));
|
||||
std::vector<int> n_reps;
|
||||
for(int i = 0; i < shape.size(); i++){
|
||||
@@ -2821,15 +2996,26 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
//
|
||||
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
|
||||
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
|
||||
if(in_layout->to_mma()){
|
||||
mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]);
|
||||
mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]);
|
||||
per_phase = 1;
|
||||
max_phase = 8;
|
||||
}
|
||||
|
||||
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
|
||||
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
|
||||
if(in_layout->to_mma()){
|
||||
n_shared_0 = 8;
|
||||
n_shared_1 = 1;
|
||||
}
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
auto shapes = cts->get_type()->get_block_shapes();
|
||||
|
||||
|
||||
// store to shared
|
||||
Value *current = nullptr;
|
||||
std::map<std::pair<int, int>, Value*> ptrs;
|
||||
@@ -2844,9 +3030,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
// input ptr info
|
||||
int id_0 = id % (in_ld/min_vec);
|
||||
int id_1 = id / (in_ld/min_vec);
|
||||
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
|
||||
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
|
||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||
// std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl;
|
||||
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
||||
if(ptrs.find(key) == ptrs.end()){
|
||||
if(FirstBB->getTerminator())
|
||||
@@ -2865,6 +3049,13 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
ptrs[key] = gep(shmems_.at(cts), {off});
|
||||
}
|
||||
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
|
||||
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
|
||||
if(in_layout->to_mma()){
|
||||
off_0 = id_0/n_shared_0*n_shared_0*8;
|
||||
off_1 = id_1/n_shared_1*n_shared_1*8;
|
||||
}
|
||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||
Value* ptr = gep(ptrs[key], {i32(off)});
|
||||
ptr = bit_cast(ptr, current->getType()->getPointerTo(3));
|
||||
// asm
|
||||
@@ -3069,7 +3260,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
if(tgt_->as_nvidia()->sm() >= 80)
|
||||
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
|
||||
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
|
||||
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
|
||||
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;";
|
||||
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
|
||||
policies_[evict] = call(iasm);
|
||||
}
|
||||
@@ -3299,7 +3490,6 @@ void generator::visit_basic_block(ir::basic_block * block) {
|
||||
BasicBlock *parent = bbs_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// i->print(std::cout);
|
||||
visit_value(i);
|
||||
// std::cout << "done" << std::endl;
|
||||
}
|
||||
@@ -3324,7 +3514,10 @@ void generator::init_idx(ir::value *v) {
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> ord(rank);
|
||||
// compute axes
|
||||
// std::cout << "axes" << std::endl;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
// std::cout << d << " " << shapes[d] << std::endl;
|
||||
// std::cout << a_axes_->get(v, d) << std::endl;
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
@@ -3334,6 +3527,7 @@ void generator::init_idx(ir::value *v) {
|
||||
axes[d].values = {i32(0)};
|
||||
}
|
||||
}
|
||||
// std::cout << "axes ok" << std::endl;
|
||||
// compute order
|
||||
analysis::data_layout* layout = layouts_->get(v);
|
||||
std::iota(ord.begin(), ord.end(), 0);
|
||||
@@ -3480,6 +3674,7 @@ void generator::finalize_phi_node(ir::phi_node *x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
PHINode *phi = (PHINode*)vals_[x][idx];
|
||||
Value *inc = vals_[x->get_incoming_value(n)][idx];
|
||||
// x->print(std::cout);
|
||||
phi->addIncoming(inc, block);
|
||||
}
|
||||
}
|
||||
|
@@ -12,8 +12,8 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
|
||||
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
|
||||
|
||||
|
||||
// simplify layout conversions using the following simple rules:
|
||||
@@ -64,15 +64,18 @@ void coalesce::run(ir::module &mod) {
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(op->get_type()->get_tile_rank() == 2)
|
||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||
if(layout_->get(op)->to_mma()){
|
||||
if(layout_->get(op)->to_mma())
|
||||
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||
builder.set_insert_point(i);
|
||||
builder.insert(new_op);
|
||||
i->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
// coalesce before copy_to_shared
|
||||
// It's dirty, but the backend is being rewritten from scratch. :)
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
// only necessary for sm < 80 as Ampere+ can handle reduction
|
||||
// on MMA layout
|
||||
if(!has_sm80_)
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(0))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(op->get_type()->get_tile_rank() == 2)
|
||||
@@ -89,7 +92,8 @@ void coalesce::run(ir::module &mod) {
|
||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(x->get_type()->is_block_ty())
|
||||
if(x->get_type()->get_tile_rank()==2)
|
||||
if(layout_->get(x)->to_mma()){
|
||||
if(layout_->get(x)->to_mma())
|
||||
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
|
||||
builder.insert(new_x);
|
||||
|
@@ -1,8 +1,10 @@
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
@@ -10,7 +12,7 @@ namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
bool cts::is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op == 0 || op == 1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
bool cts::is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
|
||||
|
||||
|
||||
// run pass on module
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
|
||||
return;
|
||||
}
|
||||
// already in shared memory
|
||||
@@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
copies.insert({x, copy});
|
||||
parent->replace_uses_of_with(x, copies.at(x));
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
// Precompute where copies should be added
|
||||
std::set<ir::value*> shmem_ops;
|
||||
std::set<ir::value*> shmem_res;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* i) {
|
||||
if(i->get_id() == ir::INST_DOT){
|
||||
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
|
||||
ir::value* lhs = i->get_operand(0);
|
||||
ir::type* ty = lhs->get_type()->get_scalar_ty();
|
||||
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
|
||||
// TODO: V100
|
||||
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
|
||||
if(is_lhs_shmem)
|
||||
shmem_ops.insert(lhs);
|
||||
shmem_ops.insert(i->get_operand(1));
|
||||
}
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
shmem_ops.insert(i->get_operand(0));
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
shmem_ops.insert(i->get_operand(0));
|
||||
if(i->get_id() == ir::INST_TRANS ||
|
||||
i->get_id() == ir::INST_COPY_TO_SHARED ||
|
||||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
shmem_res.insert(i);
|
||||
});
|
||||
|
||||
// Add shared copies
|
||||
std::map<ir::value*, ir::value*> copies;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* i) {
|
||||
size_t num_op = i->get_num_operands();
|
||||
for(size_t k = 0; k < num_op; k++){
|
||||
ir::value* op = i->get_operand(k);
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k)){
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
}
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
|
||||
if(is_shmem_op)
|
||||
add_copy(i, op, builder, true, copies);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
@@ -87,7 +87,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
|
@@ -26,7 +26,10 @@ void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after)
|
||||
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
|
||||
if(!curr_phi)
|
||||
break;
|
||||
curr_phi->replace_uses_of_with(before, after);
|
||||
// curr_phi->replace_uses_of_with(before, after);
|
||||
for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx)
|
||||
if (curr_phi->get_incoming_block(idx) == before)
|
||||
curr_phi->set_incoming_block(idx, after);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in
|
||||
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
|
||||
return insert(unmasked_store_inst::create(ptr, val, eviction));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
return insert(masked_store_inst::create(ptr, val, mask));
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
|
||||
return insert(masked_store_inst::create(ptr, val, mask, eviction));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -412,8 +412,8 @@ value *builder::create_log(value *arg){
|
||||
return insert(log_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
|
||||
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
|
||||
value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
|
||||
return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
|
||||
}
|
||||
|
||||
value *builder::create_trans(value *A, const std::vector<int>& perm) {
|
||||
|
@@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
|
||||
|
||||
// Add incoming
|
||||
void phi_node::add_incoming(value *v, basic_block *block){
|
||||
assert(v && "PHI node got a null value!!");
|
||||
resize_ops(get_num_operands() + 1);
|
||||
blocks_.resize(get_num_operands() + 1);
|
||||
set_incoming_value(get_num_operands() - 1, v);
|
||||
@@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next)
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
|
||||
{ }
|
||||
|
||||
// unmasked_store
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new unmasked_store_inst(ptr, val, name, next);
|
||||
return new unmasked_store_inst(ptr, val, eviction, name, next);
|
||||
}
|
||||
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, eviction, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
|
@@ -43,6 +43,15 @@ std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
return result;
|
||||
}
|
||||
|
||||
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::post_order(fn)){
|
||||
auto inst_list = block->get_inst_list();
|
||||
for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++)
|
||||
do_work(*it);
|
||||
}
|
||||
}
|
||||
|
||||
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
|
@@ -840,10 +840,10 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float32', 'int8']
|
||||
if not (allow_tf32 and (dtype == 'int8'))])
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
@@ -852,21 +852,30 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 128, 128, 64
|
||||
num_warps = 8
|
||||
trans_a, trans_b = False, False
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr):
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -875,39 +884,65 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
if ADD_COLS:
|
||||
ZCs = Z + off_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
if allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
ALLOW_TF32=allow_tf32)
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
z_ref = np.matmul(x, y)
|
||||
x_ref = x.T if trans_a else x
|
||||
y_ref = y.T if trans_b else y
|
||||
z_ref = np.matmul(x_ref, y_ref)
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
z_ref += z[:, 0][:, None]
|
||||
if epilogue == 'add-cols':
|
||||
z_ref += z[0, :][None, :]
|
||||
if epilogue == 'softmax':
|
||||
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
|
@@ -211,7 +211,7 @@ class ValueConstructor:
|
||||
return phi
|
||||
v = unique_handles.pop()
|
||||
phi.handle.replace_all_uses_with(v)
|
||||
phi.handle.erase_from_parent()
|
||||
# phi.handle.erase_from_parent()
|
||||
# TODO: remove trivial phis recursively
|
||||
return triton.language.tensor(v, phi.type)
|
||||
|
||||
|
@@ -732,7 +732,7 @@ def reshape(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, _builder=None):
|
||||
def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -744,7 +744,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
|
||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
return semantic.dot(input, other, allow_tf32, _builder)
|
||||
return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -782,7 +782,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
||||
|
||||
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, _builder=None):
|
||||
def store(pointer, value, eviction_policy="", mask=None, _builder=None):
|
||||
"""
|
||||
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
|
||||
@@ -799,7 +799,7 @@ def store(pointer, value, mask=None, _builder=None):
|
||||
value = _to_tensor(value, _builder)
|
||||
if mask is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
return semantic.store(pointer, value, mask, _builder)
|
||||
return semantic.store(pointer, value, mask, eviction_policy, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
@@ -648,6 +648,18 @@ def cast(input: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def _parse_eviction_policy(eviction_policy):
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
return eviction
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
@@ -684,14 +696,7 @@ def load(ptr: tl.tensor,
|
||||
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
||||
|
||||
# eviction policy
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
|
||||
if ptr.type.is_block():
|
||||
shape = ptr.type.get_block_shapes()
|
||||
@@ -721,6 +726,7 @@ def load(ptr: tl.tensor,
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
@@ -735,14 +741,15 @@ def store(ptr: tl.tensor,
|
||||
elt_ty_ptr = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
# eviction policy
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
# cast to target data-type
|
||||
val = cast(val, elt_ty, builder)
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void)
|
||||
if not mask.type.scalar.is_bool():
|
||||
raise ValueError("Mask must have boolean scalar type")
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void)
|
||||
|
||||
#########
|
||||
# atomic
|
||||
@@ -897,27 +904,31 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
def dot(a: tl.tensor,
|
||||
b: tl.tensor,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[-1] == rhs.shape[0]
|
||||
assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\
|
||||
in_a = 1 if not trans_a else 0
|
||||
in_b = 1 if trans_b else 0
|
||||
assert a.type.is_block() and b.type.is_block()
|
||||
assert len(a.shape) == 2 and len(b.shape) == 2
|
||||
assert a.shape[in_a] == b.shape[in_b]
|
||||
assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\
|
||||
"small blocks not supported!"
|
||||
if lhs.type.scalar.is_int():
|
||||
if a.type.scalar.is_int():
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
else:
|
||||
_0 = builder.get_float32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
M = lhs.type.shape[0]
|
||||
N = rhs.type.shape[1]
|
||||
M = a.type.shape[in_a ^ 1]
|
||||
N = b.type.shape[in_b ^ 1]
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32)
|
||||
return tl.tensor(ret, ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
198
python/tutorials/06-fused-attention.py
Normal file
198
python/tutorials/06-fused-attention.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kk, stride_kn,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_qm = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
|
||||
q = tl.load(q_ptrs)
|
||||
for start_n in range(0, start_qm + 1):
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs)
|
||||
qk = tl.dot(q, k)
|
||||
qk += tl.where(offs_m[:, None] >= (start_n * BLOCK_N + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
p = p.to(tl.float16)
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# r_ptrs += BLOCK_N
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
start_qm = tl.program_id(0)
|
||||
offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_out = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_out
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-2]
|
||||
assert Lq == Lk
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=64, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
return o
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_MODEL', [(2, 3, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_MODEL, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = .5 * torch.randn((Z, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
ref_qk = torch.matmul(q, k)
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
ref_qk[:, :, M == 0] = float("-inf")
|
||||
ref_qk = torch.softmax(ref_qk, dim=-1)
|
||||
ref_out = torch.matmul(ref_qk, v)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 2048, 64
|
||||
# vary batch size for fixed heads / seq
|
||||
batch_bench = triton.testing.Benchmark(
|
||||
x_names=['BATCH'],
|
||||
x_vals=[2**i for i in range(0, 8)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-seq{N_CTX}-head{N_HEADS}-d{D_HEAD}',
|
||||
args={'H': N_HEADS, 'N_CTX': N_CTX, 'D_MODEL': D_HEAD, 'dtype': torch.float16}
|
||||
)
|
||||
# vary seq length for fixed head and batch=4
|
||||
seq_bench = triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}',
|
||||
args={'H': D_HEAD, 'BATCH': BATCH, 'D_MODEL': D_HEAD, 'dtype': torch.float16}
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report([batch_bench, seq_bench])
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_MODEL, provider, dtype=torch.float16, device="cuda"):
|
||||
warmup = 25
|
||||
rep = 500
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
fn = lambda: attention(q, k, v)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_MODEL), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user