[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
for(shared_layout* x: V){
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
}
}
}

View File

@@ -212,11 +212,9 @@ mma_layout::mma_layout(size_t num_warps,
order_ = {0, 1};
}
else{
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 2};
order_ = {1, 0};
// rep_ = {2, 2, 1};
}
/* warps per tile */
@@ -233,24 +231,45 @@ mma_layout::mma_layout(size_t num_warps,
}while(wpt_nm1 != wpt_);
} else {
bool changed = false;
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
// try to have a warp own entire rows of the output
// this makes it easier to fuse multiple mmas by fusing
// registers
bool one_warp_per_row = false;
for(ir::value* v: values)
for(ir::user* u: v->get_users()){
auto* dot = dynamic_cast<ir::dot_inst*>(u);
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
}
// std::cout << one_warp_per_row << std::endl;
if(one_warp_per_row){
wpt_[1] = 1;
wpt_[0] = num_warps;
}
else{
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} while (changed);
} while(changed);
}
}
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
/* shape per block */
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
@@ -430,8 +449,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
analysis::align* align, target *tgt, bool is_tmp)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
size_ = 0;
arg_layout_ = arg;
@@ -619,7 +638,7 @@ void layouts::create_tmp_layout(size_t id, data_layout *arg,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
if (is_index) {
tmp_index_[i] = id;
} else {

View File

@@ -14,43 +14,108 @@ namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(!layout || layout->is_tmp())
continue;
for(ir::value* v:layout->get_values()){
layouts_map[v].insert(layout);
}
}
// create live intervals
std::map<ir::user*, std::set<shared_layout*>> live_in;
while(true){
bool changed = false;
ir::instruction* last_inst = nullptr;
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
// gen
std::set<shared_layout*> gen;
for(ir::value* v: i->ops())
for(shared_layout* layout: layouts_map[v])
gen.insert(layout);
// kill
std::set<shared_layout*> kill;
for(shared_layout* layout: layouts_map[i])
kill.insert(layout);
// temporaries are handled separately
if(layouts_->has_tmp(i)){
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
}
if(layouts_->has_tmp_index(i)){
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
}
// live-out
std::set<shared_layout*> live_out;
std::vector<ir::instruction*> succs = {last_inst};
if(i == i->get_parent()->get_inst_list().back())
for(ir::basic_block* succ: i->get_parent()->get_successors())
succs.push_back(succ->get_inst_list().front());
for(ir::instruction* succ: succs)
for(shared_layout* layout: live_in[succ])
if(!layout->is_tmp())
live_out.insert(layout);
// new sets
std::set<shared_layout*> live_out_minus_kill;
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
std::set<shared_layout*> new_live_in;
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
std::inserter(new_live_in, new_live_in.end()));
changed = changed || (new_live_in != live_in[i]);
live_in[i] = new_live_in;
last_inst = i;
});
if(!changed)
break;
}
// ir::for_each_instruction(mod, [&](ir::instruction* i){
// i->print(std::cout);
// std::cout << " live_in: " << live_in[i].size() << std::endl;
// });
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
slot_index index = 0;
ir::for_each_instruction(mod, [&](ir::instruction* instr){
index += 1;
indices.insert({instr, index});
});
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(layout)
intervals_[layout] = segment{INT32_MAX, 0};
}
for(auto& x: live_in)
for(shared_layout* layout: x.second)
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
for(auto& x: live_in)
for(shared_layout* layout: x.second){
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
}
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
}
}

View File

@@ -28,12 +28,15 @@ void swizzle::run(ir::module &) {
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int per_phase = 1;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(in_layout)
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
else
per_phase = 1;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
@@ -46,7 +49,7 @@ void swizzle::run(ir::module &) {
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}