[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)
This massively simplifies implementation of `reassociate` and also fixes a bunch of bug. The pass could still be improved, but can already be used to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
committed by
Philippe Tillet
parent
e16bee1a27
commit
840140bf26
@@ -74,6 +74,33 @@ struct distributed_axis {
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class adder{
|
||||
public:
|
||||
adder(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class multiplier{
|
||||
public:
|
||||
multiplier(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class geper{
|
||||
public:
|
||||
geper(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
|
||||
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void init_idx(ir::value *x);
|
||||
@@ -143,9 +170,9 @@ public:
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
void visit_make_range_sta(ir::make_range_sta*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
@@ -195,6 +222,11 @@ private:
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
// helper for creating llvm values
|
||||
adder add;
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1,49 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class builder;
|
||||
class instruction;
|
||||
class getelementptr_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class tiles;
|
||||
class align;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reassociate {
|
||||
struct cst_info {
|
||||
ir::value* dyn_ptr;
|
||||
ir::getelementptr_inst* sta_ptr;
|
||||
};
|
||||
|
||||
private:
|
||||
ir::instruction* is_bin_add(ir::value *x);
|
||||
ir::value *reassociate_idx(ir::value *value, ir::builder &builder, ir::value *&noncst, ir::value *&cst);
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -821,33 +821,33 @@ private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
// On NVIDIA, implementation is such that
|
||||
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
// so as to enable re-association on nv_static_program_idx which is constant
|
||||
class make_range_dyn: public instruction {
|
||||
private:
|
||||
make_range_dyn(type *ty, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
_TRITON_DEFINE_CLONE(make_range_dyn)
|
||||
_TRITON_DEFINE_ACCEPT(make_range_dyn)
|
||||
//// On NVIDIA, implementation is such that
|
||||
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
//// so as to enable re-association on nv_static_program_idx which is constant
|
||||
//class make_range_dyn: public instruction {
|
||||
//private:
|
||||
// make_range_dyn(type *ty, const std::string &name, instruction *next);
|
||||
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
// _TRITON_DEFINE_CLONE(make_range_dyn)
|
||||
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
|
||||
|
||||
public:
|
||||
static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
//public:
|
||||
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
class make_range_sta: public constant {
|
||||
private:
|
||||
make_range_sta(make_range *range);
|
||||
//class make_range_sta: public constant {
|
||||
//private:
|
||||
// make_range_sta(make_range *range);
|
||||
|
||||
public:
|
||||
static make_range_sta *get(make_range* range);
|
||||
make_range* get_range() const;
|
||||
std::string repr() const { return "nv_static_program_idx"; }
|
||||
_TRITON_DEFINE_ACCEPT(make_range_sta)
|
||||
//public:
|
||||
// static make_range_sta *get(make_range* range);
|
||||
// make_range* get_range() const;
|
||||
// std::string repr() const { return "nv_static_program_idx"; }
|
||||
// _TRITON_DEFINE_ACCEPT(make_range_sta)
|
||||
|
||||
private:
|
||||
make_range *range_;
|
||||
};
|
||||
//private:
|
||||
// make_range *range_;
|
||||
//};
|
||||
|
||||
|
||||
/* constant range */
|
||||
|
@@ -144,12 +144,12 @@ public:
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
// virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
|
@@ -174,8 +174,6 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
@@ -322,8 +320,6 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
auto shapes = v->get_type()->get_block_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
@@ -486,10 +482,6 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
|
||||
return add_to_cache(x, {128}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
|
@@ -12,7 +12,6 @@
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
@@ -48,7 +47,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||
codegen::transform::reassociate reassociate;
|
||||
// codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||
// run passes
|
||||
@@ -76,7 +75,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu()) {
|
||||
reassociate.run(ir);
|
||||
// reassociate.run(ir);
|
||||
cts.run(ir);
|
||||
}
|
||||
dce.run(ir);
|
||||
|
@@ -23,6 +23,63 @@ namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
|
||||
// (x + cst) + y -> (x + y) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
// (x + (y + cst)) -> (x + y) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(y))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
|
||||
// default
|
||||
return (*builder_)->CreateAdd(x, y, name);
|
||||
}
|
||||
|
||||
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
|
||||
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1)))
|
||||
if(dyn_cast<Constant>(y)){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
|
||||
(*builder_)->CreateMul(bin->getOperand(1), y));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateMul(x, y, name);
|
||||
}
|
||||
|
||||
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
|
||||
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
|
||||
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
|
||||
return (*builder_)->CreateGEP(gep->getPointerOperand(),
|
||||
(*builder_)->CreateAdd(cst1, cst2));
|
||||
}
|
||||
// ptr + (off + cst) -> (ptr + off) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(off))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateGEP(ptr, off, name);
|
||||
}
|
||||
|
||||
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
|
||||
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
||||
//}
|
||||
|
||||
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
@@ -34,7 +91,6 @@ using namespace llvm;
|
||||
// constants
|
||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||
// ops
|
||||
#define add(...) builder_->CreateAdd(__VA_ARGS__)
|
||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
||||
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
||||
@@ -52,7 +108,6 @@ using namespace llvm;
|
||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||
#define gep(...) builder_->CreateGEP(__VA_ARGS__)
|
||||
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||
@@ -64,7 +119,6 @@ using namespace llvm;
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
||||
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
||||
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
||||
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
||||
@@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes,
|
||||
target *tgt,
|
||||
unsigned num_warps)
|
||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||
tgt_(tgt), num_warps_(num_warps) {
|
||||
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
|
||||
}
|
||||
|
||||
@@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
||||
* \brief Code Generation for `binary_operator`
|
||||
*/
|
||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using tt = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case tt::Add: return ll::Add;
|
||||
@@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs);
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
|
||||
std::vector<Value*> vals;
|
||||
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
|
||||
vals.push_back(vals_[*it][idx]);
|
||||
Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty());
|
||||
vals_[x][idx] = gep(ty, ptr, vals);
|
||||
assert(vals.size() == 1);
|
||||
vals_[x][idx] = gep(ptr, vals[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
@@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||
in_base = cst ? in_base : in_gep;
|
||||
// output ptr info
|
||||
Value* out_base = shared[i].first;
|
||||
@@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
call(iasm);
|
||||
}
|
||||
|
||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
assert(idx.size() == 1);
|
||||
if(idx[0] == i32(0))
|
||||
vals_[x][idx] = idx[0];
|
||||
else{
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
vals_[x][idx] = bin_add->getOperand(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
// if(idx[0] == i32(0))
|
||||
// vals_[x][idx] = idx[0];
|
||||
// else{
|
||||
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
// assert(bin_add);
|
||||
// vals_[x][idx] = bin_add->getOperand(0);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
assert(idx.size() == 1);
|
||||
if(idx[0] == i32(0)){
|
||||
vals_[x][idx] = idx[0];
|
||||
}
|
||||
else{
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *cst = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(cst));
|
||||
vals_[x][idx] = cst;
|
||||
}
|
||||
};
|
||||
}
|
||||
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
// if(idx[0] == i32(0)){
|
||||
// vals_[x][idx] = idx[0];
|
||||
// }
|
||||
// else{
|
||||
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
// assert(bin_add);
|
||||
// Value *cst = bin_add->getOperand(1);
|
||||
// assert(isa<Constant>(cst));
|
||||
// vals_[x][idx] = cst;
|
||||
// }
|
||||
// };
|
||||
//}
|
||||
|
||||
void generator::visit_make_range(ir::make_range* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
|
@@ -1,267 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
||||
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
|
||||
if(is_bin_add)
|
||||
return (ir::instruction*)x;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(dynamic_cast<ir::make_range*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::builder &builder,
|
||||
ir::value *&noncst,
|
||||
ir::value *&cst){
|
||||
// value doesn't change by default
|
||||
ir::value* new_value = old_value;
|
||||
cst = nullptr;
|
||||
noncst = old_value;
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_block_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
||||
if(cst){
|
||||
ir::value *old_lhs = bin_add->get_operand(0);
|
||||
ir::value *old_rhs = bin_add->get_operand(1);
|
||||
ir::value *new_lhs = nullptr;
|
||||
ir::value *new_rhs = nullptr;
|
||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle binary addition
|
||||
if(ir::instruction* op = is_bin_add(old_value)){
|
||||
builder.set_insert_point(op);
|
||||
std::string name = op->get_name();
|
||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
|
||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
|
||||
builder.set_insert_point(op);
|
||||
// (x + y) + z
|
||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
||||
ir::value *llhs = bin_lhs->get_operand(0);
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
ir::value *lrhs = bin_rhs->get_operand(0);
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
|
||||
}
|
||||
}
|
||||
// extract constant and non-constant
|
||||
if(ir::instruction *bin_add = is_bin_add(new_value)){
|
||||
ir::value *new_lhs = bin_add->get_operand(0);
|
||||
ir::value *new_rhs = bin_add->get_operand(1);
|
||||
if(is_cst(new_lhs)){
|
||||
cst = new_lhs;
|
||||
noncst = new_rhs;
|
||||
}
|
||||
if(is_cst(new_rhs)){
|
||||
cst = new_rhs;
|
||||
noncst = new_lhs;
|
||||
}
|
||||
}
|
||||
// clean-up if some re-ordering happened
|
||||
if(old_value != new_value)
|
||||
old_value->replace_all_uses_with(new_value);
|
||||
return new_value;
|
||||
}
|
||||
|
||||
/* run */
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
|
||||
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::make_range*> ranges;
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::value* op: i->ops())
|
||||
if(auto *range = dynamic_cast<ir::make_range*>(op))
|
||||
ranges.push_back(range);
|
||||
}
|
||||
|
||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||
for(ir::make_range* old_range: ranges){
|
||||
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::make_range_sta::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
}
|
||||
}
|
||||
|
||||
// reassociate
|
||||
std::map<ir::value*, cst_info> infos;
|
||||
std::set<ir::value*> replaced;
|
||||
size_t n_replaced;
|
||||
do{
|
||||
n_replaced = replaced.size();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// retiling
|
||||
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
|
||||
ir::value* op = rt->get_operand(0);
|
||||
if(infos.find(op) != infos.end()){
|
||||
builder.set_insert_point(rt);
|
||||
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_block_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
infos[rt] = cst_info{ndyn, nsta};
|
||||
}
|
||||
}
|
||||
}
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
if(replaced.find(pz) != replaced.end())
|
||||
continue;
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(infos.find(py) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[py].sta_ptr;
|
||||
ir::value *dyn = infos[py].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it_a = infos.find(vpa);
|
||||
if(it_a == infos.end())
|
||||
continue;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::value *pa_dyn = it_a->second.dyn_ptr;
|
||||
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
|
||||
ir::value *pz = phi->get_incoming_value(idx_z);
|
||||
// extract offset
|
||||
ir::value *off = *pa_sta->idx_begin();
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
|
||||
phi_sta->set_name( phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
auto insts = x->get_parent()->get_inst_list();
|
||||
auto it = std::find(insts.begin(), insts.end(), x);
|
||||
it++;
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_block_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}while(replaced.size() != n_replaced);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -81,6 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
|
||||
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
||||
std::cout << n_reg << std::endl;
|
||||
if (shared_optin > 49152){
|
||||
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
||||
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||
|
@@ -833,27 +833,27 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string
|
||||
}
|
||||
|
||||
|
||||
// nv_dynamic_program_idx
|
||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
//// nv_dynamic_program_idx
|
||||
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
|
||||
make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
return new make_range_dyn(ty, name, next);
|
||||
}
|
||||
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
// return new make_range_dyn(ty, name, next);
|
||||
//}
|
||||
|
||||
// nv_static_program_idx
|
||||
make_range_sta::make_range_sta(make_range *range)
|
||||
: constant(range->get_type(), 0), range_(range) { }
|
||||
//// nv_static_program_idx
|
||||
//make_range_sta::make_range_sta(make_range *range)
|
||||
// : constant(range->get_type(), 0), range_(range) { }
|
||||
|
||||
make_range* make_range_sta::get_range() const
|
||||
{ return range_; }
|
||||
//make_range* make_range_sta::get_range() const
|
||||
//{ return range_; }
|
||||
|
||||
make_range_sta* make_range_sta::get(make_range* range) {
|
||||
static std::map<make_range*, make_range_sta*> cache;
|
||||
if(cache.find(range) == cache.end())
|
||||
cache.insert({range, new make_range_sta(range)});
|
||||
return cache.at(range);
|
||||
}
|
||||
//make_range_sta* make_range_sta::get(make_range* range) {
|
||||
// static std::map<make_range*, make_range_sta*> cache;
|
||||
// if(cache.find(range) == cache.end())
|
||||
// cache.insert({range, new make_range_sta(range)});
|
||||
// return cache.at(range);
|
||||
//}
|
||||
|
||||
|
||||
// make_range
|
||||
|
@@ -137,8 +137,8 @@ def swish(x):
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4),
|
||||
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -202,11 +202,12 @@ def matmul(a, b, activation=None):
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
pgm = _matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
#print(pgm.asm('ttir'))
|
||||
# return output
|
||||
return c
|
||||
|
||||
@@ -218,13 +219,14 @@ def matmul(a, b, activation=None):
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
|
||||
#torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
# c_0 = matmul(a, b, activation=None)
|
||||
# c_1 = torch.matmul(a, b)
|
||||
# print(c_0)
|
||||
# print(c_1)
|
||||
# print(triton.testing.allclose(c_0, c_1))
|
||||
# exit()
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
@@ -238,7 +240,7 @@ print(triton.testing.allclose(c_0, c_1))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
x_vals=[8192], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||
|
@@ -1,233 +0,0 @@
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include <iomanip>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include <iomanip>
|
||||
#include <cmath>
|
||||
#include "triton/runtime/function.h"
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
namespace src {
|
||||
|
||||
const char *dot =
|
||||
R"(
|
||||
#define STM 8
|
||||
#define STN 8
|
||||
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8),
|
||||
int* locks) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = (M + TM - 1) / TM;
|
||||
int gridn = (N + TN - 1) / TN;
|
||||
int width = STM*gridn;
|
||||
int stm = pid / width;
|
||||
int RSTM = min(gridm - stm*STM, STM);
|
||||
int stn = (pid % width) / (RSTM*STN);
|
||||
int RSTN = min(gridn - stn*STN, STN);
|
||||
int laneid = pid % (RSTM * RSTN);
|
||||
int lanem = laneid / RSTN;
|
||||
int lanen = laneid % RSTN;
|
||||
int pidm = stm*STM + lanem;
|
||||
int pidn = stn*STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = pidz * K + 0 ... TK;
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
bool checka[TM, TK] = k > TK;
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
TYPE anext[TM, TK] = *?(checka)pa;
|
||||
TYPE bnext[TK, TN] = *?(checkb)pb;
|
||||
acc += a @ b;
|
||||
a = anext;
|
||||
b = bnext;
|
||||
// __debug_barrier();
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < M &&
|
||||
rcn[newaxis, :] < N;
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
)";
|
||||
|
||||
}
|
||||
|
||||
enum dtype_t {
|
||||
FLOAT,
|
||||
HALF,
|
||||
DOUBLE
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct to_string;
|
||||
|
||||
template<> struct to_string<half_float::half>{
|
||||
static constexpr const char* value = "half";
|
||||
};
|
||||
|
||||
template<> struct to_string<float>{
|
||||
static constexpr const char* value = "float";
|
||||
};
|
||||
|
||||
template<> struct to_string<double>{
|
||||
static constexpr const char* value = "double";
|
||||
};
|
||||
|
||||
template<class T>
|
||||
float triton_dot(drv::context* context, drv::stream* stream,
|
||||
bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K){
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
drv::device* device = context->device();
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = N;
|
||||
std::vector<std::string> sa = { "1", "lda" };
|
||||
std::vector<std::string> sb = { "1", "ldb" };
|
||||
// inputs
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
||||
// initialize buffers
|
||||
std::vector<T> hc(M*N);
|
||||
std::vector<T> ha(M*K);
|
||||
std::vector<T> hb(K*N);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
// macros
|
||||
rt::options_t opt;
|
||||
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
|
||||
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
|
||||
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
|
||||
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
|
||||
opt.defines["TYPE"] = ty;
|
||||
opt.defines["TM"] = "128";
|
||||
opt.defines["TN"] = "128";
|
||||
opt.defines["TK"] = "64" ;
|
||||
opt.defines["TZ"] = "1";
|
||||
opt.num_warps = 4;
|
||||
// arguments
|
||||
std::stringstream oss;
|
||||
rt::add_arg(oss, *da->cu());
|
||||
rt::add_arg(oss, *db->cu());
|
||||
rt::add_arg(oss, *dc->cu());
|
||||
rt::add_arg(oss, (float)1);
|
||||
rt::add_arg(oss, M);
|
||||
rt::add_arg(oss, N);
|
||||
rt::add_arg(oss, K);
|
||||
rt::add_arg(oss, lda);
|
||||
rt::add_arg(oss, ldb);
|
||||
rt::add_arg(oss, ldc);
|
||||
rt::add_arg(oss, *dlocks->cu());
|
||||
// function
|
||||
rt::function function(src::dot, opt, device);
|
||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl;
|
||||
// grid
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
||||
return rt::kernel::grid_t{ceil(M, x.D<int>("TM"))*
|
||||
ceil(N, x.D<int>("TN")),
|
||||
(size_t)x.D<int>("TZ")};
|
||||
};
|
||||
|
||||
// metrics
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
double triton_ns = triton::tools::bench([&]() { function(oss.str(), grid, stream);}, stream);
|
||||
return tflops(triton_ns);
|
||||
}
|
||||
|
||||
float bench_dot(drv::context* context, drv::stream* stream,
|
||||
bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
dtype_t dtype) {
|
||||
switch(dtype){
|
||||
case HALF: return triton_dot<half_float::half>(context, stream, AT, BT, M, N, K);
|
||||
case FLOAT: return triton_dot<float>(context, stream, AT, BT, M, N, K);
|
||||
case DOUBLE: return triton_dot<double>(context, stream, AT, BT, M, N, K);
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs = {
|
||||
{false, false, 8192, 8192, 8192}
|
||||
};
|
||||
// does the work
|
||||
bool AT, BT;
|
||||
int32_t M, N, K;
|
||||
dtype_t dtype = HALF;
|
||||
for(const auto& c: configs){
|
||||
std::tie(AT, BT, M, N, K) = c;
|
||||
float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype);
|
||||
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user