New python binding

This commit is contained in:
Yan Da
2022-03-22 21:53:22 +08:00
parent 419bbe0f6e
commit f2ab318614
5 changed files with 593 additions and 461 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -395,30 +395,30 @@ class constexpr:
class tensor:
# infer dtype from ir type
@staticmethod
def _to_dtype(ir_type):
# block type
if ir_type.is_block():
scalar_ty = tensor._to_dtype(ir_type.scalar)
return block_type(scalar_ty, ir_type.get_block_shapes())
# pointer type
if ir_type.is_ptr():
element_ty = tensor._to_dtype(ir_type.element)
return pointer_type(element_ty)
# primitive type
if ir_type.is_void(): return void
if ir_type.is_int1(): return int1
if ir_type.is_int8(): return int8
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_bf16(): return bfloat16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64
raise ValueError(f"Unsupported type {ir_type.repr()}")
# # infer dtype from ir type
# @staticmethod
# def _to_dtype(ir_type):
# # block type
# if ir_type.is_block():
# scalar_ty = tensor._to_dtype(ir_type.scalar)
# return block_type(scalar_ty, ir_type.get_block_shapes())
# # pointer type
# if ir_type.is_ptr():
# element_ty = tensor._to_dtype(ir_type.element)
# return pointer_type(element_ty)
# # primitive type
# if ir_type.is_void(): return void
# if ir_type.is_int1(): return int1
# if ir_type.is_int8(): return int8
# if ir_type.is_int16(): return int16
# if ir_type.is_int32(): return int32
# if ir_type.is_int64(): return int64
# if ir_type.is_fp8(): return float8
# if ir_type.is_fp16(): return float16
# if ir_type.is_bf16(): return bfloat16
# if ir_type.is_fp32(): return float32
# if ir_type.is_fp64(): return float64
# raise ValueError(f"Unsupported type {ir_type.repr()}")
def __init__(self, handle, type: dtype):
# IR handle