[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -182,17 +182,13 @@ def matmul_kernel(
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
**meta,
|
||||
):
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# extract meta-parameters
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||
GROUP_SIZE_M = 8
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse
|
||||
|
Reference in New Issue
Block a user