[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i) {
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty())
if(connect_ret && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);

View File

@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
result = v;
result = i;
}
}
@@ -115,8 +115,10 @@ data_layout::data_layout(id_t id,
}
}
size_t data_layout::find_axis(int to_find) const {
int data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
if(it == axes_.end())
return -1;
return std::distance(axes_.begin(), it);
}
@@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const {
* MMA Layout *
* -------------------------------- */
mma884_layout::mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
mma_layout::mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
if(tgt->as_nvidia()->sm() < 80){
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
}
else{
fpw_ = {1, 1, 1};
spw_ = {16, 8, 1};
rep_ = {2, 2, 1};
}
/* warps per tile */
// try to make things as square as possible to maximize data re-use
@@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps,
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
// if(num_warps != effective_num_warps)
// throw std::runtime_error("cannot create a kernel with this amount of warps");
/* shape per block */
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
@@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order_[0];
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
int contiguous = 1;
if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
@@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps,
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
// if(num_warps * 32 != effective_num_threads)
// throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -246,9 +256,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
if(!(i_0 && !i_1) &&
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
@@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
}
shared_layout::shared_layout(const data_layout *arg,
shared_layout::shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
@@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size_ = 0;
arg_layout_ = arg;
// double-buffering
for(ir::value *v: values)
@@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg,
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order_ != arg_order) {
pad = 4;
}
shape_[order_[0]] += pad;
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
@@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) {
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
@@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v);
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
if(it_hmma_c != values.end()){
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
}
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
}
else
else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
}
}
void layouts::run(ir::module &mod) {
@@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) {
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
mma884_layout* in_layout = get(val)->to_mma884();
mma_layout* in_layout = get(val)->to_mma();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
@@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) {
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
shape[k] = in_layout->to_mma()->spt(k);
// create layout
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;

View File

@@ -0,0 +1,54 @@
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/target.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
void swizzle::run(ir::module &) {
per_phase_.clear();
max_phase_.clear();
for(auto &x: layouts_->get_all()){
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
if(!layout)
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
continue;
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else{
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout];
vec_[layout] = 8;
}
}
}
}
}
}