[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(!i->get_type()->is_void_ty())
|
||||
if(connect_ret && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
|
@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
result = v;
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,8 +115,10 @@ data_layout::data_layout(id_t id,
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_layout::find_axis(int to_find) const {
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||
if(it == axes_.end())
|
||||
return -1;
|
||||
return std::distance(axes_.begin(), it);
|
||||
}
|
||||
|
||||
@@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const {
|
||||
* MMA Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
mma884_layout::mma884_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
|
||||
mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
bool is_b_row = ord_b[0] != 0;
|
||||
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
spw_ = {16, 8, 1};
|
||||
rep_ = {2, 2, 1};
|
||||
}
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
@@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps,
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shape.size(); d++)
|
||||
effective_num_warps *= wpt_[d];
|
||||
// if(num_warps != effective_num_warps)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
/* shape per block */
|
||||
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
}
|
||||
|
||||
|
||||
@@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
ir::value *ptr = nullptr;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
||||
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
|
||||
ptr = st->get_pointer_operand();
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 4;
|
||||
if(ptr)
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||
int contiguous = 1;
|
||||
if(ptr){
|
||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
|
||||
}
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
@@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
num_threads = num_threads / mts_[i];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
effective_num_threads *= mts_[d];
|
||||
|
||||
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
|
||||
// if(num_warps * 32 != effective_num_threads)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
|
||||
@@ -246,9 +256,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
|
||||
if(!(i_0 && !i_1) &&
|
||||
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
|
||||
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
@@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(const data_layout *arg,
|
||||
shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
@@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
|
||||
// double-buffering
|
||||
for(ir::value *v: values)
|
||||
@@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
}
|
||||
|
||||
|
||||
// non-mma ordering
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
for(size_t s = 2; s < get_rank(); s++){
|
||||
col.push_back(s);
|
||||
row.push_back(s);
|
||||
}
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
order_ = is_trans(dot_a) ? row : col;
|
||||
else if(is_nonhmma_dot_b)
|
||||
order_ = is_trans(dot_b) ? col : row;
|
||||
|
||||
// padding
|
||||
size_t pad = 0;
|
||||
if(hmma_dot_a){
|
||||
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 0 : 1] % 32;
|
||||
}
|
||||
else if(hmma_dot_b){
|
||||
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 1 : 0] % 32;
|
||||
}
|
||||
else if(order_ != arg_order) {
|
||||
pad = 4;
|
||||
}
|
||||
shape_[order_[0]] += pad;
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
@@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) {
|
||||
}
|
||||
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
@@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v);
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end())
|
||||
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
|
||||
if(it_hmma_c != values.end()){
|
||||
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
}
|
||||
else
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
@@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) {
|
||||
}
|
||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||
ir::value *val = recoalasce->get_operand(0);
|
||||
mma884_layout* in_layout = get(val)->to_mma884();
|
||||
mma_layout* in_layout = get(val)->to_mma();
|
||||
scanline_layout* out_layout = get(i)->to_scanline();
|
||||
if(!in_layout || !out_layout)
|
||||
return;
|
||||
@@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) {
|
||||
shape[ld] = in_shape[ld];
|
||||
for(size_t k = 0; k < in_shape.size(); k++)
|
||||
if(k != ld)
|
||||
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
|
||||
shape[k] = in_layout->to_mma()->spt(k);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[recoalasce] = id;
|
||||
|
54
lib/codegen/analysis/swizzle.cc
Normal file
54
lib/codegen/analysis/swizzle.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void swizzle::run(ir::module &) {
|
||||
per_phase_.clear();
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else{
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = 8 / per_phase_[layout];
|
||||
vec_[layout] = 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,325 +0,0 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/machine_layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
// Grid construction
|
||||
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder.getInt32(shapes[order[k]]);
|
||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
||||
result[order[k]] = rem;
|
||||
}
|
||||
result[order[dim - 1]] = trailing;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
|
||||
|
||||
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
||||
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout_->get_double_buffer()) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto info = *layout_->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
||||
if(parent->empty())
|
||||
builder_->SetInsertPoint(parent);
|
||||
else
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create pointers
|
||||
ptr_ = builder_->CreatePHI(ptr_ty, 2);
|
||||
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
|
||||
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
|
||||
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
|
||||
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
|
||||
builder_->SetInsertPoint(current);
|
||||
}
|
||||
else{
|
||||
size_t offset = alloc_->offset(layout_);
|
||||
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
|
||||
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
tile* machine_shared_layout::create(ir::value *v) {
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
auto double_buffer = layout_->get_double_buffer();
|
||||
// offset
|
||||
Value *offset = nullptr;
|
||||
if(double_buffer && v == double_buffer->phi)
|
||||
offset = offset_;
|
||||
// base pointer
|
||||
Value *ptr = ptr_;
|
||||
if(double_buffer && v == double_buffer->latch)
|
||||
ptr = next_ptr_;
|
||||
else if(double_buffer && v == double_buffer->first)
|
||||
ptr = pre_ptr_;
|
||||
// create tile
|
||||
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
||||
}
|
||||
|
||||
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::data_layout *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
tile *machine_distributed_layout::create(ir::value *v) {
|
||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> order(rank);
|
||||
// compute axes
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
// compute order
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
auto cmp = [&](int x, int y) {
|
||||
unsigned axx = a_axes_->get(v, x);
|
||||
unsigned axy = a_axes_->get(v, y);
|
||||
size_t posx = layout_->find_axis(axx);
|
||||
size_t posy = layout_->find_axis(axy);
|
||||
if(posx < rank && posy < rank)
|
||||
return layout_->get_order(posx) < layout_->get_order(posy);
|
||||
return false;
|
||||
};
|
||||
std::sort(order.begin(), order.end(), cmp);
|
||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||
}
|
||||
|
||||
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||
target *tgt, analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::mma884_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shape = layout->get_shape();
|
||||
if(shape.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shape.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw(0);
|
||||
unsigned fpw_1 = layout->fpw(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||
// mma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// mma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::scanline_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
||||
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
||||
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
}
|
||||
thread_id[order[dim - 1]] = full_thread_id;
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
int nts = layout->nts(k);
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts * mts;
|
||||
unsigned per_thread = nts * shape[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_block + n % nts;
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,214 +0,0 @@
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
/* Distributed Tile */
|
||||
void distributed_tile::init_indices() {
|
||||
std::vector<size_t> id(axes_.size(), 0);
|
||||
// build
|
||||
size_t k = 0;
|
||||
while(true) {
|
||||
indices_t current;
|
||||
for(size_t d = 0; d < id.size(); d++)
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[order_[0]]++;
|
||||
while(id[order_[k]] == axes_[order_[k]].values.size()){
|
||||
if(k == id.size() - 1)
|
||||
return;
|
||||
id[order_[k++]] = 0;
|
||||
id[order_[k]]++;
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
||||
if(end < 0)
|
||||
end = ordered_indices_.size() + end + 1;
|
||||
for(unsigned i = start; i < end; i++)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
|
||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
||||
int rank = sizes.size();
|
||||
int len = 1;
|
||||
for(int s: sizes)
|
||||
len *= s;
|
||||
|
||||
for(int i = 0; i < len; i++){
|
||||
indices_t idx(rank);
|
||||
int current = i;
|
||||
for(int k = 0; k < rank; k++){
|
||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
||||
current = current / sizes[k];
|
||||
}
|
||||
fn(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Shared Tile */
|
||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
||||
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
|
||||
if(dyn_cast<Constant>(arg)){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
return;
|
||||
}
|
||||
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
return;
|
||||
}
|
||||
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
|
||||
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
|
||||
if(cst_lhs && cst_rhs){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
}
|
||||
else if(cst_lhs){
|
||||
cst = cst_lhs;
|
||||
non_cst = bin_op->getOperand(1);
|
||||
}
|
||||
else if(cst_rhs){
|
||||
cst = cst_rhs;
|
||||
non_cst = bin_op->getOperand(0);
|
||||
}
|
||||
else{
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
|
||||
non_cst_idx.clear();
|
||||
cst_idx.clear();
|
||||
for(Value *idx: arg_idx){
|
||||
Value *non_cst, *cst;
|
||||
extract_constant(idx, non_cst, cst);
|
||||
non_cst_idx.push_back(non_cst);
|
||||
cst_idx.push_back(cst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||
// result
|
||||
Value *result = builder.getInt32(0);
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
||||
return_vector_ = false;
|
||||
if(perm_.empty()){
|
||||
perm_.resize(shapes.size());
|
||||
std::iota(perm_.begin(), perm_.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
}
|
||||
|
||||
void shared_tile::set_vector_size(unsigned vector_size) {
|
||||
vector_size_ = vector_size;
|
||||
}
|
||||
|
||||
void shared_tile::set_return_mode(bool return_vector){
|
||||
return_vector_ = return_vector;
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::get_value(indices_t idx) {
|
||||
indices_t non_cst_idx, cst_idx;
|
||||
extract_constant(idx, non_cst_idx, cst_idx);
|
||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||
unsigned vector_size = vector_size_;
|
||||
Type *ty = ty_;
|
||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
||||
ty = IntegerType::get(ty->getContext(), 32);
|
||||
vector_size = vector_size / 2;
|
||||
}
|
||||
if(base_ptr == nullptr){
|
||||
// BasicBlock* store = builder_.GetInsertBlock();
|
||||
// if(!non_cst_idx.empty())
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
||||
Value *result = builder_.CreateLoad(ptr);
|
||||
if(return_vector_ == false && vector_size_ > 1) {
|
||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
||||
result = builder_.CreateExtractElement(result, rem);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -14,6 +14,12 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
}
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
}
|
||||
@@ -25,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m
|
||||
|
||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||
return builder.CreateCall(barrier, {});
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
@@ -45,8 +51,7 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
Intrinsic::amdgcn_workgroup_id_z
|
||||
};
|
||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
||||
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
return group_id;
|
||||
}
|
||||
|
||||
@@ -99,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
@@ -120,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_nctaid, {});
|
||||
return builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
@@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) {
|
||||
|
||||
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
if(!layout_->get(id)->to_mma884())
|
||||
if(!layout_->get(id)->to_mma())
|
||||
continue;
|
||||
// extract memory stores
|
||||
const auto& values = layout_->values_of(id);
|
||||
|
@@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
@@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
if(to_shared){
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
|
@@ -54,7 +54,7 @@ void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
@@ -63,7 +63,10 @@ void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
if(type.first)
|
||||
builder.create_async_wait();
|
||||
if(type.second)
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,8 +88,9 @@ std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc,
|
||||
std::set<ir::value*>& safe_war) {
|
||||
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
|
||||
std::set<ir::value*>& safe_war,
|
||||
std::vector<ir::instruction*>& to_sync) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
@@ -95,6 +99,8 @@ std::pair<membar::interval_vec_t,
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
if(written.size())
|
||||
to_sync.push_back(i);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering
|
||||
@@ -104,9 +110,14 @@ std::pair<membar::interval_vec_t,
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
|
||||
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
|
||||
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
|
||||
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
|
||||
insert_loc.insert({i, {copy_async_wait, barrier}});
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
to_sync.clear();
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
@@ -125,17 +136,17 @@ void membar::run(ir::module &mod) {
|
||||
if(!layout || !layout->get_double_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi)
|
||||
if(v != layout->get_double_buffer()->phi){
|
||||
safe_war.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::set<ir::instruction*> insert_locs;
|
||||
std::vector<ir::instruction*> to_sync;
|
||||
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
do{
|
||||
@@ -150,7 +161,7 @@ void membar::run(ir::module &mod) {
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
}
|
||||
@@ -158,8 +169,9 @@ void membar::run(ir::module &mod) {
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs)
|
||||
insert_barrier(i, builder);
|
||||
for(auto x: insert_locs){
|
||||
insert_barrier(x.first, x.second, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -97,6 +97,24 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
|
||||
//}
|
||||
|
||||
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
|
||||
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
|
||||
if(!copy_to_shared)
|
||||
return false;
|
||||
ir::value *arg = copy_to_shared->get_operand(0);
|
||||
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
|
||||
if(!ld)
|
||||
return false;
|
||||
builder.set_insert_point(copy_to_shared);
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
ir::value *msk = ld->get_mask_operand();
|
||||
ir::value *val = ld->get_false_value_operand();
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
@@ -197,10 +215,12 @@ void peephole::run(ir::module &mod) {
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
// if(tgt_->as_nvidia()->sm() >= 80)
|
||||
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
|
51
lib/codegen/transform/reorder.cc
Normal file
51
lib/codegen/transform/reorder.cc
Normal file
@@ -0,0 +1,51 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
ir::value* _ptr = ld->get_pointer_operand();
|
||||
ir::value* _msk = ld->get_mask_operand();
|
||||
ir::value* _val = ld->get_false_value_operand();
|
||||
auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
auto val = std::find(block->begin(), block->end(), _val);
|
||||
if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
continue;
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
int dist_ptr = std::distance(ptr, it);
|
||||
int dist_msk = std::distance(msk, it);
|
||||
int dist_val = std::distance(val, it);
|
||||
if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
builder.set_insert_point(++ptr);
|
||||
if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
builder.set_insert_point(++msk);
|
||||
if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
builder.set_insert_point(++val);
|
||||
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& x: to_replace)
|
||||
x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -48,46 +48,6 @@ std::unique_ptr<codegen::target> host_device::make_target() const {
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
// architecture
|
||||
cu_device::Architecture cu_device::nv_arch(std::pair<unsigned int, unsigned int> sm) const {
|
||||
switch(sm.first) {
|
||||
case 7:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_7_0;
|
||||
}
|
||||
|
||||
case 6:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_6_0;
|
||||
case 1: return Architecture::SM_6_1;
|
||||
}
|
||||
|
||||
case 5:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_5_0;
|
||||
case 2: return Architecture::SM_5_2;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
case 3:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_3_0;
|
||||
case 5: return Architecture::SM_3_5;
|
||||
case 7: return Architecture::SM_3_7;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
case 2:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_2_0;
|
||||
case 1: return Architecture::SM_2_1;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
// information query
|
||||
template<CUdevice_attribute attr>
|
||||
int cu_device::cuGetInfo() const{
|
||||
@@ -108,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{
|
||||
return map.at(key);
|
||||
}
|
||||
|
||||
// architecture
|
||||
cu_device::Architecture cu_device::architecture() const{
|
||||
return nv_arch(compute_capability());
|
||||
}
|
||||
|
||||
// number of address bits
|
||||
size_t cu_device::address_bits() const{
|
||||
return sizeof(size_t)*8;
|
||||
@@ -133,17 +88,17 @@ std::string cu_device::pci_bus_id() const{
|
||||
}
|
||||
|
||||
// force the device to be interpreted as a particular cc
|
||||
void cu_device::interpret_as(std::pair<size_t, size_t> cc){
|
||||
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
|
||||
void cu_device::interpret_as(int cc){
|
||||
interpreted_as_ = std::make_shared<int>(cc);
|
||||
}
|
||||
|
||||
// compute capability
|
||||
std::pair<size_t, size_t> cu_device::compute_capability() const {
|
||||
int cu_device::compute_capability() const {
|
||||
if(interpreted_as_)
|
||||
return *interpreted_as_;
|
||||
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
size_t _minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||
return std::make_pair(_major, _minor);
|
||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||
return major*10 + minor;
|
||||
}
|
||||
|
||||
// maximum number of threads per block
|
||||
@@ -218,7 +173,7 @@ std::string cu_device::infos() const{
|
||||
|
||||
// target
|
||||
std::unique_ptr<codegen::target> cu_device::make_target() const {
|
||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
|
||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target(compute_capability()));
|
||||
}
|
||||
|
||||
|
||||
|
@@ -93,6 +93,7 @@ namespace driver
|
||||
|
||||
bool dispatch::cuinit(){
|
||||
if(cuda_==nullptr){
|
||||
putenv((char*)"CUDA_CACHE_DISABLE=1");
|
||||
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
||||
if(libcuda.empty())
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
|
@@ -20,7 +20,9 @@
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/error.h"
|
||||
@@ -41,6 +43,19 @@
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
||||
std::string exec(const char* cmd) {
|
||||
std::array<char, 128> buffer;
|
||||
std::string result;
|
||||
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(cmd, "r"), pclose);
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed!");
|
||||
}
|
||||
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
|
||||
result += buffer.data();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
@@ -63,11 +78,11 @@ void module::init_llvm() {
|
||||
}
|
||||
|
||||
module::module(CUmodule mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership) {
|
||||
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||
}
|
||||
|
||||
module::module(host_module_t mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership) {
|
||||
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||
}
|
||||
|
||||
|
||||
@@ -86,10 +101,12 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
||||
file_type_t ft) {
|
||||
init_llvm();
|
||||
// // debug
|
||||
// llvm::legacy::PassManager pm;
|
||||
llvm::legacy::PassManager pm;
|
||||
std::string tmp;
|
||||
// llvm::raw_string_ostream oss(llir_);
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*module);
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*module);
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
@@ -176,7 +193,7 @@ host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_
|
||||
|
||||
// create execution engine
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
hst_->functions[fn.getName().str()] = &fn;
|
||||
|
||||
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
||||
@@ -225,7 +242,8 @@ static std::map<int, int> vptx = {
|
||||
{10010, 64},
|
||||
{10020, 65},
|
||||
{11000, 70},
|
||||
{11010, 71}
|
||||
{11010, 71},
|
||||
{11020, 72}
|
||||
};
|
||||
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
@@ -238,9 +256,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
auto _cc = ((driver::cu_device*)device)->compute_capability();
|
||||
int cc = _cc.first*10 + _cc.second;
|
||||
cc = std::min(cc, max_nvvm_cc);
|
||||
int cc = ((driver::cu_device*)device)->compute_capability();
|
||||
std::string sm = "sm_" + std::to_string(cc);
|
||||
// driver version
|
||||
int version;
|
||||
@@ -251,12 +267,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
// PTX version
|
||||
int ptx = vptx.at(version);
|
||||
ptx = std::min(ptx, max_nvvm_ptx);
|
||||
int ptx_major = ptx / 10;
|
||||
int ptx_minor = ptx % 10;
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||
@@ -266,21 +281,69 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
}
|
||||
|
||||
|
||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { }
|
||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { }
|
||||
|
||||
cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){
|
||||
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
std::string errbuf(errbufsize, 0);
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
|
||||
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::invalid_ptx const &){
|
||||
// // compile ptx with ptxas
|
||||
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||
// char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||
// int fdsrc = mkstemp(_fsrc);
|
||||
// int fdlog = mkstemp(_flog);
|
||||
// std::string fsrc = _fsrc;
|
||||
// std::string flog = _flog;
|
||||
// std::ofstream ofs(fsrc);
|
||||
// ofs << source;
|
||||
// ofs.close();
|
||||
// std::string cmd;
|
||||
// int err;
|
||||
// driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
// err = system(cmd.c_str());
|
||||
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||
// std::ifstream file(flog);
|
||||
// std::string log;
|
||||
// if(file)
|
||||
// while (!file.eof()) log.push_back(file.get());
|
||||
// unlink(_fsrc);
|
||||
// unlink(_flog);
|
||||
|
||||
// std::smatch match;
|
||||
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
// spilled_ = 0;
|
||||
// while (std::regex_search (log,match,expr)){
|
||||
// spilled_ += std::stoi(match[1]);
|
||||
// log = match.suffix();
|
||||
// }
|
||||
// std::cout << log << std::endl;
|
||||
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||
CU_JIT_LOG_VERBOSE};
|
||||
unsigned int errbufsize = 8192;
|
||||
unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
|
||||
std::string err(_err);
|
||||
std::string log(_log);
|
||||
|
||||
// std::cout << log << std::endl;
|
||||
std::smatch match;
|
||||
std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
spilled_ = 0;
|
||||
while (std::regex_search(log,match,expr)){
|
||||
spilled_ += std::stoi(match[1]);
|
||||
log = match.suffix();
|
||||
}
|
||||
}
|
||||
catch(exception::cuda::invalid_ptx const &){
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cout << source << std::endl;
|
||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
// exit(1);
|
||||
//#endif
|
||||
throw;
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/constant.h"
|
||||
@@ -253,6 +254,15 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
|
||||
value *builder::create_load(value *ptr, const std::string &name){
|
||||
return insert(unmasked_load_inst::create(ptr, name));
|
||||
// type *ty = ptr->get_type()->get_pointer_element_ty();
|
||||
// value *mask = constant_int::get(get_int1_ty(), 1);
|
||||
// value *undef = undef_value::get(ty);
|
||||
// if(ptr->get_type()->is_tile_ty()){
|
||||
// auto shapes = ptr->get_type()->get_tile_shapes();
|
||||
// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name));
|
||||
// }
|
||||
// return insert(masked_load_inst::create(ptr, mask, undef, name));
|
||||
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||
@@ -263,6 +273,7 @@ value *builder::create_masked_load(value *ptr, value *mask, value *false_value,
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, name));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
|
||||
return insert(masked_store_inst::create(ptr, val, mask, name));
|
||||
}
|
||||
@@ -348,13 +359,22 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_to_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_from_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, name));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
return insert(barrier_inst::create(ctx_, name));
|
||||
}
|
||||
|
||||
value *builder::create_async_wait() {
|
||||
return insert(async_wait_inst::create(ctx_));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -463,6 +463,20 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
|
||||
return new masked_load_inst(ptr, mask, false_value, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, name, next);
|
||||
}
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
@@ -804,6 +818,14 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
||||
return new barrier_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
|
||||
|
||||
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new async_wait_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
|
||||
// nv_dynamic_program_idx
|
||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
|
@@ -65,7 +65,12 @@ void print(module &mod, std::ostream& os) {
|
||||
os << get_name(ops[i], cnt++);
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
}
|
||||
os << ";" << std::endl;
|
||||
os << ";";
|
||||
// os << " (";
|
||||
// for(ir::user* usr: inst->get_users())
|
||||
// os << get_name(usr, cnt++) << ", " ;
|
||||
// os << " )";
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
os << "}" << std::endl;
|
||||
|
@@ -68,9 +68,10 @@ unsigned user::get_num_hidden() const {
|
||||
|
||||
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
|
||||
for(size_t i = 0; i < ops_.size(); i++)
|
||||
if(ops_[i] == before)
|
||||
if(ops_[i] == before){
|
||||
ops_[i] = after;
|
||||
after->add_use(this);
|
||||
after->add_use(this);
|
||||
}
|
||||
return before->erase_use(this);
|
||||
}
|
||||
|
||||
|
@@ -56,10 +56,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
return set_ret(bld_->create_dot(lhs, rhs, _0));
|
||||
}
|
||||
case Token::MASKED_DEREF: {
|
||||
// TODO: FIXME
|
||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
|
||||
auto it = bld_->get_insert_block();
|
||||
if(ret_ty->is_tile_ty())
|
||||
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
|
||||
bld_->set_insert_point(it);
|
||||
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
|
||||
}
|
||||
case Token::ELLIPSIS: {
|
||||
@@ -274,9 +277,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
||||
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
||||
cond,
|
||||
false_val);
|
||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
|
||||
ld->replace_all_uses_with(new_ld);
|
||||
ld->erase_from_parent();
|
||||
return set_ret(new_ld);
|
||||
@@ -468,10 +469,10 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
|
||||
});
|
||||
if(init_)
|
||||
VisitStmt(init_);
|
||||
// VisitExpr(cond_);
|
||||
// ir::value *cond = ret_;
|
||||
// bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||
bld_->create_br(loop_bb);
|
||||
VisitExpr(cond_);
|
||||
ir::value *cond = ret_;
|
||||
bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||
// bld_->create_br(loop_bb);
|
||||
bld_->set_insert_point(loop_bb);
|
||||
if(body_)
|
||||
VisitStmt(body_);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#include <string>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <functional>
|
||||
@@ -9,11 +9,13 @@
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
@@ -29,6 +31,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/runtime/error.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
@@ -67,7 +70,7 @@ void _loop_nest(std::vector<size_t> const & ranges,
|
||||
/* OPTIONS */
|
||||
/* --------------------- */
|
||||
|
||||
std::string function::options_t::to_str() const{
|
||||
std::string options_t::to_str() const{
|
||||
std::string ret = "nw-" + std::to_string(num_warps);
|
||||
for(const auto& x : defines){
|
||||
ret += '-';
|
||||
@@ -110,41 +113,41 @@ arg_type convert(ir::type *ty) {
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
void function::caller::write(std::ofstream &ofs) {
|
||||
// write name
|
||||
ofs << name_ << std::endl;
|
||||
// write signature
|
||||
for(size_t i = 0; i < param_tys_.size(); i++)
|
||||
ofs << param_tys_[i] << " ";
|
||||
ofs << std::endl;
|
||||
// write module
|
||||
std::string source = ((driver::cu_module*)(&*parent_))->source();
|
||||
ofs << source;
|
||||
}
|
||||
//void function::caller::write(std::ofstream &ofs) {
|
||||
// // write name
|
||||
// ofs << name_ << std::endl;
|
||||
// // write signature
|
||||
// for(size_t i = 0; i < param_tys_.size(); i++)
|
||||
// ofs << param_tys_[i] << " ";
|
||||
// ofs << std::endl;
|
||||
// // write module
|
||||
// std::string source = ((driver::cu_module*)(&*parent_))->ptx();
|
||||
// ofs << source;
|
||||
//}
|
||||
|
||||
void function::caller::read(std::ifstream &ifs) {
|
||||
// read name
|
||||
std::getline(ifs, name_);
|
||||
// read signature
|
||||
std::string line;
|
||||
std::getline(ifs, line);
|
||||
std::istringstream current(line);
|
||||
int param;
|
||||
param_tys_.clear();
|
||||
while(current >> param)
|
||||
param_tys_.push_back((arg_type)param);
|
||||
// read module
|
||||
std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||
std::istreambuf_iterator<char>());
|
||||
parent_.reset(new driver::cu_module(src));
|
||||
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||
//void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
|
||||
// // read name
|
||||
// std::getline(ifs, name_);
|
||||
// // read signature
|
||||
// std::string line;
|
||||
// std::getline(ifs, line);
|
||||
// std::istringstream current(line);
|
||||
// int param;
|
||||
// param_tys_.clear();
|
||||
// while(current >> param)
|
||||
// param_tys_.push_back((arg_type)param);
|
||||
// // read module
|
||||
// std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||
// std::istreambuf_iterator<char>());
|
||||
// parent_.reset(new driver::cu_module(ctx, src));
|
||||
// bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||
|
||||
}
|
||||
//}
|
||||
|
||||
function::caller::caller(std::ifstream &ifs, const options_t& opt)
|
||||
: opt_(opt) {
|
||||
read(ifs);
|
||||
}
|
||||
//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
|
||||
// : opt_(opt) {
|
||||
// read(ctx, ifs);
|
||||
//}
|
||||
|
||||
function::caller::caller(ir::function *ir,
|
||||
std::shared_ptr<driver::module> parent, const options_t& opt)
|
||||
@@ -198,20 +201,23 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
// optimizations
|
||||
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::peephole peephole(target.get());
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::cts cts;
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
dce.run(module);
|
||||
disassociate.run(module);
|
||||
@@ -233,17 +239,20 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
||||
}
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
swizzle.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > device->max_shared_memory())
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
throw exception::out_of_shared_memory();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
|
||||
if(res->spilled() > 256)
|
||||
throw exception::out_of_registers();
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -265,11 +274,11 @@ void function::make(driver::device *device, options_t opt) {
|
||||
auto ir = make_ir(parser);
|
||||
// triton-ir -> binary
|
||||
std::unique_ptr<driver::module> bin;
|
||||
// try{
|
||||
try{
|
||||
bin = make_bin(*ir, device, opt);
|
||||
// }catch(const std::runtime_error&){
|
||||
// return nullptr;
|
||||
// }
|
||||
}catch(const exception::base&){
|
||||
throw;
|
||||
}
|
||||
// create callable
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
|
||||
@@ -283,6 +292,7 @@ void function::precompile(driver::device* device, const options_space_t& space)
|
||||
for(const auto& x: space.defines)
|
||||
ranges.push_back(x.second.size());
|
||||
// functor for source with given option
|
||||
std::map<options_t, std::string> err;
|
||||
auto do_make = [&](std::vector<size_t> params) {
|
||||
// compilation options
|
||||
unsigned i = 0;
|
||||
@@ -291,20 +301,73 @@ void function::precompile(driver::device* device, const options_space_t& space)
|
||||
for(auto D: space.defines)
|
||||
opt.defines[D.first] = D.second[params[i++]];
|
||||
// compile
|
||||
make(device, opt);
|
||||
try{
|
||||
make(device, opt);
|
||||
}catch(const exception::base& e){
|
||||
err[opt] = e.what();
|
||||
}
|
||||
};
|
||||
// multi-threaded compilation
|
||||
_loop_nest(ranges, do_make);
|
||||
if(callers_.empty())
|
||||
throw std::runtime_error("could not compile kernel");
|
||||
if(callers_.empty()){
|
||||
std::ostringstream dbg;
|
||||
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
||||
for(auto x: err){
|
||||
dbg << "[ ";
|
||||
dbg << x.first.num_warps << ", ";
|
||||
dbg << "{ ";
|
||||
for(const auto& y: x.first.defines)
|
||||
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
|
||||
dbg << " } ] -> " << x.second << std::endl;
|
||||
}
|
||||
throw exception::no_valid_configuration(dbg.str());
|
||||
}
|
||||
}
|
||||
|
||||
std::string function::ptx(driver::device* device, const options_t& opt) {
|
||||
std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) {
|
||||
make(device, opt);
|
||||
const auto& fn = callers_.at(opt);
|
||||
if(!fn)
|
||||
return "";
|
||||
return ((driver::cu_module*)fn->parent())->source();
|
||||
switch(mode){
|
||||
case ASM_LLIR:{
|
||||
return fn->parent()->llir();
|
||||
}
|
||||
case ASM_NV_PTX:
|
||||
case ASM_NV_SASS:{
|
||||
std::string ptx = ((driver::cu_module*)fn->parent())->ptx();
|
||||
// SASS
|
||||
std::string input = std::tmpnam(nullptr);
|
||||
std::string output = std::tmpnam(nullptr);
|
||||
std::ofstream ofs(input);
|
||||
ofs << ptx;
|
||||
ofs.close();
|
||||
if(mode == ASM_NV_PTX)
|
||||
return ptx;
|
||||
std::string cmd;
|
||||
int err;
|
||||
// compile ptx
|
||||
driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
|
||||
err = system(cmd.c_str());
|
||||
// disassemble
|
||||
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
|
||||
err = system(cmd.c_str());
|
||||
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
|
||||
std::string to_delete = " /*";
|
||||
std::ifstream ifs(output);
|
||||
std::string line;
|
||||
std::string sass;
|
||||
while(std::getline(ifs, line))
|
||||
if(!std::regex_match(line, comment))
|
||||
sass += line + "\n";
|
||||
return sass;
|
||||
}
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
// returns program with best compilation options for given parameter
|
||||
|
Reference in New Issue
Block a user