[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)
This massively simplifies implementation of `reassociate` and also fixes a bunch of bug. The pass could still be improved, but can already be used to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
committed by
Philippe Tillet
parent
e16bee1a27
commit
840140bf26
@@ -74,6 +74,33 @@ struct distributed_axis {
|
|||||||
Value* thread_id;
|
Value* thread_id;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class adder{
|
||||||
|
public:
|
||||||
|
adder(Builder** builder): builder_(builder) { }
|
||||||
|
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||||
|
|
||||||
|
private:
|
||||||
|
Builder** builder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class multiplier{
|
||||||
|
public:
|
||||||
|
multiplier(Builder** builder): builder_(builder) { }
|
||||||
|
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||||
|
private:
|
||||||
|
Builder** builder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class geper{
|
||||||
|
public:
|
||||||
|
geper(Builder** builder): builder_(builder) { }
|
||||||
|
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
|
||||||
|
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
|
||||||
|
|
||||||
|
private:
|
||||||
|
Builder** builder_;
|
||||||
|
};
|
||||||
|
|
||||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||||
private:
|
private:
|
||||||
void init_idx(ir::value *x);
|
void init_idx(ir::value *x);
|
||||||
@@ -143,9 +170,9 @@ public:
|
|||||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||||
void visit_barrier_inst(ir::barrier_inst*);
|
void visit_barrier_inst(ir::barrier_inst*);
|
||||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||||
void visit_make_range(ir::make_range*);
|
void visit_make_range(ir::make_range*);
|
||||||
void visit_make_range_sta(ir::make_range_sta*);
|
// void visit_make_range_sta(ir::make_range_sta*);
|
||||||
void visit_undef_value(ir::undef_value*);
|
void visit_undef_value(ir::undef_value*);
|
||||||
void visit_constant_int(ir::constant_int*);
|
void visit_constant_int(ir::constant_int*);
|
||||||
void visit_constant_fp(ir::constant_fp*);
|
void visit_constant_fp(ir::constant_fp*);
|
||||||
@@ -195,6 +222,11 @@ private:
|
|||||||
std::map<ir::value*, BasicBlock *> bbs_;
|
std::map<ir::value*, BasicBlock *> bbs_;
|
||||||
std::map<ir::value*, std::vector<int>> ords_;
|
std::map<ir::value*, std::vector<int>> ords_;
|
||||||
|
|
||||||
|
// helper for creating llvm values
|
||||||
|
adder add;
|
||||||
|
multiplier mul;
|
||||||
|
geper gep;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -1,49 +0,0 @@
|
|||||||
#ifndef TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
|
|
||||||
#define TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace triton {
|
|
||||||
|
|
||||||
// forward declaration
|
|
||||||
namespace ir {
|
|
||||||
class module;
|
|
||||||
class value;
|
|
||||||
class builder;
|
|
||||||
class instruction;
|
|
||||||
class getelementptr_inst;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
namespace analysis{
|
|
||||||
class tiles;
|
|
||||||
class align;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace transform{
|
|
||||||
|
|
||||||
class reassociate {
|
|
||||||
struct cst_info {
|
|
||||||
ir::value* dyn_ptr;
|
|
||||||
ir::getelementptr_inst* sta_ptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
|
||||||
ir::instruction* is_bin_add(ir::value *x);
|
|
||||||
ir::value *reassociate_idx(ir::value *value, ir::builder &builder, ir::value *&noncst, ir::value *&cst);
|
|
||||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
|
||||||
|
|
||||||
public:
|
|
||||||
void run(ir::module& module);
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@@ -821,33 +821,33 @@ private:
|
|||||||
int N_;
|
int N_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// On NVIDIA, implementation is such that
|
//// On NVIDIA, implementation is such that
|
||||||
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||||
// so as to enable re-association on nv_static_program_idx which is constant
|
//// so as to enable re-association on nv_static_program_idx which is constant
|
||||||
class make_range_dyn: public instruction {
|
//class make_range_dyn: public instruction {
|
||||||
private:
|
//private:
|
||||||
make_range_dyn(type *ty, const std::string &name, instruction *next);
|
// make_range_dyn(type *ty, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||||
_TRITON_DEFINE_CLONE(make_range_dyn)
|
// _TRITON_DEFINE_CLONE(make_range_dyn)
|
||||||
_TRITON_DEFINE_ACCEPT(make_range_dyn)
|
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
|
||||||
|
|
||||||
public:
|
//public:
|
||||||
static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
//};
|
||||||
|
|
||||||
class make_range_sta: public constant {
|
//class make_range_sta: public constant {
|
||||||
private:
|
//private:
|
||||||
make_range_sta(make_range *range);
|
// make_range_sta(make_range *range);
|
||||||
|
|
||||||
public:
|
//public:
|
||||||
static make_range_sta *get(make_range* range);
|
// static make_range_sta *get(make_range* range);
|
||||||
make_range* get_range() const;
|
// make_range* get_range() const;
|
||||||
std::string repr() const { return "nv_static_program_idx"; }
|
// std::string repr() const { return "nv_static_program_idx"; }
|
||||||
_TRITON_DEFINE_ACCEPT(make_range_sta)
|
// _TRITON_DEFINE_ACCEPT(make_range_sta)
|
||||||
|
|
||||||
private:
|
//private:
|
||||||
make_range *range_;
|
// make_range *range_;
|
||||||
};
|
//};
|
||||||
|
|
||||||
|
|
||||||
/* constant range */
|
/* constant range */
|
||||||
|
@@ -144,12 +144,12 @@ public:
|
|||||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||||
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||||
virtual void visit_make_range(make_range*) = 0;
|
virtual void visit_make_range(make_range*) = 0;
|
||||||
|
|
||||||
virtual void visit_function(function*) = 0;
|
virtual void visit_function(function*) = 0;
|
||||||
|
|
||||||
virtual void visit_make_range_sta(make_range_sta*) = 0;
|
// virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||||
virtual void visit_undef_value(undef_value*) = 0;
|
virtual void visit_undef_value(undef_value*) = 0;
|
||||||
virtual void visit_constant_int(constant_int*) = 0;
|
virtual void visit_constant_int(constant_int*) = 0;
|
||||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||||
|
@@ -174,8 +174,6 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
|||||||
return is_constant_.at(v);
|
return is_constant_.at(v);
|
||||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
|
||||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
|
||||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||||
return populate_is_constant_phi(x);
|
return populate_is_constant_phi(x);
|
||||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||||
@@ -322,8 +320,6 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
|||||||
auto shapes = v->get_type()->get_block_shapes();
|
auto shapes = v->get_type()->get_block_shapes();
|
||||||
if(dynamic_cast<ir::make_range*>(v))
|
if(dynamic_cast<ir::make_range*>(v))
|
||||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
|
||||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
|
||||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -486,10 +482,6 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
|||||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||||
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
|
|
||||||
return add_to_cache(x, {128}, starting_multiple_);
|
|
||||||
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
|
|
||||||
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
|
|
||||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||||
return populate_starting_multiple_gep(x);
|
return populate_starting_multiple_gep(x);
|
||||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||||
|
@@ -12,7 +12,6 @@
|
|||||||
#include "triton/codegen/transform/membar.h"
|
#include "triton/codegen/transform/membar.h"
|
||||||
#include "triton/codegen/transform/peephole.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
#include "triton/codegen/transform/pipeline.h"
|
#include "triton/codegen/transform/pipeline.h"
|
||||||
#include "triton/codegen/transform/reassociate.h"
|
|
||||||
#include "triton/driver/device.h"
|
#include "triton/driver/device.h"
|
||||||
#include "triton/driver/kernel.h"
|
#include "triton/driver/kernel.h"
|
||||||
#include "triton/driver/module.h"
|
#include "triton/driver/module.h"
|
||||||
@@ -48,7 +47,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
|||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||||
codegen::transform::reassociate reassociate;
|
// codegen::transform::reassociate reassociate;
|
||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
@@ -76,7 +75,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
|||||||
align.run(ir);
|
align.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
if (target->is_gpu()) {
|
if (target->is_gpu()) {
|
||||||
reassociate.run(ir);
|
// reassociate.run(ir);
|
||||||
cts.run(ir);
|
cts.run(ir);
|
||||||
}
|
}
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
@@ -100,4 +99,4 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace codegen
|
} // namespace codegen
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
@@ -23,6 +23,63 @@ namespace codegen{
|
|||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
|
|
||||||
|
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
|
||||||
|
// (x + cst) + y -> (x + y) + cst
|
||||||
|
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||||
|
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||||
|
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||||
|
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
|
||||||
|
bin->getOperand(1));
|
||||||
|
}
|
||||||
|
// (x + (y + cst)) -> (x + y) + cst
|
||||||
|
if(auto* bin = dyn_cast<BinaryOperator>(y))
|
||||||
|
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||||
|
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||||
|
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
|
||||||
|
bin->getOperand(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// default
|
||||||
|
return (*builder_)->CreateAdd(x, y, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
|
||||||
|
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
|
||||||
|
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||||
|
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||||
|
if(dyn_cast<Constant>(bin->getOperand(1)))
|
||||||
|
if(dyn_cast<Constant>(y)){
|
||||||
|
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
|
||||||
|
(*builder_)->CreateMul(bin->getOperand(1), y));
|
||||||
|
}
|
||||||
|
// default
|
||||||
|
return (*builder_)->CreateMul(x, y, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||||
|
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
|
||||||
|
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||||
|
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
|
||||||
|
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
|
||||||
|
return (*builder_)->CreateGEP(gep->getPointerOperand(),
|
||||||
|
(*builder_)->CreateAdd(cst1, cst2));
|
||||||
|
}
|
||||||
|
// ptr + (off + cst) -> (ptr + off) + cst
|
||||||
|
if(auto* bin = dyn_cast<BinaryOperator>(off))
|
||||||
|
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||||
|
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
|
||||||
|
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
|
||||||
|
bin->getOperand(1));
|
||||||
|
}
|
||||||
|
// default
|
||||||
|
return (*builder_)->CreateGEP(ptr, off, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
|
||||||
|
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
||||||
|
//}
|
||||||
|
|
||||||
|
|
||||||
// types
|
// types
|
||||||
#define void_ty builder_->getVoidTy()
|
#define void_ty builder_->getVoidTy()
|
||||||
#define f16_ty builder_->getHalfTy()
|
#define f16_ty builder_->getHalfTy()
|
||||||
@@ -34,7 +91,6 @@ using namespace llvm;
|
|||||||
// constants
|
// constants
|
||||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||||
// ops
|
// ops
|
||||||
#define add(...) builder_->CreateAdd(__VA_ARGS__)
|
|
||||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||||
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
||||||
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
||||||
@@ -52,7 +108,6 @@ using namespace llvm;
|
|||||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||||
#define gep(...) builder_->CreateGEP(__VA_ARGS__)
|
|
||||||
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
||||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||||
@@ -64,7 +119,6 @@ using namespace llvm;
|
|||||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
|
||||||
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
||||||
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
||||||
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
||||||
@@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes,
|
|||||||
target *tgt,
|
target *tgt,
|
||||||
unsigned num_warps)
|
unsigned num_warps)
|
||||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||||
tgt_(tgt), num_warps_(num_warps) {
|
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
|||||||
* \brief Code Generation for `binary_operator`
|
* \brief Code Generation for `binary_operator`
|
||||||
*/
|
*/
|
||||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||||
|
using ll = llvm::Instruction::BinaryOps;
|
||||||
auto cvt = [](ir::binary_op_t op){
|
auto cvt = [](ir::binary_op_t op){
|
||||||
using ll = llvm::Instruction::BinaryOps;
|
|
||||||
using tt = ir::binary_op_t;
|
using tt = ir::binary_op_t;
|
||||||
switch(op) {
|
switch(op) {
|
||||||
case tt::Add: return ll::Add;
|
case tt::Add: return ll::Add;
|
||||||
@@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
|||||||
for(indices_t idx: idxs_.at(x)){
|
for(indices_t idx: idxs_.at(x)){
|
||||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||||
vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs);
|
auto op = cvt(x->get_op());
|
||||||
|
if(op == ll::Add)
|
||||||
|
vals_[x][idx] = add(lhs, rhs);
|
||||||
|
else if(op == ll::Mul)
|
||||||
|
vals_[x][idx] = mul(lhs, rhs);
|
||||||
|
else
|
||||||
|
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
|
|||||||
std::vector<Value*> vals;
|
std::vector<Value*> vals;
|
||||||
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
|
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
|
||||||
vals.push_back(vals_[*it][idx]);
|
vals.push_back(vals_[*it][idx]);
|
||||||
Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty());
|
assert(vals.size() == 1);
|
||||||
vals_[x][idx] = gep(ty, ptr, vals);
|
vals_[x][idx] = gep(ptr, vals[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
|||||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||||
std::vector<llvm::Type*> tys = {f32_ty};
|
std::vector<llvm::Type*> tys = {f32_ty};
|
||||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
|
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||||
for(auto idx: idxs_.at(x)){
|
for(auto idx: idxs_.at(x)){
|
||||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||||
@@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
|||||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||||
Value *in_base = in_gep->getPointerOperand();
|
Value *in_base = in_gep->getPointerOperand();
|
||||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
|
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||||
in_base = cst ? in_base : in_gep;
|
in_base = cst ? in_base : in_gep;
|
||||||
// output ptr info
|
// output ptr info
|
||||||
Value* out_base = shared[i].first;
|
Value* out_base = shared[i].first;
|
||||||
@@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
|||||||
call(iasm);
|
call(iasm);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||||
for(indices_t idx: idxs_.at(x)){
|
// for(indices_t idx: idxs_.at(x)){
|
||||||
assert(idx.size() == 1);
|
// assert(idx.size() == 1);
|
||||||
if(idx[0] == i32(0))
|
// if(idx[0] == i32(0))
|
||||||
vals_[x][idx] = idx[0];
|
// vals_[x][idx] = idx[0];
|
||||||
else{
|
// else{
|
||||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||||
assert(bin_add);
|
// assert(bin_add);
|
||||||
vals_[x][idx] = bin_add->getOperand(0);
|
// vals_[x][idx] = bin_add->getOperand(0);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||||
for(indices_t idx: idxs_.at(x)){
|
// for(indices_t idx: idxs_.at(x)){
|
||||||
assert(idx.size() == 1);
|
// assert(idx.size() == 1);
|
||||||
if(idx[0] == i32(0)){
|
// if(idx[0] == i32(0)){
|
||||||
vals_[x][idx] = idx[0];
|
// vals_[x][idx] = idx[0];
|
||||||
}
|
// }
|
||||||
else{
|
// else{
|
||||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||||
assert(bin_add);
|
// assert(bin_add);
|
||||||
Value *cst = bin_add->getOperand(1);
|
// Value *cst = bin_add->getOperand(1);
|
||||||
assert(isa<Constant>(cst));
|
// assert(isa<Constant>(cst));
|
||||||
vals_[x][idx] = cst;
|
// vals_[x][idx] = cst;
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
}
|
//}
|
||||||
|
|
||||||
void generator::visit_make_range(ir::make_range* x) {
|
void generator::visit_make_range(ir::make_range* x) {
|
||||||
for(indices_t idx: idxs_.at(x)){
|
for(indices_t idx: idxs_.at(x)){
|
||||||
|
@@ -1,267 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include "triton/codegen/transform/reassociate.h"
|
|
||||||
#include "triton/ir/module.h"
|
|
||||||
#include "triton/ir/function.h"
|
|
||||||
#include "triton/ir/basic_block.h"
|
|
||||||
#include "triton/ir/instructions.h"
|
|
||||||
#include "triton/ir/utils.h"
|
|
||||||
|
|
||||||
namespace triton {
|
|
||||||
namespace codegen{
|
|
||||||
namespace transform{
|
|
||||||
|
|
||||||
|
|
||||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
|
||||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
|
||||||
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
|
|
||||||
if(is_bin_add)
|
|
||||||
return (ir::instruction*)x;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_cst(ir::value *x) {
|
|
||||||
if(dynamic_cast<ir::constant*>(x))
|
|
||||||
return true;
|
|
||||||
if(dynamic_cast<ir::make_range*>(x))
|
|
||||||
return true;
|
|
||||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
|
||||||
return is_cst(v->get_operand(0));
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
|
||||||
ir::builder &builder,
|
|
||||||
ir::value *&noncst,
|
|
||||||
ir::value *&cst){
|
|
||||||
// value doesn't change by default
|
|
||||||
ir::value* new_value = old_value;
|
|
||||||
cst = nullptr;
|
|
||||||
noncst = old_value;
|
|
||||||
|
|
||||||
// handle retiling
|
|
||||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
|
||||||
auto shapes = op->get_type()->get_block_shapes();
|
|
||||||
ir::value *old_arg = op->get_operand(0);
|
|
||||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
|
||||||
// retile(x + y) = retile(x) + retile(y)
|
|
||||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
|
||||||
if(cst){
|
|
||||||
ir::value *old_lhs = bin_add->get_operand(0);
|
|
||||||
ir::value *old_rhs = bin_add->get_operand(1);
|
|
||||||
ir::value *new_lhs = nullptr;
|
|
||||||
ir::value *new_rhs = nullptr;
|
|
||||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
|
||||||
builder.set_insert_point(op);
|
|
||||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
|
||||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
|
||||||
new_value = builder.create_add(new_lhs, new_rhs);
|
|
||||||
}
|
|
||||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
|
||||||
builder.set_insert_point(op);
|
|
||||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
|
||||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
|
||||||
new_value = builder.create_add(new_lhs, new_rhs);
|
|
||||||
}
|
|
||||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
|
||||||
builder.set_insert_point(op);
|
|
||||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
|
||||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
|
||||||
new_value = builder.create_add(new_lhs, new_rhs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle binary addition
|
|
||||||
if(ir::instruction* op = is_bin_add(old_value)){
|
|
||||||
builder.set_insert_point(op);
|
|
||||||
std::string name = op->get_name();
|
|
||||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
|
|
||||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
|
|
||||||
builder.set_insert_point(op);
|
|
||||||
// (x + y) + z
|
|
||||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
|
||||||
ir::value *llhs = bin_lhs->get_operand(0);
|
|
||||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
|
||||||
// (cst + x) + y -> cst + (x + y)
|
|
||||||
if(is_cst(llhs))
|
|
||||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
|
|
||||||
// (x + cst) + y -> cst + (x + y)
|
|
||||||
if(is_cst(rlhs))
|
|
||||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
|
|
||||||
}
|
|
||||||
// x + (y + z)
|
|
||||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
|
||||||
ir::value *lrhs = bin_rhs->get_operand(0);
|
|
||||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
|
||||||
// x + (cst + y) -> cst + (x + y)
|
|
||||||
if(is_cst(lrhs))
|
|
||||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
|
|
||||||
// x + (y + cst) -> cst + (x + y)
|
|
||||||
if(is_cst(rrhs))
|
|
||||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// extract constant and non-constant
|
|
||||||
if(ir::instruction *bin_add = is_bin_add(new_value)){
|
|
||||||
ir::value *new_lhs = bin_add->get_operand(0);
|
|
||||||
ir::value *new_rhs = bin_add->get_operand(1);
|
|
||||||
if(is_cst(new_lhs)){
|
|
||||||
cst = new_lhs;
|
|
||||||
noncst = new_rhs;
|
|
||||||
}
|
|
||||||
if(is_cst(new_rhs)){
|
|
||||||
cst = new_rhs;
|
|
||||||
noncst = new_lhs;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// clean-up if some re-ordering happened
|
|
||||||
if(old_value != new_value)
|
|
||||||
old_value->replace_all_uses_with(new_value);
|
|
||||||
return new_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* run */
|
|
||||||
void reassociate::run(ir::module &mod) {
|
|
||||||
ir::builder &builder = mod.get_builder();
|
|
||||||
|
|
||||||
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
|
||||||
std::vector<ir::make_range*> ranges;
|
|
||||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
|
||||||
for(ir::basic_block *block: rpo){
|
|
||||||
// iterate through instruction
|
|
||||||
for(ir::instruction *i: block->get_inst_list())
|
|
||||||
for(ir::value* op: i->ops())
|
|
||||||
if(auto *range = dynamic_cast<ir::make_range*>(op))
|
|
||||||
ranges.push_back(range);
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
|
||||||
for(ir::make_range* old_range: ranges){
|
|
||||||
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
|
|
||||||
ir::value* static_range = ir::make_range_sta::get(old_range);
|
|
||||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
|
||||||
old_range->replace_all_uses_with(new_range);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// reassociate
|
|
||||||
std::map<ir::value*, cst_info> infos;
|
|
||||||
std::set<ir::value*> replaced;
|
|
||||||
size_t n_replaced;
|
|
||||||
do{
|
|
||||||
n_replaced = replaced.size();
|
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
|
||||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
|
||||||
// iterate through blocks
|
|
||||||
for(ir::basic_block *block: rpo){
|
|
||||||
// iterate through instruction
|
|
||||||
for(ir::instruction *i: block->get_inst_list()){
|
|
||||||
// retiling
|
|
||||||
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
|
|
||||||
ir::value* op = rt->get_operand(0);
|
|
||||||
if(infos.find(op) != infos.end()){
|
|
||||||
builder.set_insert_point(rt);
|
|
||||||
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
|
|
||||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
|
||||||
ir::value* cst = *sta->idx_begin();
|
|
||||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
|
||||||
auto shapes = rt->get_type()->get_block_shapes();
|
|
||||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
|
||||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
|
||||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
|
||||||
infos[rt] = cst_info{ndyn, nsta};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// getelementptr instruction
|
|
||||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
|
||||||
if(replaced.find(pz) != replaced.end())
|
|
||||||
continue;
|
|
||||||
// unpack GEP instruction
|
|
||||||
ir::value* py = pz->get_pointer_operand();
|
|
||||||
ir::value* offset = *pz->idx_begin();
|
|
||||||
// reassociate index
|
|
||||||
ir::value *sta = nullptr;
|
|
||||||
ir::value *dyn = offset;
|
|
||||||
reassociate_idx(offset, builder, dyn, sta);
|
|
||||||
if(sta){
|
|
||||||
builder.set_insert_point(pz);
|
|
||||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
|
||||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
|
||||||
pz->replace_all_uses_with(sta_ptr);
|
|
||||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
|
||||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
|
||||||
replaced.insert(pz);
|
|
||||||
}
|
|
||||||
// reassociate pointer argument
|
|
||||||
if(infos.find(py) != infos.end()){
|
|
||||||
builder.set_insert_point(pz);
|
|
||||||
ir::getelementptr_inst *sta = infos[py].sta_ptr;
|
|
||||||
ir::value *dyn = infos[py].dyn_ptr;
|
|
||||||
ir::value *cst = *sta->idx_begin();
|
|
||||||
ir::value *off = *pz->idx_begin();
|
|
||||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
|
||||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
|
|
||||||
pz->replace_all_uses_with(pz_sta);
|
|
||||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
|
||||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
|
||||||
replaced.insert(pz);
|
|
||||||
}
|
|
||||||
// reassociate phi-node pointer
|
|
||||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
|
||||||
// only optimize the case where py = phi pa, pz for now
|
|
||||||
std::vector<ir::value*> ops = phi->ops();
|
|
||||||
if(ops.size() != 2)
|
|
||||||
continue;
|
|
||||||
if(ops[0] != pz && ops[1] != pz)
|
|
||||||
continue;
|
|
||||||
// grab incoming
|
|
||||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
|
||||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
|
||||||
// check if pa is known to have constant offset
|
|
||||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
|
||||||
auto it_a = infos.find(vpa);
|
|
||||||
if(it_a == infos.end())
|
|
||||||
continue;
|
|
||||||
// unpack dynamically/statically offset pointer
|
|
||||||
ir::value *pa_dyn = it_a->second.dyn_ptr;
|
|
||||||
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
|
|
||||||
ir::value *pz = phi->get_incoming_value(idx_z);
|
|
||||||
// extract offset
|
|
||||||
ir::value *off = *pa_sta->idx_begin();
|
|
||||||
builder.set_insert_point(phi);
|
|
||||||
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
|
|
||||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
|
||||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
|
||||||
// re-add the offset
|
|
||||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
|
|
||||||
phi_sta->set_name( phi->get_name() + "_sta");
|
|
||||||
phi->replace_all_uses_with(phi_sta);
|
|
||||||
// remove offset from pz
|
|
||||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
|
||||||
auto insts = x->get_parent()->get_inst_list();
|
|
||||||
auto it = std::find(insts.begin(), insts.end(), x);
|
|
||||||
it++;
|
|
||||||
builder.set_insert_point(*it);
|
|
||||||
}
|
|
||||||
ir::value *_0 = builder.get_int32(0);
|
|
||||||
if(off->get_type()->is_block_ty())
|
|
||||||
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
|
|
||||||
ir::value *neg_off = builder.create_sub(_0, off);
|
|
||||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
|
||||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
|
||||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
|
||||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
|
||||||
replaced.insert(phi);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}while(replaced.size() != n_replaced);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,22 +1,22 @@
|
|||||||
/* Copyright 2015-2017 Philippe Tillet
|
/* Copyright 2015-2017 Philippe Tillet
|
||||||
*
|
*
|
||||||
* Permission is hereby granted, free of charge, to any person obtaining
|
* Permission is hereby granted, free of charge, to any person obtaining
|
||||||
* a copy of this software and associated documentation files
|
* a copy of this software and associated documentation files
|
||||||
* (the "Software"), to deal in the Software without restriction,
|
* (the "Software"), to deal in the Software without restriction,
|
||||||
* including without limitation the rights to use, copy, modify, merge,
|
* including without limitation the rights to use, copy, modify, merge,
|
||||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||||
* and to permit persons to whom the Software is furnished to do so,
|
* and to permit persons to whom the Software is furnished to do so,
|
||||||
* subject to the following conditions:
|
* subject to the following conditions:
|
||||||
*
|
*
|
||||||
* The above copyright notice and this permission notice shall be
|
* The above copyright notice and this permission notice shall be
|
||||||
* included in all copies or substantial portions of the Software.
|
* included in all copies or substantial portions of the Software.
|
||||||
*
|
*
|
||||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@@ -81,6 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
|
|||||||
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
||||||
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
||||||
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
||||||
|
std::cout << n_reg << std::endl;
|
||||||
if (shared_optin > 49152){
|
if (shared_optin > 49152){
|
||||||
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
||||||
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||||
|
@@ -833,27 +833,27 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// nv_dynamic_program_idx
|
//// nv_dynamic_program_idx
|
||||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||||
: instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||||
|
|
||||||
make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||||
return new make_range_dyn(ty, name, next);
|
// return new make_range_dyn(ty, name, next);
|
||||||
}
|
//}
|
||||||
|
|
||||||
// nv_static_program_idx
|
//// nv_static_program_idx
|
||||||
make_range_sta::make_range_sta(make_range *range)
|
//make_range_sta::make_range_sta(make_range *range)
|
||||||
: constant(range->get_type(), 0), range_(range) { }
|
// : constant(range->get_type(), 0), range_(range) { }
|
||||||
|
|
||||||
make_range* make_range_sta::get_range() const
|
//make_range* make_range_sta::get_range() const
|
||||||
{ return range_; }
|
//{ return range_; }
|
||||||
|
|
||||||
make_range_sta* make_range_sta::get(make_range* range) {
|
//make_range_sta* make_range_sta::get(make_range* range) {
|
||||||
static std::map<make_range*, make_range_sta*> cache;
|
// static std::map<make_range*, make_range_sta*> cache;
|
||||||
if(cache.find(range) == cache.end())
|
// if(cache.find(range) == cache.end())
|
||||||
cache.insert({range, new make_range_sta(range)});
|
// cache.insert({range, new make_range_sta(range)});
|
||||||
return cache.at(range);
|
// return cache.at(range);
|
||||||
}
|
//}
|
||||||
|
|
||||||
|
|
||||||
// make_range
|
// make_range
|
||||||
|
@@ -137,8 +137,8 @@ def swish(x):
|
|||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||||
],
|
],
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
@@ -202,11 +202,12 @@ def matmul(a, b, activation=None):
|
|||||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||||
# launch kernel
|
# launch kernel
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||||
_matmul[grid](
|
pgm = _matmul[grid](
|
||||||
a, b, c, M, N, K, \
|
a, b, c, M, N, K, \
|
||||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||||
ACTIVATION = activation
|
ACTIVATION = activation
|
||||||
)
|
)
|
||||||
|
#print(pgm.asm('ttir'))
|
||||||
# return output
|
# return output
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -218,13 +219,14 @@ def matmul(a, b, activation=None):
|
|||||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||||
|
|
||||||
#torch.manual_seed(0)
|
#torch.manual_seed(0)
|
||||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||||
c_0 = matmul(a, b, activation=swish)
|
# c_0 = matmul(a, b, activation=None)
|
||||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
# c_1 = torch.matmul(a, b)
|
||||||
print(c_0)
|
# print(c_0)
|
||||||
print(c_1)
|
# print(c_1)
|
||||||
print(triton.testing.allclose(c_0, c_1))
|
# print(triton.testing.allclose(c_0, c_1))
|
||||||
|
# exit()
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Benchmark
|
# Benchmark
|
||||||
@@ -238,7 +240,7 @@ print(triton.testing.allclose(c_0, c_1))
|
|||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
x_vals=[8192], # different possible values for `x_name`
|
||||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||||
line_names=["cuBLAS", "Triton"], # label name for the lines
|
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||||
|
@@ -1,233 +0,0 @@
|
|||||||
#include "triton/driver/backend.h"
|
|
||||||
#include "triton/driver/stream.h"
|
|
||||||
#include <iomanip>
|
|
||||||
#include <cstring>
|
|
||||||
#include <sstream>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <tuple>
|
|
||||||
#include "triton/driver/backend.h"
|
|
||||||
#include "triton/driver/stream.h"
|
|
||||||
#include "triton/tools/bench.hpp"
|
|
||||||
#include "triton/external/half.hpp"
|
|
||||||
#include "triton/runtime/function.h"
|
|
||||||
#include <iomanip>
|
|
||||||
#include <cmath>
|
|
||||||
#include "triton/runtime/function.h"
|
|
||||||
|
|
||||||
namespace drv = triton::driver;
|
|
||||||
namespace rt = triton::runtime;
|
|
||||||
|
|
||||||
namespace src {
|
|
||||||
|
|
||||||
const char *dot =
|
|
||||||
R"(
|
|
||||||
#define STM 8
|
|
||||||
#define STN 8
|
|
||||||
|
|
||||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|
||||||
TYPE * B __noalias __readonly __aligned(16),
|
|
||||||
TYPE * C __noalias __aligned(16),
|
|
||||||
float alpha,
|
|
||||||
int M __retune,
|
|
||||||
int N __retune,
|
|
||||||
int K __retune __multipleof(16),
|
|
||||||
int lda __multipleof(8),
|
|
||||||
int ldb __multipleof(8),
|
|
||||||
int ldc __multipleof(8),
|
|
||||||
int* locks) {
|
|
||||||
// prologue
|
|
||||||
int pid = get_program_id(0);
|
|
||||||
int pidz = get_program_id(2);
|
|
||||||
int gridm = (M + TM - 1) / TM;
|
|
||||||
int gridn = (N + TN - 1) / TN;
|
|
||||||
int width = STM*gridn;
|
|
||||||
int stm = pid / width;
|
|
||||||
int RSTM = min(gridm - stm*STM, STM);
|
|
||||||
int stn = (pid % width) / (RSTM*STN);
|
|
||||||
int RSTN = min(gridn - stn*STN, STN);
|
|
||||||
int laneid = pid % (RSTM * RSTN);
|
|
||||||
int lanem = laneid / RSTN;
|
|
||||||
int lanen = laneid % RSTN;
|
|
||||||
int pidm = stm*STM + lanem;
|
|
||||||
int pidn = stn*STN + lanen;
|
|
||||||
int rm[TM] = pidm * TM + 0 ... TM;
|
|
||||||
int rn[TN] = pidn * TN + 0 ... TN;
|
|
||||||
|
|
||||||
// reduction splitting
|
|
||||||
K = K / TZ;
|
|
||||||
int rk[TK] = pidz * K + 0 ... TK;
|
|
||||||
// pointers to operands
|
|
||||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
|
||||||
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
|
||||||
TYPE* pa[TM, TK] = A + offa;
|
|
||||||
TYPE* pb[TK, TN] = B + offb;
|
|
||||||
// prefetches operands
|
|
||||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
|
||||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
|
||||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
|
||||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
|
||||||
// reduction loop
|
|
||||||
float acc[TM, TN] = 0;
|
|
||||||
for(int k = K; k > 0; k -= TK){
|
|
||||||
bool checka[TM, TK] = k > TK;
|
|
||||||
bool checkb[TK, TN] = k > TK;
|
|
||||||
pa += TK * STRIDE_AK;
|
|
||||||
pb += TK * STRIDE_BK;
|
|
||||||
TYPE anext[TM, TK] = *?(checka)pa;
|
|
||||||
TYPE bnext[TK, TN] = *?(checkb)pb;
|
|
||||||
acc += a @ b;
|
|
||||||
a = anext;
|
|
||||||
b = bnext;
|
|
||||||
// __debug_barrier();
|
|
||||||
}
|
|
||||||
acc = acc * alpha;
|
|
||||||
TYPE c[TM, TN] = acc;
|
|
||||||
|
|
||||||
// epilogue
|
|
||||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
|
||||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
|
||||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
|
|
||||||
TYPE* pc[TM, TN] = C + offc;
|
|
||||||
bool checkc[TM, TN] = rcm[:, newaxis] < M &&
|
|
||||||
rcn[newaxis, :] < N;
|
|
||||||
#if (TZ==1)
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
#else
|
|
||||||
// accumulate partial result using spin-locks
|
|
||||||
int *plock = locks + rid;
|
|
||||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
|
||||||
int count = *pcount;
|
|
||||||
if(count == 0)
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
else
|
|
||||||
*?(checkc) pc = c + *?(checkc)pc;
|
|
||||||
atomic_xchg(pcount, (count + 1) % TZ);
|
|
||||||
atomic_xchg(plock, 0);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
enum dtype_t {
|
|
||||||
FLOAT,
|
|
||||||
HALF,
|
|
||||||
DOUBLE
|
|
||||||
};
|
|
||||||
|
|
||||||
template<class T>
|
|
||||||
struct to_string;
|
|
||||||
|
|
||||||
template<> struct to_string<half_float::half>{
|
|
||||||
static constexpr const char* value = "half";
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct to_string<float>{
|
|
||||||
static constexpr const char* value = "float";
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct to_string<double>{
|
|
||||||
static constexpr const char* value = "double";
|
|
||||||
};
|
|
||||||
|
|
||||||
template<class T>
|
|
||||||
float triton_dot(drv::context* context, drv::stream* stream,
|
|
||||||
bool AT, bool BT,
|
|
||||||
int32_t M, int32_t N, int32_t K){
|
|
||||||
std::string ty = to_string<T>::value;
|
|
||||||
size_t dt_nbytes = sizeof(T);
|
|
||||||
drv::device* device = context->device();
|
|
||||||
int32_t lda = AT ? K : M;
|
|
||||||
int32_t ldb = BT ? N : K;
|
|
||||||
int32_t ldc = N;
|
|
||||||
std::vector<std::string> sa = { "1", "lda" };
|
|
||||||
std::vector<std::string> sb = { "1", "ldb" };
|
|
||||||
// inputs
|
|
||||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
|
||||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
|
||||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
|
||||||
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
|
||||||
// initialize buffers
|
|
||||||
std::vector<T> hc(M*N);
|
|
||||||
std::vector<T> ha(M*K);
|
|
||||||
std::vector<T> hb(K*N);
|
|
||||||
for(size_t i = 0; i < ha.size(); i++)
|
|
||||||
ha[i] = (float)rand()/RAND_MAX;
|
|
||||||
for(size_t i = 0; i < hb.size(); i++)
|
|
||||||
hb[i] = (float)rand()/RAND_MAX;
|
|
||||||
stream->write(&*da, true, 0, ha);
|
|
||||||
stream->write(&*db, true, 0, hb);
|
|
||||||
// macros
|
|
||||||
rt::options_t opt;
|
|
||||||
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
|
|
||||||
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
|
|
||||||
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
|
|
||||||
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
|
|
||||||
opt.defines["TYPE"] = ty;
|
|
||||||
opt.defines["TM"] = "128";
|
|
||||||
opt.defines["TN"] = "128";
|
|
||||||
opt.defines["TK"] = "64" ;
|
|
||||||
opt.defines["TZ"] = "1";
|
|
||||||
opt.num_warps = 4;
|
|
||||||
// arguments
|
|
||||||
std::stringstream oss;
|
|
||||||
rt::add_arg(oss, *da->cu());
|
|
||||||
rt::add_arg(oss, *db->cu());
|
|
||||||
rt::add_arg(oss, *dc->cu());
|
|
||||||
rt::add_arg(oss, (float)1);
|
|
||||||
rt::add_arg(oss, M);
|
|
||||||
rt::add_arg(oss, N);
|
|
||||||
rt::add_arg(oss, K);
|
|
||||||
rt::add_arg(oss, lda);
|
|
||||||
rt::add_arg(oss, ldb);
|
|
||||||
rt::add_arg(oss, ldc);
|
|
||||||
rt::add_arg(oss, *dlocks->cu());
|
|
||||||
// function
|
|
||||||
rt::function function(src::dot, opt, device);
|
|
||||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl;
|
|
||||||
// grid
|
|
||||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
|
||||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
|
||||||
return rt::kernel::grid_t{ceil(M, x.D<int>("TM"))*
|
|
||||||
ceil(N, x.D<int>("TN")),
|
|
||||||
(size_t)x.D<int>("TZ")};
|
|
||||||
};
|
|
||||||
|
|
||||||
// metrics
|
|
||||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
|
||||||
double triton_ns = triton::tools::bench([&]() { function(oss.str(), grid, stream);}, stream);
|
|
||||||
return tflops(triton_ns);
|
|
||||||
}
|
|
||||||
|
|
||||||
float bench_dot(drv::context* context, drv::stream* stream,
|
|
||||||
bool AT, bool BT,
|
|
||||||
int32_t M, int32_t N, int32_t K,
|
|
||||||
dtype_t dtype) {
|
|
||||||
switch(dtype){
|
|
||||||
case HALF: return triton_dot<half_float::half>(context, stream, AT, BT, M, N, K);
|
|
||||||
case FLOAT: return triton_dot<float>(context, stream, AT, BT, M, N, K);
|
|
||||||
case DOUBLE: return triton_dot<double>(context, stream, AT, BT, M, N, K);
|
|
||||||
default: return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main() {
|
|
||||||
// initialize default compute device
|
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
|
||||||
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
|
||||||
// shapes to benchmark
|
|
||||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
|
||||||
std::vector<config_t> configs = {
|
|
||||||
{false, false, 8192, 8192, 8192}
|
|
||||||
};
|
|
||||||
// does the work
|
|
||||||
bool AT, BT;
|
|
||||||
int32_t M, N, K;
|
|
||||||
dtype_t dtype = HALF;
|
|
||||||
for(const auto& c: configs){
|
|
||||||
std::tie(AT, BT, M, N, K) = c;
|
|
||||||
float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype);
|
|
||||||
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user