[PYTHON] Avoid dangerous global variables in kwarg default values (#68)
This commit is contained in:
committed by
Philippe Tillet
parent
dcd14c4e8d
commit
7aa4d080b3
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# C bindings
|
# C bindings
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
@@ -33,14 +35,31 @@ def synchronize(device):
|
|||||||
dev_id = -1 if dev_id is None else dev_id
|
dev_id = -1 if dev_id is None else dev_id
|
||||||
_torch_utils.synchronize(dev_id)
|
_torch_utils.synchronize(dev_id)
|
||||||
|
|
||||||
def read(path, kernel_names=[]):
|
def read(path, kernel_names:Optional[List]=None):
|
||||||
|
if kernel_names is None:
|
||||||
|
kernel_names = []
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
source = _triton.extract_kernels(source, kernel_names)
|
source = _triton.extract_kernels(source, kernel_names)
|
||||||
return source
|
return source
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]):
|
def __init__(self,
|
||||||
|
src,
|
||||||
|
device,
|
||||||
|
defines: Optional[Dict]=None,
|
||||||
|
num_warps:int=4,
|
||||||
|
autotune_vals:Optional[List]=None,
|
||||||
|
autotune_key:Optional[List]=None):
|
||||||
|
|
||||||
|
if defines is None:
|
||||||
|
defines = {}
|
||||||
|
if autotune_vals is None:
|
||||||
|
autotune_vals = []
|
||||||
|
if autotune_key is None:
|
||||||
|
autotune_key = []
|
||||||
|
|
||||||
|
|
||||||
# check if src is empty
|
# check if src is empty
|
||||||
if src == '':
|
if src == '':
|
||||||
raise ValueError('Kernel source code is empty')
|
raise ValueError('Kernel source code is empty')
|
||||||
|
Reference in New Issue
Block a user