[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i) {
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty())
if(connect_ret && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);

View File

@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
result = v;
result = i;
}
}
@@ -115,8 +115,10 @@ data_layout::data_layout(id_t id,
}
}
size_t data_layout::find_axis(int to_find) const {
int data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
if(it == axes_.end())
return -1;
return std::distance(axes_.begin(), it);
}
@@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const {
* MMA Layout *
* -------------------------------- */
mma884_layout::mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
mma_layout::mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
if(tgt->as_nvidia()->sm() < 80){
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
}
else{
fpw_ = {1, 1, 1};
spw_ = {16, 8, 1};
rep_ = {2, 2, 1};
}
/* warps per tile */
// try to make things as square as possible to maximize data re-use
@@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps,
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
// if(num_warps != effective_num_warps)
// throw std::runtime_error("cannot create a kernel with this amount of warps");
/* shape per block */
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
@@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order_[0];
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
int contiguous = 1;
if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
@@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps,
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
// if(num_warps * 32 != effective_num_threads)
// throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -246,9 +256,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
if(!(i_0 && !i_1) &&
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
@@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
}
shared_layout::shared_layout(const data_layout *arg,
shared_layout::shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
@@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size_ = 0;
arg_layout_ = arg;
// double-buffering
for(ir::value *v: values)
@@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg,
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order_ != arg_order) {
pad = 4;
}
shape_[order_[0]] += pad;
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
@@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) {
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
@@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v);
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
if(it_hmma_c != values.end()){
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
}
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
}
else
else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
}
}
void layouts::run(ir::module &mod) {
@@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) {
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
mma884_layout* in_layout = get(val)->to_mma884();
mma_layout* in_layout = get(val)->to_mma();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
@@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) {
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
shape[k] = in_layout->to_mma()->spt(k);
// create layout
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;

View File

@@ -0,0 +1,54 @@
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/target.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
void swizzle::run(ir::module &) {
per_phase_.clear();
max_phase_.clear();
for(auto &x: layouts_->get_all()){
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
if(!layout)
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
continue;
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else{
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout];
vec_[layout] = 8;
}
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,325 +0,0 @@
#include <numeric>
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/target.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "llvm/IR/IRBuilder.h"
namespace triton{
namespace codegen{
using namespace llvm;
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
// Grid construction
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[order[k]]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[order[k]] = rem;
}
result[order[dim - 1]] = trailing;
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::shared_layout *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->get_double_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
ptr_ = builder_->CreatePHI(ptr_ty, 2);
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
builder_->SetInsertPoint(current);
}
else{
size_t offset = alloc_->offset(layout_);
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
}
}
tile* machine_shared_layout::create(ir::value *v) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
auto double_buffer = layout_->get_double_buffer();
// offset
Value *offset = nullptr;
if(double_buffer && v == double_buffer->phi)
offset = offset_;
// base pointer
Value *ptr = ptr_;
if(double_buffer && v == double_buffer->latch)
ptr = next_ptr_;
else if(double_buffer && v == double_buffer->first)
ptr = pre_ptr_;
// create tile
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
}
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout *layout)
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
tile *machine_distributed_layout::create(ir::value *v) {
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> order(rank);
// compute axes
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
// compute order
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y);
size_t posx = layout_->find_axis(axx);
size_t posy = layout_->find_axis(axy);
if(posx < rank && posy < rank)
return layout_->get_order(posx) < layout_->get_order(posy);
return false;
};
std::sort(order.begin(), order.end(), cmp);
return new distributed_tile(ty, shapes, order, axes, *builder_);
}
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shape = layout->get_shape();
if(shape.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shape.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
// mma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// mma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shape[0] / hmma_bts_0;
unsigned num_rep_1 = shape[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
}
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::scanline_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->get_order();
const auto& shape = layout->get_shape();
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
// Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
thread_id[order[k]] = rem;
}
thread_id[order[dim - 1]] = full_thread_id;
// Create axes
for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts * mts;
unsigned per_thread = nts * shape[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts * per_block + n % nts;
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
}
}
}
}

View File

@@ -1,214 +0,0 @@
#include <numeric>
#include <iostream>
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/machine_value.h"
namespace triton{
namespace codegen{
using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
// build
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[order_[0]]++;
while(id[order_[k]] == axes_[order_[k]].values.size()){
if(k == id.size() - 1)
return;
id[order_[k++]] = 0;
id[order_[k]]++;
}
k = 0;
}
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
if(end < 0)
end = ordered_indices_.size() + end + 1;
for(unsigned i = start; i < end; i++)
fn(ordered_indices_[i]);
}
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
int rank = sizes.size();
int len = 1;
for(int s: sizes)
len *= s;
for(int i = 0; i < len; i++){
indices_t idx(rank);
int current = i;
for(int k = 0; k < rank; k++){
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
current = current / sizes[k];
}
fn(idx);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
if(dyn_cast<Constant>(arg)){
cst = arg;
non_cst = _0;
return;
}
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
non_cst = arg;
cst = _0;
return;
}
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
if(cst_lhs && cst_rhs){
cst = arg;
non_cst = _0;
}
else if(cst_lhs){
cst = cst_lhs;
non_cst = bin_op->getOperand(1);
}
else if(cst_rhs){
cst = cst_rhs;
non_cst = bin_op->getOperand(0);
}
else{
non_cst = arg;
cst = _0;
}
}
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
non_cst_idx.clear();
cst_idx.clear();
for(Value *idx: arg_idx){
Value *non_cst, *cst;
extract_constant(idx, non_cst, cst);
non_cst_idx.push_back(non_cst);
cst_idx.push_back(cst);
}
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
const std::vector<int>& perm, const std::vector<int>& order,
indices_t idx) {
// strides
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0);
for(size_t i = 0; i < idx.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
}
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
unsigned vector_size = vector_size_;
Type *ty = ty_;
if(ty->isHalfTy() && (vector_size % 2 == 0)){
ty = IntegerType::get(ty->getContext(), 32);
vector_size = vector_size / 2;
}
if(base_ptr == nullptr){
// BasicBlock* store = builder_.GetInsertBlock();
// if(!non_cst_idx.empty())
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
}
}
}

View File

@@ -14,6 +14,12 @@ namespace triton{
namespace codegen{
// base
nvidia_cu_target* target::as_nvidia() {
return dynamic_cast<nvidia_cu_target*>(this);
}
bool target::is_gpu() const {
return is_gpu_;
}
@@ -25,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateCall(barrier, {});
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
}
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
@@ -45,8 +51,7 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
return group_id;
}
@@ -99,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
return cta_id;
}
@@ -120,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_nctaid, {});
return builder.CreateIntrinsic(ids[ax], {}, {});
}
// CPU

View File

@@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) {
for(size_t id = 0; id < num_groups; id++) {
if(!layout_->get(id)->to_mma884())
if(!layout_->get(id)->to_mma())
continue;
// extract memory stores
const auto& values = layout_->values_of(id);

View File

@@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){
return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true;
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
return true;
return false;
}
// run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
@@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
// copy
builder.set_insert_point_after(i);
ir::value *copy;
if(to_shared)
if(to_shared){
copy = builder.create_copy_to_shared(x);
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);

View File

@@ -54,7 +54,7 @@ void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
add_reference(i, res);
}
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
@@ -63,7 +63,10 @@ void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
builder.create_barrier();
if(type.first)
builder.create_async_wait();
if(type.second)
builder.create_barrier();
}
}
}
@@ -85,8 +88,9 @@ std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::set<ir::instruction*>& insert_loc,
std::set<ir::value*>& safe_war) {
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
@@ -95,6 +99,8 @@ std::pair<membar::interval_vec_t,
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
if(written.size())
to_sync.push_back(i);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
@@ -104,9 +110,14 @@ std::pair<membar::interval_vec_t,
}
// record hazards
if(read_after_write || write_after_read) {
insert_loc.insert(i);
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
insert_loc.insert({i, {copy_async_wait, barrier}});
new_written_to.clear();
new_read_from.clear();
to_sync.clear();
}
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
@@ -125,17 +136,17 @@ void membar::run(ir::module &mod) {
if(!layout || !layout->get_double_buffer())
continue;
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi)
if(v != layout->get_double_buffer()->phi){
safe_war.insert(v);
}
}
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to;
std::map<ir::basic_block*, interval_vec_t> read_from;
std::set<ir::instruction*> insert_locs;
std::vector<ir::instruction*> to_sync;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
size_t n_inserted_im1 = 0;
bool done = false;
do{
@@ -150,7 +161,7 @@ void membar::run(ir::module &mod) {
for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]);
// apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
written_to[block] = result.first;
read_from[block] = result.second;
}
@@ -158,8 +169,9 @@ void membar::run(ir::module &mod) {
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(ir::instruction* i: insert_locs)
insert_barrier(i, builder);
for(auto x: insert_locs){
insert_barrier(x.first, x.second, builder);
}
}
}

View File

@@ -97,6 +97,24 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
//}
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
if(!copy_to_shared)
return false;
ir::value *arg = copy_to_shared->get_operand(0);
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
if(!ld)
return false;
builder.set_insert_point(copy_to_shared);
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x)
@@ -197,10 +215,12 @@ void peephole::run(ir::module &mod) {
continue;
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);
}

View File

@@ -0,0 +1,51 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/transform/reorder.h"
namespace triton {
namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
ir::builder &builder = mod.get_builder();
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
ir::value* _ptr = ld->get_pointer_operand();
ir::value* _msk = ld->get_mask_operand();
ir::value* _val = ld->get_false_value_operand();
auto ptr = std::find(block->begin(), block->end(), _ptr);
auto msk = std::find(block->begin(), block->end(), _msk);
auto val = std::find(block->begin(), block->end(), _val);
if(ptr == block->end() || msk == block->end() || val == block->end())
continue;
auto it = std::find(block->begin(), block->end(), i);
int dist_ptr = std::distance(ptr, it);
int dist_msk = std::distance(msk, it);
int dist_val = std::distance(val, it);
if(dist_ptr < dist_msk && dist_ptr < dist_val)
builder.set_insert_point(++ptr);
if(dist_msk < dist_ptr && dist_msk < dist_val)
builder.set_insert_point(++msk);
if(dist_val < dist_ptr && dist_val < dist_msk)
builder.set_insert_point(++val);
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
to_replace.push_back(std::make_pair(ld, new_ld));
}
}
for(auto& x: to_replace)
x.first->replace_all_uses_with(x.second);
}
}
}
}

View File

@@ -48,46 +48,6 @@ std::unique_ptr<codegen::target> host_device::make_target() const {
// CUDA //
/* ------------------------ */
// architecture
cu_device::Architecture cu_device::nv_arch(std::pair<unsigned int, unsigned int> sm) const {
switch(sm.first) {
case 7:
switch(sm.second){
case 0: return Architecture::SM_7_0;
}
case 6:
switch(sm.second){
case 0: return Architecture::SM_6_0;
case 1: return Architecture::SM_6_1;
}
case 5:
switch(sm.second){
case 0: return Architecture::SM_5_0;
case 2: return Architecture::SM_5_2;
default: return Architecture::UNKNOWN;
}
case 3:
switch(sm.second){
case 0: return Architecture::SM_3_0;
case 5: return Architecture::SM_3_5;
case 7: return Architecture::SM_3_7;
default: return Architecture::UNKNOWN;
}
case 2:
switch(sm.second){
case 0: return Architecture::SM_2_0;
case 1: return Architecture::SM_2_1;
default: return Architecture::UNKNOWN;
}
default: return Architecture::UNKNOWN;
}
}
// information query
template<CUdevice_attribute attr>
int cu_device::cuGetInfo() const{
@@ -108,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{
return map.at(key);
}
// architecture
cu_device::Architecture cu_device::architecture() const{
return nv_arch(compute_capability());
}
// number of address bits
size_t cu_device::address_bits() const{
return sizeof(size_t)*8;
@@ -133,17 +88,17 @@ std::string cu_device::pci_bus_id() const{
}
// force the device to be interpreted as a particular cc
void cu_device::interpret_as(std::pair<size_t, size_t> cc){
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
void cu_device::interpret_as(int cc){
interpreted_as_ = std::make_shared<int>(cc);
}
// compute capability
std::pair<size_t, size_t> cu_device::compute_capability() const {
int cu_device::compute_capability() const {
if(interpreted_as_)
return *interpreted_as_;
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
size_t _minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
return std::make_pair(_major, _minor);
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
return major*10 + minor;
}
// maximum number of threads per block
@@ -218,7 +173,7 @@ std::string cu_device::infos() const{
// target
std::unique_ptr<codegen::target> cu_device::make_target() const {
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target(compute_capability()));
}

View File

@@ -93,6 +93,7 @@ namespace driver
bool dispatch::cuinit(){
if(cuda_==nullptr){
putenv((char*)"CUDA_CACHE_DISABLE=1");
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
if(libcuda.empty())
cuda_ = dlopen("libcuda.so", RTLD_LAZY);

View File

@@ -20,7 +20,9 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <fstream>
#include <unistd.h>
#include <memory>
#include <regex>
#include "triton/driver/module.h"
#include "triton/driver/context.h"
#include "triton/driver/error.h"
@@ -41,6 +43,19 @@
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h"
std::string exec(const char* cmd) {
std::array<char, 128> buffer;
std::string result;
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(cmd, "r"), pclose);
if (!pipe) {
throw std::runtime_error("popen() failed!");
}
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
result += buffer.data();
}
return result;
}
namespace triton
{
namespace driver
@@ -63,11 +78,11 @@ void module::init_llvm() {
}
module::module(CUmodule mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership) {
: polymorphic_resource(mod, has_ownership), spilled_(0) {
}
module::module(host_module_t mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership) {
: polymorphic_resource(mod, has_ownership), spilled_(0) {
}
@@ -86,10 +101,12 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
file_type_t ft) {
init_llvm();
// // debug
// llvm::legacy::PassManager pm;
llvm::legacy::PassManager pm;
std::string tmp;
// llvm::raw_string_ostream oss(llir_);
// pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass());
// pm.run(*module);
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
@@ -176,7 +193,7 @@ host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_
// create execution engine
for(llvm::Function& fn: src->functions())
hst_->functions[fn.getName()] = &fn;
hst_->functions[fn.getName().str()] = &fn;
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
// auto DL = JTMB.getDefaultDataLayoutForTarget();
@@ -225,7 +242,8 @@ static std::map<int, int> vptx = {
{10010, 64},
{10020, 65},
{11000, 70},
{11010, 71}
{11010, 71},
{11020, 72}
};
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
@@ -238,9 +256,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
auto _cc = ((driver::cu_device*)device)->compute_capability();
int cc = _cc.first*10 + _cc.second;
cc = std::min(cc, max_nvvm_cc);
int cc = ((driver::cu_device*)device)->compute_capability();
std::string sm = "sm_" + std::to_string(cc);
// driver version
int version;
@@ -251,12 +267,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
throw std::runtime_error("Triton requires CUDA 10+");
// PTX version
int ptx = vptx.at(version);
ptx = std::min(ptx, max_nvvm_ptx);
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly);
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
@@ -266,21 +281,69 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
}
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { }
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { }
cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
std::string errbuf(errbufsize, 0);
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
}catch(exception::cuda::invalid_ptx const &){
// // compile ptx with ptxas
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
// char _flog[] = "/tmp/triton_l_XXXXXX";
// int fdsrc = mkstemp(_fsrc);
// int fdlog = mkstemp(_flog);
// std::string fsrc = _fsrc;
// std::string flog = _flog;
// std::ofstream ofs(fsrc);
// ofs << source;
// ofs.close();
// std::string cmd;
// int err;
// driver::cu_device* cu_device = (driver::cu_device*)device;
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
// err = system(cmd.c_str());
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
// std::ifstream file(flog);
// std::string log;
// if(file)
// while (!file.eof()) log.push_back(file.get());
// unlink(_fsrc);
// unlink(_flog);
// std::smatch match;
// std::regex expr ("\\b([0-9]+) bytes spill");
// spilled_ = 0;
// while (std::regex_search (log,match,expr)){
// spilled_ += std::stoi(match[1]);
// log = match.suffix();
// }
// std::cout << log << std::endl;
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
CU_JIT_LOG_VERBOSE};
unsigned int errbufsize = 8192;
unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
std::string err(_err);
std::string log(_log);
// std::cout << log << std::endl;
std::smatch match;
std::regex expr ("\\b([0-9]+) bytes spill");
spilled_ = 0;
while (std::regex_search(log,match,expr)){
spilled_ += std::stoi(match[1]);
log = match.suffix();
}
}
catch(exception::cuda::invalid_ptx const &){
//#ifdef TRITON_LOG_PTX_ERROR
std::cout << source << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
std::cerr << errbuf << std::endl;
// exit(1);
//#endif
throw;

View File

@@ -1,5 +1,6 @@
#include <string>
#include <algorithm>
#include <iostream>
#include "triton/ir/basic_block.h"
#include "triton/ir/builder.h"
#include "triton/ir/constant.h"
@@ -253,6 +254,15 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
value *builder::create_load(value *ptr, const std::string &name){
return insert(unmasked_load_inst::create(ptr, name));
// type *ty = ptr->get_type()->get_pointer_element_ty();
// value *mask = constant_int::get(get_int1_ty(), 1);
// value *undef = undef_value::get(ty);
// if(ptr->get_type()->is_tile_ty()){
// auto shapes = ptr->get_type()->get_tile_shapes();
// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name));
// }
// return insert(masked_load_inst::create(ptr, mask, undef, name));
}
value *builder::create_store(value *ptr, value *val, const std::string &name){
@@ -263,6 +273,7 @@ value *builder::create_masked_load(value *ptr, value *mask, value *false_value,
return insert(masked_load_inst::create(ptr, mask, false_value, name));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
return insert(masked_store_inst::create(ptr, val, mask, name));
}
@@ -348,13 +359,22 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
return insert(copy_to_shared_inst::create(arg, name));
}
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
return insert(copy_from_shared_inst::create(arg, name));
}
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, name));
}
value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name));
}
value *builder::create_async_wait() {
return insert(async_wait_inst::create(ctx_));
}
}
}

View File

@@ -463,6 +463,20 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
return new masked_load_inst(ptr, mask, false_value, name, next);
}
// masked load async
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next) {
return new masked_load_async_inst(ptr, mask, false_value, name, next);
}
// atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
@@ -804,6 +818,14 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next);
}
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, name, next);
}
// nv_dynamic_program_idx
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)

View File

@@ -65,7 +65,12 @@ void print(module &mod, std::ostream& os) {
os << get_name(ops[i], cnt++);
os << (i < num_ops - 1?", ":"");
}
os << ";" << std::endl;
os << ";";
// os << " (";
// for(ir::user* usr: inst->get_users())
// os << get_name(usr, cnt++) << ", " ;
// os << " )";
os << std::endl;
}
}
os << "}" << std::endl;

View File

@@ -68,9 +68,10 @@ unsigned user::get_num_hidden() const {
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
for(size_t i = 0; i < ops_.size(); i++)
if(ops_[i] == before)
if(ops_[i] == before){
ops_[i] = after;
after->add_use(this);
after->add_use(this);
}
return before->erase_use(this);
}

View File

@@ -56,10 +56,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
return set_ret(bld_->create_dot(lhs, rhs, _0));
}
case Token::MASKED_DEREF: {
// TODO: FIXME
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
auto it = bld_->get_insert_block();
if(ret_ty->is_tile_ty())
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
bld_->set_insert_point(it);
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
}
case Token::ELLIPSIS: {
@@ -274,9 +277,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
cond,
false_val);
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
ld->replace_all_uses_with(new_ld);
ld->erase_from_parent();
return set_ret(new_ld);
@@ -468,10 +469,10 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
});
if(init_)
VisitStmt(init_);
// VisitExpr(cond_);
// ir::value *cond = ret_;
// bld_->create_cond_br(cond, loop_bb, next_bb);
bld_->create_br(loop_bb);
VisitExpr(cond_);
ir::value *cond = ret_;
bld_->create_cond_br(cond, loop_bb, next_bb);
// bld_->create_br(loop_bb);
bld_->set_insert_point(loop_bb);
if(body_)
VisitStmt(body_);

View File

@@ -1,4 +1,4 @@
#include <string>
#include <string>
#include <mutex>
#include <regex>
#include <functional>
@@ -9,11 +9,13 @@
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h"
@@ -29,6 +31,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/print.h"
#include "triton/runtime/error.h"
#include "triton/tools/bench.hpp"
#include "triton/tools/sha1.hpp"
#include "triton/tools/sys/getenv.hpp"
@@ -67,7 +70,7 @@ void _loop_nest(std::vector<size_t> const & ranges,
/* OPTIONS */
/* --------------------- */
std::string function::options_t::to_str() const{
std::string options_t::to_str() const{
std::string ret = "nw-" + std::to_string(num_warps);
for(const auto& x : defines){
ret += '-';
@@ -110,41 +113,41 @@ arg_type convert(ir::type *ty) {
throw std::runtime_error("unknown type");
}
void function::caller::write(std::ofstream &ofs) {
// write name
ofs << name_ << std::endl;
// write signature
for(size_t i = 0; i < param_tys_.size(); i++)
ofs << param_tys_[i] << " ";
ofs << std::endl;
// write module
std::string source = ((driver::cu_module*)(&*parent_))->source();
ofs << source;
}
//void function::caller::write(std::ofstream &ofs) {
// // write name
// ofs << name_ << std::endl;
// // write signature
// for(size_t i = 0; i < param_tys_.size(); i++)
// ofs << param_tys_[i] << " ";
// ofs << std::endl;
// // write module
// std::string source = ((driver::cu_module*)(&*parent_))->ptx();
// ofs << source;
//}
void function::caller::read(std::ifstream &ifs) {
// read name
std::getline(ifs, name_);
// read signature
std::string line;
std::getline(ifs, line);
std::istringstream current(line);
int param;
param_tys_.clear();
while(current >> param)
param_tys_.push_back((arg_type)param);
// read module
std::string src((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());
parent_.reset(new driver::cu_module(src));
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
//void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
// // read name
// std::getline(ifs, name_);
// // read signature
// std::string line;
// std::getline(ifs, line);
// std::istringstream current(line);
// int param;
// param_tys_.clear();
// while(current >> param)
// param_tys_.push_back((arg_type)param);
// // read module
// std::string src((std::istreambuf_iterator<char>(ifs)),
// std::istreambuf_iterator<char>());
// parent_.reset(new driver::cu_module(ctx, src));
// bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
}
//}
function::caller::caller(std::ifstream &ifs, const options_t& opt)
: opt_(opt) {
read(ifs);
}
//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
// : opt_(opt) {
// read(ctx, ifs);
//}
function::caller::caller(ir::function *ir,
std::shared_ptr<driver::module> parent, const options_t& opt)
@@ -198,20 +201,23 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// optimizations
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::peephole peephole(target.get());
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::cts cts;
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(module);
disassociate.run(module);
@@ -233,17 +239,20 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
}
peephole.run(module);
dce.run(module);
// ir::print(module, std::cout);
align.run(module);
axes.run(module);
layouts.run(module);
swizzle.run(module);
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > device->max_shared_memory())
throw std::runtime_error("using too much shared memory");
throw exception::out_of_shared_memory();
barriers.run(module);
// ir::print(module, std::cout);
isel.visit(module, *llvm);
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
if(res->spilled() > 256)
throw exception::out_of_registers();
return res;
}
@@ -265,11 +274,11 @@ void function::make(driver::device *device, options_t opt) {
auto ir = make_ir(parser);
// triton-ir -> binary
std::unique_ptr<driver::module> bin;
// try{
try{
bin = make_bin(*ir, device, opt);
// }catch(const std::runtime_error&){
// return nullptr;
// }
}catch(const exception::base&){
throw;
}
// create callable
ir::function *tmp = ir->get_function_list()[0];
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
@@ -283,6 +292,7 @@ void function::precompile(driver::device* device, const options_space_t& space)
for(const auto& x: space.defines)
ranges.push_back(x.second.size());
// functor for source with given option
std::map<options_t, std::string> err;
auto do_make = [&](std::vector<size_t> params) {
// compilation options
unsigned i = 0;
@@ -291,20 +301,73 @@ void function::precompile(driver::device* device, const options_space_t& space)
for(auto D: space.defines)
opt.defines[D.first] = D.second[params[i++]];
// compile
make(device, opt);
try{
make(device, opt);
}catch(const exception::base& e){
err[opt] = e.what();
}
};
// multi-threaded compilation
_loop_nest(ranges, do_make);
if(callers_.empty())
throw std::runtime_error("could not compile kernel");
if(callers_.empty()){
std::ostringstream dbg;
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
for(auto x: err){
dbg << "[ ";
dbg << x.first.num_warps << ", ";
dbg << "{ ";
for(const auto& y: x.first.defines)
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
dbg << " } ] -> " << x.second << std::endl;
}
throw exception::no_valid_configuration(dbg.str());
}
}
std::string function::ptx(driver::device* device, const options_t& opt) {
std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) {
make(device, opt);
const auto& fn = callers_.at(opt);
if(!fn)
return "";
return ((driver::cu_module*)fn->parent())->source();
switch(mode){
case ASM_LLIR:{
return fn->parent()->llir();
}
case ASM_NV_PTX:
case ASM_NV_SASS:{
std::string ptx = ((driver::cu_module*)fn->parent())->ptx();
// SASS
std::string input = std::tmpnam(nullptr);
std::string output = std::tmpnam(nullptr);
std::ofstream ofs(input);
ofs << ptx;
ofs.close();
if(mode == ASM_NV_PTX)
return ptx;
std::string cmd;
int err;
// compile ptx
driver::cu_device* cu_device = (driver::cu_device*)device;
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
err = system(cmd.c_str());
// disassemble
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
err = system(cmd.c_str());
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
std::string to_delete = " /*";
std::ifstream ifs(output);
std::string line;
std::string sass;
while(std::getline(ifs, line))
if(!std::regex_match(line, comment))
sass += line + "\n";
return sass;
}
default:
return "";
}
}
// returns program with best compilation options for given parameter