diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 02c6f8f46..206a2dd9f 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -1,5 +1,7 @@ import os import struct +from typing import Optional, Dict, List + import torch # C bindings import triton._C.libtriton.triton as _triton @@ -33,14 +35,31 @@ def synchronize(device): dev_id = -1 if dev_id is None else dev_id _torch_utils.synchronize(dev_id) -def read(path, kernel_names=[]): +def read(path, kernel_names:Optional[List]=None): + if kernel_names is None: + kernel_names = [] with open(path, 'r') as f: source = f.read() source = _triton.extract_kernels(source, kernel_names) return source class kernel: - def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]): + def __init__(self, + src, + device, + defines: Optional[Dict]=None, + num_warps:int=4, + autotune_vals:Optional[List]=None, + autotune_key:Optional[List]=None): + + if defines is None: + defines = {} + if autotune_vals is None: + autotune_vals = [] + if autotune_key is None: + autotune_key = [] + + # check if src is empty if src == '': raise ValueError('Kernel source code is empty')