[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)

This massively simplifies implementation of `reassociate` and also fixes
a bunch of bug. The pass could still be improved, but can already be used
to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
Philippe Tillet
2021-05-07 17:54:37 -04:00
committed by Philippe Tillet
parent e16bee1a27
commit 840140bf26
12 changed files with 204 additions and 667 deletions

View File

@@ -74,6 +74,33 @@ struct distributed_axis {
Value* thread_id;
};
class adder{
public:
adder(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class multiplier{
public:
multiplier(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class geper{
public:
geper(Builder** builder): builder_(builder) { }
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
private:
Builder** builder_;
};
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void init_idx(ir::value *x);
@@ -143,9 +170,9 @@ public:
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
void visit_make_range_dyn(ir::make_range_dyn*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_make_range_sta(ir::make_range_sta*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
@@ -195,6 +222,11 @@ private:
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
// helper for creating llvm values
adder add;
multiplier mul;
geper gep;
};
}

View File

@@ -1,49 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
#define TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
#include <map>
#include <set>
#include <vector>
namespace triton {
// forward declaration
namespace ir {
class module;
class value;
class builder;
class instruction;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class tiles;
class align;
}
namespace transform{
class reassociate {
struct cst_info {
ir::value* dyn_ptr;
ir::getelementptr_inst* sta_ptr;
};
private:
ir::instruction* is_bin_add(ir::value *x);
ir::value *reassociate_idx(ir::value *value, ir::builder &builder, ir::value *&noncst, ir::value *&cst);
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@@ -821,33 +821,33 @@ private:
int N_;
};
// On NVIDIA, implementation is such that
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
// so as to enable re-association on nv_static_program_idx which is constant
class make_range_dyn: public instruction {
private:
make_range_dyn(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
_TRITON_DEFINE_CLONE(make_range_dyn)
_TRITON_DEFINE_ACCEPT(make_range_dyn)
//// On NVIDIA, implementation is such that
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
//// so as to enable re-association on nv_static_program_idx which is constant
//class make_range_dyn: public instruction {
//private:
// make_range_dyn(type *ty, const std::string &name, instruction *next);
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
// _TRITON_DEFINE_CLONE(make_range_dyn)
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
public:
static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
};
//public:
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
//};
class make_range_sta: public constant {
private:
make_range_sta(make_range *range);
//class make_range_sta: public constant {
//private:
// make_range_sta(make_range *range);
public:
static make_range_sta *get(make_range* range);
make_range* get_range() const;
std::string repr() const { return "nv_static_program_idx"; }
_TRITON_DEFINE_ACCEPT(make_range_sta)
//public:
// static make_range_sta *get(make_range* range);
// make_range* get_range() const;
// std::string repr() const { return "nv_static_program_idx"; }
// _TRITON_DEFINE_ACCEPT(make_range_sta)
private:
make_range *range_;
};
//private:
// make_range *range_;
//};
/* constant range */

View File

@@ -144,12 +144,12 @@ public:
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_make_range_sta(make_range_sta*) = 0;
// virtual void visit_make_range_sta(make_range_sta*) = 0;
virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0;

View File

@@ -174,8 +174,6 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
@@ -322,8 +320,6 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
auto shapes = v->get_type()->get_block_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
@@ -486,10 +482,6 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
return add_to_cache(x, {128}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))

View File

@@ -12,7 +12,6 @@
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/driver/device.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
@@ -48,7 +47,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate;
// codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
// run passes
@@ -76,7 +75,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
align.run(ir);
dce.run(ir);
if (target->is_gpu()) {
reassociate.run(ir);
// reassociate.run(ir);
cts.run(ir);
}
dce.run(ir);

View File

@@ -23,6 +23,63 @@ namespace codegen{
using namespace llvm;
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
// (x + cst) + y -> (x + y) + cst
if(auto* bin = dyn_cast<BinaryOperator>(x))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1))){
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
bin->getOperand(1));
}
// (x + (y + cst)) -> (x + y) + cst
if(auto* bin = dyn_cast<BinaryOperator>(y))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1))){
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
bin->getOperand(1));
}
// default
return (*builder_)->CreateAdd(x, y, name);
}
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
if(auto* bin = dyn_cast<BinaryOperator>(x))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1)))
if(dyn_cast<Constant>(y)){
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
(*builder_)->CreateMul(bin->getOperand(1), y));
}
// default
return (*builder_)->CreateMul(x, y, name);
}
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
return (*builder_)->CreateGEP(gep->getPointerOperand(),
(*builder_)->CreateAdd(cst1, cst2));
}
// ptr + (off + cst) -> (ptr + off) + cst
if(auto* bin = dyn_cast<BinaryOperator>(off))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
bin->getOperand(1));
}
// default
return (*builder_)->CreateGEP(ptr, off, name);
}
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
//}
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
@@ -34,7 +91,6 @@ using namespace llvm;
// constants
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define add(...) builder_->CreateAdd(__VA_ARGS__)
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
@@ -52,7 +108,6 @@ using namespace llvm;
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define gep(...) builder_->CreateGEP(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
@@ -64,7 +119,6 @@ using namespace llvm;
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define mul(...) builder_->CreateMul(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
@@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes,
target *tgt,
unsigned num_warps)
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
tgt_(tgt), num_warps_(num_warps) {
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
}
@@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
* \brief Code Generation for `binary_operator`
*/
void generator::visit_binary_operator(ir::binary_operator*x) {
auto cvt = [](ir::binary_op_t op){
using ll = llvm::Instruction::BinaryOps;
auto cvt = [](ir::binary_op_t op){
using tt = ir::binary_op_t;
switch(op) {
case tt::Add: return ll::Add;
@@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs);
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
}
@@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
std::vector<Value*> vals;
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
vals.push_back(vals_[*it][idx]);
Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty());
vals_[x][idx] = gep(ty, ptr, vals);
assert(vals.size() == 1);
vals_[x][idx] = gep(ptr, vals[0]);
}
}
@@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
@@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
@@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
call(iasm);
}
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
for(indices_t idx: idxs_.at(x)){
assert(idx.size() == 1);
if(idx[0] == i32(0))
vals_[x][idx] = idx[0];
else{
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
vals_[x][idx] = bin_add->getOperand(0);
}
}
}
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
// if(idx[0] == i32(0))
// vals_[x][idx] = idx[0];
// else{
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
// assert(bin_add);
// vals_[x][idx] = bin_add->getOperand(0);
// }
// }
//}
void generator::visit_make_range_sta(ir::make_range_sta* x) {
for(indices_t idx: idxs_.at(x)){
assert(idx.size() == 1);
if(idx[0] == i32(0)){
vals_[x][idx] = idx[0];
}
else{
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *cst = bin_add->getOperand(1);
assert(isa<Constant>(cst));
vals_[x][idx] = cst;
}
};
}
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
// if(idx[0] == i32(0)){
// vals_[x][idx] = idx[0];
// }
// else{
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
// assert(bin_add);
// Value *cst = bin_add->getOperand(1);
// assert(isa<Constant>(cst));
// vals_[x][idx] = cst;
// }
// };
//}
void generator::visit_make_range(ir::make_range* x) {
for(indices_t idx: idxs_.at(x)){

View File

@@ -1,267 +0,0 @@
#include <algorithm>
#include "triton/codegen/transform/reassociate.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
if(is_bin_add)
return (ir::instruction*)x;
return nullptr;
}
inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x))
return true;
if(dynamic_cast<ir::make_range*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0));
return false;
}
ir::value *reassociate::reassociate_idx(ir::value *old_value,
ir::builder &builder,
ir::value *&noncst,
ir::value *&cst){
// value doesn't change by default
ir::value* new_value = old_value;
cst = nullptr;
noncst = old_value;
// handle retiling
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
auto shapes = op->get_type()->get_block_shapes();
ir::value *old_arg = op->get_operand(0);
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
// retile(x + y) = retile(x) + retile(y)
if(ir::instruction* bin_add = is_bin_add(new_arg))
if(cst){
ir::value *old_lhs = bin_add->get_operand(0);
ir::value *old_rhs = bin_add->get_operand(1);
ir::value *new_lhs = nullptr;
ir::value *new_rhs = nullptr;
if(dynamic_cast<ir::reshape_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_reshape(old_lhs, shapes);
new_rhs = builder.create_reshape(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs);
}
if(dynamic_cast<ir::broadcast_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_broadcast(old_lhs, shapes);
new_rhs = builder.create_broadcast(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs);
}
if(dynamic_cast<ir::splat_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_splat(old_lhs, shapes);
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs);
}
}
}
// handle binary addition
if(ir::instruction* op = is_bin_add(old_value)){
builder.set_insert_point(op);
std::string name = op->get_name();
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
builder.set_insert_point(op);
// (x + y) + z
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
ir::value *llhs = bin_lhs->get_operand(0);
ir::value *rlhs = bin_lhs->get_operand(1);
// (cst + x) + y -> cst + (x + y)
if(is_cst(llhs))
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
// (x + cst) + y -> cst + (x + y)
if(is_cst(rlhs))
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
}
// x + (y + z)
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
ir::value *lrhs = bin_rhs->get_operand(0);
ir::value *rrhs = bin_rhs->get_operand(1);
// x + (cst + y) -> cst + (x + y)
if(is_cst(lrhs))
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
// x + (y + cst) -> cst + (x + y)
if(is_cst(rrhs))
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
}
}
// extract constant and non-constant
if(ir::instruction *bin_add = is_bin_add(new_value)){
ir::value *new_lhs = bin_add->get_operand(0);
ir::value *new_rhs = bin_add->get_operand(1);
if(is_cst(new_lhs)){
cst = new_lhs;
noncst = new_rhs;
}
if(is_cst(new_rhs)){
cst = new_rhs;
noncst = new_lhs;
}
}
// clean-up if some re-ordering happened
if(old_value != new_value)
old_value->replace_all_uses_with(new_value);
return new_value;
}
/* run */
void reassociate::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::make_range*> ranges;
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list())
for(ir::value* op: i->ops())
if(auto *range = dynamic_cast<ir::make_range*>(op))
ranges.push_back(range);
}
builder.set_insert_point(rpo.front()->get_first_non_phi());
for(ir::make_range* old_range: ranges){
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
ir::value* static_range = ir::make_range_sta::get(old_range);
ir::value* new_range = builder.create_add(dyn_range, static_range);
old_range->replace_all_uses_with(new_range);
}
}
// reassociate
std::map<ir::value*, cst_info> infos;
std::set<ir::value*> replaced;
size_t n_replaced;
do{
n_replaced = replaced.size();
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list()){
// retiling
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
ir::value* op = rt->get_operand(0);
if(infos.find(op) != infos.end()){
builder.set_insert_point(rt);
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
ir::value* dyn = infos.at(op).dyn_ptr;
ir::value* cst = *sta->idx_begin();
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
auto shapes = rt->get_type()->get_block_shapes();
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
ir::value* broadcast = builder.create_broadcast(cst, shapes);
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
infos[rt] = cst_info{ndyn, nsta};
}
}
}
// getelementptr instruction
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
if(replaced.find(pz) != replaced.end())
continue;
// unpack GEP instruction
ir::value* py = pz->get_pointer_operand();
ir::value* offset = *pz->idx_begin();
// reassociate index
ir::value *sta = nullptr;
ir::value *dyn = offset;
reassociate_idx(offset, builder, dyn, sta);
if(sta){
builder.set_insert_point(pz);
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
pz->replace_all_uses_with(sta_ptr);
infos[sta_ptr].dyn_ptr = dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
replaced.insert(pz);
}
// reassociate pointer argument
if(infos.find(py) != infos.end()){
builder.set_insert_point(pz);
ir::getelementptr_inst *sta = infos[py].sta_ptr;
ir::value *dyn = infos[py].dyn_ptr;
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *pz_dyn = builder.create_gep(dyn, {off});
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
replaced.insert(pz);
}
// reassociate phi-node pointer
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
// only optimize the case where py = phi pa, pz for now
std::vector<ir::value*> ops = phi->ops();
if(ops.size() != 2)
continue;
if(ops[0] != pz && ops[1] != pz)
continue;
// grab incoming
size_t idx_z = (ops[0] == pz) ? 0 : 1;
size_t idx_a = (ops[0] == pz) ? 1 : 0;
// check if pa is known to have constant offset
ir::value *vpa = phi->get_incoming_value(idx_a);
auto it_a = infos.find(vpa);
if(it_a == infos.end())
continue;
// unpack dynamically/statically offset pointer
ir::value *pa_dyn = it_a->second.dyn_ptr;
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
ir::value *pz = phi->get_incoming_value(idx_z);
// extract offset
ir::value *off = *pa_sta->idx_begin();
builder.set_insert_point(phi);
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
// re-add the offset
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
phi_sta->set_name( phi->get_name() + "_sta");
phi->replace_all_uses_with(phi_sta);
// remove offset from pz
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
auto insts = x->get_parent()->get_inst_list();
auto it = std::find(insts.begin(), insts.end(), x);
it++;
builder.set_insert_point(*it);
}
ir::value *_0 = builder.get_int32(0);
if(off->get_type()->is_block_ty())
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
ir::value *neg_off = builder.create_sub(_0, off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
infos[phi_sta].dyn_ptr = phi_dyn;
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
replaced.insert(phi);
}
}
}
}
}
}while(replaced.size() != n_replaced);
}
}
}
}

View File

@@ -81,6 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
std::cout << n_reg << std::endl;
if (shared_optin > 49152){
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);

View File

@@ -833,27 +833,27 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string
}
// nv_dynamic_program_idx
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
: instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
return new make_range_dyn(ty, name, next);
}
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next);
//}
// nv_static_program_idx
make_range_sta::make_range_sta(make_range *range)
: constant(range->get_type(), 0), range_(range) { }
//// nv_static_program_idx
//make_range_sta::make_range_sta(make_range *range)
// : constant(range->get_type(), 0), range_(range) { }
make_range* make_range_sta::get_range() const
{ return range_; }
//make_range* make_range_sta::get_range() const
//{ return range_; }
make_range_sta* make_range_sta::get(make_range* range) {
static std::map<make_range*, make_range_sta*> cache;
if(cache.find(range) == cache.end())
cache.insert({range, new make_range_sta(range)});
return cache.at(range);
}
//make_range_sta* make_range_sta::get(make_range* range) {
// static std::map<make_range*, make_range_sta*> cache;
// if(cache.find(range) == cache.end())
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range);
//}
// make_range

View File

@@ -137,8 +137,8 @@ def swish(x):
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4),
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
],
key=['M', 'N', 'K'],
)
@@ -202,11 +202,12 @@ def matmul(a, b, activation=None):
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
_matmul[grid](
pgm = _matmul[grid](
a, b, c, M, N, K, \
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
)
#print(pgm.asm('ttir'))
# return output
return c
@@ -218,13 +219,14 @@ def matmul(a, b, activation=None):
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
#torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=swish)
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# c_0 = matmul(a, b, activation=None)
# c_1 = torch.matmul(a, b)
# print(c_0)
# print(c_1)
# print(triton.testing.allclose(c_0, c_1))
# exit()
# %%
# Benchmark
@@ -238,7 +240,7 @@ print(triton.testing.allclose(c_0, c_1))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
x_vals=[8192], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'triton'], # possible values for `line_arg``
line_names=["cuBLAS", "Triton"], # label name for the lines

View File

@@ -1,233 +0,0 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include <tuple>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include <iomanip>
#include <cmath>
#include "triton/runtime/function.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
namespace src {
const char *dot =
R"(
#define STM 8
#define STN 8
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
int* locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
int width = STM*gridn;
int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM);
int stn = (pid % width) / (RSTM*STN);
int RSTN = min(gridn - stn*STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
TYPE anext[TM, TK] = *?(checka)pa;
TYPE bnext[TK, TN] = *?(checkb)pb;
acc += a @ b;
a = anext;
b = bnext;
// __debug_barrier();
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M &&
rcn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";
}
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<class T>
float triton_dot(drv::context* context, drv::stream* stream,
bool AT, bool BT,
int32_t M, int32_t N, int32_t K){
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::device* device = context->device();
int32_t lda = AT ? K : M;
int32_t ldb = BT ? N : K;
int32_t ldc = N;
std::vector<std::string> sa = { "1", "lda" };
std::vector<std::string> sb = { "1", "ldb" };
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
// initialize buffers
std::vector<T> hc(M*N);
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
// macros
rt::options_t opt;
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128";
opt.defines["TN"] = "128";
opt.defines["TK"] = "64" ;
opt.defines["TZ"] = "1";
opt.num_warps = 4;
// arguments
std::stringstream oss;
rt::add_arg(oss, *da->cu());
rt::add_arg(oss, *db->cu());
rt::add_arg(oss, *dc->cu());
rt::add_arg(oss, (float)1);
rt::add_arg(oss, M);
rt::add_arg(oss, N);
rt::add_arg(oss, K);
rt::add_arg(oss, lda);
rt::add_arg(oss, ldb);
rt::add_arg(oss, ldc);
rt::add_arg(oss, *dlocks->cu());
// function
rt::function function(src::dot, opt, device);
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl;
// grid
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [ceil, M, N](const rt::options_t& x) {
return rt::kernel::grid_t{ceil(M, x.D<int>("TM"))*
ceil(N, x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
// metrics
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function(oss.str(), grid, stream);}, stream);
return tflops(triton_ns);
}
float bench_dot(drv::context* context, drv::stream* stream,
bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
dtype_t dtype) {
switch(dtype){
case HALF: return triton_dot<half_float::half>(context, stream, AT, BT, M, N, K);
case FLOAT: return triton_dot<float>(context, stream, AT, BT, M, N, K);
case DOUBLE: return triton_dot<double>(context, stream, AT, BT, M, N, K);
default: return 0;
}
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs = {
{false, false, 8192, 8192, 8192}
};
// does the work
bool AT, BT;
int32_t M, N, K;
dtype_t dtype = HALF;
for(const auto& c: configs){
std::tie(AT, BT, M, N, K) = c;
float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype);
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl;
}
}