[FRONTEND] Added tl.clock and tl.globaltimer (#485)

This commit is contained in:
Philippe Tillet
2022-03-28 16:15:43 -07:00
committed by GitHub
parent 76a9ee50a8
commit e0cc488055
13 changed files with 99 additions and 25 deletions

View File

@@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
case tt::Xchg: name = "exch", s_ty = "b"; break;
}
std::string s_vec = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz";
std::string mod = nbits == 16 ? ".noftz" : "";
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
std::string ty_id = nbits*vec == 32 ? "r" : "h";
std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
@@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_clock_inst(ir::clock_inst* clock){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
vals_[clock][{}] = call(iasm);
}
void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
vals_[timer][{}] = call(iasm);
}
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();

View File

@@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine
module->setTargetTriple(triple);
std::string error;
llvm::TargetMachine* machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())

View File

@@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::globaltimer(ir::builder *builder) {
return builder->insert(globaltimer_inst::create(builder->get_context()));
}
ir::value *dispatch::clock(ir::builder *builder) {
return builder->insert(clock_inst::create(builder->get_context()));
}
//

View File

@@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
@@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
return new prefetch_s_inst(ctx, arg, inc, name, next);
}
//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
// global timer
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next);
//}
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
return new globaltimer_inst(ctx, name, next);
}
//// nv_static_program_idx
//make_range_sta::make_range_sta(make_range *range)
// : constant(range->get_type(), 0), range_(range) { }
// clock
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
//make_range* make_range_sta::get_range() const
//{ return range_; }
//make_range_sta* make_range_sta::get(make_range* range) {
// static std::map<make_range*, make_range_sta*> cache;
// if(cache.find(range) == cache.end())
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range);
//}
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
return new clock_inst(ctx, name, next);
}
// make_range