[BACKEND] Add isRow attribute for DotOp tensors whose parent is mmav1 (#970)
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
@@ -32,7 +32,7 @@ def matmul_no_scf_kernel(
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[128, 256, 32],
|
||||
[256, 128, 16],
|
||||
# [256, 128, 16],
|
||||
[128, 16, 32],
|
||||
[32, 128, 64],
|
||||
[128, 128, 64],
|
||||
@@ -43,8 +43,6 @@ def matmul_no_scf_kernel(
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
|
||||
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
@@ -83,7 +81,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True)
|
||||
guard_for_volta(is_int8=True)
|
||||
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
|
||||
@@ -199,7 +197,6 @@ def get_proper_err(a, b, golden):
|
||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
@@ -276,7 +273,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
guard_for_volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32)
|
||||
guard_for_volta(is_tf32=allow_tf32)
|
||||
|
||||
# Configure the pytorch counterpart
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
@@ -302,7 +299,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
||||
|
||||
def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
|
||||
def guard_for_volta(is_int8=False, is_tf32=False):
|
||||
'''
|
||||
Tell whether the test case is valid on Volta GPU.
|
||||
Some features are WIP, so the corresponding support are missing.
|
||||
@@ -311,8 +308,7 @@ def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
|
||||
is_on_Volta = capability[0] < 8
|
||||
# TODO[Superjomn]: Remove the constraints below when features are ready
|
||||
is_feature_supported = not (is_int8 or is_tf32)
|
||||
is_feature_ready = not (trans_a or trans_b)
|
||||
|
||||
if is_on_Volta:
|
||||
if (not is_feature_supported) or (not is_feature_ready):
|
||||
pytest.skip("Not valid on Volta")
|
||||
if (not is_feature_supported):
|
||||
pytest.skip("Not valid on Volta")
|
@@ -1385,6 +1385,8 @@ arg_type_pattern = {
|
||||
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
def compile(fn, **kwargs):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
# if fn is not a JITFunction, then it
|
||||
# has to be a path to a file
|
||||
@@ -1392,11 +1394,9 @@ def compile(fn, **kwargs):
|
||||
asm = dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
|
Reference in New Issue
Block a user