From 5ca1ed01016530056c4507661c24d6c21efc983d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 24 Oct 2022 19:41:25 -0700 Subject: [PATCH 01/10] Add bf16/fp16/fp64 support for ty_to_cpp (#800) In ```torch._inductor```, we [convert 0d CPU tensor to scalar during triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need add missing triton support for bf16/fp16/fp64. --- python/triton/compiler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 1332f2c76..ab7733b60 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -913,7 +913,10 @@ def ty_to_cpp(ty): "i64": "int64_t", "u32": "uint32_t", "u64": "uint64_t", + "fp16": "float", + "bf16": "float", "fp32": "float", + "fp64": "double", }[ty] @@ -943,6 +946,8 @@ def generate_launcher(identifier, constants, signature): 'i64': 'int64_t', 'u32': 'uint32_t', 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', 'fp32': 'float', 'fp64': 'double', }[ty] From 3ca667dfa8df4b64bb47309bfbf7ffcad75bda51 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 28 Oct 2022 23:27:06 -0700 Subject: [PATCH 02/10] [Frontend] Return a scalar if all input args are scalar (#816) --- python/test/unit/language/test_core.py | 31 +++++++++++++++- python/triton/language/extern.py | 50 ++++++++++++++------------ 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d7d9130d5..1282f24d9 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1546,7 +1546,7 @@ def test_num_warps_pow2(): [('int32', 'libdevice.ffs', ''), ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), ('float64', 'libdevice.norm4d', '')]) -def test_libdevice(dtype_str, expr, lib_path): +def test_libdevice_tensor(dtype_str, expr, lib_path): @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): @@ -1582,3 +1582,32 @@ def test_libdevice(dtype_str, expr, lib_path): np.testing.assert_equal(y_ref, to_numpy(y_tri)) else: np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('float32', 'libdevice.pow', '')]) +def test_libdevice_scalar(dtype_str, expr, lib_path): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = X + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((1,), dtype_str=dtype_str, rs=rs) + y_ref = np.zeros(shape, dtype=x.dtype) + + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref[:] = np.power(x, x) + + # triton result + x_tri = to_triton(x)[0].item() + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py index a306a2e9a..1f3c9371c 100644 --- a/python/triton/language/extern.py +++ b/python/triton/language/extern.py @@ -59,28 +59,34 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: :return: the return value of the function ''' dispatch_args = args.copy() - if len(args) == 1: - dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) - ret_shape = dispatch_args[0].shape - elif len(args) == 2: - dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) - dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) - dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( - dispatch_args[0], dispatch_args[1], _builder) - ret_shape = dispatch_args[0].shape - else: - for i in range(len(dispatch_args)): - dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) - broadcast_arg = dispatch_args[0] - # Get the broadcast shape over all the arguments - for i in range(len(dispatch_args)): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) - # Change the shape of each argument based on the broadcast shape - for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) - ret_shape = broadcast_arg.shape + all_scalar = True + ret_shape = None + for dispatch_arg in dispatch_args: + if dispatch_arg.type.is_block(): + all_scalar = False + if not all_scalar: + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape func = getattr(_builder, "create_extern_elementwise") return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) From 584086f08c0aafe554f1b50da7b131afe04141c4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 29 Oct 2022 16:59:06 -0700 Subject: [PATCH 03/10] [BUILD] Now using cibuildwheel default --- .github/workflows/wheels.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index d627888c5..74d376bd5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -30,8 +30,6 @@ jobs: export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014" export CIBW_BEFORE_BUILD="pip install cmake;\ yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;" - export CIBW_SKIP="{cp,pp}35-*" - export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" python3 -m cibuildwheel python --output-dir wheelhouse From 6311d7040666930dd027513c7a5093b28dc2c7b4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 29 Oct 2022 17:15:47 -0700 Subject: [PATCH 04/10] Revert "[BUILD] Now using cibuildwheel default" This reverts commit 584086f08c0aafe554f1b50da7b131afe04141c4. --- .github/workflows/wheels.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 74d376bd5..d627888c5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -30,6 +30,8 @@ jobs: export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014" export CIBW_BEFORE_BUILD="pip install cmake;\ yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;" + export CIBW_SKIP="{cp,pp}35-*" + export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" python3 -m cibuildwheel python --output-dir wheelhouse From 578ada774098449e8046da2f23d129fc373ebee6 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 31 Oct 2022 11:08:18 -0700 Subject: [PATCH 05/10] [DOCS] Add install from source instructions to README (#821) --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index ed0fc71b1..bab417daa 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,15 @@ And the latest nightly release: pip install -U --pre triton ``` +# Install from source + +``` +git clone https://github.com/openai/triton.git; +cd triton/python; +pip install cmake; # build time dependency +pip install -e . +``` + # Changelog Version 1.1 is out! New features include: From f16138d447bccc54641a9c48ffedbd449a1a40a7 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 2 Nov 2022 01:51:58 +0800 Subject: [PATCH 06/10] [Frontend] Interface fixes for libdevice (#830) - Unifying several interfaces with different types to a single one, e.g. `fsub_ru` and `dsub_ru` -> `sub_ru`; - Minor bug fix: `fast_pow` is incorrectly classified into the `pow` interface, of which arguments are the same as `powf`; - Explicit interfaces for casting functions, e.g. decoupling `ll2float_ru` to `ll2float_ru` and `ull2float_ru`; - Removing interfaces that are not in NVIDIA's official documents, e.g. `fmaf_ieee_rn`, which is confusing together with `fmaf_rn`. Note that this PR for the master branch is different from #829, which is for the MLIR branch. --- python/triton/language/libdevice.py | 470 ++++++++-------------- python/triton/tools/build_extern.py | 116 +++--- python/tutorials/07-libdevice-function.py | 2 +- 3 files changed, 233 insertions(+), 355 deletions(-) diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py index be0ab417e..25e75c89d 100644 --- a/python/triton/language/libdevice.py +++ b/python/triton/language/libdevice.py @@ -58,13 +58,7 @@ def mulhi(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.int32, core.int32,): ("__nv_mulhi", core.int32), (core.uint32, core.uint32,): ("__nv_umulhi", core.uint32), - }, _builder) - - -@extern.extern -def mul64hi(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.int64, core.int64,): ("__nv_mul64hi", core.int64), + (core.int64, core.int64,): ("__nv_mul64hi", core.int64), (core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64), }, _builder) @@ -157,262 +151,138 @@ def saturatef(arg0, _builder=None): }, _builder) -@extern.extern -def fmaf_rn(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32), - }, _builder) - - -@extern.extern -def fmaf_rz(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32), - }, _builder) - - -@extern.extern -def fmaf_rd(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32), - }, _builder) - - -@extern.extern -def fmaf_ru(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32), - }, _builder) - - -@extern.extern -def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32), - }, _builder) - - -@extern.extern -def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32), - }, _builder) - - -@extern.extern -def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32), - }, _builder) - - -@extern.extern -def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32), - }, _builder) - - @extern.extern def fma_rn(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64), + {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32), + (core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64), }, _builder) @extern.extern def fma_rz(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64), + {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32), + (core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64), }, _builder) @extern.extern def fma_rd(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64), + {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32), + (core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64), }, _builder) @extern.extern def fma_ru(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64), + {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32), + (core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64), }, _builder) @extern.extern -def fast_fdividef(arg0, arg1, _builder=None): +def fast_dividef(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32), }, _builder) @extern.extern -def fdiv_rn(arg0, arg1, _builder=None): +def div_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32), + (core.float64, core.float64,): ("__nv_ddiv_rn", core.float64), }, _builder) @extern.extern -def fdiv_rz(arg0, arg1, _builder=None): +def div_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32), + (core.float64, core.float64,): ("__nv_ddiv_rz", core.float64), }, _builder) @extern.extern -def fdiv_rd(arg0, arg1, _builder=None): +def div_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32), + (core.float64, core.float64,): ("__nv_ddiv_rd", core.float64), }, _builder) @extern.extern -def fdiv_ru(arg0, arg1, _builder=None): +def div_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32), + (core.float64, core.float64,): ("__nv_ddiv_ru", core.float64), }, _builder) @extern.extern -def frcp_rn(arg0, _builder=None): +def rcp_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_frcp_rn", core.float32), + (core.float64,): ("__nv_drcp_rn", core.float64), }, _builder) @extern.extern -def frcp_rz(arg0, _builder=None): +def rcp_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_frcp_rz", core.float32), + (core.float64,): ("__nv_drcp_rz", core.float64), }, _builder) @extern.extern -def frcp_rd(arg0, _builder=None): +def rcp_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_frcp_rd", core.float32), + (core.float64,): ("__nv_drcp_rd", core.float64), }, _builder) @extern.extern -def frcp_ru(arg0, _builder=None): +def rcp_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_frcp_ru", core.float32), + (core.float64,): ("__nv_drcp_ru", core.float64), }, _builder) @extern.extern -def fsqrt_rn(arg0, _builder=None): +def sqrt_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_fsqrt_rn", core.float32), + (core.float64,): ("__nv_dsqrt_rn", core.float64), }, _builder) @extern.extern -def fsqrt_rz(arg0, _builder=None): +def sqrt_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_fsqrt_rz", core.float32), + (core.float64,): ("__nv_dsqrt_rz", core.float64), }, _builder) @extern.extern -def fsqrt_rd(arg0, _builder=None): +def sqrt_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_fsqrt_rd", core.float32), + (core.float64,): ("__nv_dsqrt_rd", core.float64), }, _builder) @extern.extern -def fsqrt_ru(arg0, _builder=None): +def sqrt_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_fsqrt_ru", core.float32), - }, _builder) - - -@extern.extern -def ddiv_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64), - }, _builder) - - -@extern.extern -def ddiv_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64), - }, _builder) - - -@extern.extern -def ddiv_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64), - }, _builder) - - -@extern.extern -def ddiv_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64), - }, _builder) - - -@extern.extern -def drcp_rn(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_drcp_rn", core.float64), - }, _builder) - - -@extern.extern -def drcp_rz(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_drcp_rz", core.float64), - }, _builder) - - -@extern.extern -def drcp_rd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_drcp_rd", core.float64), - }, _builder) - - -@extern.extern -def drcp_ru(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_drcp_ru", core.float64), - }, _builder) - - -@extern.extern -def dsqrt_rn(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_dsqrt_rn", core.float64), - }, _builder) - - -@extern.extern -def dsqrt_rz(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_dsqrt_rz", core.float64), - }, _builder) - - -@extern.extern -def dsqrt_rd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_dsqrt_rd", core.float64), - }, _builder) - - -@extern.extern -def dsqrt_ru(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_dsqrt_ru", core.float64), + (core.float64,): ("__nv_dsqrt_ru", core.float64), }, _builder) @@ -425,114 +295,66 @@ def sqrt(arg0, _builder=None): @extern.extern -def dadd_rn(arg0, arg1, _builder=None): +def add_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dadd_rn", core.float64), + (core.float32, core.float32,): ("__nv_fadd_rn", core.float32), }, _builder) @extern.extern -def dadd_rz(arg0, arg1, _builder=None): +def add_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dadd_rz", core.float64), + (core.float32, core.float32,): ("__nv_fadd_rz", core.float32), }, _builder) @extern.extern -def dadd_rd(arg0, arg1, _builder=None): +def add_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dadd_rd", core.float64), + (core.float32, core.float32,): ("__nv_fadd_rd", core.float32), }, _builder) @extern.extern -def dadd_ru(arg0, arg1, _builder=None): +def add_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dadd_ru", core.float64), + (core.float32, core.float32,): ("__nv_fadd_ru", core.float32), }, _builder) @extern.extern -def dmul_rn(arg0, arg1, _builder=None): +def mul_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dmul_rn", core.float64), + (core.float32, core.float32,): ("__nv_fmul_rn", core.float32), }, _builder) @extern.extern -def dmul_rz(arg0, arg1, _builder=None): +def mul_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dmul_rz", core.float64), + (core.float32, core.float32,): ("__nv_fmul_rz", core.float32), }, _builder) @extern.extern -def dmul_rd(arg0, arg1, _builder=None): +def mul_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dmul_rd", core.float64), + (core.float32, core.float32,): ("__nv_fmul_rd", core.float32), }, _builder) @extern.extern -def dmul_ru(arg0, arg1, _builder=None): +def mul_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float64, core.float64,): ("__nv_dmul_ru", core.float64), - }, _builder) - - -@extern.extern -def fadd_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fadd_rd", core.float32), - }, _builder) - - -@extern.extern -def fadd_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fadd_ru", core.float32), - }, _builder) - - -@extern.extern -def fmul_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fmul_rd", core.float32), - }, _builder) - - -@extern.extern -def fmul_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fmul_ru", core.float32), - }, _builder) - - -@extern.extern -def fadd_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fadd_rn", core.float32), - }, _builder) - - -@extern.extern -def fadd_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fadd_rz", core.float32), - }, _builder) - - -@extern.extern -def fmul_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fmul_rn", core.float32), - }, _builder) - - -@extern.extern -def fmul_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float32, core.float32,): ("__nv_fmul_rz", core.float32), + (core.float32, core.float32,): ("__nv_fmul_ru", core.float32), }, _builder) @@ -624,7 +446,13 @@ def double2uint_ru(arg0, _builder=None): def int2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int32,): ("__nv_int2double_rn", core.float64), - (core.uint32,): ("__nv_uint2double_rn", core.float64), + }, _builder) + + +@extern.extern +def uint2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint32,): ("__nv_uint2double_rn", core.float64), }, _builder) @@ -688,7 +516,6 @@ def float2uint_ru(arg0, _builder=None): def int2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int32,): ("__nv_int2float_rn", core.float32), - (core.uint32,): ("__nv_uint2float_rn", core.float32), }, _builder) @@ -696,7 +523,6 @@ def int2float_rn(arg0, _builder=None): def int2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int32,): ("__nv_int2float_rz", core.float32), - (core.uint32,): ("__nv_uint2float_rz", core.float32), }, _builder) @@ -704,7 +530,6 @@ def int2float_rz(arg0, _builder=None): def int2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int32,): ("__nv_int2float_rd", core.float32), - (core.uint32,): ("__nv_uint2float_rd", core.float32), }, _builder) @@ -712,7 +537,34 @@ def int2float_rd(arg0, _builder=None): def int2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int32,): ("__nv_int2float_ru", core.float32), - (core.uint32,): ("__nv_uint2float_ru", core.float32), + }, _builder) + + +@extern.extern +def uint2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint32,): ("__nv_uint2float_rn", core.float32), + }, _builder) + + +@extern.extern +def uint2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint32,): ("__nv_uint2float_rz", core.float32), + }, _builder) + + +@extern.extern +def uint2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint32,): ("__nv_uint2float_rd", core.float32), + }, _builder) + + +@extern.extern +def uint2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint32,): ("__nv_uint2float_ru", core.float32), }, _builder) @@ -853,7 +705,6 @@ def double2ull_ru(arg0, _builder=None): def ll2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2float_rn", core.float32), - (core.uint64,): ("__nv_ull2float_rn", core.float32), }, _builder) @@ -861,7 +712,6 @@ def ll2float_rn(arg0, _builder=None): def ll2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2float_rz", core.float32), - (core.uint64,): ("__nv_ull2float_rz", core.float32), }, _builder) @@ -869,7 +719,6 @@ def ll2float_rz(arg0, _builder=None): def ll2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2float_rd", core.float32), - (core.uint64,): ("__nv_ull2float_rd", core.float32), }, _builder) @@ -877,7 +726,34 @@ def ll2float_rd(arg0, _builder=None): def ll2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2float_ru", core.float32), - (core.uint64,): ("__nv_ull2float_ru", core.float32), + }, _builder) + + +@extern.extern +def ull2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2float_rn", core.float32), + }, _builder) + + +@extern.extern +def ull2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2float_rz", core.float32), + }, _builder) + + +@extern.extern +def ull2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2float_rd", core.float32), + }, _builder) + + +@extern.extern +def ull2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2float_ru", core.float32), }, _builder) @@ -885,7 +761,6 @@ def ll2float_ru(arg0, _builder=None): def ll2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2double_rn", core.float64), - (core.uint64,): ("__nv_ull2double_rn", core.float64), }, _builder) @@ -893,7 +768,6 @@ def ll2double_rn(arg0, _builder=None): def ll2double_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2double_rz", core.float64), - (core.uint64,): ("__nv_ull2double_rz", core.float64), }, _builder) @@ -901,7 +775,6 @@ def ll2double_rz(arg0, _builder=None): def ll2double_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2double_rd", core.float64), - (core.uint64,): ("__nv_ull2double_rd", core.float64), }, _builder) @@ -909,7 +782,34 @@ def ll2double_rd(arg0, _builder=None): def ll2double_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int64,): ("__nv_ll2double_ru", core.float64), - (core.uint64,): ("__nv_ull2double_ru", core.float64), + }, _builder) + + +@extern.extern +def ull2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2double_rn", core.float64), + }, _builder) + + +@extern.extern +def ull2double_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2double_rz", core.float64), + }, _builder) + + +@extern.extern +def ull2double_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2double_rd", core.float64), + }, _builder) + + +@extern.extern +def ull2double_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint64,): ("__nv_ull2double_ru", core.float64), }, _builder) @@ -917,7 +817,6 @@ def ll2double_ru(arg0, _builder=None): def int_as_float(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.int32,): ("__nv_int_as_float", core.float32), - (core.uint32,): ("__nv_uint_as_float", core.float32), }, _builder) @@ -928,6 +827,13 @@ def float_as_int(arg0, _builder=None): }, _builder) +@extern.extern +def uint_as_float(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.uint32,): ("__nv_uint_as_float", core.float32), + }, _builder) + + @extern.extern def float_as_uint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], @@ -1006,11 +912,9 @@ def fast_log10f(arg0, _builder=None): @extern.extern -def pow(arg0, arg1, _builder=None): +def fast_powf(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fast_powf", core.float32), - (core.float32, core.float32,): ("__nv_powf", core.float32), - (core.float64, core.float64,): ("__nv_pow", core.float64), }, _builder) @@ -1031,35 +935,39 @@ def rhadd(arg0, arg1, _builder=None): @extern.extern -def fsub_rn(arg0, arg1, _builder=None): +def sub_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fsub_rn", core.float32), + (core.float64, core.float64,): ("__nv_dsub_rn", core.float64), }, _builder) @extern.extern -def fsub_rz(arg0, arg1, _builder=None): +def sub_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fsub_rz", core.float32), + (core.float64, core.float64,): ("__nv_dsub_rz", core.float64), }, _builder) @extern.extern -def fsub_rd(arg0, arg1, _builder=None): +def sub_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fsub_rd", core.float32), + (core.float64, core.float64,): ("__nv_dsub_rd", core.float64), }, _builder) @extern.extern -def fsub_ru(arg0, arg1, _builder=None): +def sub_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.float32,): ("__nv_fsub_ru", core.float32), + (core.float64, core.float64,): ("__nv_dsub_ru", core.float64), }, _builder) @extern.extern -def frsqrt_rn(arg0, _builder=None): +def rsqrt_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_frsqrt_rn", core.float32), }, _builder) @@ -1098,16 +1006,18 @@ def nearbyint(arg0, _builder=None): @extern.extern -def isnanf(arg0, _builder=None): +def isnan(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_isnanf", core.int32), + (core.float64,): ("__nv_isnand", core.int32), }, _builder) @extern.extern -def signbitf(arg0, _builder=None): +def signbit(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_signbitf", core.int32), + (core.float64,): ("__nv_signbitd", core.int32), }, _builder) @@ -1127,9 +1037,10 @@ def finitef(arg0, _builder=None): @extern.extern -def isinff(arg0, _builder=None): +def isinf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float32,): ("__nv_isinff", core.int32), + (core.float64,): ("__nv_isinfd", core.int32), }, _builder) @@ -1550,10 +1461,12 @@ def fma(arg0, arg1, arg2, _builder=None): @extern.extern -def powi(arg0, arg1, _builder=None): +def pow(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.float32, core.int32,): ("__nv_powif", core.float32), (core.float64, core.int32,): ("__nv_powi", core.float64), + (core.float32, core.float32,): ("__nv_powf", core.float32), + (core.float64, core.float64,): ("__nv_pow", core.float64), }, _builder) @@ -1605,57 +1518,8 @@ def logb(arg0, _builder=None): }, _builder) -@extern.extern -def signbitd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_signbitd", core.int32), - }, _builder) - - @extern.extern def isfinited(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.float64,): ("__nv_isfinited", core.int32), }, _builder) - - -@extern.extern -def isinfd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_isinfd", core.int32), - }, _builder) - - -@extern.extern -def isnand(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.float64,): ("__nv_isnand", core.int32), - }, _builder) - - -@extern.extern -def dsub_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_dsub_rn", core.float64), - }, _builder) - - -@extern.extern -def dsub_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_dsub_rz", core.float64), - }, _builder) - - -@extern.extern -def dsub_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_dsub_ru", core.float64), - }, _builder) - - -@extern.extern -def dsub_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.float64, core.float64,): ("__nv_dsub_rd", core.float64), - }, _builder) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index 6ced7e7ed..551def69d 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -149,6 +149,9 @@ class Libdevice(ExternLibrary): func_strs = func_str.split("(") func_name = func_strs[0].replace("@", "") op_name = func_name.replace("__nv_", "") + # To filter some interfaces unlisted in NVIDIA's official documents. + if 'ieee' in op_name: + return None # Get arg_types arg_strs = func_strs[1].split(",") arg_types = [] @@ -176,55 +179,66 @@ class Libdevice(ExternLibrary): for symbol in self._symbols.values(): op_name = symbol.op_name symbol_set[op_name] = symbol - # The following cases are grouped together: - # op_name, op_name + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', + 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn', + 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', + 'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin', + 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', + 'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt', + 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', + 'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi', + 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', + 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru', + 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', + 'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx', + 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', + 'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs', + 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', + 'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor', + 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', + 'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', + 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', + 'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', + 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', + 'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10', + 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', + 'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max', + 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', + 'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', + 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru', + 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', + 'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi', + 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter', + 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', + 'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', + 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd', + 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', + 'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', + 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', + 'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn', + 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', + 'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh', + 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', + 'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', + 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz', + 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', + 'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', + 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', + 'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn' + } + for symbol in self._symbols.values(): op_name = symbol.op_name - if "max" in op_name: - op_name = "max" - elif "min" in op_name: - op_name = "min" - elif "abs" in op_name: - op_name = "abs" - elif "pow" in op_name and "fast" in op_name: - op_name = "pow" - elif "round" in op_name: - if "llround" in op_name: - op_name = "llround" - else: - op_name = "round" - elif "rint" in op_name: - if "llrint" in op_name: - op_name = "llrint" - else: - op_name = "rint" - elif op_name.startswith("ull"): - if "2" not in op_name: - # e.g., ullmax->max - op_name = op_name[3:] - else: - # e.g., ull2double->ll2double - op_name = op_name[1:] - elif op_name.startswith("u"): - if "2" not in op_name: - # e.g., uhadd->hadd - op_name = op_name[1:] - else: - # e.g., uint2double_rn->int2double_rn - op_name = op_name[1:] - elif op_name.startswith("ll"): - if "2" not in op_name: - # e.g., llmax->max - op_name = op_name[2:] - elif op_name.endswith("ll"): - op_name = op_name[:-2] - elif op_name.endswith("f"): - op_name = op_name[:-1] - if op_name in symbol_set: - # Update op_name only if there's an existing symbol + if op_name in renaming: + op_name = renaming[op_name] symbol._op_name = op_name - else: - op_name = symbol._op_name if op_name in self._symbol_groups: self._symbol_groups[op_name].append(symbol) else: @@ -250,7 +264,7 @@ class Libdevice(ExternLibrary): # return extern.dispatch("libdevice", , , , _builder) import_str = "from . import core, extern\n" import_str += "import os\n" - header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" + header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" func_str = "" for symbols in self._symbol_groups.values(): func_str += "@extern.extern\n" @@ -331,10 +345,10 @@ def build(llvm_dis_path, lib_path, lib_name, output_dir): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis") - parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library") - parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library") - parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/") + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") args = parser.parse_args() build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/python/tutorials/07-libdevice-function.py b/python/tutorials/07-libdevice-function.py index bb5f7b26d..19e6cac7a 100644 --- a/python/tutorials/07-libdevice-function.py +++ b/python/tutorials/07-libdevice-function.py @@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. -Using triton, you can simply call `tl.libdevice.asinf`. +Using triton, you can simply call `tl.libdevice.asin`. triton automatically selects the correct underlying device function to invoke based on input and output types. """ From 77bc5187b57f544c038a1734979409476ca3f9a7 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Thu, 3 Nov 2022 00:11:52 -0700 Subject: [PATCH 07/10] Better NVIDIA Pascal GPU Support (#827) This PR clarifies which features are supported on P100 via its tests, though Pascal is not officially and fully supported by Triton. ## What this PR does - Skip unsupported tests on P100. - Atomic RMW - `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()` tests do not work on P100). - Add an explicit error if shared memory size >= 64K on P100. - Otherwise it causes `Invalid CUDA argument` error at `cuLaunchKernel()`, but this error is not very straightforward to understand. Instead of this generic CUDA argument error, this PR makes Triton show an error during codegen when `sm < 70`. This check happens in C/C++ so won't add an overhead in Triton's Python runtime. - 3 tests (see below) are currently failing, but these are not marked as skipped because any codegen update in the future can change the kernel size of the other tests. - This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic `tl.dot()` implementation would support P100. Importantly, Triton passed all the other tests on P100. Though this support is not official, it is great for, for example, PyTorch's TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its backend (https://github.com/pytorch/torchdynamo/issues/1591). ### Results on P100 (Google Cloud) ```sh $ pytest test/unit ... ================================================================================== short test summary info ================================================================================== FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes ================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ================================================================== ```
Environment Details (collapsed)

### VM details (Google Cloud) https://cloud.google.com/ ``` # You need a paid account (free trial does not cover GPUs) Google Cloud -> New Project -> Compute-Engine -> VM Instance Machine: GPU: NVIDIA Tesla P100 x 1 CPU: 2 vCPUs, 7.5GB memory Boot disk: OS: Ubuntu 18.04 LTS Disk: 40GB (cannot build Triton on the default 10GB disk) - When I tried, about $1.2 per hour. - US instances were full when I tried. I used Asia or Australia. - Needed a paid account (GPU is not covered by free trial) - Needed quota request for any GPU instance (by default, no GPU instance is allowed). Needed to wait an hour for approval ``` ### Reproducer ```sh ## 1. Install CUDA and a driver # Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/) sudo apt-key del 7fa2af80 wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb # Download CUDA as instructed wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600 sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /" sudo apt-get update sudo apt-get -y install cuda # Are you using P100? nvidia-smi | grep "Tesla P100" ## 2. Setup the build environment sudo apt update sudo apt install -y build-essential wget git libz-dev wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3 eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)" conda create -y --name triton_base conda activate triton_base conda install -y cmake setuptools ## 3. Build Triton git clone https://github.com/openai/triton.git cd triton/python pip3 install -e '.[tests]' ## 4. Test pytest test/unit ``` ### Environment ```sh $ nvidia-smi +-----------------------------------------------------------------------------+ | NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla P100-PCIE... On | 00000000:00:04.0 Off | 0 | | N/A 36C P0 25W / 250W | 0MiB / 16384MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ ```

--- lib/codegen/pass.cc | 7 +++++++ python/test/unit/language/test_core.py | 10 ++++++++++ python/test/unit/operators/test_blocksparse.py | 5 +++++ python/test/unit/operators/test_matmul.py | 2 ++ 4 files changed, 24 insertions(+) diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 024a838d9..1057cfef6 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -149,6 +149,13 @@ std::unique_ptr add_passes_to_emit_bin( // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); + if (target->as_nvidia() && target->as_nvidia()->sm() < 70) { + // sm < 70 (Pascal) has little shared memory resource. + // Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here. + if (shared_static >= 65536) { + throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes"); + } + } if (isel.get_extern_lib_map().size() > 0) { // If there's any extern lib calls, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 1282f24d9..5231a5bfa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -605,6 +605,10 @@ def test_tuples(): ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on devices with sm >= 70") n_programs = 5 # triton kernel @@ -1042,6 +1046,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'): if not (allow_tf32 and (dtype in ['float16']))]) def test_dot(epilogue, allow_tf32, dtype, device='cuda'): cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") if cc < 80: if dtype == 'int8': pytest.skip("Only test int8 on devices with sm >= 80") @@ -1227,6 +1233,10 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested M = 32 diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 9e0c72de9..ebe36e254 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -2,6 +2,7 @@ import pytest import torch import triton +import triton._C.libtriton.triton as _triton @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @@ -125,6 +126,10 @@ def test_attention_fwd_bwd( batch_size=2, n_heads=2, ): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + # inputs qkv_shape = (batch_size, n_heads, n_ctx, 64) qkvs = [ diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index e14ea6ae7..8d20fbae3 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -68,6 +68,8 @@ import triton._C.libtriton.triton as _triton ) def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 70: + pytest.skip("Only test tl.dot() on devices with sm >= 70") if cc < 80 and DTYPE == "bfloat16": pytest.skip("Only test bfloat16 on devices with sm >= 80") if DTYPE == "bfloat16" and SPLIT_K != 1: From 0d7e7532279e45672555e344646f5c19c3972331 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Fri, 4 Nov 2022 18:05:16 -0700 Subject: [PATCH 08/10] [TESTING] use torch.int for autotuning cache (#840) For stupid reasons, ops on int8 are 3 times slower than on int, and for another set of stupid reasons we are not using cudaMemset for `zero_`, so using `int8` buffer in `do_bench` makes it slow. Co-authored-by: Philippe Tillet --- python/test/regression/test_performance.py | 2 +- python/triton/testing.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 16811eaa9..afec19019 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -128,7 +128,7 @@ elementwise_data = { 1024 * 16: 0.0219, 1024 * 64: 0.0791, 1024 * 256: 0.243, - 1024 * 1024: 0.534, + 1024 * 1024: 0.530, 1024 * 4096: 0.796, 1024 * 16384: 0.905, 1024 * 65536: 0.939, diff --git a/python/triton/testing.py b/python/triton/testing.py index b474baaf8..c83c1e682 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -115,7 +115,9 @@ def nvsmi(attrs): return ret -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, + percentiles=(0.5, 0.2, 0.8), + record_clocks=False, fast_flush=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -130,6 +132,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0 :type grad_to_none: torch.tensor, optional :param percentiles: Performance percentile to return in addition to the median. :type percentiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool """ # Estimate the runtime of the function @@ -151,7 +155,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0 # doesn't contain any input data before the run start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') # Warm-up for _ in range(n_warmup): fn() From 0e4691e6dd91e001a8d33b71badf8b3314325459 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 17 Nov 2022 09:45:30 -0800 Subject: [PATCH 09/10] [FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883) Ran mypy over `build_extern.py`, cleaned up type annotations. Found a fixed a bug where `ExternLibrary(format=)` was being ignored. --- python/triton/tools/build_extern.py | 90 ++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index 551def69d..f4141c31f 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -1,10 +1,24 @@ import argparse import subprocess from abc import ABC, abstractmethod +from typing import Dict, List, Optional class Symbol: - def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: ''' A symbol is a function declaration. @@ -17,31 +31,31 @@ class Symbol: self._name = name self._op_name = op_name self._ret_type = ret_type - self._arg_names = arg_names - self._arg_types = arg_types + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) @property - def name(self): + def name(self) -> str: return self._name @property - def op_name(self): + def op_name(self) -> str: return self._op_name @property - def ret_type(self): + def ret_type(self) -> str: return self._ret_type @property - def arg_names(self): + def arg_names(self) -> List[str]: return self._arg_names @property - def arg_types(self): + def arg_types(self) -> List[str]: return self._arg_types -def convert_type(type_str): +def convert_type(type_str) -> Optional[str]: if type_str == "i32": return "int32" elif type_str == "u32": @@ -59,7 +73,7 @@ def convert_type(type_str): return None -def to_unsigned(type_str): +def to_unsigned(type_str) -> str: if type_str == "int32": return "uint32" elif type_str == "int64": @@ -69,7 +83,19 @@ def to_unsigned(type_str): class ExternLibrary(ABC): - def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: ''' Abstract class for extern library. @@ -80,34 +106,34 @@ class ExternLibrary(ABC): self._name = name self._path = path self._symbols = {} - self._format = True + self._format = format self._grouping = grouping @property - def name(self): + def name(self) -> str: return self._name @property - def path(self): + def path(self) -> str: return self._path @property - def symbols(self): + def symbols(self) -> Dict[str, Symbol]: return self._symbols @property - def grouping(self): + def grouping(self) -> bool: return self._grouping @abstractmethod - def parse_symbols(self, input_file): + def parse_symbols(self, input_file) -> None: pass @abstractmethod def _output_stubs(self) -> str: pass - def generate_stub_file(self, output_dir): + def generate_stub_file(self, output_dir) -> None: file_str = self._output_stubs() if file_str is None or len(file_str) == 0: raise Exception("file_str is empty") @@ -123,6 +149,8 @@ class ExternLibrary(ABC): class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + def __init__(self, path) -> None: ''' Constructor for Libdevice. @@ -132,7 +160,7 @@ class Libdevice(ExternLibrary): super().__init__("libdevice", path) self._symbol_groups = {} - def _extract_symbol(self, line): + def _extract_symbol(self, line) -> Optional[Symbol]: # Extract symbols from line in the following format: # "define [internal] @(,)" entries = line.split("@") @@ -174,7 +202,7 @@ class Libdevice(ExternLibrary): arg_types[i] = to_unsigned(arg_type) return Symbol(func_name, op_name, ret_type, arg_names, arg_types) - def _group_symbols(self): + def _group_symbols(self) -> None: symbol_set = {} for symbol in self._symbols.values(): op_name = symbol.op_name @@ -244,7 +272,7 @@ class Libdevice(ExternLibrary): else: self._symbol_groups[op_name] = [symbol] - def parse_symbols(self, input_file): + def parse_symbols(self, input_file) -> None: if len(self.symbols) > 0: return output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() @@ -256,7 +284,7 @@ class Libdevice(ExternLibrary): self._group_symbols() - def _output_stubs(self): + def _output_stubs(self) -> str: # Generate python functions in the following format: # @extern.extern # def (, _builder=None): @@ -297,7 +325,10 @@ class Libdevice(ExternLibrary): class LLVMDisassembler: - def __init__(self, path): + _path: str + _ll_file: str + + def __init__(self, path) -> None: ''' Invoke llvm-dis to disassemble the given file. @@ -306,23 +337,28 @@ class LLVMDisassembler: self._path = path self._ll_file = "/tmp/extern_lib.ll" - def disasm(self, lib_path): + def disasm(self, lib_path: str) -> None: subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() @property - def ll_file(self): + def ll_file(self) -> str: return self._ll_file @property - def path(self): + def path(self) -> str: return self._path extern_libs = ["libdevice"] -def build(llvm_dis_path, lib_path, lib_name, output_dir): +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: ''' Interface function to build the library file. From 44f577984d28ee979f704e2c28a1dcbac9639840 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Sun, 20 Nov 2022 11:44:42 -0800 Subject: [PATCH 10/10] Fix format double substitution bug: `{i}` => `{{i}}` (#886) The previous `{i}` was silently expanding to the `i` from the enumeration loop on `regular_args` (when it wasn't empty). --- python/triton/runtime/jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 89ad3e2ca..5a234afc2 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -270,7 +270,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # build stub signature -- includes arguments that are specialized for i, arg in constants.items(): if callable(arg): - raise TypeError(f"Callable constexpr at index {i} is not supported") + raise TypeError(f"Callable constexpr at index {{i}} is not supported") if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) if not warmup: