From e0cc48805506b1f3cd99a08d25e331ed3dbf8a45 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 28 Mar 2022 16:15:43 -0700 Subject: [PATCH] [FRONTEND] Added `tl.clock` and `tl.globaltimer` (#485) --- include/triton/codegen/selection/generator.h | 3 ++ include/triton/ir/builder.h | 4 ++- include/triton/ir/dispatch.h | 4 +++ include/triton/ir/enums.h | 2 ++ include/triton/ir/instructions.h | 21 +++++++++++++ include/triton/ir/visitor.h | 4 +++ lib/codegen/selection/generator.cc | 16 ++++++++-- lib/driver/llvm.cc | 3 +- lib/ir/dispatch.cc | 10 ++++++ lib/ir/instructions.cc | 33 ++++++++------------ python/src/triton.cc | 3 ++ python/triton/language/core.py | 13 ++++++++ python/tutorials/01-vector-add.py | 8 ++++- 13 files changed, 99 insertions(+), 25 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index ad7d01a55..293aa8908 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -182,6 +182,8 @@ public: void visit_async_wait_inst(ir::async_wait_inst*); // void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range(ir::make_range*); + void visit_clock_inst(ir::clock_inst*); + void visit_globaltimer_inst(ir::globaltimer_inst*); // void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); @@ -192,6 +194,7 @@ public: void visit_argument(ir::argument*); void visit(ir::module &, llvm::Module &); + // layouts void visit_layout_mma(analysis::mma_layout*); void visit_layout_scanline(analysis::scanline_layout*); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 2b6bc6ab3..45a7d5111 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -28,7 +28,9 @@ public: // Constructor builder(context &ctx); // Getters - const context& get_context() { return ctx_; } + // const context& get_context() const { return ctx_; } + context& get_context() { return ctx_; } + // Setters void set_insert_point(iterator instr); void set_insert_point(instruction* i); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index ef14043dd..c7f23779c 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -101,6 +101,10 @@ struct dispatch{ static ir::value *sin(ir::value *x, ir::builder *builder); static ir::value *sqrt(ir::value *x, ir::builder *builder); + // utilities + static ir::value *globaltimer(ir::builder *builder); + static ir::value *clock(ir::builder *builder); + // internal (debug/optimization) static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 8cb7835f0..2d4c09d79 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -165,6 +165,8 @@ enum value_id_t: unsigned { INST_MAKE_RANGE_STA, INST_MAKE_RANGE, INST_PREFETCH_S, + INST_GLOBALTIMER, + INST_CLOCK, }; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 0fb85db02..e9e0f0f11 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -971,6 +971,27 @@ private: constant_int* last_; }; +/* timing utilities */ +class clock_inst: public instruction{ + clock_inst(context &ctx, const std::string &name, instruction *next); + std::string repr_impl() const { return "clock"; } + _TRITON_DEFINE_CLONE(clock_inst) + _TRITON_DEFINE_ACCEPT(clock_inst) + +public: + static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); +}; + +class globaltimer_inst: public instruction{ + globaltimer_inst(context &ctx, const std::string &name, instruction *next); + std::string repr_impl() const { return "globaltimer"; } + _TRITON_DEFINE_CLONE(globaltimer_inst) + _TRITON_DEFINE_ACCEPT(globaltimer_inst) + +public: + static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); +}; + } } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 4979b0b52..25ce578e3 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -75,6 +75,8 @@ class async_wait_inst; class make_range_dyn; class make_range; class prefetch_s_inst; +class clock_inst; +class globaltimer_inst; class make_range_sta; class undef_value; @@ -157,6 +159,8 @@ public: virtual void visit_make_range(make_range*) = 0; virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0; virtual void visit_function(function*) = 0; + virtual void visit_clock_inst(clock_inst*) = 0; + virtual void visit_globaltimer_inst(globaltimer_inst*) = 0; virtual void visit_undef_value(undef_value*) = 0; virtual void visit_constant_int(constant_int*) = 0; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f8cf08cba..b36f51d92 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { case tt::Xchg: name = "exch", s_ty = "b"; break; } std::string s_vec = vec == 2 ? "x2" : ""; - std::string mod = nbits == 32 ? "" : ".noftz"; + std::string mod = nbits == 16 ? ".noftz" : ""; std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;"; - std::string ty_id = nbits*vec == 32 ? "r" : "h"; + std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h"); std::string constraint = "=" + ty_id + ",b,l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); @@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) { add_barrier(); } +void generator::visit_clock_inst(ir::clock_inst* clock){ + InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true); + vals_[clock][{}] = call(iasm); +} + +void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){ + InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true); + vals_[timer][{}] = call(iasm); +} + + + void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { ir::value *v = i->get_operand(0); int inc = i->get_inc(); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 0d1c9c3d2..726ac9a97 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ // create machine module->setTargetTriple(triple); std::string error; + llvm::TargetMachine* machine; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); llvm::TargetOptions opt; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, + machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); // set data layout if(layout.empty()) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 664fbb983..aabbc4385 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { return builder->create_sqrt(x); } +// + +ir::value *dispatch::globaltimer(ir::builder *builder) { + return builder->insert(globaltimer_inst::create(builder->get_context())); +} + +ir::value *dispatch::clock(ir::builder *builder) { + return builder->insert(clock_inst::create(builder->get_context())); + +} // diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index c225b315f..d1f81f136 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri } // barrier -barrier_inst::barrier_inst(context &ctx, const std::string &name, - instruction *next) +barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next) : instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { } barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) { @@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons return new prefetch_s_inst(ctx, arg, inc, name, next); } -//// nv_dynamic_program_idx -//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) -// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } +// global timer +globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next) + : instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { } -//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { -// return new make_range_dyn(ty, name, next); -//} +globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) { + return new globaltimer_inst(ctx, name, next); +} -//// nv_static_program_idx -//make_range_sta::make_range_sta(make_range *range) -// : constant(range->get_type(), 0), range_(range) { } +// clock +clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next) + : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { } -//make_range* make_range_sta::get_range() const -//{ return range_; } - -//make_range_sta* make_range_sta::get(make_range* range) { -// static std::map cache; -// if(cache.find(range) == cache.end()) -// cache.insert({range, new make_range_sta(range)}); -// return cache.at(range); -//} +clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) { + return new clock_inst(ctx, name, next); +} // make_range diff --git a/python/src/triton.cc b/python/src/triton.cc index 9e53cc341..22017ebf5 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) { m.def("cos", &ir::dispatch::cos, ret::reference); m.def("sin", &ir::dispatch::sin, ret::reference); m.def("sqrt", &ir::dispatch::sqrt, ret::reference); + // utilities + m.def("clock", &ir::dispatch::clock, ret::reference); + m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference); // internal (debugging only) m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df25e59fb..0312d8146 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -792,6 +792,19 @@ def sum(input, axis, _builder=None): def xor_sum(input, axis, _builder=None): return frontend.xor_sum(input, axis, _builder) +# ----------------------- +# Utilities +# ----------------------- + + +@builtin +def globaltimer(_builder=None): + return frontend.globaltimer(_builder) + + +@builtin +def clock(_builder=None): + return frontend.clock(_builder) # ----------------------- # Internal for debugging diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index d684106f1..c0fb85328 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -24,9 +24,11 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector + time_start_ptr, time_end_ptr, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process # NOTE: `constexpr` so it can be used as a shape value ): + tl.atomic_min(time_start_ptr, tl.clock()) # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 @@ -45,6 +47,7 @@ def add_kernel( output = x + y # Write x + y back to DRAM tl.store(output_ptr + offsets, output, mask=mask) + tl.atomic_max(time_end_ptr, tl.clock()) # %% @@ -53,6 +56,8 @@ def add_kernel( def add(x: torch.Tensor, y: torch.Tensor): + time_start = torch.zeros(1, dtype=torch.int64, device='cuda') + time_end = torch.zeros(1, dtype=torch.int64, device='cuda') # We need to preallocate the output output = torch.empty_like(x) assert x.is_cuda and y.is_cuda and output.is_cuda @@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor): # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. + print((time_end, time_start)) return output