[FRONTEND] Added tl.clock and tl.globaltimer (#485)

This commit is contained in:
Philippe Tillet
2022-03-28 16:15:43 -07:00
committed by GitHub
parent 76a9ee50a8
commit e0cc488055
13 changed files with 99 additions and 25 deletions

View File

@@ -182,6 +182,8 @@ public:
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
@@ -192,6 +194,7 @@ public:
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);

View File

@@ -28,7 +28,9 @@ public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);

View File

@@ -101,6 +101,10 @@ struct dispatch{
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// utilities
static ir::value *globaltimer(ir::builder *builder);
static ir::value *clock(ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);

View File

@@ -165,6 +165,8 @@ enum value_id_t: unsigned {
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
};

View File

@@ -971,6 +971,27 @@ private:
constant_int* last_;
};
/* timing utilities */
class clock_inst: public instruction{
clock_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "clock"; }
_TRITON_DEFINE_CLONE(clock_inst)
_TRITON_DEFINE_ACCEPT(clock_inst)
public:
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class globaltimer_inst: public instruction{
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "globaltimer"; }
_TRITON_DEFINE_CLONE(globaltimer_inst)
_TRITON_DEFINE_ACCEPT(globaltimer_inst)
public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
}
}

View File

@@ -75,6 +75,8 @@ class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;
class make_range_sta;
class undef_value;
@@ -157,6 +159,8 @@ public:
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_clock_inst(clock_inst*) = 0;
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;
virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;

View File

@@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
case tt::Xchg: name = "exch", s_ty = "b"; break;
}
std::string s_vec = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz";
std::string mod = nbits == 16 ? ".noftz" : "";
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
std::string ty_id = nbits*vec == 32 ? "r" : "h";
std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
@@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_clock_inst(ir::clock_inst* clock){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
vals_[clock][{}] = call(iasm);
}
void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
vals_[timer][{}] = call(iasm);
}
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();

View File

@@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine
module->setTargetTriple(triple);
std::string error;
llvm::TargetMachine* machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())

View File

@@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::globaltimer(ir::builder *builder) {
return builder->insert(globaltimer_inst::create(builder->get_context()));
}
ir::value *dispatch::clock(ir::builder *builder) {
return builder->insert(clock_inst::create(builder->get_context()));
}
//

View File

@@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
@@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
return new prefetch_s_inst(ctx, arg, inc, name, next);
}
//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
// global timer
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next);
//}
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
return new globaltimer_inst(ctx, name, next);
}
//// nv_static_program_idx
//make_range_sta::make_range_sta(make_range *range)
// : constant(range->get_type(), 0), range_(range) { }
// clock
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
//make_range* make_range_sta::get_range() const
//{ return range_; }
//make_range_sta* make_range_sta::get(make_range* range) {
// static std::map<make_range*, make_range_sta*> cache;
// if(cache.find(range) == cache.end())
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range);
//}
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
return new clock_inst(ctx, name, next);
}
// make_range

View File

@@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) {
m.def("cos", &ir::dispatch::cos, ret::reference);
m.def("sin", &ir::dispatch::sin, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// utilities
m.def("clock", &ir::dispatch::clock, ret::reference);
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);

View File

@@ -792,6 +792,19 @@ def sum(input, axis, _builder=None):
def xor_sum(input, axis, _builder=None):
return frontend.xor_sum(input, axis, _builder)
# -----------------------
# Utilities
# -----------------------
@builtin
def globaltimer(_builder=None):
return frontend.globaltimer(_builder)
@builtin
def clock(_builder=None):
return frontend.clock(_builder)
# -----------------------
# Internal for debugging

View File

@@ -24,9 +24,11 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
time_start_ptr, time_end_ptr,
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
tl.atomic_min(time_start_ptr, tl.clock())
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -45,6 +47,7 @@ def add_kernel(
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
tl.atomic_max(time_end_ptr, tl.clock())
# %%
@@ -53,6 +56,8 @@ def add_kernel(
def add(x: torch.Tensor, y: torch.Tensor):
time_start = torch.zeros(1, dtype=torch.int64, device='cuda')
time_end = torch.zeros(1, dtype=torch.int64, device='cuda')
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
@@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor):
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
print((time_end, time_start))
return output