[DOCS] Added matrix multiplication tutorial

This commit is contained in:
Philippe Tillet
2021-03-14 18:49:59 -04:00
parent f4fb209dad
commit 183878dce5
9 changed files with 395 additions and 18 deletions

View File

@@ -126,10 +126,19 @@ def make_kernel(N, device):
# Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix
BLOCK = next_power_of_2(N)
key = (BLOCK, device)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache:
defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines)
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key]
@@ -174,7 +183,7 @@ print(torch.allclose(y_tri, y_ref))
# As expected, the results are identical.
# %%
# Benchmarking
# Benchmark
# -------------
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.