Merge branch 'master' into keren/improve-hook
This commit is contained in:
@@ -33,6 +33,15 @@ And the latest nightly release:
|
|||||||
pip install -U --pre triton
|
pip install -U --pre triton
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Install from source
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/openai/triton.git;
|
||||||
|
cd triton/python;
|
||||||
|
pip install cmake; # build time dependency
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
Version 1.1 is out! New features include:
|
Version 1.1 is out! New features include:
|
||||||
|
@@ -149,6 +149,13 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
|||||||
// ir.print(std::cout);
|
// ir.print(std::cout);
|
||||||
isel.visit(ir, *llvm);
|
isel.visit(ir, *llvm);
|
||||||
shared_static = allocation.allocated_size();
|
shared_static = allocation.allocated_size();
|
||||||
|
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
|
||||||
|
// sm < 70 (Pascal) has little shared memory resource.
|
||||||
|
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
|
||||||
|
if (shared_static >= 65536) {
|
||||||
|
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (isel.get_extern_lib_map().size() > 0) {
|
if (isel.get_extern_lib_map().size() > 0) {
|
||||||
// If there's any extern lib calls,
|
// If there's any extern lib calls,
|
||||||
|
@@ -128,7 +128,7 @@ elementwise_data = {
|
|||||||
1024 * 16: 0.0219,
|
1024 * 16: 0.0219,
|
||||||
1024 * 64: 0.0791,
|
1024 * 64: 0.0791,
|
||||||
1024 * 256: 0.243,
|
1024 * 256: 0.243,
|
||||||
1024 * 1024: 0.534,
|
1024 * 1024: 0.530,
|
||||||
1024 * 4096: 0.796,
|
1024 * 4096: 0.796,
|
||||||
1024 * 16384: 0.905,
|
1024 * 16384: 0.905,
|
||||||
1024 * 65536: 0.939,
|
1024 * 65536: 0.939,
|
||||||
|
@@ -605,6 +605,10 @@ def test_tuples():
|
|||||||
]
|
]
|
||||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||||
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 70:
|
||||||
|
if dtype_x_str == 'float16':
|
||||||
|
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
|
||||||
n_programs = 5
|
n_programs = 5
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@@ -1042,6 +1046,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
if not (allow_tf32 and (dtype in ['float16']))])
|
if not (allow_tf32 and (dtype in ['float16']))])
|
||||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 70:
|
||||||
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||||
if cc < 80:
|
if cc < 80:
|
||||||
if dtype == 'int8':
|
if dtype == 'int8':
|
||||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||||
@@ -1227,6 +1233,10 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||||
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 70:
|
||||||
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||||
|
|
||||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||||
|
|
||||||
M = 32
|
M = 32
|
||||||
@@ -1546,7 +1556,7 @@ def test_num_warps_pow2():
|
|||||||
[('int32', 'libdevice.ffs', ''),
|
[('int32', 'libdevice.ffs', ''),
|
||||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||||
('float64', 'libdevice.norm4d', '')])
|
('float64', 'libdevice.norm4d', '')])
|
||||||
def test_libdevice(dtype_str, expr, lib_path):
|
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
@@ -1582,3 +1592,32 @@ def test_libdevice(dtype_str, expr, lib_path):
|
|||||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||||
else:
|
else:
|
||||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
|
[('float32', 'libdevice.pow', '')])
|
||||||
|
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = X
|
||||||
|
y = GENERATE_TEST_HERE
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
rs = RandomState(17)
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
|
||||||
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||||
|
|
||||||
|
# numpy does not allow negative factors in power, so we use abs()
|
||||||
|
x = np.abs(x)
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||||
|
y_ref[:] = np.power(x, x)
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
x_tri = to_triton(x)[0].item()
|
||||||
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||||
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||||
|
# compare
|
||||||
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
@@ -2,6 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||||
@@ -125,6 +126,10 @@ def test_attention_fwd_bwd(
|
|||||||
batch_size=2,
|
batch_size=2,
|
||||||
n_heads=2,
|
n_heads=2,
|
||||||
):
|
):
|
||||||
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 70:
|
||||||
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||||
|
|
||||||
# inputs
|
# inputs
|
||||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||||
qkvs = [
|
qkvs = [
|
||||||
|
@@ -68,6 +68,8 @@ import triton._C.libtriton.triton as _triton
|
|||||||
)
|
)
|
||||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 70:
|
||||||
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||||
if cc < 80 and DTYPE == "bfloat16":
|
if cc < 80 and DTYPE == "bfloat16":
|
||||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||||
|
@@ -913,7 +913,10 @@ def ty_to_cpp(ty):
|
|||||||
"i64": "int64_t",
|
"i64": "int64_t",
|
||||||
"u32": "uint32_t",
|
"u32": "uint32_t",
|
||||||
"u64": "uint64_t",
|
"u64": "uint64_t",
|
||||||
|
"fp16": "float",
|
||||||
|
"bf16": "float",
|
||||||
"fp32": "float",
|
"fp32": "float",
|
||||||
|
"fp64": "double",
|
||||||
}[ty]
|
}[ty]
|
||||||
|
|
||||||
|
|
||||||
@@ -943,6 +946,8 @@ def generate_launcher(identifier, constants, signature):
|
|||||||
'i64': 'int64_t',
|
'i64': 'int64_t',
|
||||||
'u32': 'uint32_t',
|
'u32': 'uint32_t',
|
||||||
'u64': 'uint64_t',
|
'u64': 'uint64_t',
|
||||||
|
'fp16': 'float',
|
||||||
|
'bf16': 'float',
|
||||||
'fp32': 'float',
|
'fp32': 'float',
|
||||||
'fp64': 'double',
|
'fp64': 'double',
|
||||||
}[ty]
|
}[ty]
|
||||||
|
@@ -59,6 +59,12 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
|||||||
:return: the return value of the function
|
:return: the return value of the function
|
||||||
'''
|
'''
|
||||||
dispatch_args = args.copy()
|
dispatch_args = args.copy()
|
||||||
|
all_scalar = True
|
||||||
|
ret_shape = None
|
||||||
|
for dispatch_arg in dispatch_args:
|
||||||
|
if dispatch_arg.type.is_block():
|
||||||
|
all_scalar = False
|
||||||
|
if not all_scalar:
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
ret_shape = dispatch_args[0].shape
|
ret_shape = dispatch_args[0].shape
|
||||||
|
@@ -58,13 +58,7 @@ def mulhi(arg0, arg1, _builder=None):
|
|||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.int32, core.int32,): ("__nv_mulhi", core.int32),
|
{(core.int32, core.int32,): ("__nv_mulhi", core.int32),
|
||||||
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
|
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
|
||||||
}, _builder)
|
(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def mul64hi(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
|
|
||||||
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
|
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
@@ -157,262 +151,138 @@ def saturatef(arg0, _builder=None):
|
|||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_rn(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_rz(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_rd(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_ru(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
|
||||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fma_rn(arg0, arg1, arg2, _builder=None):
|
def fma_rn(arg0, arg1, arg2, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
|
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
|
||||||
|
(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fma_rz(arg0, arg1, arg2, _builder=None):
|
def fma_rz(arg0, arg1, arg2, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
|
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
|
||||||
|
(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fma_rd(arg0, arg1, arg2, _builder=None):
|
def fma_rd(arg0, arg1, arg2, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
|
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
|
||||||
|
(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fma_ru(arg0, arg1, arg2, _builder=None):
|
def fma_ru(arg0, arg1, arg2, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
|
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
|
||||||
|
(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fast_fdividef(arg0, arg1, _builder=None):
|
def fast_dividef(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
|
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fdiv_rn(arg0, arg1, _builder=None):
|
def div_rn(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
|
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fdiv_rz(arg0, arg1, _builder=None):
|
def div_rz(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
|
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fdiv_rd(arg0, arg1, _builder=None):
|
def div_rd(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
|
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fdiv_ru(arg0, arg1, _builder=None):
|
def div_ru(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
|
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def frcp_rn(arg0, _builder=None):
|
def rcp_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_frcp_rn", core.float32),
|
{(core.float32,): ("__nv_frcp_rn", core.float32),
|
||||||
|
(core.float64,): ("__nv_drcp_rn", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def frcp_rz(arg0, _builder=None):
|
def rcp_rz(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_frcp_rz", core.float32),
|
{(core.float32,): ("__nv_frcp_rz", core.float32),
|
||||||
|
(core.float64,): ("__nv_drcp_rz", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def frcp_rd(arg0, _builder=None):
|
def rcp_rd(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_frcp_rd", core.float32),
|
{(core.float32,): ("__nv_frcp_rd", core.float32),
|
||||||
|
(core.float64,): ("__nv_drcp_rd", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def frcp_ru(arg0, _builder=None):
|
def rcp_ru(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_frcp_ru", core.float32),
|
{(core.float32,): ("__nv_frcp_ru", core.float32),
|
||||||
|
(core.float64,): ("__nv_drcp_ru", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsqrt_rn(arg0, _builder=None):
|
def sqrt_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_fsqrt_rn", core.float32),
|
{(core.float32,): ("__nv_fsqrt_rn", core.float32),
|
||||||
|
(core.float64,): ("__nv_dsqrt_rn", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsqrt_rz(arg0, _builder=None):
|
def sqrt_rz(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_fsqrt_rz", core.float32),
|
{(core.float32,): ("__nv_fsqrt_rz", core.float32),
|
||||||
|
(core.float64,): ("__nv_dsqrt_rz", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsqrt_rd(arg0, _builder=None):
|
def sqrt_rd(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_fsqrt_rd", core.float32),
|
{(core.float32,): ("__nv_fsqrt_rd", core.float32),
|
||||||
|
(core.float64,): ("__nv_dsqrt_rd", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsqrt_ru(arg0, _builder=None):
|
def sqrt_ru(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_fsqrt_ru", core.float32),
|
{(core.float32,): ("__nv_fsqrt_ru", core.float32),
|
||||||
}, _builder)
|
(core.float64,): ("__nv_dsqrt_ru", core.float64),
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def ddiv_rn(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def ddiv_rz(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def ddiv_rd(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def ddiv_ru(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def drcp_rn(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_drcp_rn", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def drcp_rz(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_drcp_rz", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def drcp_rd(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_drcp_rd", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def drcp_ru(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_drcp_ru", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsqrt_rn(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_dsqrt_rn", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsqrt_rz(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_dsqrt_rz", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsqrt_rd(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_dsqrt_rd", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsqrt_ru(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_dsqrt_ru", core.float64),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -425,114 +295,66 @@ def sqrt(arg0, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dadd_rn(arg0, arg1, _builder=None):
|
def add_rn(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
|
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dadd_rz(arg0, arg1, _builder=None):
|
def add_rz(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
|
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dadd_rd(arg0, arg1, _builder=None):
|
def add_rd(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
|
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dadd_ru(arg0, arg1, _builder=None):
|
def add_ru(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
|
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dmul_rn(arg0, arg1, _builder=None):
|
def mul_rn(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
|
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dmul_rz(arg0, arg1, _builder=None):
|
def mul_rz(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
|
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dmul_rd(arg0, arg1, _builder=None):
|
def mul_rd(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
|
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def dmul_ru(arg0, arg1, _builder=None):
|
def mul_ru(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
|
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
|
||||||
}, _builder)
|
(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fadd_rd(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fadd_ru(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmul_rd(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmul_ru(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fadd_rn(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fadd_rz(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmul_rn(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def fmul_rz(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -624,7 +446,13 @@ def double2uint_ru(arg0, _builder=None):
|
|||||||
def int2double_rn(arg0, _builder=None):
|
def int2double_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int32,): ("__nv_int2double_rn", core.float64),
|
{(core.int32,): ("__nv_int2double_rn", core.float64),
|
||||||
(core.uint32,): ("__nv_uint2double_rn", core.float64),
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def uint2double_rn(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint32,): ("__nv_uint2double_rn", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -688,7 +516,6 @@ def float2uint_ru(arg0, _builder=None):
|
|||||||
def int2float_rn(arg0, _builder=None):
|
def int2float_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int32,): ("__nv_int2float_rn", core.float32),
|
{(core.int32,): ("__nv_int2float_rn", core.float32),
|
||||||
(core.uint32,): ("__nv_uint2float_rn", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -696,7 +523,6 @@ def int2float_rn(arg0, _builder=None):
|
|||||||
def int2float_rz(arg0, _builder=None):
|
def int2float_rz(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int32,): ("__nv_int2float_rz", core.float32),
|
{(core.int32,): ("__nv_int2float_rz", core.float32),
|
||||||
(core.uint32,): ("__nv_uint2float_rz", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -704,7 +530,6 @@ def int2float_rz(arg0, _builder=None):
|
|||||||
def int2float_rd(arg0, _builder=None):
|
def int2float_rd(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int32,): ("__nv_int2float_rd", core.float32),
|
{(core.int32,): ("__nv_int2float_rd", core.float32),
|
||||||
(core.uint32,): ("__nv_uint2float_rd", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -712,7 +537,34 @@ def int2float_rd(arg0, _builder=None):
|
|||||||
def int2float_ru(arg0, _builder=None):
|
def int2float_ru(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int32,): ("__nv_int2float_ru", core.float32),
|
{(core.int32,): ("__nv_int2float_ru", core.float32),
|
||||||
(core.uint32,): ("__nv_uint2float_ru", core.float32),
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def uint2float_rn(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint32,): ("__nv_uint2float_rn", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def uint2float_rz(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint32,): ("__nv_uint2float_rz", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def uint2float_rd(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint32,): ("__nv_uint2float_rd", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def uint2float_ru(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint32,): ("__nv_uint2float_ru", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -853,7 +705,6 @@ def double2ull_ru(arg0, _builder=None):
|
|||||||
def ll2float_rn(arg0, _builder=None):
|
def ll2float_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2float_rn", core.float32),
|
{(core.int64,): ("__nv_ll2float_rn", core.float32),
|
||||||
(core.uint64,): ("__nv_ull2float_rn", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -861,7 +712,6 @@ def ll2float_rn(arg0, _builder=None):
|
|||||||
def ll2float_rz(arg0, _builder=None):
|
def ll2float_rz(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2float_rz", core.float32),
|
{(core.int64,): ("__nv_ll2float_rz", core.float32),
|
||||||
(core.uint64,): ("__nv_ull2float_rz", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -869,7 +719,6 @@ def ll2float_rz(arg0, _builder=None):
|
|||||||
def ll2float_rd(arg0, _builder=None):
|
def ll2float_rd(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2float_rd", core.float32),
|
{(core.int64,): ("__nv_ll2float_rd", core.float32),
|
||||||
(core.uint64,): ("__nv_ull2float_rd", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -877,7 +726,34 @@ def ll2float_rd(arg0, _builder=None):
|
|||||||
def ll2float_ru(arg0, _builder=None):
|
def ll2float_ru(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2float_ru", core.float32),
|
{(core.int64,): ("__nv_ll2float_ru", core.float32),
|
||||||
(core.uint64,): ("__nv_ull2float_ru", core.float32),
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2float_rn(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2float_rn", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2float_rz(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2float_rz", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2float_rd(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2float_rd", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2float_ru(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2float_ru", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -885,7 +761,6 @@ def ll2float_ru(arg0, _builder=None):
|
|||||||
def ll2double_rn(arg0, _builder=None):
|
def ll2double_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2double_rn", core.float64),
|
{(core.int64,): ("__nv_ll2double_rn", core.float64),
|
||||||
(core.uint64,): ("__nv_ull2double_rn", core.float64),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -893,7 +768,6 @@ def ll2double_rn(arg0, _builder=None):
|
|||||||
def ll2double_rz(arg0, _builder=None):
|
def ll2double_rz(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2double_rz", core.float64),
|
{(core.int64,): ("__nv_ll2double_rz", core.float64),
|
||||||
(core.uint64,): ("__nv_ull2double_rz", core.float64),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -901,7 +775,6 @@ def ll2double_rz(arg0, _builder=None):
|
|||||||
def ll2double_rd(arg0, _builder=None):
|
def ll2double_rd(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2double_rd", core.float64),
|
{(core.int64,): ("__nv_ll2double_rd", core.float64),
|
||||||
(core.uint64,): ("__nv_ull2double_rd", core.float64),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -909,7 +782,34 @@ def ll2double_rd(arg0, _builder=None):
|
|||||||
def ll2double_ru(arg0, _builder=None):
|
def ll2double_ru(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int64,): ("__nv_ll2double_ru", core.float64),
|
{(core.int64,): ("__nv_ll2double_ru", core.float64),
|
||||||
(core.uint64,): ("__nv_ull2double_ru", core.float64),
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2double_rn(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2double_rn", core.float64),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2double_rz(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2double_rz", core.float64),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2double_rd(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2double_rd", core.float64),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def ull2double_ru(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint64,): ("__nv_ull2double_ru", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -917,7 +817,6 @@ def ll2double_ru(arg0, _builder=None):
|
|||||||
def int_as_float(arg0, _builder=None):
|
def int_as_float(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.int32,): ("__nv_int_as_float", core.float32),
|
{(core.int32,): ("__nv_int_as_float", core.float32),
|
||||||
(core.uint32,): ("__nv_uint_as_float", core.float32),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -928,6 +827,13 @@ def float_as_int(arg0, _builder=None):
|
|||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@extern.extern
|
||||||
|
def uint_as_float(arg0, _builder=None):
|
||||||
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
|
{(core.uint32,): ("__nv_uint_as_float", core.float32),
|
||||||
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def float_as_uint(arg0, _builder=None):
|
def float_as_uint(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
@@ -1006,11 +912,9 @@ def fast_log10f(arg0, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def pow(arg0, arg1, _builder=None):
|
def fast_powf(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
|
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
|
||||||
(core.float32, core.float32,): ("__nv_powf", core.float32),
|
|
||||||
(core.float64, core.float64,): ("__nv_pow", core.float64),
|
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -1031,35 +935,39 @@ def rhadd(arg0, arg1, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsub_rn(arg0, arg1, _builder=None):
|
def sub_rn(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
|
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsub_rz(arg0, arg1, _builder=None):
|
def sub_rz(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
|
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsub_rd(arg0, arg1, _builder=None):
|
def sub_rd(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
|
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def fsub_ru(arg0, arg1, _builder=None):
|
def sub_ru(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
|
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def frsqrt_rn(arg0, _builder=None):
|
def rsqrt_rn(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_frsqrt_rn", core.float32),
|
{(core.float32,): ("__nv_frsqrt_rn", core.float32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
@@ -1098,16 +1006,18 @@ def nearbyint(arg0, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def isnanf(arg0, _builder=None):
|
def isnan(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_isnanf", core.int32),
|
{(core.float32,): ("__nv_isnanf", core.int32),
|
||||||
|
(core.float64,): ("__nv_isnand", core.int32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def signbitf(arg0, _builder=None):
|
def signbit(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_signbitf", core.int32),
|
{(core.float32,): ("__nv_signbitf", core.int32),
|
||||||
|
(core.float64,): ("__nv_signbitd", core.int32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -1127,9 +1037,10 @@ def finitef(arg0, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def isinff(arg0, _builder=None):
|
def isinf(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float32,): ("__nv_isinff", core.int32),
|
{(core.float32,): ("__nv_isinff", core.int32),
|
||||||
|
(core.float64,): ("__nv_isinfd", core.int32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -1550,10 +1461,12 @@ def fma(arg0, arg1, arg2, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def powi(arg0, arg1, _builder=None):
|
def pow(arg0, arg1, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||||
{(core.float32, core.int32,): ("__nv_powif", core.float32),
|
{(core.float32, core.int32,): ("__nv_powif", core.float32),
|
||||||
(core.float64, core.int32,): ("__nv_powi", core.float64),
|
(core.float64, core.int32,): ("__nv_powi", core.float64),
|
||||||
|
(core.float32, core.float32,): ("__nv_powf", core.float32),
|
||||||
|
(core.float64, core.float64,): ("__nv_pow", core.float64),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -1605,57 +1518,8 @@ def logb(arg0, _builder=None):
|
|||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def signbitd(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_signbitd", core.int32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
@extern.extern
|
||||||
def isfinited(arg0, _builder=None):
|
def isfinited(arg0, _builder=None):
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||||
{(core.float64,): ("__nv_isfinited", core.int32),
|
{(core.float64,): ("__nv_isfinited", core.int32),
|
||||||
}, _builder)
|
}, _builder)
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def isinfd(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_isinfd", core.int32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def isnand(arg0, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
|
||||||
{(core.float64,): ("__nv_isnand", core.int32),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsub_rn(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsub_rz(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsub_ru(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
|
||||||
|
|
||||||
@extern.extern
|
|
||||||
def dsub_rd(arg0, arg1, _builder=None):
|
|
||||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
|
||||||
{(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
|
|
||||||
}, _builder)
|
|
||||||
|
@@ -270,7 +270,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
# build stub signature -- includes arguments that are specialized
|
# build stub signature -- includes arguments that are specialized
|
||||||
for i, arg in constants.items():
|
for i, arg in constants.items():
|
||||||
if callable(arg):
|
if callable(arg):
|
||||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
if not warmup:
|
if not warmup:
|
||||||
|
@@ -115,7 +115,9 @@ def nvsmi(attrs):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||||
|
percentiles=(0.5, 0.2, 0.8),
|
||||||
|
record_clocks=False, fast_flush=False):
|
||||||
"""
|
"""
|
||||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||||
the 20-th and 80-th performance percentile.
|
the 20-th and 80-th performance percentile.
|
||||||
@@ -130,6 +132,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
|
|||||||
:type grad_to_none: torch.tensor, optional
|
:type grad_to_none: torch.tensor, optional
|
||||||
:param percentiles: Performance percentile to return in addition to the median.
|
:param percentiles: Performance percentile to return in addition to the median.
|
||||||
:type percentiles: list[float]
|
:type percentiles: list[float]
|
||||||
|
:param fast_flush: Use faster kernel to flush L2 between measurements
|
||||||
|
:type fast_flush: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Estimate the runtime of the function
|
# Estimate the runtime of the function
|
||||||
@@ -151,6 +155,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
|
|||||||
# doesn't contain any input data before the run
|
# doesn't contain any input data before the run
|
||||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||||
|
if fast_flush:
|
||||||
|
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||||
|
else:
|
||||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||||
# Warm-up
|
# Warm-up
|
||||||
for _ in range(n_warmup):
|
for _ in range(n_warmup):
|
||||||
|
@@ -1,10 +1,24 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class Symbol:
|
class Symbol:
|
||||||
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
|
_name: str
|
||||||
|
_op_name: str
|
||||||
|
_ret_type: str
|
||||||
|
_arg_names: List[str]
|
||||||
|
_arg_types: List[str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
op_name: str,
|
||||||
|
ret_type: str,
|
||||||
|
arg_names: List[str],
|
||||||
|
arg_types: List[str],
|
||||||
|
) -> None:
|
||||||
'''
|
'''
|
||||||
A symbol is a function declaration.
|
A symbol is a function declaration.
|
||||||
|
|
||||||
@@ -17,31 +31,31 @@ class Symbol:
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._op_name = op_name
|
self._op_name = op_name
|
||||||
self._ret_type = ret_type
|
self._ret_type = ret_type
|
||||||
self._arg_names = arg_names
|
self._arg_names = list(arg_names)
|
||||||
self._arg_types = arg_types
|
self._arg_types = list(arg_types)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def op_name(self):
|
def op_name(self) -> str:
|
||||||
return self._op_name
|
return self._op_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ret_type(self):
|
def ret_type(self) -> str:
|
||||||
return self._ret_type
|
return self._ret_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arg_names(self):
|
def arg_names(self) -> List[str]:
|
||||||
return self._arg_names
|
return self._arg_names
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arg_types(self):
|
def arg_types(self) -> List[str]:
|
||||||
return self._arg_types
|
return self._arg_types
|
||||||
|
|
||||||
|
|
||||||
def convert_type(type_str):
|
def convert_type(type_str) -> Optional[str]:
|
||||||
if type_str == "i32":
|
if type_str == "i32":
|
||||||
return "int32"
|
return "int32"
|
||||||
elif type_str == "u32":
|
elif type_str == "u32":
|
||||||
@@ -59,7 +73,7 @@ def convert_type(type_str):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def to_unsigned(type_str):
|
def to_unsigned(type_str) -> str:
|
||||||
if type_str == "int32":
|
if type_str == "int32":
|
||||||
return "uint32"
|
return "uint32"
|
||||||
elif type_str == "int64":
|
elif type_str == "int64":
|
||||||
@@ -69,7 +83,19 @@ def to_unsigned(type_str):
|
|||||||
|
|
||||||
|
|
||||||
class ExternLibrary(ABC):
|
class ExternLibrary(ABC):
|
||||||
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
|
_name: str
|
||||||
|
_path: str
|
||||||
|
_symbols: Dict[str, Symbol]
|
||||||
|
_format: bool
|
||||||
|
_grouping: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
format: bool = True,
|
||||||
|
grouping: bool = True,
|
||||||
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Abstract class for extern library.
|
Abstract class for extern library.
|
||||||
|
|
||||||
@@ -80,34 +106,34 @@ class ExternLibrary(ABC):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._path = path
|
self._path = path
|
||||||
self._symbols = {}
|
self._symbols = {}
|
||||||
self._format = True
|
self._format = format
|
||||||
self._grouping = grouping
|
self._grouping = grouping
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self) -> str:
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def symbols(self):
|
def symbols(self) -> Dict[str, Symbol]:
|
||||||
return self._symbols
|
return self._symbols
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def grouping(self):
|
def grouping(self) -> bool:
|
||||||
return self._grouping
|
return self._grouping
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_symbols(self, input_file):
|
def parse_symbols(self, input_file) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _output_stubs(self) -> str:
|
def _output_stubs(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def generate_stub_file(self, output_dir):
|
def generate_stub_file(self, output_dir) -> None:
|
||||||
file_str = self._output_stubs()
|
file_str = self._output_stubs()
|
||||||
if file_str is None or len(file_str) == 0:
|
if file_str is None or len(file_str) == 0:
|
||||||
raise Exception("file_str is empty")
|
raise Exception("file_str is empty")
|
||||||
@@ -123,6 +149,8 @@ class ExternLibrary(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class Libdevice(ExternLibrary):
|
class Libdevice(ExternLibrary):
|
||||||
|
_symbol_groups: Dict[str, List[Symbol]]
|
||||||
|
|
||||||
def __init__(self, path) -> None:
|
def __init__(self, path) -> None:
|
||||||
'''
|
'''
|
||||||
Constructor for Libdevice.
|
Constructor for Libdevice.
|
||||||
@@ -132,7 +160,7 @@ class Libdevice(ExternLibrary):
|
|||||||
super().__init__("libdevice", path)
|
super().__init__("libdevice", path)
|
||||||
self._symbol_groups = {}
|
self._symbol_groups = {}
|
||||||
|
|
||||||
def _extract_symbol(self, line):
|
def _extract_symbol(self, line) -> Optional[Symbol]:
|
||||||
# Extract symbols from line in the following format:
|
# Extract symbols from line in the following format:
|
||||||
# "define [internal] <ret_type> @<name>(<arg_types>,)"
|
# "define [internal] <ret_type> @<name>(<arg_types>,)"
|
||||||
entries = line.split("@")
|
entries = line.split("@")
|
||||||
@@ -149,6 +177,9 @@ class Libdevice(ExternLibrary):
|
|||||||
func_strs = func_str.split("(")
|
func_strs = func_str.split("(")
|
||||||
func_name = func_strs[0].replace("@", "")
|
func_name = func_strs[0].replace("@", "")
|
||||||
op_name = func_name.replace("__nv_", "")
|
op_name = func_name.replace("__nv_", "")
|
||||||
|
# To filter some interfaces unlisted in NVIDIA's official documents.
|
||||||
|
if 'ieee' in op_name:
|
||||||
|
return None
|
||||||
# Get arg_types
|
# Get arg_types
|
||||||
arg_strs = func_strs[1].split(",")
|
arg_strs = func_strs[1].split(",")
|
||||||
arg_types = []
|
arg_types = []
|
||||||
@@ -171,66 +202,77 @@ class Libdevice(ExternLibrary):
|
|||||||
arg_types[i] = to_unsigned(arg_type)
|
arg_types[i] = to_unsigned(arg_type)
|
||||||
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
|
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
|
||||||
|
|
||||||
def _group_symbols(self):
|
def _group_symbols(self) -> None:
|
||||||
symbol_set = {}
|
symbol_set = {}
|
||||||
for symbol in self._symbols.values():
|
for symbol in self._symbols.values():
|
||||||
op_name = symbol.op_name
|
op_name = symbol.op_name
|
||||||
symbol_set[op_name] = symbol
|
symbol_set[op_name] = symbol
|
||||||
# The following cases are grouped together:
|
|
||||||
# op_name, <u/ull/ll>op_name<ll/f/i>
|
# Group functions together by renaming.
|
||||||
|
renaming = {
|
||||||
|
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
|
||||||
|
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
|
||||||
|
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
|
||||||
|
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
|
||||||
|
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
|
||||||
|
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
|
||||||
|
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
|
||||||
|
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
|
||||||
|
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||||
|
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
|
||||||
|
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
|
||||||
|
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
|
||||||
|
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
|
||||||
|
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
|
||||||
|
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
|
||||||
|
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
|
||||||
|
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
|
||||||
|
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
|
||||||
|
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
|
||||||
|
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
|
||||||
|
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
|
||||||
|
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||||
|
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
|
||||||
|
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
|
||||||
|
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
|
||||||
|
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
|
||||||
|
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
|
||||||
|
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
|
||||||
|
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
|
||||||
|
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
|
||||||
|
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
|
||||||
|
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
|
||||||
|
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
|
||||||
|
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
|
||||||
|
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
|
||||||
|
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
|
||||||
|
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
|
||||||
|
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||||
|
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
|
||||||
|
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
|
||||||
|
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
|
||||||
|
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
|
||||||
|
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
|
||||||
|
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
|
||||||
|
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
|
||||||
|
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
|
||||||
|
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
|
||||||
|
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||||
|
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
|
||||||
|
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
|
||||||
|
}
|
||||||
|
|
||||||
for symbol in self._symbols.values():
|
for symbol in self._symbols.values():
|
||||||
op_name = symbol.op_name
|
op_name = symbol.op_name
|
||||||
if "max" in op_name:
|
if op_name in renaming:
|
||||||
op_name = "max"
|
op_name = renaming[op_name]
|
||||||
elif "min" in op_name:
|
|
||||||
op_name = "min"
|
|
||||||
elif "abs" in op_name:
|
|
||||||
op_name = "abs"
|
|
||||||
elif "pow" in op_name and "fast" in op_name:
|
|
||||||
op_name = "pow"
|
|
||||||
elif "round" in op_name:
|
|
||||||
if "llround" in op_name:
|
|
||||||
op_name = "llround"
|
|
||||||
else:
|
|
||||||
op_name = "round"
|
|
||||||
elif "rint" in op_name:
|
|
||||||
if "llrint" in op_name:
|
|
||||||
op_name = "llrint"
|
|
||||||
else:
|
|
||||||
op_name = "rint"
|
|
||||||
elif op_name.startswith("ull"):
|
|
||||||
if "2" not in op_name:
|
|
||||||
# e.g., ullmax->max
|
|
||||||
op_name = op_name[3:]
|
|
||||||
else:
|
|
||||||
# e.g., ull2double->ll2double
|
|
||||||
op_name = op_name[1:]
|
|
||||||
elif op_name.startswith("u"):
|
|
||||||
if "2" not in op_name:
|
|
||||||
# e.g., uhadd->hadd
|
|
||||||
op_name = op_name[1:]
|
|
||||||
else:
|
|
||||||
# e.g., uint2double_rn->int2double_rn
|
|
||||||
op_name = op_name[1:]
|
|
||||||
elif op_name.startswith("ll"):
|
|
||||||
if "2" not in op_name:
|
|
||||||
# e.g., llmax->max
|
|
||||||
op_name = op_name[2:]
|
|
||||||
elif op_name.endswith("ll"):
|
|
||||||
op_name = op_name[:-2]
|
|
||||||
elif op_name.endswith("f"):
|
|
||||||
op_name = op_name[:-1]
|
|
||||||
if op_name in symbol_set:
|
|
||||||
# Update op_name only if there's an existing symbol
|
|
||||||
symbol._op_name = op_name
|
symbol._op_name = op_name
|
||||||
else:
|
|
||||||
op_name = symbol._op_name
|
|
||||||
if op_name in self._symbol_groups:
|
if op_name in self._symbol_groups:
|
||||||
self._symbol_groups[op_name].append(symbol)
|
self._symbol_groups[op_name].append(symbol)
|
||||||
else:
|
else:
|
||||||
self._symbol_groups[op_name] = [symbol]
|
self._symbol_groups[op_name] = [symbol]
|
||||||
|
|
||||||
def parse_symbols(self, input_file):
|
def parse_symbols(self, input_file) -> None:
|
||||||
if len(self.symbols) > 0:
|
if len(self.symbols) > 0:
|
||||||
return
|
return
|
||||||
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
|
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
|
||||||
@@ -242,7 +284,7 @@ class Libdevice(ExternLibrary):
|
|||||||
|
|
||||||
self._group_symbols()
|
self._group_symbols()
|
||||||
|
|
||||||
def _output_stubs(self):
|
def _output_stubs(self) -> str:
|
||||||
# Generate python functions in the following format:
|
# Generate python functions in the following format:
|
||||||
# @extern.extern
|
# @extern.extern
|
||||||
# def <op_name>(<args>, _builder=None):
|
# def <op_name>(<args>, _builder=None):
|
||||||
@@ -250,7 +292,7 @@ class Libdevice(ExternLibrary):
|
|||||||
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||||
import_str = "from . import core, extern\n"
|
import_str = "from . import core, extern\n"
|
||||||
import_str += "import os\n"
|
import_str += "import os\n"
|
||||||
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
|
header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
|
||||||
func_str = ""
|
func_str = ""
|
||||||
for symbols in self._symbol_groups.values():
|
for symbols in self._symbol_groups.values():
|
||||||
func_str += "@extern.extern\n"
|
func_str += "@extern.extern\n"
|
||||||
@@ -283,7 +325,10 @@ class Libdevice(ExternLibrary):
|
|||||||
|
|
||||||
|
|
||||||
class LLVMDisassembler:
|
class LLVMDisassembler:
|
||||||
def __init__(self, path):
|
_path: str
|
||||||
|
_ll_file: str
|
||||||
|
|
||||||
|
def __init__(self, path) -> None:
|
||||||
'''
|
'''
|
||||||
Invoke llvm-dis to disassemble the given file.
|
Invoke llvm-dis to disassemble the given file.
|
||||||
|
|
||||||
@@ -292,23 +337,28 @@ class LLVMDisassembler:
|
|||||||
self._path = path
|
self._path = path
|
||||||
self._ll_file = "/tmp/extern_lib.ll"
|
self._ll_file = "/tmp/extern_lib.ll"
|
||||||
|
|
||||||
def disasm(self, lib_path):
|
def disasm(self, lib_path: str) -> None:
|
||||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||||
stdout=subprocess.PIPE).communicate()
|
stdout=subprocess.PIPE).communicate()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ll_file(self):
|
def ll_file(self) -> str:
|
||||||
return self._ll_file
|
return self._ll_file
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self) -> str:
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
|
|
||||||
extern_libs = ["libdevice"]
|
extern_libs = ["libdevice"]
|
||||||
|
|
||||||
|
|
||||||
def build(llvm_dis_path, lib_path, lib_name, output_dir):
|
def build(
|
||||||
|
llvm_dis_path: str,
|
||||||
|
lib_path: str,
|
||||||
|
lib_name: str,
|
||||||
|
output_dir: str,
|
||||||
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Interface function to build the library file.
|
Interface function to build the library file.
|
||||||
|
|
||||||
@@ -331,10 +381,10 @@ def build(llvm_dis_path, lib_path, lib_name, output_dir):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis")
|
parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
|
||||||
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library")
|
parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
|
||||||
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library")
|
parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
|
||||||
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/")
|
parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)
|
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)
|
||||||
|
@@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re
|
|||||||
|
|
||||||
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||||
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||||
Using triton, you can simply call `tl.libdevice.asinf`.
|
Using triton, you can simply call `tl.libdevice.asin`.
|
||||||
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user