[DOCS] Added matrix multiplication tutorial
This commit is contained in:
@@ -126,10 +126,19 @@ def make_kernel(N, device):
|
||||
# Now are kernels are indexed not only by the provided device but also
|
||||
# by the rounded number of columns in the input matrix
|
||||
BLOCK = next_power_of_2(N)
|
||||
key = (BLOCK, device)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
# that are longer
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
# Each (BLOCK, num_warps, device) results in a different kernel
|
||||
key = (BLOCK, num_warps, device)
|
||||
if key not in cache:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
cache[key] = triton.kernel(_src, device=device, defines=defines)
|
||||
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
|
||||
return cache[key]
|
||||
|
||||
|
||||
@@ -174,7 +183,7 @@ print(torch.allclose(y_tri, y_ref))
|
||||
# As expected, the results are identical.
|
||||
|
||||
# %%
|
||||
# Benchmarking
|
||||
# Benchmark
|
||||
# -------------
|
||||
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
Reference in New Issue
Block a user