[PYTHON] Avoid dangerous global variables in kwarg default values (#68)

This commit is contained in:
Tom B Brown
2021-02-18 14:56:54 -08:00
committed by Philippe Tillet
parent dcd14c4e8d
commit 7aa4d080b3

View File

@@ -1,5 +1,7 @@
import os import os
import struct import struct
from typing import Optional, Dict, List
import torch import torch
# C bindings # C bindings
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
@@ -33,14 +35,31 @@ def synchronize(device):
dev_id = -1 if dev_id is None else dev_id dev_id = -1 if dev_id is None else dev_id
_torch_utils.synchronize(dev_id) _torch_utils.synchronize(dev_id)
def read(path, kernel_names=[]): def read(path, kernel_names:Optional[List]=None):
if kernel_names is None:
kernel_names = []
with open(path, 'r') as f: with open(path, 'r') as f:
source = f.read() source = f.read()
source = _triton.extract_kernels(source, kernel_names) source = _triton.extract_kernels(source, kernel_names)
return source return source
class kernel: class kernel:
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]): def __init__(self,
src,
device,
defines: Optional[Dict]=None,
num_warps:int=4,
autotune_vals:Optional[List]=None,
autotune_key:Optional[List]=None):
if defines is None:
defines = {}
if autotune_vals is None:
autotune_vals = []
if autotune_key is None:
autotune_key = []
# check if src is empty # check if src is empty
if src == '': if src == '':
raise ValueError('Kernel source code is empty') raise ValueError('Kernel source code is empty')