[Frontend] Interface fixes for libdevice (#830)

- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
This commit is contained in:
Chenggang Zhao
2022-11-02 01:51:58 +08:00
committed by GitHub
parent 578ada7740
commit f16138d447
3 changed files with 233 additions and 355 deletions

View File

@@ -58,13 +58,7 @@ def mulhi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int32, core.int32,): ("__nv_mulhi", core.int32),
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
}, _builder)
@extern.extern
def mul64hi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
}, _builder)
@@ -157,262 +151,138 @@ def saturatef(arg0, _builder=None):
}, _builder)
@extern.extern
def fmaf_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32),
}, _builder)
@extern.extern
def fma_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
}, _builder)
@extern.extern
def fma_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
}, _builder)
@extern.extern
def fma_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
}, _builder)
@extern.extern
def fma_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
}, _builder)
@extern.extern
def fast_fdividef(arg0, arg1, _builder=None):
def fast_dividef(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
}, _builder)
@extern.extern
def fdiv_rn(arg0, arg1, _builder=None):
def div_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder)
@extern.extern
def fdiv_rz(arg0, arg1, _builder=None):
def div_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder)
@extern.extern
def fdiv_rd(arg0, arg1, _builder=None):
def div_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder)
@extern.extern
def fdiv_ru(arg0, arg1, _builder=None):
def div_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder)
@extern.extern
def frcp_rn(arg0, _builder=None):
def rcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rn", core.float32),
(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder)
@extern.extern
def frcp_rz(arg0, _builder=None):
def rcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rz", core.float32),
(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder)
@extern.extern
def frcp_rd(arg0, _builder=None):
def rcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rd", core.float32),
(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder)
@extern.extern
def frcp_ru(arg0, _builder=None):
def rcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_ru", core.float32),
(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder)
@extern.extern
def fsqrt_rn(arg0, _builder=None):
def sqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rn", core.float32),
(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder)
@extern.extern
def fsqrt_rz(arg0, _builder=None):
def sqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rz", core.float32),
(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder)
@extern.extern
def fsqrt_rd(arg0, _builder=None):
def sqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rd", core.float32),
(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder)
@extern.extern
def fsqrt_ru(arg0, _builder=None):
def sqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_ru", core.float32),
}, _builder)
@extern.extern
def ddiv_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder)
@extern.extern
def ddiv_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder)
@extern.extern
def ddiv_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder)
@extern.extern
def ddiv_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder)
@extern.extern
def drcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder)
@extern.extern
def drcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder)
@extern.extern
def drcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder)
@extern.extern
def drcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder)
@extern.extern
def dsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder)
@extern.extern
def dsqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder)
@extern.extern
def dsqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder)
@extern.extern
def dsqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_ru", core.float64),
(core.float64,): ("__nv_dsqrt_ru", core.float64),
}, _builder)
@@ -425,114 +295,66 @@ def sqrt(arg0, _builder=None):
@extern.extern
def dadd_rn(arg0, arg1, _builder=None):
def add_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder)
@extern.extern
def dadd_rz(arg0, arg1, _builder=None):
def add_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder)
@extern.extern
def dadd_rd(arg0, arg1, _builder=None):
def add_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder)
@extern.extern
def dadd_ru(arg0, arg1, _builder=None):
def add_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder)
@extern.extern
def dmul_rn(arg0, arg1, _builder=None):
def mul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder)
@extern.extern
def dmul_rz(arg0, arg1, _builder=None):
def mul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder)
@extern.extern
def dmul_rd(arg0, arg1, _builder=None):
def mul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder)
@extern.extern
def dmul_ru(arg0, arg1, _builder=None):
def mul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
}, _builder)
@extern.extern
def fadd_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder)
@extern.extern
def fadd_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder)
@extern.extern
def fmul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder)
@extern.extern
def fmul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
}, _builder)
@extern.extern
def fadd_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder)
@extern.extern
def fadd_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder)
@extern.extern
def fmul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder)
@extern.extern
def fmul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
}, _builder)
@@ -624,7 +446,13 @@ def double2uint_ru(arg0, _builder=None):
def int2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2double_rn", core.float64),
(core.uint32,): ("__nv_uint2double_rn", core.float64),
}, _builder)
@extern.extern
def uint2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2double_rn", core.float64),
}, _builder)
@@ -688,7 +516,6 @@ def float2uint_ru(arg0, _builder=None):
def int2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rn", core.float32),
(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder)
@@ -696,7 +523,6 @@ def int2float_rn(arg0, _builder=None):
def int2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rz", core.float32),
(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder)
@@ -704,7 +530,6 @@ def int2float_rz(arg0, _builder=None):
def int2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rd", core.float32),
(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder)
@@ -712,7 +537,34 @@ def int2float_rd(arg0, _builder=None):
def int2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_ru", core.float32),
(core.uint32,): ("__nv_uint2float_ru", core.float32),
}, _builder)
@extern.extern
def uint2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder)
@extern.extern
def uint2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder)
@extern.extern
def uint2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder)
@extern.extern
def uint2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_ru", core.float32),
}, _builder)
@@ -853,7 +705,6 @@ def double2ull_ru(arg0, _builder=None):
def ll2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rn", core.float32),
(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder)
@@ -861,7 +712,6 @@ def ll2float_rn(arg0, _builder=None):
def ll2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rz", core.float32),
(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder)
@@ -869,7 +719,6 @@ def ll2float_rz(arg0, _builder=None):
def ll2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rd", core.float32),
(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder)
@@ -877,7 +726,34 @@ def ll2float_rd(arg0, _builder=None):
def ll2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_ru", core.float32),
(core.uint64,): ("__nv_ull2float_ru", core.float32),
}, _builder)
@extern.extern
def ull2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder)
@extern.extern
def ull2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder)
@extern.extern
def ull2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder)
@extern.extern
def ull2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_ru", core.float32),
}, _builder)
@@ -885,7 +761,6 @@ def ll2float_ru(arg0, _builder=None):
def ll2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rn", core.float64),
(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder)
@@ -893,7 +768,6 @@ def ll2double_rn(arg0, _builder=None):
def ll2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rz", core.float64),
(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder)
@@ -901,7 +775,6 @@ def ll2double_rz(arg0, _builder=None):
def ll2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rd", core.float64),
(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder)
@@ -909,7 +782,34 @@ def ll2double_rd(arg0, _builder=None):
def ll2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_ru", core.float64),
(core.uint64,): ("__nv_ull2double_ru", core.float64),
}, _builder)
@extern.extern
def ull2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder)
@extern.extern
def ull2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder)
@extern.extern
def ull2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder)
@extern.extern
def ull2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_ru", core.float64),
}, _builder)
@@ -917,7 +817,6 @@ def ll2double_ru(arg0, _builder=None):
def int_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int_as_float", core.float32),
(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder)
@@ -928,6 +827,13 @@ def float_as_int(arg0, _builder=None):
}, _builder)
@extern.extern
def uint_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder)
@extern.extern
def float_as_uint(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
@@ -1006,11 +912,9 @@ def fast_log10f(arg0, _builder=None):
@extern.extern
def pow(arg0, arg1, _builder=None):
def fast_powf(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder)
@@ -1031,35 +935,39 @@ def rhadd(arg0, arg1, _builder=None):
@extern.extern
def fsub_rn(arg0, arg1, _builder=None):
def sub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder)
@extern.extern
def fsub_rz(arg0, arg1, _builder=None):
def sub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder)
@extern.extern
def fsub_rd(arg0, arg1, _builder=None):
def sub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder)
@extern.extern
def fsub_ru(arg0, arg1, _builder=None):
def sub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder)
@extern.extern
def frsqrt_rn(arg0, _builder=None):
def rsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frsqrt_rn", core.float32),
}, _builder)
@@ -1098,16 +1006,18 @@ def nearbyint(arg0, _builder=None):
@extern.extern
def isnanf(arg0, _builder=None):
def isnan(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isnanf", core.int32),
(core.float64,): ("__nv_isnand", core.int32),
}, _builder)
@extern.extern
def signbitf(arg0, _builder=None):
def signbit(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_signbitf", core.int32),
(core.float64,): ("__nv_signbitd", core.int32),
}, _builder)
@@ -1127,9 +1037,10 @@ def finitef(arg0, _builder=None):
@extern.extern
def isinff(arg0, _builder=None):
def isinf(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isinff", core.int32),
(core.float64,): ("__nv_isinfd", core.int32),
}, _builder)
@@ -1550,10 +1461,12 @@ def fma(arg0, arg1, arg2, _builder=None):
@extern.extern
def powi(arg0, arg1, _builder=None):
def pow(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.int32,): ("__nv_powif", core.float32),
(core.float64, core.int32,): ("__nv_powi", core.float64),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder)
@@ -1605,57 +1518,8 @@ def logb(arg0, _builder=None):
}, _builder)
@extern.extern
def signbitd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_signbitd", core.int32),
}, _builder)
@extern.extern
def isfinited(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isfinited", core.int32),
}, _builder)
@extern.extern
def isinfd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isinfd", core.int32),
}, _builder)
@extern.extern
def isnand(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isnand", core.int32),
}, _builder)
@extern.extern
def dsub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder)
@extern.extern
def dsub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder)
@extern.extern
def dsub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder)
@extern.extern
def dsub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder)

View File

@@ -149,6 +149,9 @@ class Libdevice(ExternLibrary):
func_strs = func_str.split("(")
func_name = func_strs[0].replace("@", "")
op_name = func_name.replace("__nv_", "")
# To filter some interfaces unlisted in NVIDIA's official documents.
if 'ieee' in op_name:
return None
# Get arg_types
arg_strs = func_strs[1].split(",")
arg_types = []
@@ -176,55 +179,66 @@ class Libdevice(ExternLibrary):
for symbol in self._symbols.values():
op_name = symbol.op_name
symbol_set[op_name] = symbol
# The following cases are grouped together:
# op_name, <u/ull/ll>op_name<ll/f/i>
# Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
}
for symbol in self._symbols.values():
op_name = symbol.op_name
if "max" in op_name:
op_name = "max"
elif "min" in op_name:
op_name = "min"
elif "abs" in op_name:
op_name = "abs"
elif "pow" in op_name and "fast" in op_name:
op_name = "pow"
elif "round" in op_name:
if "llround" in op_name:
op_name = "llround"
else:
op_name = "round"
elif "rint" in op_name:
if "llrint" in op_name:
op_name = "llrint"
else:
op_name = "rint"
elif op_name.startswith("ull"):
if "2" not in op_name:
# e.g., ullmax->max
op_name = op_name[3:]
else:
# e.g., ull2double->ll2double
op_name = op_name[1:]
elif op_name.startswith("u"):
if "2" not in op_name:
# e.g., uhadd->hadd
op_name = op_name[1:]
else:
# e.g., uint2double_rn->int2double_rn
op_name = op_name[1:]
elif op_name.startswith("ll"):
if "2" not in op_name:
# e.g., llmax->max
op_name = op_name[2:]
elif op_name.endswith("ll"):
op_name = op_name[:-2]
elif op_name.endswith("f"):
op_name = op_name[:-1]
if op_name in symbol_set:
# Update op_name only if there's an existing symbol
if op_name in renaming:
op_name = renaming[op_name]
symbol._op_name = op_name
else:
op_name = symbol._op_name
if op_name in self._symbol_groups:
self._symbol_groups[op_name].append(symbol)
else:
@@ -250,7 +264,7 @@ class Libdevice(ExternLibrary):
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str += "import os\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"
@@ -331,10 +345,10 @@ def build(llvm_dis_path, lib_path, lib_name, output_dir):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library")
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/")
parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
args = parser.parse_args()
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)

View File

@@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`.
Using triton, you can simply call `tl.libdevice.asin`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
"""