From 5c7122004c25266f5dfd65c5613813107732278c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 14 Apr 2022 17:33:44 -0700 Subject: [PATCH] [TUTORIALS] Tutorial shouldn't expose `clock`. Just removed it. --- python/tutorials/01-vector-add.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 51de7ac6c..d684106f1 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -24,11 +24,9 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector - time_start_ptr, time_end_ptr, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process # NOTE: `constexpr` so it can be used as a shape value ): - tl.atomic_min(time_start_ptr, tl.clock()) # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 @@ -47,7 +45,6 @@ def add_kernel( output = x + y # Write x + y back to DRAM tl.store(output_ptr + offsets, output, mask=mask) - tl.atomic_max(time_end_ptr, tl.clock()) # %% @@ -56,8 +53,6 @@ def add_kernel( def add(x: torch.Tensor, y: torch.Tensor): - time_start = torch.zeros(1, dtype=torch.int64, device='cuda') - time_end = torch.zeros(1, dtype=torch.int64, device='cuda') # We need to preallocate the output output = torch.empty_like(x) assert x.is_cuda and y.is_cuda and output.is_cuda @@ -70,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor): # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments - add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output