diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 024a838d9..1057cfef6 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -149,6 +149,13 @@ std::unique_ptr add_passes_to_emit_bin( // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); + if (target->as_nvidia() && target->as_nvidia()->sm() < 70) { + // sm < 70 (Pascal) has little shared memory resource. + // Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here. + if (shared_static >= 65536) { + throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes"); + } + } if (isel.get_extern_lib_map().size() > 0) { // If there's any extern lib calls, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 1282f24d9..5231a5bfa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -605,6 +605,10 @@ def test_tuples(): ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on devices with sm >= 70") n_programs = 5 # triton kernel @@ -1042,6 +1046,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'): if not (allow_tf32 and (dtype in ['float16']))]) def test_dot(epilogue, allow_tf32, dtype, device='cuda'): cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") if cc < 80: if dtype == 'int8': pytest.skip("Only test int8 on devices with sm >= 80") @@ -1227,6 +1233,10 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested M = 32 diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 9e0c72de9..ebe36e254 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -2,6 +2,7 @@ import pytest import torch import triton +import triton._C.libtriton.triton as _triton @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @@ -125,6 +126,10 @@ def test_attention_fwd_bwd( batch_size=2, n_heads=2, ): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + # inputs qkv_shape = (batch_size, n_heads, n_ctx, 64) qkvs = [ diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index e14ea6ae7..8d20fbae3 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -68,6 +68,8 @@ import triton._C.libtriton.triton as _triton ) def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") if cc < 80 and DTYPE == "bfloat16": pytest.skip("Only test bfloat16 on devices with sm >= 80") if DTYPE == "bfloat16" and SPLIT_K != 1: