[codegen/optimize_dce.cpp] fixed bugs whereby barriers were removed by DCE
This commit is contained in:
@@ -130,7 +130,7 @@ public:
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// blocksparse matmul
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||
Tensor *tmp = nullptr;
|
||||
TensorShape tmp_shapes;
|
||||
|
@@ -7,10 +7,12 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
/* Dense matrix multiplication */
|
||||
|
||||
typedef std::vector<unsigned> params_t;
|
||||
typedef std::tuple<bool, bool> trans_key_t;
|
||||
typedef std::tuple<size_t, size_t> size_key_t;
|
||||
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
|
||||
static const std::map<trans_key_t, std::map<size_key_t, params_t>> dot_params = {
|
||||
/* NN */
|
||||
{trans_key_t(false, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}},
|
||||
@@ -108,7 +110,7 @@ static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
|
||||
// small search space for partial auto-tuning
|
||||
inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
|
||||
std::vector<params_t> result;
|
||||
for(auto x: params.at(trans_key_t{AT, BT}))
|
||||
for(auto x: dot_params.at(trans_key_t{AT, BT}))
|
||||
result.push_back(x.second);
|
||||
return result;
|
||||
}
|
||||
@@ -118,9 +120,41 @@ inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
size_t TM = 128;
|
||||
size_t TN = 128;
|
||||
// return {4, 4, 128, 8, 4, 128, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
return dot_params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
}
|
||||
|
||||
|
||||
/* Block-sparse matrix multiplication */
|
||||
|
||||
static const std::map<std::pair<bool, size_t>, std::map<size_t, params_t>> bsdot_params = {
|
||||
/* 32x32 */
|
||||
{{true, 32}, std::map<size_t, params_t>{
|
||||
{32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}},
|
||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}},
|
||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}}
|
||||
}},
|
||||
{{false, 32}, std::map<size_t, params_t>{
|
||||
{32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}},
|
||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}},
|
||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}}
|
||||
}}
|
||||
};
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
inline std::vector<params_t> bsdot_search_space(bool is_fprop, size_t block_size) {
|
||||
std::vector<params_t> result;
|
||||
for(auto x: bsdot_params.at({is_fprop, block_size}))
|
||||
result.push_back(x.second);
|
||||
return result;
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
inline params_t bsdot_heuristics(bool is_fprop, size_t block_size, size_t N, size_t S) {
|
||||
return bsdot_params.at({is_fprop,block_size}).at(128);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -303,6 +303,7 @@ void alignment_info::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << " " << max_contiguous_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -20,7 +20,8 @@ void optimize_dce::run(ir::module &mod) {
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|
||||
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)
|
||||
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i) ){
|
||||
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i)
|
||||
|| dynamic_cast<ir::barrier_inst*>(i)){
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
}
|
||||
|
@@ -368,6 +368,8 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
Value *res = builder.CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
|
||||
builder.CreateBr(tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, builder);
|
||||
tgt_->add_barrier(module, builder);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
||||
|
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
|
@@ -61,8 +61,8 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
// params_t params = heuristics();
|
||||
params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include "triton/dnn/blocksparse/dot.h"
|
||||
|
||||
namespace triton{
|
||||
@@ -18,11 +19,11 @@ bool dot::operator <(const base& other) const {
|
||||
}
|
||||
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
throw std::runtime_error("not implemented");
|
||||
return bsdot_search_space(op_ == FPROP, BS_);
|
||||
}
|
||||
|
||||
params_t dot::heuristics() const {
|
||||
throw std::runtime_error("not implemented");
|
||||
return bsdot_heuristics(op_ == FPROP, BS_, N_, S_);
|
||||
}
|
||||
|
||||
base * dot::clone() const {
|
||||
@@ -116,7 +117,8 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int32 column = *(header + 2);
|
||||
int32 lockid = *(header + 3);
|
||||
int32 *plut = lut + offset * 2;
|
||||
for(int32 k = K; k > 0; k = k - 1){
|
||||
for(int32 k = K; k > 0; k = k - 1)
|
||||
{
|
||||
int32 ak = *(plut + 0);
|
||||
int32 bk = *(plut + 1);
|
||||
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
||||
@@ -133,16 +135,19 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int1 checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0)
|
||||
@checkc *pc = c;
|
||||
else {
|
||||
else
|
||||
{
|
||||
int32 *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int32 *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 count = *pcount;
|
||||
if(count == 0)
|
||||
if(count == 0){
|
||||
@checkc *pc = c;
|
||||
else
|
||||
__atomic_exch(pcount, 1);
|
||||
}
|
||||
else{
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = 1;
|
||||
}
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
})";
|
||||
|
Reference in New Issue
Block a user