[Frontend] Interface fixes for libdevice (#830)

- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
This commit is contained in:
Chenggang Zhao
2022-11-02 01:51:58 +08:00
committed by GitHub
parent 578ada7740
commit f16138d447
3 changed files with 233 additions and 355 deletions

View File

@@ -58,13 +58,7 @@ def mulhi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int32, core.int32,): ("__nv_mulhi", core.int32), {(core.int32, core.int32,): ("__nv_mulhi", core.int32),
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32), (core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
}, _builder) (core.int64, core.int64,): ("__nv_mul64hi", core.int64),
@extern.extern
def mul64hi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64), (core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
}, _builder) }, _builder)
@@ -157,262 +151,138 @@ def saturatef(arg0, _builder=None):
}, _builder) }, _builder)
@extern.extern
def fmaf_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32),
}, _builder)
@extern.extern @extern.extern
def fma_rn(arg0, arg1, arg2, _builder=None): def fma_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fma_rz(arg0, arg1, arg2, _builder=None): def fma_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fma_rd(arg0, arg1, arg2, _builder=None): def fma_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fma_ru(arg0, arg1, arg2, _builder=None): def fma_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64), {(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fast_fdividef(arg0, arg1, _builder=None): def fast_dividef(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32), {(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_rn(arg0, arg1, _builder=None): def div_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_rz(arg0, arg1, _builder=None): def div_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_rd(arg0, arg1, _builder=None): def div_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fdiv_ru(arg0, arg1, _builder=None): def div_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32), {(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_rn(arg0, _builder=None): def rcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rn", core.float32), {(core.float32,): ("__nv_frcp_rn", core.float32),
(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_rz(arg0, _builder=None): def rcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rz", core.float32), {(core.float32,): ("__nv_frcp_rz", core.float32),
(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_rd(arg0, _builder=None): def rcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rd", core.float32), {(core.float32,): ("__nv_frcp_rd", core.float32),
(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frcp_ru(arg0, _builder=None): def rcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_ru", core.float32), {(core.float32,): ("__nv_frcp_ru", core.float32),
(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_rn(arg0, _builder=None): def sqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rn", core.float32), {(core.float32,): ("__nv_fsqrt_rn", core.float32),
(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_rz(arg0, _builder=None): def sqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rz", core.float32), {(core.float32,): ("__nv_fsqrt_rz", core.float32),
(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_rd(arg0, _builder=None): def sqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rd", core.float32), {(core.float32,): ("__nv_fsqrt_rd", core.float32),
(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsqrt_ru(arg0, _builder=None): def sqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_ru", core.float32), {(core.float32,): ("__nv_fsqrt_ru", core.float32),
}, _builder) (core.float64,): ("__nv_dsqrt_ru", core.float64),
@extern.extern
def ddiv_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder)
@extern.extern
def ddiv_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder)
@extern.extern
def ddiv_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder)
@extern.extern
def ddiv_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder)
@extern.extern
def drcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder)
@extern.extern
def drcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder)
@extern.extern
def drcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder)
@extern.extern
def drcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder)
@extern.extern
def dsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder)
@extern.extern
def dsqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder)
@extern.extern
def dsqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder)
@extern.extern
def dsqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_ru", core.float64),
}, _builder) }, _builder)
@@ -425,114 +295,66 @@ def sqrt(arg0, _builder=None):
@extern.extern @extern.extern
def dadd_rn(arg0, arg1, _builder=None): def add_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64), {(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dadd_rz(arg0, arg1, _builder=None): def add_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64), {(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dadd_rd(arg0, arg1, _builder=None): def add_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64), {(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dadd_ru(arg0, arg1, _builder=None): def add_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64), {(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_rn(arg0, arg1, _builder=None): def mul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64), {(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_rz(arg0, arg1, _builder=None): def mul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64), {(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_rd(arg0, arg1, _builder=None): def mul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64), {(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def dmul_ru(arg0, arg1, _builder=None): def mul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64), {(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
}, _builder) (core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
@extern.extern
def fadd_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder)
@extern.extern
def fadd_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder)
@extern.extern
def fmul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder)
@extern.extern
def fmul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
}, _builder)
@extern.extern
def fadd_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder)
@extern.extern
def fadd_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder)
@extern.extern
def fmul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder)
@extern.extern
def fmul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder) }, _builder)
@@ -624,7 +446,13 @@ def double2uint_ru(arg0, _builder=None):
def int2double_rn(arg0, _builder=None): def int2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2double_rn", core.float64), {(core.int32,): ("__nv_int2double_rn", core.float64),
(core.uint32,): ("__nv_uint2double_rn", core.float64), }, _builder)
@extern.extern
def uint2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2double_rn", core.float64),
}, _builder) }, _builder)
@@ -688,7 +516,6 @@ def float2uint_ru(arg0, _builder=None):
def int2float_rn(arg0, _builder=None): def int2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rn", core.float32), {(core.int32,): ("__nv_int2float_rn", core.float32),
(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder) }, _builder)
@@ -696,7 +523,6 @@ def int2float_rn(arg0, _builder=None):
def int2float_rz(arg0, _builder=None): def int2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rz", core.float32), {(core.int32,): ("__nv_int2float_rz", core.float32),
(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder) }, _builder)
@@ -704,7 +530,6 @@ def int2float_rz(arg0, _builder=None):
def int2float_rd(arg0, _builder=None): def int2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rd", core.float32), {(core.int32,): ("__nv_int2float_rd", core.float32),
(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder) }, _builder)
@@ -712,7 +537,34 @@ def int2float_rd(arg0, _builder=None):
def int2float_ru(arg0, _builder=None): def int2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_ru", core.float32), {(core.int32,): ("__nv_int2float_ru", core.float32),
(core.uint32,): ("__nv_uint2float_ru", core.float32), }, _builder)
@extern.extern
def uint2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder)
@extern.extern
def uint2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder)
@extern.extern
def uint2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder)
@extern.extern
def uint2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_ru", core.float32),
}, _builder) }, _builder)
@@ -853,7 +705,6 @@ def double2ull_ru(arg0, _builder=None):
def ll2float_rn(arg0, _builder=None): def ll2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rn", core.float32), {(core.int64,): ("__nv_ll2float_rn", core.float32),
(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder) }, _builder)
@@ -861,7 +712,6 @@ def ll2float_rn(arg0, _builder=None):
def ll2float_rz(arg0, _builder=None): def ll2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rz", core.float32), {(core.int64,): ("__nv_ll2float_rz", core.float32),
(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder) }, _builder)
@@ -869,7 +719,6 @@ def ll2float_rz(arg0, _builder=None):
def ll2float_rd(arg0, _builder=None): def ll2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rd", core.float32), {(core.int64,): ("__nv_ll2float_rd", core.float32),
(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder) }, _builder)
@@ -877,7 +726,34 @@ def ll2float_rd(arg0, _builder=None):
def ll2float_ru(arg0, _builder=None): def ll2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_ru", core.float32), {(core.int64,): ("__nv_ll2float_ru", core.float32),
(core.uint64,): ("__nv_ull2float_ru", core.float32), }, _builder)
@extern.extern
def ull2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder)
@extern.extern
def ull2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder)
@extern.extern
def ull2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder)
@extern.extern
def ull2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_ru", core.float32),
}, _builder) }, _builder)
@@ -885,7 +761,6 @@ def ll2float_ru(arg0, _builder=None):
def ll2double_rn(arg0, _builder=None): def ll2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rn", core.float64), {(core.int64,): ("__nv_ll2double_rn", core.float64),
(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder) }, _builder)
@@ -893,7 +768,6 @@ def ll2double_rn(arg0, _builder=None):
def ll2double_rz(arg0, _builder=None): def ll2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rz", core.float64), {(core.int64,): ("__nv_ll2double_rz", core.float64),
(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder) }, _builder)
@@ -901,7 +775,6 @@ def ll2double_rz(arg0, _builder=None):
def ll2double_rd(arg0, _builder=None): def ll2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rd", core.float64), {(core.int64,): ("__nv_ll2double_rd", core.float64),
(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder) }, _builder)
@@ -909,7 +782,34 @@ def ll2double_rd(arg0, _builder=None):
def ll2double_ru(arg0, _builder=None): def ll2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_ru", core.float64), {(core.int64,): ("__nv_ll2double_ru", core.float64),
(core.uint64,): ("__nv_ull2double_ru", core.float64), }, _builder)
@extern.extern
def ull2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder)
@extern.extern
def ull2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder)
@extern.extern
def ull2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder)
@extern.extern
def ull2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_ru", core.float64),
}, _builder) }, _builder)
@@ -917,7 +817,6 @@ def ll2double_ru(arg0, _builder=None):
def int_as_float(arg0, _builder=None): def int_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int_as_float", core.float32), {(core.int32,): ("__nv_int_as_float", core.float32),
(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder) }, _builder)
@@ -928,6 +827,13 @@ def float_as_int(arg0, _builder=None):
}, _builder) }, _builder)
@extern.extern
def uint_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder)
@extern.extern @extern.extern
def float_as_uint(arg0, _builder=None): def float_as_uint(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
@@ -1006,11 +912,9 @@ def fast_log10f(arg0, _builder=None):
@extern.extern @extern.extern
def pow(arg0, arg1, _builder=None): def fast_powf(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32), {(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder) }, _builder)
@@ -1031,35 +935,39 @@ def rhadd(arg0, arg1, _builder=None):
@extern.extern @extern.extern
def fsub_rn(arg0, arg1, _builder=None): def sub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32), {(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsub_rz(arg0, arg1, _builder=None): def sub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32), {(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsub_rd(arg0, arg1, _builder=None): def sub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32), {(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def fsub_ru(arg0, arg1, _builder=None): def sub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32), {(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder) }, _builder)
@extern.extern @extern.extern
def frsqrt_rn(arg0, _builder=None): def rsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frsqrt_rn", core.float32), {(core.float32,): ("__nv_frsqrt_rn", core.float32),
}, _builder) }, _builder)
@@ -1098,16 +1006,18 @@ def nearbyint(arg0, _builder=None):
@extern.extern @extern.extern
def isnanf(arg0, _builder=None): def isnan(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isnanf", core.int32), {(core.float32,): ("__nv_isnanf", core.int32),
(core.float64,): ("__nv_isnand", core.int32),
}, _builder) }, _builder)
@extern.extern @extern.extern
def signbitf(arg0, _builder=None): def signbit(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_signbitf", core.int32), {(core.float32,): ("__nv_signbitf", core.int32),
(core.float64,): ("__nv_signbitd", core.int32),
}, _builder) }, _builder)
@@ -1127,9 +1037,10 @@ def finitef(arg0, _builder=None):
@extern.extern @extern.extern
def isinff(arg0, _builder=None): def isinf(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isinff", core.int32), {(core.float32,): ("__nv_isinff", core.int32),
(core.float64,): ("__nv_isinfd", core.int32),
}, _builder) }, _builder)
@@ -1550,10 +1461,12 @@ def fma(arg0, arg1, arg2, _builder=None):
@extern.extern @extern.extern
def powi(arg0, arg1, _builder=None): def pow(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.int32,): ("__nv_powif", core.float32), {(core.float32, core.int32,): ("__nv_powif", core.float32),
(core.float64, core.int32,): ("__nv_powi", core.float64), (core.float64, core.int32,): ("__nv_powi", core.float64),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder) }, _builder)
@@ -1605,57 +1518,8 @@ def logb(arg0, _builder=None):
}, _builder) }, _builder)
@extern.extern
def signbitd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_signbitd", core.int32),
}, _builder)
@extern.extern @extern.extern
def isfinited(arg0, _builder=None): def isfinited(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isfinited", core.int32), {(core.float64,): ("__nv_isfinited", core.int32),
}, _builder) }, _builder)
@extern.extern
def isinfd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isinfd", core.int32),
}, _builder)
@extern.extern
def isnand(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isnand", core.int32),
}, _builder)
@extern.extern
def dsub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder)
@extern.extern
def dsub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder)
@extern.extern
def dsub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder)
@extern.extern
def dsub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder)

View File

@@ -149,6 +149,9 @@ class Libdevice(ExternLibrary):
func_strs = func_str.split("(") func_strs = func_str.split("(")
func_name = func_strs[0].replace("@", "") func_name = func_strs[0].replace("@", "")
op_name = func_name.replace("__nv_", "") op_name = func_name.replace("__nv_", "")
# To filter some interfaces unlisted in NVIDIA's official documents.
if 'ieee' in op_name:
return None
# Get arg_types # Get arg_types
arg_strs = func_strs[1].split(",") arg_strs = func_strs[1].split(",")
arg_types = [] arg_types = []
@@ -176,55 +179,66 @@ class Libdevice(ExternLibrary):
for symbol in self._symbols.values(): for symbol in self._symbols.values():
op_name = symbol.op_name op_name = symbol.op_name
symbol_set[op_name] = symbol symbol_set[op_name] = symbol
# The following cases are grouped together:
# op_name, <u/ull/ll>op_name<ll/f/i> # Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
}
for symbol in self._symbols.values(): for symbol in self._symbols.values():
op_name = symbol.op_name op_name = symbol.op_name
if "max" in op_name: if op_name in renaming:
op_name = "max" op_name = renaming[op_name]
elif "min" in op_name:
op_name = "min"
elif "abs" in op_name:
op_name = "abs"
elif "pow" in op_name and "fast" in op_name:
op_name = "pow"
elif "round" in op_name:
if "llround" in op_name:
op_name = "llround"
else:
op_name = "round"
elif "rint" in op_name:
if "llrint" in op_name:
op_name = "llrint"
else:
op_name = "rint"
elif op_name.startswith("ull"):
if "2" not in op_name:
# e.g., ullmax->max
op_name = op_name[3:]
else:
# e.g., ull2double->ll2double
op_name = op_name[1:]
elif op_name.startswith("u"):
if "2" not in op_name:
# e.g., uhadd->hadd
op_name = op_name[1:]
else:
# e.g., uint2double_rn->int2double_rn
op_name = op_name[1:]
elif op_name.startswith("ll"):
if "2" not in op_name:
# e.g., llmax->max
op_name = op_name[2:]
elif op_name.endswith("ll"):
op_name = op_name[:-2]
elif op_name.endswith("f"):
op_name = op_name[:-1]
if op_name in symbol_set:
# Update op_name only if there's an existing symbol
symbol._op_name = op_name symbol._op_name = op_name
else:
op_name = symbol._op_name
if op_name in self._symbol_groups: if op_name in self._symbol_groups:
self._symbol_groups[op_name].append(symbol) self._symbol_groups[op_name].append(symbol)
else: else:
@@ -250,7 +264,7 @@ class Libdevice(ExternLibrary):
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder) # return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n" import_str = "from . import core, extern\n"
import_str += "import os\n" import_str += "import os\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
func_str = "" func_str = ""
for symbols in self._symbol_groups.values(): for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n" func_str += "@extern.extern\n"
@@ -331,10 +345,10 @@ def build(llvm_dis_path, lib_path, lib_name, output_dir):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis") parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library") parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library") parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/") parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
args = parser.parse_args() args = parser.parse_args()
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)

View File

@@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`. Using triton, you can simply call `tl.libdevice.asin`.
triton automatically selects the correct underlying device function to invoke based on input and output types. triton automatically selects the correct underlying device function to invoke based on input and output types.
""" """