[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* x: V){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -212,11 +212,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
order_ = {0, 1};
|
||||
}
|
||||
else{
|
||||
// fpw_ = {1, 1, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 2};
|
||||
order_ = {1, 0};
|
||||
// rep_ = {2, 2, 1};
|
||||
}
|
||||
|
||||
/* warps per tile */
|
||||
@@ -233,24 +231,45 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
// try to have a warp own entire rows of the output
|
||||
// this makes it easier to fuse multiple mmas by fusing
|
||||
// registers
|
||||
bool one_warp_per_row = false;
|
||||
for(ir::value* v: values)
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto* dot = dynamic_cast<ir::dot_inst*>(u);
|
||||
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
|
||||
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
|
||||
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
|
||||
}
|
||||
// std::cout << one_warp_per_row << std::endl;
|
||||
|
||||
if(one_warp_per_row){
|
||||
wpt_[1] = 1;
|
||||
wpt_[0] = num_warps;
|
||||
}
|
||||
else{
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} while (changed);
|
||||
} while(changed);
|
||||
}
|
||||
}
|
||||
|
||||
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
}
|
||||
@@ -430,8 +449,8 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
|
||||
analysis::align* align, target *tgt, bool is_tmp)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
@@ -619,7 +638,7 @@ void layouts::create_tmp_layout(size_t id, data_layout *arg,
|
||||
ir::instruction *i, bool is_index) {
|
||||
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
|
||||
: i->get_type()->get_scalar_ty();
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
|
||||
if (is_index) {
|
||||
tmp_index_[i] = id;
|
||||
} else {
|
||||
|
@@ -14,43 +14,108 @@ namespace analysis{
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || layout->is_tmp())
|
||||
continue;
|
||||
for(ir::value* v:layout->get_values()){
|
||||
layouts_map[v].insert(layout);
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
|
||||
|
||||
std::map<ir::user*, std::set<shared_layout*>> live_in;
|
||||
while(true){
|
||||
bool changed = false;
|
||||
ir::instruction* last_inst = nullptr;
|
||||
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
|
||||
// gen
|
||||
std::set<shared_layout*> gen;
|
||||
for(ir::value* v: i->ops())
|
||||
for(shared_layout* layout: layouts_map[v])
|
||||
gen.insert(layout);
|
||||
// kill
|
||||
std::set<shared_layout*> kill;
|
||||
for(shared_layout* layout: layouts_map[i])
|
||||
kill.insert(layout);
|
||||
// temporaries are handled separately
|
||||
if(layouts_->has_tmp(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
}
|
||||
if(layouts_->has_tmp_index(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
}
|
||||
// live-out
|
||||
std::set<shared_layout*> live_out;
|
||||
std::vector<ir::instruction*> succs = {last_inst};
|
||||
if(i == i->get_parent()->get_inst_list().back())
|
||||
for(ir::basic_block* succ: i->get_parent()->get_successors())
|
||||
succs.push_back(succ->get_inst_list().front());
|
||||
for(ir::instruction* succ: succs)
|
||||
for(shared_layout* layout: live_in[succ])
|
||||
if(!layout->is_tmp())
|
||||
live_out.insert(layout);
|
||||
|
||||
// new sets
|
||||
std::set<shared_layout*> live_out_minus_kill;
|
||||
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
|
||||
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
|
||||
std::set<shared_layout*> new_live_in;
|
||||
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
|
||||
std::inserter(new_live_in, new_live_in.end()));
|
||||
|
||||
changed = changed || (new_live_in != live_in[i]);
|
||||
live_in[i] = new_live_in;
|
||||
last_inst = i;
|
||||
});
|
||||
if(!changed)
|
||||
break;
|
||||
}
|
||||
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction* i){
|
||||
// i->print(std::cout);
|
||||
// std::cout << " live_in: " << live_in[i].size() << std::endl;
|
||||
// });
|
||||
|
||||
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
slot_index index = 0;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* instr){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
});
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(layout)
|
||||
intervals_[layout] = segment{INT32_MAX, 0};
|
||||
}
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second)
|
||||
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second){
|
||||
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
|
||||
}
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@@ -28,12 +28,15 @@ void swizzle::run(ir::module &) {
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int per_phase = 1;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(in_layout)
|
||||
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
else
|
||||
per_phase = 1;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
@@ -46,7 +49,7 @@ void swizzle::run(ir::module &) {
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
|
Reference in New Issue
Block a user