[general] various cleaning and bugfix:

* added copy1d and copy2d benchmark
* fixed issue in reassociation pass
This commit is contained in:
Philippe Tillet
2019-09-02 23:00:49 -04:00
parent 90d80c3b2e
commit a842d337c5
36 changed files with 265 additions and 191 deletions

View File

@@ -13,7 +13,7 @@ namespace ir {
namespace codegen{
namespace analysis{
class alignment_info {
class align {
struct cst_info {
unsigned num_cst;
unsigned value;

View File

@@ -13,11 +13,10 @@ namespace ir{
namespace codegen{
namespace analysis{
namespace shmem{
typedef unsigned slot_index;
class info;
class meminfo;
struct segment {
slot_index start;
@@ -45,7 +44,7 @@ public:
public:
// constructor
liveness(info *info): info_(info){ }
liveness(meminfo *info): info_(info){ }
// accessors
const intervals_map_t& intervals() const { return intervals_; }
@@ -55,7 +54,7 @@ public:
void run(ir::module &mod);
private:
info *info_;
meminfo *info_;
has_storage_map_t has_dedicated_storage_;
indices_map_t indices_;
intervals_map_t intervals_;
@@ -64,7 +63,6 @@ private:
}
}
}
}
#endif

View File

@@ -17,14 +17,12 @@ namespace analysis{
class grids;
namespace shmem{
class liveness;
class info;
class meminfo;
class allocation {
class memalloc {
public:
allocation(liveness *live, info *buffer_info, grids *params)
memalloc(liveness *live, meminfo *buffer_info, grids *params)
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
// utilities
@@ -44,13 +42,12 @@ private:
size_t allocated_size_;
// dependences
liveness *liveness_;
info *buffer_info_;
meminfo *buffer_info_;
grids *params_;
};
}
}
}
}
#endif

View File

@@ -15,9 +15,8 @@ namespace ir {
namespace codegen{
namespace analysis{
namespace shmem{
class info {
class meminfo {
public:
void run(ir::module &mod);
// queries
@@ -38,6 +37,5 @@ private:
}
}
}
}
#endif

View File

@@ -5,7 +5,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/meminfo.h"
namespace llvm{
@@ -45,14 +45,10 @@ namespace codegen{
namespace analysis{
class grids;
class alignment_info;
class align;
class memalloc;
class meminfo;
namespace shmem{
class allocation;
class info;
}
}
class target;
@@ -196,7 +192,7 @@ private:
public:
selection(analysis::shmem::allocation *alloc, analysis::grids *params, analysis::shmem::info *buffer_info, analysis::alignment_info *alignment, target *tgt)
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), tgt_(tgt){ }
void run(ir::module &src, Module &dst);
@@ -204,10 +200,10 @@ public:
private:
vmap_t vmap_;
tmap_t tmap_;
analysis::shmem::allocation *alloc_;
analysis::memalloc *alloc_;
analysis::grids *params_;
analysis::shmem::info *buffer_info_;
analysis::alignment_info *alignment_;
analysis::meminfo *buffer_info_;
analysis::align *alignment_;
target *tgt_;
std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_;

View File

@@ -46,6 +46,7 @@ public:
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
bool is_gpu() const;
private:
@@ -62,6 +63,7 @@ public:
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class nvidia_cu_target: public target {
@@ -74,6 +76,7 @@ public:
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class cpu_target: public target {
@@ -86,6 +89,7 @@ public:
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 1; }
};
}

View File

@@ -14,17 +14,15 @@ namespace ir {
namespace codegen{
namespace analysis{
namespace shmem{
class allocation;
class info;
class memalloc;
class meminfo;
}
}
namespace transform{
class shmem_barriers {
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
@@ -40,12 +38,12 @@ private:
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
public:
shmem_barriers(analysis::shmem::allocation *alloc, analysis::shmem::info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
membar(analysis::memalloc *alloc, analysis::meminfo *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
void run(ir::module &mod);
private:
analysis::shmem::allocation *alloc_;
analysis::shmem::info *buffer_info_;
analysis::memalloc *alloc_;
analysis::meminfo *buffer_info_;
};

View File

@@ -20,7 +20,7 @@ namespace codegen{
namespace analysis{
class grids;
class alignment_info;
class align;
}
namespace transform{
@@ -37,12 +37,12 @@ private:
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
reassociate(analysis::alignment_info* align, analysis::grids *params);
reassociate(analysis::align* align, analysis::grids *params);
void run(ir::module& module);
private:
analysis::grids* params_;
analysis::alignment_info* align_;
analysis::align* align_;
};
}

View File

@@ -61,6 +61,19 @@ public:
return kind_ != multiple_of;
}
std::string repr() const {
switch(kind_){
case readonly: return ".readonly";
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".readonly";
default: break;
}
assert(false);
return "";
}
private:
attribute_kind_t kind_;
unsigned value_;

View File

@@ -687,7 +687,7 @@ private:
public:
static nv_static_program_idx *get(constant_range* range);
constant_range* get_range() const;
std::string repr() const { return get_name(); }
std::string repr() const { return "nv_static_program_idx"; }
private:
constant_range *range_;

View File

@@ -9,16 +9,16 @@
#include <memory>
#include <functional>
// codegen
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
#include "triton/lang/parser.h"

View File

@@ -41,8 +41,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();

View File

@@ -1,4 +1,4 @@
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/analysis/align.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -29,14 +29,14 @@ inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
}
bool alignment_info::is_first_axis_unit(ir::value *x){
bool align::is_first_axis_unit(ir::value *x){
if(x->get_type()->is_tile_ty())
return x->get_type()->get_tile_shapes()[0] == 1;
else
return true;
}
alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
align::cst_info align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
// helper for the cache
@@ -102,7 +102,7 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
return cache({1, 0});
}
unsigned alignment_info::populate_max_contiguous(ir::value *v){
unsigned align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
// helper for the cache
@@ -181,7 +181,7 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
return cache(1);
}
unsigned alignment_info::populate_starting_multiple(ir::value *v){
unsigned align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
auto cache = [this,v](unsigned value){
@@ -240,7 +240,19 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
int rhs = populate_starting_multiple(x->get_operand(1));
return cache(gcd(lhs, rhs));
}
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
if(auto *x = dynamic_cast<ir::splat_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
return cache(op);
}
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
auto shapes = x->get_type()->get_tile_shapes();
if(shapes[0] == 1)
return cache(1);
else
return cache(op);
}
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
return cache(op);
}
@@ -271,22 +283,22 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
return cache(result);
}
unsigned alignment_info::get_starting_multiple(ir::value* v) const {
unsigned align::get_starting_multiple(ir::value* v) const {
return starting_multiple_.at(v);
}
unsigned alignment_info::get_max_contiguous(ir::value* v) const {
unsigned align::get_max_contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void alignment_info::copy(ir::value *dst, ir::value *src) {
void align::copy(ir::value *dst, ir::value *src) {
starting_multiple_[dst] = starting_multiple_[src];
max_contiguous_[dst] = max_contiguous_[src];
is_constant_[dst] = is_constant_[src];
}
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
void alignment_info::run(ir::module &mod) {
void align::run(ir::module &mod) {
// populate constant
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
@@ -304,9 +316,13 @@ void alignment_info::run(ir::module &mod) {
// populate maximum contiguous
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
for(ir::instruction *i: block->get_inst_list())
populate_max_contiguous(i);
}
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction *i: block->get_inst_list())
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << std::endl;
}

View File

@@ -1,6 +1,6 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
@@ -292,7 +292,7 @@ void grids::run(ir::module &mod) {
else{
unsigned shape = shapes[0];
unsigned current = num_threads;
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 4));
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
current = current / params_.at(i).at("mts.d0")->get_value();
for(size_t d = 1; d < shapes.size(); d++){

View File

@@ -1,5 +1,5 @@
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
@@ -9,7 +9,6 @@
namespace triton{
namespace codegen{
namespace analysis{
namespace shmem{
// Entry point
void liveness::run(ir::module &mod) {
@@ -41,4 +40,3 @@ void liveness::run(ir::module &mod) {
}
}
}
}

View File

@@ -1,8 +1,8 @@
#include <algorithm>
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/type.h"
#include "triton/ir/value.h"
@@ -12,9 +12,8 @@
namespace triton{
namespace codegen{
namespace analysis{
namespace shmem{
unsigned allocation::is_ld_padded(ir::value *x) {
unsigned memalloc::is_ld_padded(ir::value *x) {
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
if(trans->get_perm()[0]->get_value() != 0)
return 4;
@@ -46,7 +45,7 @@ unsigned allocation::is_ld_padded(ir::value *x) {
return 0;
}
unsigned allocation::get_num_bytes(ir::value *x) {
unsigned memalloc::get_num_bytes(ir::value *x) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis();
@@ -74,7 +73,7 @@ unsigned allocation::get_num_bytes(ir::value *x) {
return num_bytes;
}
void allocation::run(){
void memalloc::run(){
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
@@ -178,4 +177,3 @@ void allocation::run(){
}
}
}
}

View File

@@ -1,5 +1,5 @@
#include <algorithm>
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -10,10 +10,9 @@ namespace triton {
namespace codegen{
namespace analysis{
namespace shmem{
// run pass on module
bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
@@ -25,7 +24,7 @@ bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
throw std::runtime_error("unreachable");
}
void info::replace(ir::value* before, ir::value *after) {
void meminfo::replace(ir::value* before, ir::value *after) {
shared_.erase(before);
shared_.insert(after);
if(refs_.find(before) != refs_.end()){
@@ -72,7 +71,7 @@ void add_copy(ir::value *x, ir::builder &builder) {
}
}
void info::run(ir::module &mod) {
void meminfo::run(ir::module &mod) {
// Add shared copies
for(ir::function *fn: mod.get_function_list()){
ir::builder builder(mod.get_context());
@@ -122,15 +121,15 @@ void info::run(ir::module &mod) {
}
// query double-buffered status
bool info::is_double(ir::value *x)
bool meminfo::is_double(ir::value *x)
{ return double_.find(x) != double_.end(); }
// query shared status
bool info::is_shared(ir::value *x)
bool meminfo::is_shared(ir::value *x)
{ return shared_.find(x) != shared_.end(); }
// get reference if any
ir::value *info::get_reference(ir::value *x)
ir::value *meminfo::get_reference(ir::value *x)
{ return refs_[x]; }
@@ -138,4 +137,3 @@ ir::value *info::get_reference(ir::value *x)
}
}
}
}

View File

@@ -1,8 +1,8 @@
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/align.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
@@ -1304,10 +1304,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
Value *ptr = pointers->get_value(idx);
// ConstantInt *cst = nullptr;
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
// cst = dyn_cast<ConstantInt>(gep->idx_begin());
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
@@ -1326,23 +1323,28 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
Value *result_false = false_values->get_value(idx);
if(result_then->getType()->isVectorTy())
result_false = builder.CreateVectorSplat(vector_size, result_false);
result_false = builder.CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType()));
((PHINode*)current_result)->addIncoming(result_false, current_bb);
}
else
current_result = result_then;
// ConstantInt *cst = nullptr;
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
// cst = dyn_cast<ConstantInt>(gep->idx_begin());
// llvm::Value* mask = masks->get_value(idx);
// std::string offset = "";
// if(cst)
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
// Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
// if(false_value)
// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
// if(false_values)
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
// Value *result = builder.CreateCall(iasm, {mask, ptr});
// Value *current_result = builder.CreateCall(iasm, {mask, ptr});
packets[id] = current_result;
}
@@ -1499,9 +1501,11 @@ void selection::run(ir::module &src, Module &dst) {
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute attr: attr_pair.second)
if(attr.is_llvm_attr())
if(attr.is_llvm_attr()){
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
}
}
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
// set metadata
Metadata *md_args[] = {

View File

@@ -1,4 +1,4 @@
#include "triton/codegen/selection/target.h"
#include "triton/codegen/target.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"

View File

@@ -2,9 +2,9 @@
#include <set>
#include <algorithm>
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -16,7 +16,7 @@ namespace triton {
namespace codegen{
namespace transform{
bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
@@ -24,13 +24,13 @@ bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
});
}
bool shmem_barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){
void membar::add_reference(ir::value *v, interval_vec_t &res){
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
unsigned offset = alloc_->get_offset(v);
unsigned num_bytes = alloc_->get_num_bytes(v);
@@ -38,17 +38,17 @@ void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){
}
}
void shmem_barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
}
void shmem_barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i))
add_reference(i, res);
}
void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
@@ -67,16 +67,16 @@ void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder
}
}
shmem_barriers::interval_vec_t shmem_barriers::join(const std::vector<interval_vec_t>& intervals) {
shmem_barriers::interval_vec_t result;
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
}
std::pair<shmem_barriers::interval_vec_t,
shmem_barriers::interval_vec_t> shmem_barriers::transfer(ir::basic_block *block,
std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::set<ir::instruction*>& insert_loc) {
@@ -104,7 +104,7 @@ std::pair<shmem_barriers::interval_vec_t,
return std::make_pair(new_written_to, new_read_from);
}
void shmem_barriers::run(ir::module &mod) {
void membar::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);

View File

@@ -1,7 +1,8 @@
#include <algorithm>
#include <iostream>
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -161,7 +162,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
return new_value;
}
reassociate::reassociate(analysis::alignment_info *align, analysis::grids* params)
reassociate::reassociate(analysis::align *align, analysis::grids* params)
: params_(params), align_(align)
{ }
@@ -209,6 +210,29 @@ void reassociate::run(ir::module &mod) {
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list()){
// retiling
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
ir::value* op = rt->get_operand(0);
if(infos.find(op) != infos.end()){
builder.set_insert_point(rt);
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
ir::value* dyn = infos.at(op).dyn_ptr;
ir::value* cst = *sta->idx_begin();
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
auto shapes = rt->get_type()->get_tile_shapes();
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
ir::value* broadcast = builder.create_broadcast(cst, shapes);
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
params_->copy(ndyn, rt);
params_->copy(nsta, rt);
params_->copy(broadcast, rt);
align_->copy(ndyn, rt);
align_->copy(nsta, rt);
align_->copy(broadcast, rt);
infos[rt] = cst_info{ndyn, nsta};
}
}
}
// getelementptr instruction
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
if(replaced.find(pz) != replaced.end())

View File

@@ -1,5 +1,5 @@
#include "triton/codegen/transform/vectorize.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"

View File

@@ -27,7 +27,7 @@
#include <memory>
#include "triton/driver/device.h"
#include "triton/driver/context.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/target.h"
namespace triton
{

View File

@@ -223,12 +223,12 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"])->setValue(true);
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_70", "", buffer, "", Assembly);
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_60", "", buffer, "", Assembly);
std::string result(buffer.begin(), buffer.end());
size_t start_replace = result.find(".version");
size_t end_replace = result.find('\n', start_replace);
assert(start_replace != std::string::npos);
result.replace(start_replace, end_replace - start_replace, ".version 6.4");
result.replace(start_replace, end_replace - start_replace, ".version 6.0");
return result;
}
@@ -245,10 +245,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
}catch(exception::cuda::base const &){
#ifdef TRITON_LOG_PTX_ERROR
//#ifdef TRITON_LOG_PTX_ERROR
std::cerr << "Compilation Failed! Log: " << std::endl;
std::cerr << errbuf << std::endl;
#endif
//#endif
throw;
}
}

View File

@@ -29,6 +29,7 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
value->set_name(name);
}
void module::set_value(const std::string& name, ir::value *value){

View File

@@ -22,6 +22,18 @@ std::string get_name(ir::value *v, unsigned i) {
void print(module &mod, std::ostream& os) {
unsigned cnt = 0;
for(ir::function *fn: mod.get_function_list()){
os << "def " << fn->get_fn_type()->get_return_ty()->repr() << " " << fn->get_name() << "(" ;
for(ir::argument* arg: fn->args()) {
if(arg->get_arg_no() > 0)
os << ", ";
os << arg->get_type()->repr() << " " << arg->get_name();
auto attrs = fn->get_attributes(arg);
if(attrs.size() > 0)
os << " ";
for(ir::attribute attr: attrs)
os << attr.repr() << " ";
}
os << ")" << std::endl;
os << "{" << std::endl;
for(ir::basic_block *block: fn->blocks()){
auto const &predecessors = block->get_predecessors();

View File

@@ -373,8 +373,11 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
for(Object* obj: type->Params()){
std::string name = obj->Name();
args[i]->set_name(name);
for(ASTNode::Attr attr: obj->GetAttrList())
fn->add_attr(i, GenIRAttr(attr));
if(obj->Type()->ToPointer())
fn->add_attr(i + 1, ir::attribute(ir::aligned, 16));
for(ASTNode::Attr attr: obj->GetAttrList()){
fn->add_attr(i + 1, GenIRAttr(attr));
}
if(obj->IsRestrictQualified())
fn->add_attr(i, ir::attribute(ir::noalias));
mod_->set_value(name, nullptr, args[i]);

View File

@@ -3,7 +3,7 @@
#include <regex>
#include <functional>
#include <algorithm>
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection.h"
#include "triton/runtime/function.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
@@ -167,8 +167,6 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
bin = make_bin(*ir, stream->context(), opt);
}catch(const std::runtime_error& e) {
return;
}catch(const driver::exception::cuda::invalid_ptx& e) {
return;
}
// benchmark
ir::function *tmp = ir->get_function_list()[0];
@@ -191,23 +189,31 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
std::unique_ptr<codegen::target> target = context->device()->make_target();
// create passes
codegen::analysis::grids grids(opt.num_warps);
codegen::analysis::shmem::info shmem_info;
codegen::analysis::shmem::liveness shmem_liveness(&shmem_info);
codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::alignment_info alignment_info;
codegen::transform::shmem_barriers shmem_barriers(&shmem_allocation, &shmem_info);
codegen::analysis::meminfo shmem_info;
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::align alignment_info;
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&grids);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&alignment_info, &grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
// run passes
peephole.run(module);
dce.run(module);
grids.run(module);
alignment_info.run(module);
grids.run(module);
// ir::print(module, std::cout);
reassociate.run(module);
dce.run(module);
// ir::print(module, std::cout);
peephole.run(module);
if(target->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
@@ -217,7 +223,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
vectorize.run(module);
dce.run(module);
// ir::print(module, std::cout);
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));

View File

@@ -1,4 +1,4 @@
foreach(PROG dot)
foreach(PROG dot copy1d copy2d)
set(TARGET bench_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})

View File

@@ -17,7 +17,7 @@ inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
};
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
return [M, N](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))};
@@ -42,11 +42,9 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
// create options
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
if(AT)
opt.defines.push_back({"AT", {""}});
if(BT)
opt.defines.push_back({"BT", {""}});
opt.defines.push_back({"TM", {"64"}});
opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"64"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4};
@@ -55,18 +53,18 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
// benchmark available libraries
std::vector<double> result;
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
// // cublas
// if(cublas::cublasinit()){
// NumericT alpha(static_cast<double>(1));
// NumericT beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
// result.push_back(tflops(cublas_ms));
// }
// cublas
if(cublas::cublasinit()){
NumericT alpha(static_cast<double>(1));
NumericT beta(static_cast<double>(0));
cublasGemmAlgo_t fastest = CUBLAS_GEMM_ALGO5;
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
&alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
result.push_back(tflops(cublas_ms));
}
// triton
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);}, stream);
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);}, stream);
result.push_back(tflops(triton_ms));
// done
return result;
@@ -79,11 +77,9 @@ int main() {
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false}}){
for(auto x: std::vector<std::array<bool, 2>>{{false, true}}){
std::vector<config_t> tmp = {
config_t{x[0], x[1], 8192, 8192, 8192}
config_t{x[0], x[1], 2048, 2048, 2048}
// config_t{x[0], x[1], 16, 2048, 2048},
// config_t{x[0], x[1], 32, 2048, 2048},
// config_t{x[0], x[1], 64, 2048, 2048},

View File

@@ -24,11 +24,11 @@ typedef enum{
typedef enum {
CUBLAS_GEMM_DFALT = -1,
CUBLAS_GEMM_DEFAULT = -1,
CUBLAS_GEMM_ALGO0 = 0,
CUBLAS_GEMM_ALGO1 = 1,
CUBLAS_GEMM_ALGO2 = 2,
CUBLAS_GEMM_ALGO3 = 3,
CUBLAS_GEMM_ALGO4 = 4,
CUBLAS_GEMM_ALGO0 = 0, // maxwell_sgemm_32x128_nt
CUBLAS_GEMM_ALGO1 = 1, // maxwell_sgemm_64x64_nt
CUBLAS_GEMM_ALGO2 = 2, // maxwell_sgemm_128x32_nt
CUBLAS_GEMM_ALGO3 = 3, // maxwell_sgemm_128x64_nt
CUBLAS_GEMM_ALGO4 = 4, // maxwell_sgemm_128x128_nt
CUBLAS_GEMM_ALGO5 = 5,
CUBLAS_GEMM_ALGO6 = 6,
CUBLAS_GEMM_ALGO7 = 7,
@@ -102,4 +102,4 @@ typedef enum {
CUBLAS_TENSOR_OP_MATH = 1
} cublasMath_t;
#endif
#endif

View File

@@ -2,33 +2,33 @@ namespace src {
const char *dot =
R"(
#ifdef AT
#if AT == 1
#define USEA ^a
#define STRIDE_AK lda
#define STRIDE_AM 1
#define STRIDE_AK 1
#define STRIDE_AM lda
#define BROADCAST_AK :, newaxis
#define BROADCAST_AM newaxis, :
#define SHAPE_A TK, TM
#else
#define USEA a
#define STRIDE_AK 1
#define STRIDE_AM lda
#define STRIDE_AK lda
#define STRIDE_AM 1
#define BROADCAST_AK newaxis, :
#define BROADCAST_AM :, newaxis
#define SHAPE_A TM, TK
#endif
#ifdef BT
#if BT == 1
#define USEB ^b
#define STRIDE_BK 1
#define STRIDE_BN ldb
#define STRIDE_BK ldb
#define STRIDE_BN 1
#define BROADCAST_BK newaxis, :
#define BROADCAST_BN :, newaxis
#define SHAPE_B TN, TK
#else
#define USEB b
#define STRIDE_BK ldb
#define STRIDE_BN 1
#define STRIDE_BK 1
#define STRIDE_BN ldb
#define BROADCAST_BK :, newaxis
#define BROADCAST_BN newaxis, :
#define SHAPE_B TK, TN
@@ -58,17 +58,15 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
c += USEA @ USEB;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
a = ((bool[SHAPE_A])(k > TK)) ? *pa : 0;
b = ((bool[SHAPE_B])(k > TK)) ? *pb : 0;
}
// epilogue
int rxc[TM] = ridx * TM + 0 ... TM;
int ryc[TN] = ridy * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :];
*?(checkc) pc = c;
TYPE* pc[TM, TN] = C + ryc[newaxis, :] * ldc + rxc[:, newaxis];
*pc = c;
}
)";
}

View File

@@ -3,21 +3,35 @@
#ifndef _TRITON_TESTS_UTIL_H
#define _TRITON_TESTS_UTIL_H
#include <iomanip>
#include "triton/runtime/function.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
};
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
inline rt::function::grid_fn_ty grid1d(size_t N) {
return [N](const rt::function::options_t& x) {
return rt::grid_t{ceil(N, x.D<int>("TN"))};
};
}
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
return [M, N](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))};
};
}
enum order_t {
ROWMAJOR,
COLMAJOR
};
namespace aux{
template<std::size_t...> struct seq{};
@@ -51,11 +65,14 @@ namespace testing {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false;
}
return true;
}
}
#endif
#endif

View File

@@ -1,5 +1,5 @@
foreach(PROG dot)
set(TARGET test_${PROG})
set(TARGET unit_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
target_link_libraries(${TARGET} triton dl)

View File

@@ -51,8 +51,8 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
typedef half_float::half NumericT;
std::string ty = "half";
typedef float NumericT;
std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
std::vector<NumericT> hc(M*N);
@@ -78,17 +78,15 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_
// run
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
if(AT)
opt.defines.push_back({"AT", {""}});
if(BT)
opt.defines.push_back({"BT", {""}});
opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {std::to_string(TM)}});
opt.defines.push_back({"TN", {std::to_string(TN)}});
opt.defines.push_back({"TK", {std::to_string(TK)}});
opt.num_warps = {nwarp};
rt::function function(src::dot, opt);
try {
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);
} catch (const std::runtime_error& e) {
return true;
}