[PYTHON] Avoid dangerous global variables in kwarg default values (#68)
This commit is contained in:
committed by
Philippe Tillet
parent
dcd14c4e8d
commit
7aa4d080b3
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
import torch
|
||||
# C bindings
|
||||
import triton._C.libtriton.triton as _triton
|
||||
@@ -33,14 +35,31 @@ def synchronize(device):
|
||||
dev_id = -1 if dev_id is None else dev_id
|
||||
_torch_utils.synchronize(dev_id)
|
||||
|
||||
def read(path, kernel_names=[]):
|
||||
def read(path, kernel_names:Optional[List]=None):
|
||||
if kernel_names is None:
|
||||
kernel_names = []
|
||||
with open(path, 'r') as f:
|
||||
source = f.read()
|
||||
source = _triton.extract_kernels(source, kernel_names)
|
||||
return source
|
||||
|
||||
class kernel:
|
||||
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]):
|
||||
def __init__(self,
|
||||
src,
|
||||
device,
|
||||
defines: Optional[Dict]=None,
|
||||
num_warps:int=4,
|
||||
autotune_vals:Optional[List]=None,
|
||||
autotune_key:Optional[List]=None):
|
||||
|
||||
if defines is None:
|
||||
defines = {}
|
||||
if autotune_vals is None:
|
||||
autotune_vals = []
|
||||
if autotune_key is None:
|
||||
autotune_key = []
|
||||
|
||||
|
||||
# check if src is empty
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
|
Reference in New Issue
Block a user