[codegen][coalesce] more bugfixes
This commit is contained in:
@@ -26,7 +26,7 @@ public:
|
|||||||
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||||
ir::value *get_reference(ir::value *x);
|
ir::value *get_reference(ir::value *x);
|
||||||
void replace(ir::value* before, ir::value *after);
|
void replace(ir::value* before, ir::value *after);
|
||||||
|
void copy(ir::value* y, ir::value *x);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::set<ir::value*> shared_;
|
std::set<ir::value*> shared_;
|
||||||
|
@@ -298,7 +298,7 @@ void grids::run(ir::module &mod) {
|
|||||||
unsigned current = num_threads;
|
unsigned current = num_threads;
|
||||||
std::string nts = "nts.d" + s_ld;
|
std::string nts = "nts.d" + s_ld;
|
||||||
std::string mts = "mts.d" + s_ld;
|
std::string mts = "mts.d" + s_ld;
|
||||||
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 8));
|
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 4));
|
||||||
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
|
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
|
||||||
current = current / params_.at(i).at(mts)->get_value();
|
current = current / params_.at(i).at(mts)->get_value();
|
||||||
for(size_t d = 1; d < shapes.size(); d++){
|
for(size_t d = 1; d < shapes.size(); d++){
|
||||||
|
@@ -34,6 +34,16 @@ void meminfo::replace(ir::value* before, ir::value *after) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void meminfo::copy(ir::value* y, ir::value *x) {
|
||||||
|
if(shared_.find(x) != shared_.end())
|
||||||
|
shared_.insert(y);
|
||||||
|
if(refs_.find(x) != refs_.end())
|
||||||
|
refs_[y] = refs_[x];
|
||||||
|
if(double_.find(x) != double_.end())
|
||||||
|
double_.insert(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
inline bool get_is_shared(ir::value* v) {
|
inline bool get_is_shared(ir::value* v) {
|
||||||
if(dynamic_cast<ir::atomic_cas_inst*>(v))
|
if(dynamic_cast<ir::atomic_cas_inst*>(v))
|
||||||
return true;
|
return true;
|
||||||
|
@@ -556,15 +556,15 @@ inline int32_t ceil(int32_t num, int32_t div){
|
|||||||
return (num + div - 1)/div;
|
return (num + div - 1)/div;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
||||||
static const size_t warp_size = 32;
|
static const size_t warp_size = 32;
|
||||||
size_t nthreads = 1, nwarps = 1;
|
size_t nthreads = 1, nwarps = 1;
|
||||||
nw.resize(bs.size());
|
nw.resize(bs.size());
|
||||||
ws.resize(bs.size());
|
ws.resize(bs.size());
|
||||||
for(size_t i = 0; i < bs.size(); ++i){
|
for(size_t i = 0; i < bs.size(); ++i){
|
||||||
nthreads *= bs[i];
|
nthreads *= bs[i];
|
||||||
nw[i] = ceil(nthreads, nwarps*warp_size);
|
nw[order[i]] = ceil(nthreads, nwarps*warp_size);
|
||||||
nwarps *= nw[i];
|
nwarps *= nw[order[i]];
|
||||||
}
|
}
|
||||||
for(size_t i = 0; i < bs.size(); ++i){
|
for(size_t i = 0; i < bs.size(); ++i){
|
||||||
ws[i] = bs[i] / nw[i];
|
ws[i] = bs[i] / nw[i];
|
||||||
@@ -585,7 +585,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||||
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||||
}
|
}
|
||||||
to_warps(block_size, n_warps, warp_size);
|
to_warps(block_size, order, n_warps, warp_size);
|
||||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
||||||
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
|
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
|
||||||
// Create axes
|
// Create axes
|
||||||
@@ -711,52 +711,6 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void selection::create_grids(std::vector<ir::value*> &grids,
|
|
||||||
std::map<unsigned, ir::value*> &references,
|
|
||||||
ir::function *fn) {
|
|
||||||
// get number of dimensions greater than 1
|
|
||||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
|
||||||
unsigned result = 0;
|
|
||||||
for(auto shape: v->get_type()->get_tile_shapes()) {
|
|
||||||
result += (shape > 1)? shape : 0;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
// bind references
|
|
||||||
std::set<ir::value*> seen;
|
|
||||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
|
||||||
{
|
|
||||||
// skip
|
|
||||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
|
||||||
return;
|
|
||||||
// recurse
|
|
||||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
|
||||||
for(ir::value *op: user->ops())
|
|
||||||
bind_references(op);
|
|
||||||
// bind
|
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
|
||||||
if(buffer_info_->is_shared(v))
|
|
||||||
return;
|
|
||||||
for(size_t d = 0; d < shapes.size(); d++){
|
|
||||||
if(shapes[d] == 1)
|
|
||||||
continue;
|
|
||||||
unsigned x = params_->get_param_group(v, d);
|
|
||||||
ir::value *&r = references[x];
|
|
||||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
|
||||||
r = v;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for(ir::basic_block *block: fn->blocks())
|
|
||||||
for(ir::instruction *i: block->get_inst_list())
|
|
||||||
bind_references(i);
|
|
||||||
|
|
||||||
// create grid
|
|
||||||
for(auto &ref: references)
|
|
||||||
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
|
|
||||||
grids.push_back(ref.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool static inline has_phi_user(ir::value *v) {
|
bool static inline has_phi_user(ir::value *v) {
|
||||||
for(ir::user *usr: v->get_users()){
|
for(ir::user *usr: v->get_users()){
|
||||||
if(dynamic_cast<ir::phi_node*>(usr))
|
if(dynamic_cast<ir::phi_node*>(usr))
|
||||||
|
@@ -28,16 +28,17 @@ void coalesce::run(ir::module &mod) {
|
|||||||
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
||||||
if(order_.find(v) != order_.end())
|
if(order_.find(v) != order_.end())
|
||||||
return;
|
return;
|
||||||
order_[v] = {};
|
ir::type *tile_ty = v->get_type();
|
||||||
|
if(auto *x = dynamic_cast<ir::store_inst*>(v))
|
||||||
|
tile_ty = x->get_operand(0)->get_type();
|
||||||
|
if(!tile_ty->is_tile_ty())
|
||||||
|
return;
|
||||||
|
std::vector<unsigned> order(tile_ty->get_tile_shapes().size());
|
||||||
|
std::iota(order.begin(), order.end(), 0);
|
||||||
|
order_[v] = order;
|
||||||
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
||||||
for(ir::value* op: u->ops())
|
for(ir::value* op: u->ops())
|
||||||
set_order(op);
|
set_order(op);
|
||||||
ir::type* ty = v->get_type();
|
|
||||||
if(!ty->is_tile_ty())
|
|
||||||
return;
|
|
||||||
std::vector<unsigned> order(ty->get_tile_shapes().size());
|
|
||||||
std::iota(order.begin(), order.end(), 0);
|
|
||||||
order_[v] = order;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// initialize work-list
|
// initialize work-list
|
||||||
@@ -52,56 +53,58 @@ void coalesce::run(ir::module &mod) {
|
|||||||
set_order(i);
|
set_order(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
// std::set<ir::value*> seen;
|
std::map<ir::value*, ir::value*> replaced;
|
||||||
// for(ir::io_inst *i: io) {
|
for(ir::io_inst *i: io) {
|
||||||
// ir::value *ptr = i->get_pointer_operand();
|
ir::value *ptr = i->get_pointer_operand();
|
||||||
// auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
||||||
// std::vector<unsigned> order(max_contiguous.size());
|
std::vector<unsigned> order(max_contiguous.size());
|
||||||
// std::iota(order.begin(), order.end(), 0);
|
std::iota(order.begin(), order.end(), 0);
|
||||||
// std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||||
// std::list<ir::instruction*> work_list;
|
std::list<ir::instruction*> work_list;
|
||||||
// if(order != order_[i])
|
if(order != order_[i])
|
||||||
// work_list.push_back(i);
|
work_list.push_back(i);
|
||||||
// // rematerialize recursively
|
// rematerialize recursively
|
||||||
// while(!work_list.empty()) {
|
while(!work_list.empty()) {
|
||||||
// ir::instruction* current = work_list.back();
|
ir::instruction* current = work_list.back();
|
||||||
// order_[current] = order;
|
order_[current] = order;
|
||||||
// work_list.pop_back();
|
work_list.pop_back();
|
||||||
// for(ir::value *op: current->ops()) {
|
for(ir::value *op: current->ops()) {
|
||||||
// ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||||
// if(!seen.insert(op).second)
|
if(replaced.find(i_op) != replaced.end()){
|
||||||
// continue;
|
current->replace_uses_of_with(i_op, replaced.at(i_op));
|
||||||
// if(!i_op)
|
continue;
|
||||||
// continue;
|
}
|
||||||
// ir::type *ty = i_op->get_type();
|
if(!i_op)
|
||||||
// if(!ty->is_tile_ty())
|
continue;
|
||||||
// continue;
|
ir::type *ty = i_op->get_type();
|
||||||
// auto& inst_list = i_op->get_parent()->get_inst_list();
|
if(!ty->is_tile_ty())
|
||||||
// auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
continue;
|
||||||
// it++;
|
auto& inst_list = i_op->get_parent()->get_inst_list();
|
||||||
// builder.set_insert_point(it);
|
auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
||||||
// // found a load; write to shared memory and stop recursion
|
it++;
|
||||||
// ir::instruction *n_op = nullptr;
|
builder.set_insert_point(it);
|
||||||
// if(mem_->is_shared(i_op)){
|
// found a load; write to shared memory and stop recursion
|
||||||
// continue;
|
ir::instruction *n_op = nullptr;
|
||||||
// }
|
if(mem_->is_shared(i_op))
|
||||||
// if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
|
continue;
|
||||||
// n_op = ir::copy_to_shared_inst::create(ld);
|
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op))
|
||||||
// }
|
n_op = ir::copy_to_shared_inst::create(ld);
|
||||||
// // not a load; rematerialize and recurse
|
// not a load; rematerialize and recurse
|
||||||
// else {
|
else {
|
||||||
// n_op = i_op->clone();
|
n_op = i_op->clone();
|
||||||
// work_list.push_back(n_op);
|
work_list.push_back(n_op);
|
||||||
// }
|
}
|
||||||
// n_op = builder.insert(n_op);
|
n_op = builder.insert(n_op);
|
||||||
// order_[n_op] = order;
|
replaced.insert({i_op, n_op});
|
||||||
// align_->copy(n_op, i_op);
|
order_[n_op] = order;
|
||||||
// current->replace_uses_of_with(i_op, n_op);
|
align_->copy(n_op, i_op);
|
||||||
// }
|
// mem_->copy(n_op, i_op);
|
||||||
// }
|
current->replace_uses_of_with(i_op, n_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/cfg.h"
|
#include "triton/ir/cfg.h"
|
||||||
#include "triton/codegen/transform/dce.h"
|
#include "triton/codegen/transform/dce.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -35,9 +36,10 @@ void dce::run(ir::module &mod) {
|
|||||||
work_list.pop_back();
|
work_list.pop_back();
|
||||||
// mark instruction operands
|
// mark instruction operands
|
||||||
for(ir::value* op: current->ops()) {
|
for(ir::value* op: current->ops()) {
|
||||||
if(auto *i = dynamic_cast<ir::instruction*>(op))
|
if(auto *i = dynamic_cast<ir::instruction*>(op)){
|
||||||
if(marked.insert(i).second)
|
if(marked.insert(i).second)
|
||||||
work_list.push_back(i);
|
work_list.push_back(i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// TODO: mark last intstruction of current's reverse-dominance frontier
|
// TODO: mark last intstruction of current's reverse-dominance frontier
|
||||||
}
|
}
|
||||||
|
@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
|
// std::cout << source_ << std::endl;
|
||||||
cu_context::context_switcher ctx_switch(*context);
|
cu_context::context_switcher ctx_switch(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -212,10 +212,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
alignment_info.run(module);
|
alignment_info.run(module);
|
||||||
if(target->is_gpu())
|
if(target->is_gpu())
|
||||||
shmem_info.run(module);
|
shmem_info.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
reorder.run(module);
|
reorder.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
grids.run(module);
|
grids.run(module);
|
||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
@@ -48,23 +48,23 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {8};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
std::vector<double> result;
|
std::vector<double> result;
|
||||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||||
// cublas
|
// // cublas
|
||||||
if(cublas::cublasinit()){
|
// if(cublas::cublasinit()){
|
||||||
NumericT alpha(static_cast<double>(1));
|
// NumericT alpha(static_cast<double>(1));
|
||||||
NumericT beta(static_cast<double>(0));
|
// NumericT beta(static_cast<double>(0));
|
||||||
cublasGemmAlgo_t fastest;
|
// cublasGemmAlgo_t fastest;
|
||||||
cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
// cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||||
double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
|
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
|
||||||
&alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||||
ldc, nullptr, fastest); }, stream);
|
// ldc, nullptr, fastest); }, stream);
|
||||||
result.push_back(tflops(cublas_ms));
|
// result.push_back(tflops(cublas_ms));
|
||||||
}
|
// }
|
||||||
// triton
|
// triton
|
||||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);}, stream);
|
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);}, stream);
|
||||||
result.push_back(tflops(triton_ms));
|
result.push_back(tflops(triton_ms));
|
||||||
|
@@ -64,7 +64,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
// epilogue
|
// epilogue
|
||||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||||
int ryc[TN] = ridy * TN + 0 ... TN;
|
int ryc[TN] = ridy * TN + 0 ... TN;
|
||||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :] * ldc + rxc[:, newaxis];
|
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
|
||||||
*pc = c;
|
*pc = c;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
Reference in New Issue
Block a user