@@ -63,7 +63,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
|
@@ -80,7 +80,7 @@ def softmax_kernel(
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Subtract maximum for numerical stability
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
@@ -188,4 +188,4 @@ benchmark.run(show_plots=True, print_data=True)
|
||||
#
|
||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape.
|
||||
|
@@ -15,7 +15,7 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
|
Reference in New Issue
Block a user