[PYTHON] Avoid dangerous global variables in kwarg default values (#68)

This commit is contained in:
Tom B Brown
2021-02-18 14:56:54 -08:00
committed by Philippe Tillet
parent dcd14c4e8d
commit 7aa4d080b3

View File

@@ -1,5 +1,7 @@
import os
import struct
from typing import Optional, Dict, List
import torch
# C bindings
import triton._C.libtriton.triton as _triton
@@ -33,14 +35,31 @@ def synchronize(device):
dev_id = -1 if dev_id is None else dev_id
_torch_utils.synchronize(dev_id)
def read(path, kernel_names=[]):
def read(path, kernel_names:Optional[List]=None):
if kernel_names is None:
kernel_names = []
with open(path, 'r') as f:
source = f.read()
source = _triton.extract_kernels(source, kernel_names)
return source
class kernel:
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]):
def __init__(self,
src,
device,
defines: Optional[Dict]=None,
num_warps:int=4,
autotune_vals:Optional[List]=None,
autotune_key:Optional[List]=None):
if defines is None:
defines = {}
if autotune_vals is None:
autotune_vals = []
if autotune_key is None:
autotune_key = []
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')