[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -258,7 +258,8 @@ public:
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align, target *tgt);
analysis::align* align, target *tgt,
bool is_tmp = false);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
@@ -276,6 +277,7 @@ public:
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
bool is_tmp() const { return is_tmp_; }
private:
size_t size_;
@@ -290,6 +292,7 @@ private:
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
bool is_tmp_;
};

View File

@@ -32,11 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
bool has_sm80_;
analysis::align* align_;
analysis::layouts* layout_;
};

View File

@@ -15,18 +15,26 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
bool is_shmem_op(ir::instruction* i, int op);
bool is_shmem_res(ir::value* i);
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
public:
cts(bool use_async = false): use_async_(use_async) {}
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
void run(ir::module &mod);
private:
bool use_async_;
bool has_sm80_;
analysis::layouts* layouts_;
};
}

View File

@@ -142,9 +142,9 @@ public:
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val);
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask);
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
// Struct instructions
value *create_insert_value(value* val, value *elt, size_t idx);
value *create_extract_value(value* val, size_t idx);
@@ -176,7 +176,7 @@ public:
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);

View File

@@ -112,7 +112,7 @@ public:
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);

View File

@@ -435,13 +435,31 @@ private:
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
public:
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
protected:
EVICTION_POLICY eviction_;
};
// load
@@ -453,14 +471,8 @@ public:
CG,
};
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
bool get_is_volatile() const { return is_volatile_; }
protected:
@@ -472,12 +484,6 @@ protected:
if (cache_ == CG) return ".cg";
return "";
}
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
EVICTION_POLICY eviction_;
CACHE_MODIFIER cache_;
std::string get_volatile_repr() {
@@ -553,7 +559,7 @@ public:
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
public:
@@ -564,11 +570,11 @@ public:
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
@@ -578,14 +584,14 @@ public:
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
@@ -755,6 +761,8 @@ private:
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
io_inst(ty, id, num_ops, NORMAL, name, next) {}
};
class atomic_rmw_inst: public atomic_inst {
@@ -856,6 +864,8 @@ public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
bool is_trans_a() const { return AT_ == Trans; }
bool is_trans_b() const { return BT_ == Trans; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
@@ -872,6 +882,8 @@ private:
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
TransT AT_;
TransT BT_;
};
//class outer_inst: public builtin_inst {

View File

@@ -22,6 +22,7 @@ public:
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}

View File

@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
for(shared_layout* x: V){
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
}
}
}

View File

@@ -212,11 +212,9 @@ mma_layout::mma_layout(size_t num_warps,
order_ = {0, 1};
}
else{
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 2};
order_ = {1, 0};
// rep_ = {2, 2, 1};
}
/* warps per tile */
@@ -233,6 +231,24 @@ mma_layout::mma_layout(size_t num_warps,
}while(wpt_nm1 != wpt_);
} else {
bool changed = false;
// try to have a warp own entire rows of the output
// this makes it easier to fuse multiple mmas by fusing
// registers
bool one_warp_per_row = false;
for(ir::value* v: values)
for(ir::user* u: v->get_users()){
auto* dot = dynamic_cast<ir::dot_inst*>(u);
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
}
// std::cout << one_warp_per_row << std::endl;
if(one_warp_per_row){
wpt_[1] = 1;
wpt_[0] = num_warps;
}
else{
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
@@ -250,6 +266,9 @@ mma_layout::mma_layout(size_t num_warps,
}
} while(changed);
}
}
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
/* shape per block */
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
@@ -430,8 +449,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
analysis::align* align, target *tgt, bool is_tmp)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
size_ = 0;
arg_layout_ = arg;
@@ -619,7 +638,7 @@ void layouts::create_tmp_layout(size_t id, data_layout *arg,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
if (is_index) {
tmp_index_[i] = id;
} else {

View File

@@ -14,40 +14,105 @@ namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(!layout || layout->is_tmp())
continue;
for(ir::value* v:layout->get_values()){
layouts_map[v].insert(layout);
}
}
// create live intervals
std::map<ir::user*, std::set<shared_layout*>> live_in;
while(true){
bool changed = false;
ir::instruction* last_inst = nullptr;
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
// gen
std::set<shared_layout*> gen;
for(ir::value* v: i->ops())
for(shared_layout* layout: layouts_map[v])
gen.insert(layout);
// kill
std::set<shared_layout*> kill;
for(shared_layout* layout: layouts_map[i])
kill.insert(layout);
// temporaries are handled separately
if(layouts_->has_tmp(i)){
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
}
if(layouts_->has_tmp_index(i)){
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
}
// live-out
std::set<shared_layout*> live_out;
std::vector<ir::instruction*> succs = {last_inst};
if(i == i->get_parent()->get_inst_list().back())
for(ir::basic_block* succ: i->get_parent()->get_successors())
succs.push_back(succ->get_inst_list().front());
for(ir::instruction* succ: succs)
for(shared_layout* layout: live_in[succ])
if(!layout->is_tmp())
live_out.insert(layout);
// new sets
std::set<shared_layout*> live_out_minus_kill;
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
std::set<shared_layout*> new_live_in;
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
std::inserter(new_live_in, new_live_in.end()));
changed = changed || (new_live_in != live_in[i]);
live_in[i] = new_live_in;
last_inst = i;
});
if(!changed)
break;
}
// ir::for_each_instruction(mod, [&](ir::instruction* i){
// i->print(std::cout);
// std::cout << " live_in: " << live_in[i].size() << std::endl;
// });
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
slot_index index = 0;
ir::for_each_instruction(mod, [&](ir::instruction* instr){
index += 1;
indices.insert({instr, index});
});
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(layout)
intervals_[layout] = segment{INT32_MAX, 0};
}
for(auto& x: live_in)
for(shared_layout* layout: x.second)
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
for(auto& x: live_in)
for(shared_layout* layout: x.second){
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
}
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
}

View File

@@ -28,12 +28,15 @@ void swizzle::run(ir::module &) {
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int per_phase = 1;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(in_layout)
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
else
per_phase = 1;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
@@ -46,7 +49,7 @@ void swizzle::run(ir::module &) {
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}

View File

@@ -31,27 +31,28 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::transform::inliner inliner;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::pipeline pipeline(has_sm80, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::transform::cts cts(&layouts, has_sm80);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
// run passes
inliner.run(ir);
dce.run(ir);
// ir.print(std::cout);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
@@ -84,10 +85,15 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
liveness.run(ir);
allocation.run(ir);
prefetch_s.run(ir);
barriers.run(ir);
// exit(1);
// ir.print(std::cout);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();

View File

@@ -744,11 +744,13 @@ void generator::visit_load_inst(ir::load_inst* x){
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *lane = urem(tid, i32(32));
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
// compute vector width
size_t vec = 1;
bool is_mma_first_row = false;
if(op->get_type()->is_block_ty()){
auto ord = ords_.at(op);
size_t aln = alignment_->get(op, ord[0]);
@@ -757,11 +759,15 @@ void generator::visit_load_inst(ir::load_inst* x){
max_eq = std::max<size_t>(max_eq, 1);
aln = std::min(aln, max_eq);
}
auto layout = layouts_->get(x)->to_scanline();
if(layout){
size_t nts = layout->nts(ord[0]);
vec = std::min(nts, aln);
}
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x));
assert(layout);
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
// TODO: generalize
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
(a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
if(is_mma_first_row)
vec = std::min<size_t>(2, aln);
}
// code generation
auto idxs = idxs_.at(x);
@@ -795,8 +801,8 @@ void generator::visit_load_inst(ir::load_inst* x){
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
has_evict_policy = false; // currently disable until supported in `store`
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
// has_evict_policy = false; // currently disable until supported in `store`
// -----
// create inline asm string
// -----
@@ -810,7 +816,7 @@ void generator::visit_load_inst(ir::load_inst* x){
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (has_evict_policy) asm_oss << ".L2::cache_hint";
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
@@ -822,7 +828,7 @@ void generator::visit_load_inst(ir::load_inst* x){
asm_oss << "}";
asm_oss << ", [ $" << n_words + 1; // load
asm_oss << " + " << in_off << "]"; // constant offset
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
bool has_other = other && (other != UndefValue::get(other->getType()));
std::vector<Value *> others;
@@ -844,7 +850,7 @@ void generator::visit_load_inst(ir::load_inst* x){
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
asm_oss << "0x" << std::hex << cst->getSExtValue();
else{
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii;
others.push_back(v);
}
asm_oss.flags(flags);
@@ -859,7 +865,7 @@ void generator::visit_load_inst(ir::load_inst* x){
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
if (has_evict_policy)
if (has_l2_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
// ---
@@ -875,7 +881,7 @@ void generator::visit_load_inst(ir::load_inst* x){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_evict_policy)
if (has_l2_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
@@ -884,7 +890,7 @@ void generator::visit_load_inst(ir::load_inst* x){
std::vector<Value*> args = {pred, ptr};
for(Value *v: others)
args.push_back(v);
if (has_evict_policy)
if (has_l2_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
@@ -935,6 +941,9 @@ void generator::visit_store_inst(ir::store_inst * x){
// operands
ir::value *ptr_op = x->get_pointer_operand();
ir::value *val_op = x->get_value_operand();
ir::value *msk_op = nullptr;
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
msk_op = msk_st->get_mask_operand();
// vector size
size_t vec = 1;
if(val_op->get_type()->is_block_ty()){
@@ -946,36 +955,107 @@ void generator::visit_store_inst(ir::store_inst * x){
max_eq = std::max<size_t>(max_eq, 1);
aln = std::min(aln, max_eq);
}
vec = std::min(nts, aln);
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op));
assert(layout);
// vec = std::min(nts, aln);
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
// TODO: generalize
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
(a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
if(is_mma_first_row)
vec = std::min<size_t>(2, aln);
}
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
if(ty->isIntegerTy(1))
ty = builder_->getInt8Ty();
for(size_t i = 0; i < idxs.size(); i += vec){
auto idx = idxs[i];
// pointer
indices_t idx = idxs[i];
// pointers
Value *ptr = vals_[ptr_op][idx];
// vectorize
Type *v_ty = vec_ty(ty, vec);
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
// value
Value* val = UndefValue::get(v_ty);
for(size_t ii = 0; ii < vec; ii++)
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
builder_->SetInsertPoint(no_op->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
store(val, ptr);
builder_->SetInsertPoint(no_op);
size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8);
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
size_t in_off;
if(in_gep){
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
ptr = cst ? in_gep->getPointerOperand() : in_gep;
}
else
store(val, ptr);
else{
in_off = 0;
}
// mask
Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue();
size_t nbits = dtsize*8;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// -----
// create inline asm string
// -----
std::ostringstream asm_oss;
asm_oss << "@$0"; // predicate
asm_oss << " st.global";
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
asm_oss << " [ $1 + " << in_off << "]";
asm_oss << " , {";
for(int i = 0; i < n_words; i++){ // return values
if(i > 0) asm_oss << ",";
asm_oss << "$" << 2 + i;
}
asm_oss << "}";
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
// ----
// create inline ASM signature
// ---
Type* val_arg_ty = IntegerType::get(*ctx_, width);
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(int ii = 0; ii < n_words; ii++)
arg_tys.push_back(val_arg_ty);
if (has_l2_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
// ---
// create inline ASM constraints
// ---
std::string asm_cstrt = "b,l";
for(int ii = 0; ii < n_words; ii++){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_l2_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
// ---
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
std::vector<Value*> args = {pred, ptr};
for(unsigned int ii = 0; ii < n_words; ii++){
size_t n_subw = width / nbits;
Value* curr = UndefValue::get(vec_ty(ty, n_subw));
for(unsigned int jj = 0; jj < n_subw; jj++){
Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]];
if(new_elt->getType()->isIntegerTy(1))
new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty());
new_elt = bit_cast(new_elt, ty);
curr = builder_->CreateInsertElement(curr, new_elt, jj);
}
args.push_back(bit_cast(curr, val_arg_ty));
}
if (has_l2_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
call(_asm, args);
}
}
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
@@ -1098,6 +1178,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
// Value *ex2arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
}
}
@@ -1291,6 +1372,18 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
// order
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
bool is_a_trans = C->is_trans_a();
// is_a_trans = false;
if(C->is_trans_a()){
std::swap(ord_a[0], ord_a[1]);
std::swap(shape_a[0], shape_a[1]);
std::swap(offset_a_m_, offset_a_k_);
}
// std::cout << "visiting" << std::endl;
// if(C->is_trans_b()){
// std::swap(ord_b[0], ord_b[1]);
// std::swap(shape_b[0], shape_b[1]);
// }
// layouts
analysis::mma_layout* layout_c = layouts_->get(C)->to_mma();
analysis::shared_layout* layout_a = layouts_->get(A)->to_shared();
@@ -1322,6 +1415,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
int step_b0 = is_b_row ? stride_rep_n : stride_rep_k;
int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1);
// max_phase_a = 4;
// vec_a = 8;
// std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl;
// std::cout << vec_a << " " << vec_b << std::endl;
/* --------------------------------- */
/* --- pre-compute pointer lanes --- */
/* --------------------------------- */
@@ -1916,12 +2015,17 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
if(C->is_trans_a()){
std::swap(ord_a[0], ord_a[1]);
std::swap(shape_a[0], shape_a[1]);
}
auto ord_b = layouts_->get(B)->get_order();
if(C->is_trans_b()){
std::swap(ord_b[0], ord_b[1]);
std::swap(shape_b[0], shape_b[1]);
}
NK = shape_a[1];
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1;
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
const int mma_instr_m = mma_instr_shape[0];
@@ -1933,10 +2037,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
const int mat_shape_n = mat_shape[1];
const int mat_shape_k = mat_shape[2];
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
@@ -2001,7 +2101,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
if(FirstBB != CurrBB)
// if true, this will move pointer declarations to the entry basic block
// not prefetched cases tend to be more limited in resource usage
// so we don't pre-compute ptrs to save registers
bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB);
if(licm_ptrs)
builder_->SetInsertPoint(FirstBB->getTerminator());
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
@@ -2015,47 +2120,137 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds2 =
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
} else
vals[{mn, k}] = val;
};
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
// v (s0_0(0), s1_0(2), | *num_rep_k
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
// -----------
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
std::function<void(int,int,int,bool)> load_a;
analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared();
bool is_a_shared = layout_a != nullptr;
if(is_a_shared) {
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
int num_ptr_a = a_loader.get_num_ptr();
// pointers
std::vector<Value*> ptrs_a(num_ptr_a);
if(licm_ptrs)
builder_->SetInsertPoint(CurrBB);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
if(licm_ptrs)
builder_->SetInsertPoint(FirstBB->getTerminator());
// loading function
load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
}
else {
load_a = [&](int m, int k, int inc, bool is_prefetch) {
distributed_axis ax_n = axes_.at(a_axes_->get(A, 1));
int ldm = ax_n.values.size();
if(ldm != num_rep_k*4)
throw std::runtime_error("Internal compiler error when trying to fuse matmuls!");
// std::cout << m << " " << k << std::endl;
// std::cout << idxs_[A].size() << std::endl;
// std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
// int ldm = num_rep_k*4;
Value* ha0 = UndefValue::get(fp16x2_ty);
Value* ha1 = UndefValue::get(fp16x2_ty);
Value* ha2 = UndefValue::get(fp16x2_ty);
Value* ha3 = UndefValue::get(fp16x2_ty);
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1));
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0));
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1));
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0));
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1));
ha[{m, k}] = ha0;
ha[{m+1, k}] = ha1;
ha[{m, k+1}] = ha2;
ha[{m+1, k+1}] = ha3;
};
}
// | -> n (col-major)
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
// -----------
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared();
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n};
std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n};
int k_order_b = 0;
// if(C->is_trans_b()){
// std::swap(mma_instr_b[0], mma_instr_b[1]);
// std::swap(mat_shape_b[0], mat_shape_b[1]);
// k_order_b = k_order_b ^ 1;
// std::swap(ord_b[0], ord_b[1]);
// std::swap(shape_b[0], shape_b[1]);
// }
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b,
mma_instr_b, mat_shape_b,
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
int num_ptr_b = b_loader.get_num_ptr();
if(licm_ptrs)
builder_->SetInsertPoint(CurrBB);
// A pointer
std::vector<Value*> ptrs_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
// B pointer
// pointers
int num_ptr_b = b_loader.get_num_ptr();
std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
// loading function
std::function<void(int,int,int,bool)> load_b;
load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
" {$0, $1, $2, $3},"
" {$4, $5, $6, $7},"
" {$8, $9},"
" {$10, $11, $12, $13};",
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
unsigned cols_per_thread = num_rep_n * 2;
std::vector<size_t> idx = {
(m + 0)*cols_per_thread + (n*2 + 0),
@@ -2072,39 +2267,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
};
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds2 =
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
} else
vals[{mn, k}] = val;
};
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
@@ -2163,6 +2325,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
i = 0;
vals_[C][idx] = fcs.at(key)[i++];
};
}
/**
@@ -2384,7 +2547,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
} else if (layout->to_mma()) {
shuffle_width = 4;
warps_per_inner = layout->to_mma()->wpt(1);
col_per_thread = 16;
col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
}
assert(warp_j != nullptr);
@@ -2403,6 +2566,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
Value* is_warp0 = icmp_eq(warp, i32(0));
Value* is_thread0 = icmp_eq(thread, i32(0));
Value* lane_j = urem(lane, i32(shuffle_width));
if(warps_per_inner > 1)
add_barrier();
// compute partial sum for each warp, and store to shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
@@ -2425,6 +2589,11 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
// store partial result to shared memory
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
// single warp on the reduce dimension -- no need to use shmem
if(warps_per_inner==1){
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
}
else{
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
if (with_index) {
@@ -2432,6 +2601,9 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
}
}
}
if(warps_per_inner==1)
return;
add_barrier();
// at this point, partial accumulator synchronized in shared memory
// Just need to reduce `warp_per_inner` numbers in shared memory
@@ -2559,6 +2731,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
case ir::reduce_inst::FMAX: return max_num(x, y);
case ir::reduce_inst::FMIN: return min_num(x, y);
case ir::reduce_inst::XOR: return xor_(x, y);
default: throw std::runtime_error("unreachable");
}
};
@@ -2639,7 +2812,9 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
Value *base;
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
int off = alloc_->offset(layouts_->get(layouts_->tmp(out)));
// std::cout << off << std::endl;
base = gep(shmem_, i32(off));
base = bit_cast(base, ptr_ty(ty, 3));
std::vector<int> n_reps;
for(int i = 0; i < shape.size(); i++){
@@ -2821,15 +2996,26 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
//
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
if(in_layout->to_mma()){
mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]);
mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]);
per_phase = 1;
max_phase = 8;
}
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
if(in_layout->to_mma()){
n_shared_0 = 8;
n_shared_1 = 1;
}
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_block_shapes();
// store to shared
Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs;
@@ -2844,9 +3030,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
// input ptr info
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
int off = (off_1*shapes[in_order[0]] + off_0);
// std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl;
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(ptrs.find(key) == ptrs.end()){
if(FirstBB->getTerminator())
@@ -2865,6 +3049,13 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
builder_->SetInsertPoint(CurrBB);
ptrs[key] = gep(shmems_.at(cts), {off});
}
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
if(in_layout->to_mma()){
off_0 = id_0/n_shared_0*n_shared_0*8;
off_1 = id_1/n_shared_1*n_shared_1*8;
}
int off = (off_1*shapes[in_order[0]] + off_0);
Value* ptr = gep(ptrs[key], {i32(off)});
ptr = bit_cast(ptr, current->getType()->getPointerTo(3));
// asm
@@ -3069,7 +3260,7 @@ void generator::visit_function(ir::function* fn) {
if(tgt_->as_nvidia()->sm() >= 80)
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;";
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
policies_[evict] = call(iasm);
}
@@ -3299,7 +3490,6 @@ void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = bbs_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
// i->print(std::cout);
visit_value(i);
// std::cout << "done" << std::endl;
}
@@ -3324,7 +3514,10 @@ void generator::init_idx(ir::value *v) {
std::vector<distributed_axis> axes(rank);
std::vector<int> ord(rank);
// compute axes
// std::cout << "axes" << std::endl;
for(size_t d = 0; d < shapes.size(); d++){
// std::cout << d << " " << shapes[d] << std::endl;
// std::cout << a_axes_->get(v, d) << std::endl;
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
@@ -3334,6 +3527,7 @@ void generator::init_idx(ir::value *v) {
axes[d].values = {i32(0)};
}
}
// std::cout << "axes ok" << std::endl;
// compute order
analysis::data_layout* layout = layouts_->get(v);
std::iota(ord.begin(), ord.end(), 0);
@@ -3480,6 +3674,7 @@ void generator::finalize_phi_node(ir::phi_node *x) {
for(indices_t idx: idxs_.at(x)){
PHINode *phi = (PHINode*)vals_[x][idx];
Value *inc = vals_[x->get_incoming_value(n)][idx];
// x->print(std::cout);
phi->addIncoming(inc, block);
}
}

View File

@@ -12,8 +12,8 @@ namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
// simplify layout conversions using the following simple rules:
@@ -64,15 +64,18 @@ void coalesce::run(ir::module &mod) {
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_rank() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma()){
if(layout_->get(op)->to_mma())
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
i->replace_uses_of_with(op, new_op);
}
// coalesce before copy_to_shared
// It's dirty, but the backend is being rewritten from scratch. :)
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
// only necessary for sm < 80 as Ampere+ can handle reduction
// on MMA layout
if(!has_sm80_)
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
if(ir::value* op = i->get_operand(0))
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_rank() == 2)
@@ -89,7 +92,8 @@ void coalesce::run(ir::module &mod) {
if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_rank()==2)
if(layout_->get(x)->to_mma()){
if(layout_->get(x)->to_mma())
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
builder.set_insert_point_after(x);
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x);

View File

@@ -1,8 +1,10 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
@@ -10,7 +12,7 @@ namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
bool cts::is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op == 0 || op == 1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
return false;
}
inline bool is_shmem_res(ir::value* v){
bool cts::is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
// run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
return;
}
// already in shared memory
@@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
copies.insert({x, copy});
parent->replace_uses_of_with(x, copies.at(x));
}
void cts::run(ir::module &mod) {
// Precompute where copies should be added
std::set<ir::value*> shmem_ops;
std::set<ir::value*> shmem_res;
ir::for_each_instruction(mod, [&](ir::instruction* i) {
if(i->get_id() == ir::INST_DOT){
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
ir::value* lhs = i->get_operand(0);
ir::type* ty = lhs->get_type()->get_scalar_ty();
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
// TODO: V100
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
if(is_lhs_shmem)
shmem_ops.insert(lhs);
shmem_ops.insert(i->get_operand(1));
}
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS ||
i->get_id() == ir::INST_COPY_TO_SHARED ||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
shmem_res.insert(i);
});
// Add shared copies
std::map<ir::value*, ir::value*> copies;
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
ir::for_each_instruction(mod, [&](ir::instruction* i) {
size_t num_op = i->get_num_operands();
for(size_t k = 0; k < num_op; k++){
ir::value* op = i->get_operand(k);
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k)){
add_copy(i, i->get_operand(k), builder, true);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
if(is_shmem_op)
add_copy(i, op, builder, true, copies);
}
});
}

View File

@@ -87,7 +87,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}

View File

@@ -26,7 +26,10 @@ void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after)
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
if(!curr_phi)
break;
curr_phi->replace_uses_of_with(before, after);
// curr_phi->replace_uses_of_with(before, after);
for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx)
if (curr_phi->get_incoming_block(idx) == before)
curr_phi->set_incoming_block(idx, after);
}
}

View File

@@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
}
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
return insert(unmasked_store_inst::create(ptr, val, eviction));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask){
return insert(masked_store_inst::create(ptr, val, mask));
value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
return insert(masked_store_inst::create(ptr, val, mask, eviction));
}
//===----------------------------------------------------------------------===//
@@ -412,8 +412,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg));
}
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
}
value *builder::create_trans(value *A, const std::vector<int>& perm) {

View File

@@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
// Add incoming
void phi_node::add_incoming(value *v, basic_block *block){
assert(v && "PHI node got a null value!!");
resize_ops(get_num_operands() + 1);
blocks_.resize(get_num_operands() + 1);
set_incoming_value(get_num_operands() - 1, v);
@@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
//===----------------------------------------------------------------------===//
// io_inst
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next)
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
{ }
// load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
{ }
// load
@@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
// store
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
{ }
// unmasked_store
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new unmasked_store_inst(ptr, val, name, next);
return new unmasked_store_inst(ptr, val, eviction, name, next);
}
// masked store
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, mask);
}
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, name, next);
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, eviction, name, next);
}
//===----------------------------------------------------------------------===//
@@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);

View File

@@ -43,6 +43,15 @@ std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
return result;
}
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::post_order(fn)){
auto inst_list = block->get_inst_list();
for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++)
do_work(*it);
}
}
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::reverse_post_order(fn))

View File

@@ -840,10 +840,10 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float32', 'int8']
if not (allow_tf32 and (dtype == 'int8'))])
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
@@ -852,21 +852,30 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
M, N, K = 128, 128, 64
num_warps = 8
trans_a, trans_b = False, False
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr):
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
@@ -875,39 +884,65 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
if ADD_COLS:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
rs = RandomState(17)
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
if allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
TRANS_A=trans_a, TRANS_B=trans_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
ALLOW_TF32=allow_tf32)
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
z_ref = np.matmul(x, y)
x_ref = x.T if trans_a else x
y_ref = y.T if trans_b else y
z_ref = np.matmul(x_ref, y_ref)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:, 0][:, None]
if epilogue == 'add-cols':
z_ref += z[0, :][None, :]
if epilogue == 'softmax':
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']

View File

@@ -211,7 +211,7 @@ class ValueConstructor:
return phi
v = unique_handles.pop()
phi.handle.replace_all_uses_with(v)
phi.handle.erase_from_parent()
# phi.handle.erase_from_parent()
# TODO: remove trivial phis recursively
return triton.language.tensor(v, phi.type)

View File

@@ -732,7 +732,7 @@ def reshape(input, shape, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, _builder=None):
def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -744,7 +744,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
return semantic.dot(input, other, allow_tf32, _builder)
return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder)
# -----------------------
@@ -782,7 +782,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
@builtin
def store(pointer, value, mask=None, _builder=None):
def store(pointer, value, eviction_policy="", mask=None, _builder=None):
"""
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
@@ -799,7 +799,7 @@ def store(pointer, value, mask=None, _builder=None):
value = _to_tensor(value, _builder)
if mask is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)
return semantic.store(pointer, value, mask, eviction_policy, _builder)
# -----------------------

View File

@@ -648,6 +648,18 @@ def cast(input: tl.tensor,
# ===----------------------------------------------------------------------===//
def _parse_eviction_policy(eviction_policy):
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
return eviction
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
@@ -684,14 +696,7 @@ def load(ptr: tl.tensor,
raise ValueError(f"Cache modifier {cache_modifier} not supported")
# eviction policy
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
eviction = _parse_eviction_policy(eviction_policy)
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
@@ -721,6 +726,7 @@ def load(ptr: tl.tensor,
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
@@ -735,14 +741,15 @@ def store(ptr: tl.tensor,
elt_ty_ptr = tl.int8
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# eviction policy
eviction = _parse_eviction_policy(eviction_policy)
# cast to target data-type
val = cast(val, elt_ty, builder)
if not mask:
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void)
#########
# atomic
@@ -897,27 +904,31 @@ def atomic_xchg(ptr: tl.tensor,
# ===----------------------------------------------------------------------===//
def dot(lhs: tl.tensor,
rhs: tl.tensor,
def dot(a: tl.tensor,
b: tl.tensor,
trans_a: bool,
trans_b: bool,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[-1] == rhs.shape[0]
assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\
in_a = 1 if not trans_a else 0
in_b = 1 if trans_b else 0
assert a.type.is_block() and b.type.is_block()
assert len(a.shape) == 2 and len(b.shape) == 2
assert a.shape[in_a] == b.shape[in_b]
assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
if a.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = lhs.type.shape[0]
N = rhs.type.shape[1]
M = a.type.shape[in_a ^ 1]
N = b.type.shape[in_b ^ 1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32)
return tl.tensor(ret, ret_ty)
# ===----------------------------------------------------------------------===//

View File

@@ -0,0 +1,198 @@
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kk, stride_kn,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_qm = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
q = tl.load(q_ptrs)
for start_n in range(0, start_qm + 1):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.dot(q, k)
qk += tl.where(offs_m[:, None] >= (start_n * BLOCK_N + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
p = p.to(tl.float16)
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# r_ptrs += BLOCK_N
l_i = l_i_new
m_i = m_i_new
start_qm = tl.program_id(0)
offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_out = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_out
tl.store(out_ptrs, acc)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
BLOCK = 128
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-2]
assert Lq == Lk
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q, k, v,
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=64, num_warps=4,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
return o
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_MODEL', [(2, 3, 1024, 64)])
def test_op(Z, H, N_CTX, D_MODEL, dtype=torch.float16):
torch.manual_seed(20)
q = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
k = .5 * torch.randn((Z, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
v = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
# triton implementation
tri_out = attention(q, k, v)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
ref_qk = torch.matmul(q, k)
for z in range(Z):
for h in range(H):
ref_qk[:, :, M == 0] = float("-inf")
ref_qk = torch.softmax(ref_qk, dim=-1)
ref_out = torch.matmul(ref_qk, v)
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 2048, 64
# vary batch size for fixed heads / seq
batch_bench = triton.testing.Benchmark(
x_names=['BATCH'],
x_vals=[2**i for i in range(0, 8)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-seq{N_CTX}-head{N_HEADS}-d{D_HEAD}',
args={'H': N_HEADS, 'N_CTX': N_CTX, 'D_MODEL': D_HEAD, 'dtype': torch.float16}
)
# vary seq length for fixed head and batch=4
seq_bench = triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}',
args={'H': D_HEAD, 'BATCH': BATCH, 'D_MODEL': D_HEAD, 'dtype': torch.float16}
)
@triton.testing.perf_report([batch_bench, seq_bench])
def bench_flash_attention(BATCH, H, N_CTX, D_MODEL, provider, dtype=torch.float16, device="cuda"):
warmup = 25
rep = 500
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
fn = lambda: attention(q, k, v)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_MODEL), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
bench_flash_attention.run(save_path='.', print_data=True)