New python binding
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -395,30 +395,30 @@ class constexpr:
|
||||
|
||||
|
||||
class tensor:
|
||||
# infer dtype from ir type
|
||||
@staticmethod
|
||||
def _to_dtype(ir_type):
|
||||
# block type
|
||||
if ir_type.is_block():
|
||||
scalar_ty = tensor._to_dtype(ir_type.scalar)
|
||||
return block_type(scalar_ty, ir_type.get_block_shapes())
|
||||
# pointer type
|
||||
if ir_type.is_ptr():
|
||||
element_ty = tensor._to_dtype(ir_type.element)
|
||||
return pointer_type(element_ty)
|
||||
# primitive type
|
||||
if ir_type.is_void(): return void
|
||||
if ir_type.is_int1(): return int1
|
||||
if ir_type.is_int8(): return int8
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_bf16(): return bfloat16
|
||||
if ir_type.is_fp32(): return float32
|
||||
if ir_type.is_fp64(): return float64
|
||||
raise ValueError(f"Unsupported type {ir_type.repr()}")
|
||||
# # infer dtype from ir type
|
||||
# @staticmethod
|
||||
# def _to_dtype(ir_type):
|
||||
# # block type
|
||||
# if ir_type.is_block():
|
||||
# scalar_ty = tensor._to_dtype(ir_type.scalar)
|
||||
# return block_type(scalar_ty, ir_type.get_block_shapes())
|
||||
# # pointer type
|
||||
# if ir_type.is_ptr():
|
||||
# element_ty = tensor._to_dtype(ir_type.element)
|
||||
# return pointer_type(element_ty)
|
||||
# # primitive type
|
||||
# if ir_type.is_void(): return void
|
||||
# if ir_type.is_int1(): return int1
|
||||
# if ir_type.is_int8(): return int8
|
||||
# if ir_type.is_int16(): return int16
|
||||
# if ir_type.is_int32(): return int32
|
||||
# if ir_type.is_int64(): return int64
|
||||
# if ir_type.is_fp8(): return float8
|
||||
# if ir_type.is_fp16(): return float16
|
||||
# if ir_type.is_bf16(): return bfloat16
|
||||
# if ir_type.is_fp32(): return float32
|
||||
# if ir_type.is_fp64(): return float64
|
||||
# raise ValueError(f"Unsupported type {ir_type.repr()}")
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
|
Reference in New Issue
Block a user