[DOCS] Made documentation consistent with the new kernel API
This commit is contained in:
committed by
Philippe Tillet
parent
eadaeab299
commit
a5e3397e6e
@@ -57,7 +57,8 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
}
|
||||
"""
|
||||
# create callable kernel for the source-code
|
||||
kernel = triton.kernel(src)
|
||||
# options: 4 warps and a -DTILE=1024
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
|
||||
|
||||
# Forward pass
|
||||
@staticmethod
|
||||
@@ -72,11 +73,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
|
||||
# launch kernel
|
||||
# options: 4 warps and a -DTILE=1024
|
||||
_add.kernel(z, x, y, N,
|
||||
grid = grid,
|
||||
num_warps = 4,
|
||||
defines = {'TILE': 1024})
|
||||
_add.kernel(z, x, y, N, grid = grid)
|
||||
# return output
|
||||
return z
|
||||
|
||||
|
@@ -8,4 +8,3 @@ Tutorials
|
||||
triton-vs-cuda
|
||||
matrix-transposition
|
||||
matrix-multiplication
|
||||
putting-it-all-together
|
||||
|
@@ -98,11 +98,9 @@ Now assume that you want to tune the above code for different data types, tile s
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
_vector_add.kernel(y, x, N, grid=grid,
|
||||
defines={'TILE': [256, 512, 1024]},
|
||||
num_warps = [2, 4, 8])
|
||||
kernel = triton.kernel(src, defines = {'TILE': [256, 512, 1024]}, num_warps = [2, 4, 8])
|
||||
|
||||
would benchmark our above triton-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
||||
would benchmark our above triton source-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
||||
|
||||
=============================
|
||||
Going Further
|
||||
|
Reference in New Issue
Block a user