[DOCS] Made documentation consistent with the new kernel API
This commit is contained in:
committed by
Philippe Tillet
parent
eadaeab299
commit
a5e3397e6e
@@ -57,7 +57,8 @@ As you will see, a wrapper for the above Triton function can be created in just
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# create callable kernel for the source-code
|
# create callable kernel for the source-code
|
||||||
kernel = triton.kernel(src)
|
# options: 4 warps and a -DTILE=1024
|
||||||
|
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -72,11 +73,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
|||||||
N = x.numel()
|
N = x.numel()
|
||||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
|
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
|
||||||
# launch kernel
|
# launch kernel
|
||||||
# options: 4 warps and a -DTILE=1024
|
_add.kernel(z, x, y, N, grid = grid)
|
||||||
_add.kernel(z, x, y, N,
|
|
||||||
grid = grid,
|
|
||||||
num_warps = 4,
|
|
||||||
defines = {'TILE': 1024})
|
|
||||||
# return output
|
# return output
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
@@ -8,4 +8,3 @@ Tutorials
|
|||||||
triton-vs-cuda
|
triton-vs-cuda
|
||||||
matrix-transposition
|
matrix-transposition
|
||||||
matrix-multiplication
|
matrix-multiplication
|
||||||
putting-it-all-together
|
|
||||||
|
@@ -97,12 +97,10 @@ Auto-Tuning
|
|||||||
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
|
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
kernel = triton.kernel(src, defines = {'TILE': [256, 512, 1024]}, num_warps = [2, 4, 8])
|
||||||
|
|
||||||
_vector_add.kernel(y, x, N, grid=grid,
|
would benchmark our above triton source-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
||||||
defines={'TILE': [256, 512, 1024]},
|
|
||||||
num_warps = [2, 4, 8])
|
|
||||||
|
|
||||||
would benchmark our above triton-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
|
||||||
|
|
||||||
=============================
|
=============================
|
||||||
Going Further
|
Going Further
|
||||||
|
Reference in New Issue
Block a user