Merge branch 'master' into keren/improve-hook

This commit is contained in:
Jokeren
2022-11-25 15:04:59 -08:00
13 changed files with 403 additions and 409 deletions

View File

@@ -33,6 +33,15 @@ And the latest nightly release:
pip install -U --pre triton pip install -U --pre triton
``` ```
# Install from source
```
git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
```
# Changelog # Changelog
Version 1.1 is out! New features include: Version 1.1 is out! New features include:

View File

@@ -149,6 +149,13 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
// ir.print(std::cout); // ir.print(std::cout);
isel.visit(ir, *llvm); isel.visit(ir, *llvm);
shared_static = allocation.allocated_size(); shared_static = allocation.allocated_size();
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
// sm < 70 (Pascal) has little shared memory resource.
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
if (shared_static >= 65536) {
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
}
}
if (isel.get_extern_lib_map().size() > 0) { if (isel.get_extern_lib_map().size() > 0) {
// If there's any extern lib calls, // If there's any extern lib calls,

View File

@@ -128,7 +128,7 @@ elementwise_data = {
1024 * 16: 0.0219, 1024 * 16: 0.0219,
1024 * 64: 0.0791, 1024 * 64: 0.0791,
1024 * 256: 0.243, 1024 * 256: 0.243,
1024 * 1024: 0.534, 1024 * 1024: 0.530,
1024 * 4096: 0.796, 1024 * 4096: 0.796,
1024 * 16384: 0.905, 1024 * 16384: 0.905,
1024 * 65536: 0.939, 1024 * 65536: 0.939,

View File

@@ -605,6 +605,10 @@ def test_tuples():
] ]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
n_programs = 5 n_programs = 5
# triton kernel # triton kernel
@@ -1042,6 +1046,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
if not (allow_tf32 and (dtype in ['float16']))]) if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'): def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80: if cc < 80:
if dtype == 'int8': if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80") pytest.skip("Only test int8 on devices with sm >= 80")
@@ -1227,6 +1233,10 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32 M = 32
@@ -1546,7 +1556,7 @@ def test_num_warps_pow2():
[('int32', 'libdevice.ffs', ''), [('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')]) ('float64', 'libdevice.norm4d', '')])
def test_libdevice(dtype_str, expr, lib_path): def test_libdevice_tensor(dtype_str, expr, lib_path):
@triton.jit @triton.jit
def kernel(X, Y, BLOCK: tl.constexpr): def kernel(X, Y, BLOCK: tl.constexpr):
@@ -1582,3 +1592,32 @@ def test_libdevice(dtype_str, expr, lib_path):
np.testing.assert_equal(y_ref, to_numpy(y_tri)) np.testing.assert_equal(y_ref, to_numpy(y_tri))
else: else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'libdevice.pow', '')])
def test_libdevice_scalar(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = X
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
y_ref = np.zeros(shape, dtype=x.dtype)
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref[:] = np.power(x, x)
# triton result
x_tri = to_triton(x)[0].item()
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -2,6 +2,7 @@ import pytest
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -125,6 +126,10 @@ def test_attention_fwd_bwd(
batch_size=2, batch_size=2,
n_heads=2, n_heads=2,
): ):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs # inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64) qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [ qkvs = [

View File

@@ -68,6 +68,8 @@ import triton._C.libtriton.triton as _triton
) )
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80 and DTYPE == "bfloat16": if cc < 80 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80") pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1: if DTYPE == "bfloat16" and SPLIT_K != 1:

View File

@@ -913,7 +913,10 @@ def ty_to_cpp(ty):
"i64": "int64_t", "i64": "int64_t",
"u32": "uint32_t", "u32": "uint32_t",
"u64": "uint64_t", "u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp32": "float", "fp32": "float",
"fp64": "double",
}[ty] }[ty]
@@ -943,6 +946,8 @@ def generate_launcher(identifier, constants, signature):
'i64': 'int64_t', 'i64': 'int64_t',
'u32': 'uint32_t', 'u32': 'uint32_t',
'u64': 'uint64_t', 'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float', 'fp32': 'float',
'fp64': 'double', 'fp64': 'double',
}[ty] }[ty]

View File

@@ -59,6 +59,12 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
:return: the return value of the function :return: the return value of the function
''' '''
dispatch_args = args.copy() dispatch_args = args.copy()
all_scalar = True
ret_shape = None
for dispatch_arg in dispatch_args:
if dispatch_arg.type.is_block():
all_scalar = False
if not all_scalar:
if len(args) == 1: if len(args) == 1:
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
ret_shape = dispatch_args[0].shape ret_shape = dispatch_args[0].shape

View File

@@ -58,13 +58,7 @@ def mulhi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int32, core.int32,): ("__nv_mulhi", core.int32), {(core.int32, core.int32,): ("__nv_mulhi", core.int32),
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32), (core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
}, _builder) (core.int64, core.int64,): ("__nv_mul64hi", core.int64),
@extern.extern
def mul64hi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64), (core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
}, _builder) }, _builder)
@@ -157,262 +151,138 @@ def saturatef(arg0, _builder=None):
}, _builder) }, _builder)
@extern.extern
def fmaf_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32),
}, _builder)
@extern.extern @extern.extern
def fma_rn(arg0, arg1, arg2, _builder=None): def fma_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fma_rz(arg0, arg1, arg2, _builder=None): def fma_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fma_rd(arg0, arg1, arg2, _builder=None): def fma_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fma_ru(arg0, arg1, arg2, _builder=None): def fma_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fast_fdividef(arg0, arg1, _builder=None): def fast_dividef(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32), {(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_rn(arg0, arg1, _builder=None): def div_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_rz(arg0, arg1, _builder=None): def div_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_rd(arg0, arg1, _builder=None): def div_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_ru(arg0, arg1, _builder=None): def div_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_rn(arg0, _builder=None): def rcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rn", core.float32), {(core.float32,): ("__nv_frcp_rn", core.float32),
(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_rz(arg0, _builder=None): def rcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rz", core.float32), {(core.float32,): ("__nv_frcp_rz", core.float32),
(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_rd(arg0, _builder=None): def rcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rd", core.float32), {(core.float32,): ("__nv_frcp_rd", core.float32),
(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_ru(arg0, _builder=None): def rcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_ru", core.float32), {(core.float32,): ("__nv_frcp_ru", core.float32),
(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_rn(arg0, _builder=None): def sqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rn", core.float32), {(core.float32,): ("__nv_fsqrt_rn", core.float32),
(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_rz(arg0, _builder=None): def sqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rz", core.float32), {(core.float32,): ("__nv_fsqrt_rz", core.float32),
(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_rd(arg0, _builder=None): def sqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rd", core.float32), {(core.float32,): ("__nv_fsqrt_rd", core.float32),
(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_ru(arg0, _builder=None): def sqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_ru", core.float32), {(core.float32,): ("__nv_fsqrt_ru", core.float32),
}, _builder) (core.float64,): ("__nv_dsqrt_ru", core.float64),
@extern.extern
def ddiv_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder)
@extern.extern
def ddiv_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder)
@extern.extern
def ddiv_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder)
@extern.extern
def ddiv_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder)
@extern.extern
def drcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder)
@extern.extern
def drcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder)
@extern.extern
def drcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder)
@extern.extern
def drcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder)
@extern.extern
def dsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder)
@extern.extern
def dsqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder)
@extern.extern
def dsqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder)
@extern.extern
def dsqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_ru", core.float64),
}, _builder) }, _builder)
@@ -425,114 +295,66 @@ def sqrt(arg0, _builder=None):
@extern.extern @extern.extern
def dadd_rn(arg0, arg1, _builder=None): def add_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64), {(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dadd_rz(arg0, arg1, _builder=None): def add_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64), {(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dadd_rd(arg0, arg1, _builder=None): def add_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64), {(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dadd_ru(arg0, arg1, _builder=None): def add_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64), {(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_rn(arg0, arg1, _builder=None): def mul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64), {(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_rz(arg0, arg1, _builder=None): def mul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64), {(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_rd(arg0, arg1, _builder=None): def mul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64), {(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_ru(arg0, arg1, _builder=None): def mul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64), {(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
}, _builder) (core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
@extern.extern
def fadd_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder)
@extern.extern
def fadd_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder)
@extern.extern
def fmul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder)
@extern.extern
def fmul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
}, _builder)
@extern.extern
def fadd_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder)
@extern.extern
def fadd_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder)
@extern.extern
def fmul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder)
@extern.extern
def fmul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder) }, _builder)
@@ -624,7 +446,13 @@ def double2uint_ru(arg0, _builder=None):
def int2double_rn(arg0, _builder=None): def int2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2double_rn", core.float64), {(core.int32,): ("__nv_int2double_rn", core.float64),
(core.uint32,): ("__nv_uint2double_rn", core.float64), }, _builder)
@extern.extern
def uint2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2double_rn", core.float64),
}, _builder) }, _builder)
@@ -688,7 +516,6 @@ def float2uint_ru(arg0, _builder=None):
def int2float_rn(arg0, _builder=None): def int2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rn", core.float32), {(core.int32,): ("__nv_int2float_rn", core.float32),
(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder) }, _builder)
@@ -696,7 +523,6 @@ def int2float_rn(arg0, _builder=None):
def int2float_rz(arg0, _builder=None): def int2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rz", core.float32), {(core.int32,): ("__nv_int2float_rz", core.float32),
(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder) }, _builder)
@@ -704,7 +530,6 @@ def int2float_rz(arg0, _builder=None):
def int2float_rd(arg0, _builder=None): def int2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rd", core.float32), {(core.int32,): ("__nv_int2float_rd", core.float32),
(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder) }, _builder)
@@ -712,7 +537,34 @@ def int2float_rd(arg0, _builder=None):
def int2float_ru(arg0, _builder=None): def int2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_ru", core.float32), {(core.int32,): ("__nv_int2float_ru", core.float32),
(core.uint32,): ("__nv_uint2float_ru", core.float32), }, _builder)
@extern.extern
def uint2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder)
@extern.extern
def uint2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder)
@extern.extern
def uint2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder)
@extern.extern
def uint2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_ru", core.float32),
}, _builder) }, _builder)
@@ -853,7 +705,6 @@ def double2ull_ru(arg0, _builder=None):
def ll2float_rn(arg0, _builder=None): def ll2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rn", core.float32), {(core.int64,): ("__nv_ll2float_rn", core.float32),
(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder) }, _builder)
@@ -861,7 +712,6 @@ def ll2float_rn(arg0, _builder=None):
def ll2float_rz(arg0, _builder=None): def ll2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rz", core.float32), {(core.int64,): ("__nv_ll2float_rz", core.float32),
(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder) }, _builder)
@@ -869,7 +719,6 @@ def ll2float_rz(arg0, _builder=None):
def ll2float_rd(arg0, _builder=None): def ll2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rd", core.float32), {(core.int64,): ("__nv_ll2float_rd", core.float32),
(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder) }, _builder)
@@ -877,7 +726,34 @@ def ll2float_rd(arg0, _builder=None):
def ll2float_ru(arg0, _builder=None): def ll2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_ru", core.float32), {(core.int64,): ("__nv_ll2float_ru", core.float32),
(core.uint64,): ("__nv_ull2float_ru", core.float32), }, _builder)
@extern.extern
def ull2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder)
@extern.extern
def ull2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder)
@extern.extern
def ull2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder)
@extern.extern
def ull2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_ru", core.float32),
}, _builder) }, _builder)
@@ -885,7 +761,6 @@ def ll2float_ru(arg0, _builder=None):
def ll2double_rn(arg0, _builder=None): def ll2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rn", core.float64), {(core.int64,): ("__nv_ll2double_rn", core.float64),
(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder) }, _builder)
@@ -893,7 +768,6 @@ def ll2double_rn(arg0, _builder=None):
def ll2double_rz(arg0, _builder=None): def ll2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rz", core.float64), {(core.int64,): ("__nv_ll2double_rz", core.float64),
(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder) }, _builder)
@@ -901,7 +775,6 @@ def ll2double_rz(arg0, _builder=None):
def ll2double_rd(arg0, _builder=None): def ll2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rd", core.float64), {(core.int64,): ("__nv_ll2double_rd", core.float64),
(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder) }, _builder)
@@ -909,7 +782,34 @@ def ll2double_rd(arg0, _builder=None):
def ll2double_ru(arg0, _builder=None): def ll2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_ru", core.float64), {(core.int64,): ("__nv_ll2double_ru", core.float64),
(core.uint64,): ("__nv_ull2double_ru", core.float64), }, _builder)
@extern.extern
def ull2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder)
@extern.extern
def ull2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder)
@extern.extern
def ull2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder)
@extern.extern
def ull2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_ru", core.float64),
}, _builder) }, _builder)
@@ -917,7 +817,6 @@ def ll2double_ru(arg0, _builder=None):
def int_as_float(arg0, _builder=None): def int_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int_as_float", core.float32), {(core.int32,): ("__nv_int_as_float", core.float32),
(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder) }, _builder)
@@ -928,6 +827,13 @@ def float_as_int(arg0, _builder=None):
}, _builder) }, _builder)
@extern.extern
def uint_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder)
@extern.extern @extern.extern
def float_as_uint(arg0, _builder=None): def float_as_uint(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
@@ -1006,11 +912,9 @@ def fast_log10f(arg0, _builder=None):
@extern.extern @extern.extern
def pow(arg0, arg1, _builder=None): def fast_powf(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32), {(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder) }, _builder)
@@ -1031,35 +935,39 @@ def rhadd(arg0, arg1, _builder=None):
@extern.extern @extern.extern
def fsub_rn(arg0, arg1, _builder=None): def sub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32), {(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsub_rz(arg0, arg1, _builder=None): def sub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32), {(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsub_rd(arg0, arg1, _builder=None): def sub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32), {(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsub_ru(arg0, arg1, _builder=None): def sub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32), {(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frsqrt_rn(arg0, _builder=None): def rsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frsqrt_rn", core.float32), {(core.float32,): ("__nv_frsqrt_rn", core.float32),
}, _builder) }, _builder)
@@ -1098,16 +1006,18 @@ def nearbyint(arg0, _builder=None):
@extern.extern @extern.extern
def isnanf(arg0, _builder=None): def isnan(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isnanf", core.int32), {(core.float32,): ("__nv_isnanf", core.int32),
(core.float64,): ("__nv_isnand", core.int32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def signbitf(arg0, _builder=None): def signbit(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_signbitf", core.int32), {(core.float32,): ("__nv_signbitf", core.int32),
(core.float64,): ("__nv_signbitd", core.int32),
}, _builder) }, _builder)
@@ -1127,9 +1037,10 @@ def finitef(arg0, _builder=None):
@extern.extern @extern.extern
def isinff(arg0, _builder=None): def isinf(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isinff", core.int32), {(core.float32,): ("__nv_isinff", core.int32),
(core.float64,): ("__nv_isinfd", core.int32),
}, _builder) }, _builder)
@@ -1550,10 +1461,12 @@ def fma(arg0, arg1, arg2, _builder=None):
@extern.extern @extern.extern
def powi(arg0, arg1, _builder=None): def pow(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.int32,): ("__nv_powif", core.float32), {(core.float32, core.int32,): ("__nv_powif", core.float32),
(core.float64, core.int32,): ("__nv_powi", core.float64), (core.float64, core.int32,): ("__nv_powi", core.float64),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder) }, _builder)
@@ -1605,57 +1518,8 @@ def logb(arg0, _builder=None):
}, _builder) }, _builder)
@extern.extern
def signbitd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_signbitd", core.int32),
}, _builder)
@extern.extern @extern.extern
def isfinited(arg0, _builder=None): def isfinited(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isfinited", core.int32), {(core.float64,): ("__nv_isfinited", core.int32),
}, _builder) }, _builder)
@extern.extern
def isinfd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isinfd", core.int32),
}, _builder)
@extern.extern
def isnand(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isnand", core.int32),
}, _builder)
@extern.extern
def dsub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder)
@extern.extern
def dsub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder)
@extern.extern
def dsub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder)
@extern.extern
def dsub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder)

View File

@@ -270,7 +270,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# build stub signature -- includes arguments that are specialized # build stub signature -- includes arguments that are specialized
for i, arg in constants.items(): for i, arg in constants.items():
if callable(arg): if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported") raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs) bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
if not warmup: if not warmup:

View File

@@ -115,7 +115,9 @@ def nvsmi(attrs):
return ret return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
percentiles=(0.5, 0.2, 0.8),
record_clocks=False, fast_flush=False):
""" """
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile. the 20-th and 80-th performance percentile.
@@ -130,6 +132,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
:type grad_to_none: torch.tensor, optional :type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median. :param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float] :type percentiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
""" """
# Estimate the runtime of the function # Estimate the runtime of the function
@@ -151,6 +155,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
# doesn't contain any input data before the run # doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up # Warm-up
for _ in range(n_warmup): for _ in range(n_warmup):

View File

@@ -1,10 +1,24 @@
import argparse import argparse
import subprocess import subprocess
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional
class Symbol: class Symbol:
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: _name: str
_op_name: str
_ret_type: str
_arg_names: List[str]
_arg_types: List[str]
def __init__(
self,
name: str,
op_name: str,
ret_type: str,
arg_names: List[str],
arg_types: List[str],
) -> None:
''' '''
A symbol is a function declaration. A symbol is a function declaration.
@@ -17,31 +31,31 @@ class Symbol:
self._name = name self._name = name
self._op_name = op_name self._op_name = op_name
self._ret_type = ret_type self._ret_type = ret_type
self._arg_names = arg_names self._arg_names = list(arg_names)
self._arg_types = arg_types self._arg_types = list(arg_types)
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@property @property
def op_name(self): def op_name(self) -> str:
return self._op_name return self._op_name
@property @property
def ret_type(self): def ret_type(self) -> str:
return self._ret_type return self._ret_type
@property @property
def arg_names(self): def arg_names(self) -> List[str]:
return self._arg_names return self._arg_names
@property @property
def arg_types(self): def arg_types(self) -> List[str]:
return self._arg_types return self._arg_types
def convert_type(type_str): def convert_type(type_str) -> Optional[str]:
if type_str == "i32": if type_str == "i32":
return "int32" return "int32"
elif type_str == "u32": elif type_str == "u32":
@@ -59,7 +73,7 @@ def convert_type(type_str):
return None return None
def to_unsigned(type_str): def to_unsigned(type_str) -> str:
if type_str == "int32": if type_str == "int32":
return "uint32" return "uint32"
elif type_str == "int64": elif type_str == "int64":
@@ -69,7 +83,19 @@ def to_unsigned(type_str):
class ExternLibrary(ABC): class ExternLibrary(ABC):
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: _name: str
_path: str
_symbols: Dict[str, Symbol]
_format: bool
_grouping: bool
def __init__(
self,
name: str,
path: str,
format: bool = True,
grouping: bool = True,
) -> None:
''' '''
Abstract class for extern library. Abstract class for extern library.
@@ -80,34 +106,34 @@ class ExternLibrary(ABC):
self._name = name self._name = name
self._path = path self._path = path
self._symbols = {} self._symbols = {}
self._format = True self._format = format
self._grouping = grouping self._grouping = grouping
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@property @property
def path(self): def path(self) -> str:
return self._path return self._path
@property @property
def symbols(self): def symbols(self) -> Dict[str, Symbol]:
return self._symbols return self._symbols
@property @property
def grouping(self): def grouping(self) -> bool:
return self._grouping return self._grouping
@abstractmethod @abstractmethod
def parse_symbols(self, input_file): def parse_symbols(self, input_file) -> None:
pass pass
@abstractmethod @abstractmethod
def _output_stubs(self) -> str: def _output_stubs(self) -> str:
pass pass
def generate_stub_file(self, output_dir): def generate_stub_file(self, output_dir) -> None:
file_str = self._output_stubs() file_str = self._output_stubs()
if file_str is None or len(file_str) == 0: if file_str is None or len(file_str) == 0:
raise Exception("file_str is empty") raise Exception("file_str is empty")
@@ -123,6 +149,8 @@ class ExternLibrary(ABC):
class Libdevice(ExternLibrary): class Libdevice(ExternLibrary):
_symbol_groups: Dict[str, List[Symbol]]
def __init__(self, path) -> None: def __init__(self, path) -> None:
''' '''
Constructor for Libdevice. Constructor for Libdevice.
@@ -132,7 +160,7 @@ class Libdevice(ExternLibrary):
super().__init__("libdevice", path) super().__init__("libdevice", path)
self._symbol_groups = {} self._symbol_groups = {}
def _extract_symbol(self, line): def _extract_symbol(self, line) -> Optional[Symbol]:
# Extract symbols from line in the following format: # Extract symbols from line in the following format:
# "define [internal] <ret_type> @<name>(<arg_types>,)" # "define [internal] <ret_type> @<name>(<arg_types>,)"
entries = line.split("@") entries = line.split("@")
@@ -149,6 +177,9 @@ class Libdevice(ExternLibrary):
func_strs = func_str.split("(") func_strs = func_str.split("(")
func_name = func_strs[0].replace("@", "") func_name = func_strs[0].replace("@", "")
op_name = func_name.replace("__nv_", "") op_name = func_name.replace("__nv_", "")
# To filter some interfaces unlisted in NVIDIA's official documents.
if 'ieee' in op_name:
return None
# Get arg_types # Get arg_types
arg_strs = func_strs[1].split(",") arg_strs = func_strs[1].split(",")
arg_types = [] arg_types = []
@@ -171,66 +202,77 @@ class Libdevice(ExternLibrary):
arg_types[i] = to_unsigned(arg_type) arg_types[i] = to_unsigned(arg_type)
return Symbol(func_name, op_name, ret_type, arg_names, arg_types) return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
def _group_symbols(self): def _group_symbols(self) -> None:
symbol_set = {} symbol_set = {}
for symbol in self._symbols.values(): for symbol in self._symbols.values():
op_name = symbol.op_name op_name = symbol.op_name
symbol_set[op_name] = symbol symbol_set[op_name] = symbol
# The following cases are grouped together:
# op_name, <u/ull/ll>op_name<ll/f/i> # Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
}
for symbol in self._symbols.values(): for symbol in self._symbols.values():
op_name = symbol.op_name op_name = symbol.op_name
if "max" in op_name: if op_name in renaming:
op_name = "max" op_name = renaming[op_name]
elif "min" in op_name:
op_name = "min"
elif "abs" in op_name:
op_name = "abs"
elif "pow" in op_name and "fast" in op_name:
op_name = "pow"
elif "round" in op_name:
if "llround" in op_name:
op_name = "llround"
else:
op_name = "round"
elif "rint" in op_name:
if "llrint" in op_name:
op_name = "llrint"
else:
op_name = "rint"
elif op_name.startswith("ull"):
if "2" not in op_name:
# e.g., ullmax->max
op_name = op_name[3:]
else:
# e.g., ull2double->ll2double
op_name = op_name[1:]
elif op_name.startswith("u"):
if "2" not in op_name:
# e.g., uhadd->hadd
op_name = op_name[1:]
else:
# e.g., uint2double_rn->int2double_rn
op_name = op_name[1:]
elif op_name.startswith("ll"):
if "2" not in op_name:
# e.g., llmax->max
op_name = op_name[2:]
elif op_name.endswith("ll"):
op_name = op_name[:-2]
elif op_name.endswith("f"):
op_name = op_name[:-1]
if op_name in symbol_set:
# Update op_name only if there's an existing symbol
symbol._op_name = op_name symbol._op_name = op_name
else:
op_name = symbol._op_name
if op_name in self._symbol_groups: if op_name in self._symbol_groups:
self._symbol_groups[op_name].append(symbol) self._symbol_groups[op_name].append(symbol)
else: else:
self._symbol_groups[op_name] = [symbol] self._symbol_groups[op_name] = [symbol]
def parse_symbols(self, input_file): def parse_symbols(self, input_file) -> None:
if len(self.symbols) > 0: if len(self.symbols) > 0:
return return
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
@@ -242,7 +284,7 @@ class Libdevice(ExternLibrary):
self._group_symbols() self._group_symbols()
def _output_stubs(self): def _output_stubs(self) -> str:
# Generate python functions in the following format: # Generate python functions in the following format:
# @extern.extern # @extern.extern
# def <op_name>(<args>, _builder=None): # def <op_name>(<args>, _builder=None):
@@ -250,7 +292,7 @@ class Libdevice(ExternLibrary):
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder) # return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n" import_str = "from . import core, extern\n"
import_str += "import os\n" import_str += "import os\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
func_str = "" func_str = ""
for symbols in self._symbol_groups.values(): for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n" func_str += "@extern.extern\n"
@@ -283,7 +325,10 @@ class Libdevice(ExternLibrary):
class LLVMDisassembler: class LLVMDisassembler:
def __init__(self, path): _path: str
_ll_file: str
def __init__(self, path) -> None:
''' '''
Invoke llvm-dis to disassemble the given file. Invoke llvm-dis to disassemble the given file.
@@ -292,23 +337,28 @@ class LLVMDisassembler:
self._path = path self._path = path
self._ll_file = "/tmp/extern_lib.ll" self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path): def disasm(self, lib_path: str) -> None:
subprocess.Popen([self._path, lib_path, "-o", self.ll_file], subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate() stdout=subprocess.PIPE).communicate()
@property @property
def ll_file(self): def ll_file(self) -> str:
return self._ll_file return self._ll_file
@property @property
def path(self): def path(self) -> str:
return self._path return self._path
extern_libs = ["libdevice"] extern_libs = ["libdevice"]
def build(llvm_dis_path, lib_path, lib_name, output_dir): def build(
llvm_dis_path: str,
lib_path: str,
lib_name: str,
output_dir: str,
) -> None:
''' '''
Interface function to build the library file. Interface function to build the library file.
@@ -331,10 +381,10 @@ def build(llvm_dis_path, lib_path, lib_name, output_dir):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis") parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library") parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library") parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/") parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
args = parser.parse_args() args = parser.parse_args()
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)

View File

@@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`. Using triton, you can simply call `tl.libdevice.asin`.
triton automatically selects the correct underlying device function to invoke based on input and output types. triton automatically selects the correct underlying device function to invoke based on input and output types.
""" """