diff --git a/python/tests/test_math_ops.py b/python/tests/test_math_ops.py index c464818fb..6b3463490 100644 --- a/python/tests/test_math_ops.py +++ b/python/tests/test_math_ops.py @@ -13,8 +13,8 @@ def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr): y1 = tl.sin(x1) y2 = tl.libdevice.sin(x2) - y3 = tl.libdevice.fdiv_rn(x3, x3) - y4 = tl.libdevice.fmaf_rd(x4, x4, x4) + y3 = tl.libdevice.div_rn(x3, x3) + y4 = tl.libdevice.fma_rd(x4, x4, x4) tl.store(x1_ptr + offsets, y1, mask=offsets < n) tl.store(x2_ptr + offsets, y2, mask=offsets < n) diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py index 226480fa2..cae14a797 100644 --- a/python/triton/language/libdevice.py +++ b/python/triton/language/libdevice.py @@ -58,13 +58,7 @@ def mulhi(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), - }, _builder) - - -@extern.extern -def mul64hi(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), }, _builder) @@ -157,262 +151,138 @@ def saturatef(arg0, _builder=None): }, _builder) -@extern.extern -def fmaf_rn(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_rz(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_rd(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_ru(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rn", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rz", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rd", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_ru", core.dtype("fp32")), - }, _builder) - - @extern.extern def fma_rn(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), }, _builder) @extern.extern def fma_rz(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), }, _builder) @extern.extern def fma_rd(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), }, _builder) @extern.extern def fma_ru(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], - {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), }, _builder) @extern.extern -def fast_fdividef(arg0, arg1, _builder=None): +def fast_dividef(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), }, _builder) @extern.extern -def fdiv_rn(arg0, arg1, _builder=None): +def div_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), }, _builder) @extern.extern -def fdiv_rz(arg0, arg1, _builder=None): +def div_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), }, _builder) @extern.extern -def fdiv_rd(arg0, arg1, _builder=None): +def div_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), }, _builder) @extern.extern -def fdiv_ru(arg0, arg1, _builder=None): +def div_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), }, _builder) @extern.extern -def frcp_rn(arg0, _builder=None): +def rcp_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), }, _builder) @extern.extern -def frcp_rz(arg0, _builder=None): +def rcp_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), }, _builder) @extern.extern -def frcp_rd(arg0, _builder=None): +def rcp_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), }, _builder) @extern.extern -def frcp_ru(arg0, _builder=None): +def rcp_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), }, _builder) @extern.extern -def fsqrt_rn(arg0, _builder=None): +def sqrt_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), }, _builder) @extern.extern -def fsqrt_rz(arg0, _builder=None): +def sqrt_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), }, _builder) @extern.extern -def fsqrt_rd(arg0, _builder=None): +def sqrt_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), }, _builder) @extern.extern -def fsqrt_ru(arg0, _builder=None): +def sqrt_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def ddiv_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def ddiv_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def ddiv_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def ddiv_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def drcp_rn(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def drcp_rz(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def drcp_rd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def drcp_ru(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsqrt_rn(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsqrt_rz(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsqrt_rd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsqrt_ru(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), + (core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), }, _builder) @@ -425,114 +295,66 @@ def sqrt(arg0, _builder=None): @extern.extern -def dadd_rn(arg0, arg1, _builder=None): +def add_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), }, _builder) @extern.extern -def dadd_rz(arg0, arg1, _builder=None): +def add_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), }, _builder) @extern.extern -def dadd_rd(arg0, arg1, _builder=None): +def add_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), }, _builder) @extern.extern -def dadd_ru(arg0, arg1, _builder=None): +def add_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), }, _builder) @extern.extern -def dmul_rn(arg0, arg1, _builder=None): +def mul_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), }, _builder) @extern.extern -def dmul_rz(arg0, arg1, _builder=None): +def mul_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), }, _builder) @extern.extern -def dmul_rd(arg0, arg1, _builder=None): +def mul_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), }, _builder) @extern.extern -def dmul_ru(arg0, arg1, _builder=None): +def mul_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def fadd_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fadd_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmul_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmul_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fadd_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fadd_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmul_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), - }, _builder) - - -@extern.extern -def fmul_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), }, _builder) @@ -624,7 +446,13 @@ def double2uint_ru(arg0, _builder=None): def int2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), - (core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def uint2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), }, _builder) @@ -688,7 +516,6 @@ def float2uint_ru(arg0, _builder=None): def int2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), - (core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), }, _builder) @@ -696,7 +523,6 @@ def int2float_rn(arg0, _builder=None): def int2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), - (core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), }, _builder) @@ -704,7 +530,6 @@ def int2float_rz(arg0, _builder=None): def int2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), - (core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), }, _builder) @@ -712,7 +537,34 @@ def int2float_rd(arg0, _builder=None): def int2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), - (core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def uint2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def uint2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def uint2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def uint2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), }, _builder) @@ -853,7 +705,6 @@ def double2ull_ru(arg0, _builder=None): def ll2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), - (core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), }, _builder) @@ -861,7 +712,6 @@ def ll2float_rn(arg0, _builder=None): def ll2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), - (core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), }, _builder) @@ -869,7 +719,6 @@ def ll2float_rz(arg0, _builder=None): def ll2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), - (core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), }, _builder) @@ -877,7 +726,34 @@ def ll2float_rd(arg0, _builder=None): def ll2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), - (core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ull2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ull2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ull2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ull2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), }, _builder) @@ -885,7 +761,6 @@ def ll2float_ru(arg0, _builder=None): def ll2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), - (core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), }, _builder) @@ -893,7 +768,6 @@ def ll2double_rn(arg0, _builder=None): def ll2double_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), - (core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), }, _builder) @@ -901,7 +775,6 @@ def ll2double_rz(arg0, _builder=None): def ll2double_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), - (core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), }, _builder) @@ -909,7 +782,34 @@ def ll2double_rd(arg0, _builder=None): def ll2double_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), - (core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ull2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ull2double_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ull2double_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ull2double_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), }, _builder) @@ -917,7 +817,6 @@ def ll2double_ru(arg0, _builder=None): def int_as_float(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), - (core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), }, _builder) @@ -928,6 +827,13 @@ def float_as_int(arg0, _builder=None): }, _builder) +@extern.extern +def uint_as_float(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), + }, _builder) + + @extern.extern def float_as_uint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], @@ -1006,11 +912,9 @@ def fast_log10f(arg0, _builder=None): @extern.extern -def pow(arg0, arg1, _builder=None): +def fast_powf(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), }, _builder) @@ -1031,35 +935,39 @@ def rhadd(arg0, arg1, _builder=None): @extern.extern -def fsub_rn(arg0, arg1, _builder=None): +def sub_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), }, _builder) @extern.extern -def fsub_rz(arg0, arg1, _builder=None): +def sub_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), }, _builder) @extern.extern -def fsub_rd(arg0, arg1, _builder=None): +def sub_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), }, _builder) @extern.extern -def fsub_ru(arg0, arg1, _builder=None): +def sub_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), }, _builder) @extern.extern -def frsqrt_rn(arg0, _builder=None): +def rsqrt_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), }, _builder) @@ -1098,16 +1006,18 @@ def nearbyint(arg0, _builder=None): @extern.extern -def isnanf(arg0, _builder=None): +def isnan(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), }, _builder) @extern.extern -def signbitf(arg0, _builder=None): +def signbit(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), }, _builder) @@ -1127,9 +1037,10 @@ def finitef(arg0, _builder=None): @extern.extern -def isinff(arg0, _builder=None): +def isinf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), }, _builder) @@ -1550,10 +1461,12 @@ def fma(arg0, arg1, arg2, _builder=None): @extern.extern -def powi(arg0, arg1, _builder=None): +def pow(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), }, _builder) @@ -1605,57 +1518,8 @@ def logb(arg0, _builder=None): }, _builder) -@extern.extern -def signbitd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), - }, _builder) - - @extern.extern def isfinited(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), }, _builder) - - -@extern.extern -def isinfd(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), - }, _builder) - - -@extern.extern -def isnand(arg0, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], - {(core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), - }, _builder) - - -@extern.extern -def dsub_rn(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsub_rz(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsub_ru(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), - }, _builder) - - -@extern.extern -def dsub_rd(arg0, arg1, _builder=None): - return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), - }, _builder) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py new file mode 100644 index 000000000..bcf29f2b9 --- /dev/null +++ b/python/triton/tools/build_extern.py @@ -0,0 +1,348 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod + + +class Symbol: + def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = arg_names + self._arg_types = arg_types + + @property + def name(self): + return self._name + + @property + def op_name(self): + return self._op_name + + @property + def ret_type(self): + return self._ret_type + + @property + def arg_names(self): + return self._arg_names + + @property + def arg_types(self): + return self._arg_types + + +def convert_type(type_str): + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str): + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = True + self._grouping = grouping + + @property + def name(self): + return self._name + + @property + def path(self): + return self._path + + @property + def symbols(self): + return self._symbols + + @property + def grouping(self): + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file): + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir): + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], + stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + + def _extract_symbol(self, line): + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self): + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', + 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn', + 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', + 'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin', + 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', + 'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt', + 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', + 'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi', + 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', + 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru', + 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', + 'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx', + 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', + 'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs', + 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', + 'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor', + 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', + 'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', + 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', + 'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', + 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', + 'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10', + 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', + 'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max', + 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', + 'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', + 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru', + 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', + 'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi', + 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter', + 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', + 'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', + 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd', + 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', + 'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', + 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', + 'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn', + 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', + 'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh', + 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', + 'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', + 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz', + 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', + 'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', + 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', + 'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file): + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self): + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return extern.dispatch("libdevice", , , , _builder) + import_str = "from . import core, extern\n" + import_str += "import os\n" + header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@extern.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += ", _builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + def __init__(self, path): + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path): + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], + stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self): + return self._ll_file + + @property + def path(self): + return self._path + + +extern_libs = ["libdevice"] + + +def build(llvm_dis_path, lib_path, lib_name, output_dir): + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)