[code generation] bugfix in double-buffering
This commit is contained in:
@@ -60,27 +60,11 @@ void matmul(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
||||
int1 checkc1[TN] = ryc < N;\
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
for(k = K; k > 0; k = k - TK){\
|
||||
int1 checka[TM, TK] = (k > 8);\
|
||||
int1 checkb[TN, TK] = (k > 8);\
|
||||
int1 checka0[TM];\
|
||||
int1 checka1[TK];\
|
||||
int1 checkb0[TN];\
|
||||
int1 checkb1[TK];\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + TK*M;\
|
||||
pb = pb + TK*K;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k > 8)\
|
||||
continue;\
|
||||
checka0 = rxa < M;\
|
||||
checka1 = rka < k;\
|
||||
checkb0 = ryb < N;\
|
||||
checkb1 = rkb < k;\
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
a = *pa;\
|
||||
b = *pb;\
|
||||
}\
|
||||
@checkc *pc = C;\
|
||||
}\
|
||||
@@ -219,21 +203,22 @@ int main() {
|
||||
|
||||
// tuning parameters
|
||||
tune.run(module);
|
||||
|
||||
std::vector<unsigned> params = {
|
||||
// shapes
|
||||
16, 16, 8,
|
||||
8, 8, 8,
|
||||
// a0
|
||||
2, 8, 1,
|
||||
1, 8, 1,
|
||||
// b0
|
||||
4, 4, 1,
|
||||
1, 8, 1,
|
||||
// c0
|
||||
2, 8, 1,
|
||||
1, 8, 1,
|
||||
// c1
|
||||
4, 4, 1,
|
||||
1, 4, 2,
|
||||
// a1
|
||||
2, 4, 1,
|
||||
1, 4, 2,
|
||||
// b1
|
||||
1, 8, 1
|
||||
1, 4, 2
|
||||
};
|
||||
// meta-parameters
|
||||
unsigned i = 0;
|
||||
@@ -255,23 +240,22 @@ int main() {
|
||||
|
||||
|
||||
// run passes
|
||||
triton::ir::print(module, std::cout);
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
// triton::ir::print(module, std::cout);
|
||||
vectorize.run(module);
|
||||
selection.run(module, llvm_module);
|
||||
|
||||
// llvm source
|
||||
llvm::legacy::PassManager manager;
|
||||
manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
manager.add(llvm::createVerifierPass(true));
|
||||
manager.run(llvm_module);
|
||||
|
||||
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
||||
std::cout << src << std::endl;
|
||||
|
||||
// compile machine code
|
||||
CUdevice cu_device;
|
||||
@@ -285,16 +269,17 @@ int main() {
|
||||
// execute machine code
|
||||
// Allocate buffers
|
||||
typedef float numeric_t;
|
||||
size_t M = 32, N = 32, K = 32;
|
||||
size_t M = 128, N = 128, K = 128;
|
||||
size_t bound = 8;
|
||||
std::vector<numeric_t> c(M*N);
|
||||
std::vector<numeric_t> rc(M*N);
|
||||
std::vector<numeric_t> a(M*K);
|
||||
std::vector<numeric_t> b(K*N);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < a.size(); i++)
|
||||
a[i] = (float)rand() / RAND_MAX;
|
||||
a[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < b.size(); i++)
|
||||
b[i] = (float)rand() / RAND_MAX;
|
||||
b[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < c.size(); i++)
|
||||
c[i] = 0;
|
||||
CUdeviceptr d_a, d_b, d_c;
|
||||
@@ -311,7 +296,7 @@ int main() {
|
||||
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
|
||||
unsigned TM = context.p_impl->mp_constants_[0]->get_value();
|
||||
unsigned TN = context.p_impl->mp_constants_[1]->get_value();
|
||||
unsigned nthreads = 32;
|
||||
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
|
||||
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
|
||||
checkCudaErrors(cuStreamSynchronize(cu_stream));
|
||||
// Write back
|
||||
|
@@ -12,7 +12,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
unsigned allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned result = x->get_type()->get_tile_bitwidth();
|
||||
unsigned result = x->get_type()->get_tile_bitwidth() / 8;
|
||||
if(buffer_info_->is_double(x))
|
||||
result *= 2;
|
||||
return result;
|
||||
|
@@ -16,7 +16,7 @@ bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return br->get_dest() == phi->get_parent();
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
@@ -376,6 +376,13 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
grids.push_back(ref.second);
|
||||
}
|
||||
|
||||
bool static inline has_phi_user(ir::value *v) {
|
||||
for(ir::user *usr: v->get_users()){
|
||||
if(dynamic_cast<ir::phi_node*>(usr))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
const std::map<unsigned*, ir::value*>& references,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
@@ -394,8 +401,9 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// TODO - buffer info not up-to-date with references
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v)) {
|
||||
if(buffer_info_->get_reference(v) == nullptr){
|
||||
if(!has_phi_user(v)){
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
@@ -417,7 +425,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
// next pointer
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset);
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes2, ptr, builder, offset)});
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
@@ -720,12 +728,13 @@ void selection::run(ir::module &src, Module &dst){
|
||||
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
|
||||
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(n);
|
||||
ir::value* inc_val = phi->get_incoming_value(n);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
|
||||
GetElementPtrInst *inc_ptr = dyn_cast<GetElementPtrInst>(inc_shared->get_pointer());
|
||||
if(inc_ptr && ptr == inc_ptr->getPointerOperand()){
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
if(is_loop_latch){
|
||||
dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
|
||||
Value *next_offset = dst_builder.CreateNeg(offset);
|
||||
offset->addIncoming(next_offset, llvm_inc_block);
|
||||
|
@@ -15,8 +15,10 @@ void vectorize::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
builder.set_insert_point(i);
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(*params_->get_param(x, "p0.d0") == 1)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
|
Reference in New Issue
Block a user