[FRONTEND] Added tl.clock
and tl.globaltimer
(#485)
This commit is contained in:
@@ -182,6 +182,8 @@ public:
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
void visit_clock_inst(ir::clock_inst*);
|
||||
void visit_globaltimer_inst(ir::globaltimer_inst*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
@@ -192,6 +194,7 @@ public:
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
|
@@ -28,7 +28,9 @@ public:
|
||||
// Constructor
|
||||
builder(context &ctx);
|
||||
// Getters
|
||||
const context& get_context() { return ctx_; }
|
||||
// const context& get_context() const { return ctx_; }
|
||||
context& get_context() { return ctx_; }
|
||||
|
||||
// Setters
|
||||
void set_insert_point(iterator instr);
|
||||
void set_insert_point(instruction* i);
|
||||
|
@@ -101,6 +101,10 @@ struct dispatch{
|
||||
static ir::value *sin(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// utilities
|
||||
static ir::value *globaltimer(ir::builder *builder);
|
||||
static ir::value *clock(ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
|
||||
|
@@ -165,6 +165,8 @@ enum value_id_t: unsigned {
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE,
|
||||
INST_PREFETCH_S,
|
||||
INST_GLOBALTIMER,
|
||||
INST_CLOCK,
|
||||
};
|
||||
|
||||
|
||||
|
@@ -971,6 +971,27 @@ private:
|
||||
constant_int* last_;
|
||||
};
|
||||
|
||||
/* timing utilities */
|
||||
class clock_inst: public instruction{
|
||||
clock_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "clock"; }
|
||||
_TRITON_DEFINE_CLONE(clock_inst)
|
||||
_TRITON_DEFINE_ACCEPT(clock_inst)
|
||||
|
||||
public:
|
||||
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class globaltimer_inst: public instruction{
|
||||
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "globaltimer"; }
|
||||
_TRITON_DEFINE_CLONE(globaltimer_inst)
|
||||
_TRITON_DEFINE_ACCEPT(globaltimer_inst)
|
||||
|
||||
public:
|
||||
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -75,6 +75,8 @@ class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
class prefetch_s_inst;
|
||||
class clock_inst;
|
||||
class globaltimer_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
@@ -157,6 +159,8 @@ public:
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
virtual void visit_function(function*) = 0;
|
||||
virtual void visit_clock_inst(clock_inst*) = 0;
|
||||
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;
|
||||
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
|
@@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
case tt::Xchg: name = "exch", s_ty = "b"; break;
|
||||
}
|
||||
std::string s_vec = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string mod = nbits == 16 ? ".noftz" : "";
|
||||
|
||||
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits*vec == 32 ? "r" : "h";
|
||||
std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
|
||||
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
@@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_clock_inst(ir::clock_inst* clock){
|
||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
|
||||
vals_[clock][{}] = call(iasm);
|
||||
}
|
||||
|
||||
void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
|
||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
|
||||
vals_[timer][{}] = call(iasm);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
|
||||
ir::value *v = i->get_operand(0);
|
||||
int inc = i->get_inc();
|
||||
|
@@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
llvm::TargetMachine* machine;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
|
@@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_sqrt(x);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
ir::value *dispatch::globaltimer(ir::builder *builder) {
|
||||
return builder->insert(globaltimer_inst::create(builder->get_context()));
|
||||
}
|
||||
|
||||
ir::value *dispatch::clock(ir::builder *builder) {
|
||||
return builder->insert(clock_inst::create(builder->get_context()));
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
|
@@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
|
||||
}
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
|
||||
|
||||
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
@@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
|
||||
return new prefetch_s_inst(ctx, arg, inc, name, next);
|
||||
}
|
||||
|
||||
//// nv_dynamic_program_idx
|
||||
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
// global timer
|
||||
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
|
||||
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
|
||||
|
||||
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
// return new make_range_dyn(ty, name, next);
|
||||
//}
|
||||
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new globaltimer_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
//// nv_static_program_idx
|
||||
//make_range_sta::make_range_sta(make_range *range)
|
||||
// : constant(range->get_type(), 0), range_(range) { }
|
||||
// clock
|
||||
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
|
||||
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
|
||||
|
||||
//make_range* make_range_sta::get_range() const
|
||||
//{ return range_; }
|
||||
|
||||
//make_range_sta* make_range_sta::get(make_range* range) {
|
||||
// static std::map<make_range*, make_range_sta*> cache;
|
||||
// if(cache.find(range) == cache.end())
|
||||
// cache.insert({range, new make_range_sta(range)});
|
||||
// return cache.at(range);
|
||||
//}
|
||||
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new clock_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
|
||||
// make_range
|
||||
|
@@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// utilities
|
||||
m.def("clock", &ir::dispatch::clock, ret::reference);
|
||||
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
|
||||
|
@@ -792,6 +792,19 @@ def sum(input, axis, _builder=None):
|
||||
def xor_sum(input, axis, _builder=None):
|
||||
return frontend.xor_sum(input, axis, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Utilities
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def globaltimer(_builder=None):
|
||||
return frontend.globaltimer(_builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def clock(_builder=None):
|
||||
return frontend.clock(_builder)
|
||||
|
||||
# -----------------------
|
||||
# Internal for debugging
|
||||
|
@@ -24,9 +24,11 @@ def add_kernel(
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
time_start_ptr, time_end_ptr,
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
tl.atomic_min(time_start_ptr, tl.clock())
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
@@ -45,6 +47,7 @@ def add_kernel(
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
tl.atomic_max(time_end_ptr, tl.clock())
|
||||
|
||||
|
||||
# %%
|
||||
@@ -53,6 +56,8 @@ def add_kernel(
|
||||
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
time_start = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
time_end = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
@@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
print((time_end, time_start))
|
||||
return output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user