[LANG] Various (relatively minor) improvements (#320)
This commit is contained in:
@@ -103,7 +103,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values = []
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
arg_values.append(self.constants[i])
|
||||
cst = triton.language.core._to_ir(self.constants[i], self.builder)
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[i].type.is_ptr()
|
||||
@@ -463,9 +464,6 @@ class Kernel:
|
||||
@staticmethod
|
||||
def _type_name(obj):
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
triton.language.float8: 'f8',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float16: 'f16',
|
||||
@@ -477,12 +475,25 @@ class Kernel:
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
return type_names[obj]
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, int):
|
||||
if abs(obj) <= 0xffffffff:
|
||||
return 'I'
|
||||
return 'L'
|
||||
if isinstance(obj, float):
|
||||
return 'f'
|
||||
if isinstance(obj, bool):
|
||||
return 'B'
|
||||
assert False
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
type_map = {
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'L': _triton.ir.type.get_int64,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
@@ -498,11 +509,11 @@ class Kernel:
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
name = Kernel._type_name(obj.dtype)
|
||||
name = Kernel._type_name(obj)
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
name = Kernel._type_name(obj.__class__)
|
||||
name = Kernel._type_name(obj)
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
@@ -511,7 +522,7 @@ class Kernel:
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
||||
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@@ -646,7 +657,7 @@ class Kernel:
|
||||
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
callable = drv_cache[key]
|
||||
|
Reference in New Issue
Block a user