[ALL] Merge master (#447)

This commit is contained in:
Philippe Tillet
2022-01-30 20:21:20 -08:00
committed by GitHub
parent bef76b142a
commit 807d8a1945
14 changed files with 199 additions and 130 deletions

View File

@@ -23,6 +23,8 @@ def get_llvm():
paths = [p for p in paths if p is not None]
if paths:
return '', ''
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp'
@@ -104,7 +106,7 @@ class CMakeBuild(build_ext):
build_args = ["--config", cfg]
if platform.system() == "Windows":
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
cmake_args += ["-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
build_args += ["--", "/m"]

View File

@@ -15,6 +15,7 @@
#include <pybind11/stl.h>
#include "Python.h"
#include <regex>
#include <sstream>
#include <string>
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"

View File

@@ -25,13 +25,13 @@ def get_p2p_matrix():
def get_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "OK")
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
def get_non_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "NS")
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
p2p_devices = get_p2p_devices()

View File

@@ -358,9 +358,6 @@ class CodeGenerator(ast.NodeVisitor):
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Str(self, node):
return ast.literal_eval(node)
def visit_Subscript(self, node):
assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value)
@@ -441,9 +438,6 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Index(self, node):
return self.visit(node.value)
def visit_NameConstant(self, node):
return node.value
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
@@ -460,10 +454,23 @@ class CodeGenerator(ast.NodeVisitor):
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args]
return fn(*args, **kws)
def visit_Num(self, node):
return triton.language.constexpr(node.n)
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return triton.language.constexpr(node.value)
def visit_Num(self, node):
return triton.language.constexpr(node.n)
def visit_Str(self, node):
return triton.language.constexpr(ast.literal_eval(node))
def visit_Attribute(self, node):
lhs = self.visit(node.value)

View File

@@ -130,6 +130,94 @@ float64 = dtype(ir.type.get_fp64)
# pointer types
pi32_t = pointer_dtype(int32)
# -----------------------
# constexpr
# -----------------------
class constexpr:
"""
This class is used to store a value that is known at compile-time.
"""
def __init__(self, value):
if isinstance(value, constexpr):
self.value = value.value
else:
self.value = value
def __repr__(self) -> str:
return f"constexpr[{self.value}]"
#
def __add__(self, other):
return self.value + other.value
def __radd__(self, other):
return other.value + self.value
def __sub__(self, other):
return self.value - other.value
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
def __rgt__(self, other):
return other.value > self.value
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __bool__(self):
return bool(self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)
class block:
@staticmethod
@@ -296,7 +384,7 @@ class block:
dst_shape = []
curr = 0
for sl in slices:
if sl is None:
if isinstance(sl, constexpr) and sl.value is None:
dst_shape.append(1)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value)
@@ -312,93 +400,6 @@ class block:
return frontend.cast(self, dtype, _builder)
# -----------------------
# constexpr
# -----------------------
class constexpr:
"""
This class is used to store a value that is known at compile-time.
"""
def __init__(self, value):
if isinstance(value, constexpr):
self.value = value.value
else:
self.value = value
def __repr__(self) -> str:
return f"constexpr[{self.value}]"
#
def __add__(self, other):
return self.value + other.value
def __radd__(self, other):
return other.value + self.value
def __sub__(self, other):
return self.value - other.value
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
def __rgt__(self, other):
return other.value > self.value
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __bool__(self):
return bool(self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)
# -----------------------
# SPMD Programming Model
# -----------------------