[FRONTEND] Added tl.clock
and tl.globaltimer
(#485)
This commit is contained in:
@@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// utilities
|
||||
m.def("clock", &ir::dispatch::clock, ret::reference);
|
||||
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
|
||||
|
@@ -792,6 +792,19 @@ def sum(input, axis, _builder=None):
|
||||
def xor_sum(input, axis, _builder=None):
|
||||
return frontend.xor_sum(input, axis, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Utilities
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def globaltimer(_builder=None):
|
||||
return frontend.globaltimer(_builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def clock(_builder=None):
|
||||
return frontend.clock(_builder)
|
||||
|
||||
# -----------------------
|
||||
# Internal for debugging
|
||||
|
@@ -24,9 +24,11 @@ def add_kernel(
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
time_start_ptr, time_end_ptr,
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
tl.atomic_min(time_start_ptr, tl.clock())
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
@@ -45,6 +47,7 @@ def add_kernel(
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
tl.atomic_max(time_end_ptr, tl.clock())
|
||||
|
||||
|
||||
# %%
|
||||
@@ -53,6 +56,8 @@ def add_kernel(
|
||||
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
time_start = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
time_end = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
@@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
print((time_end, time_start))
|
||||
return output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user