[ALL] Merge master (#447)
This commit is contained in:
@@ -23,6 +23,8 @@ def get_llvm():
|
||||
paths = [p for p in paths if p is not None]
|
||||
if paths:
|
||||
return '', ''
|
||||
if platform.system() == "Windows":
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||
dir = '/tmp'
|
||||
@@ -104,7 +106,7 @@ class CMakeBuild(build_ext):
|
||||
build_args = ["--config", cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
||||
cmake_args += ["-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
||||
if sys.maxsize > 2**32:
|
||||
cmake_args += ["-A", "x64"]
|
||||
build_args += ["--", "/m"]
|
||||
|
@@ -15,6 +15,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include "Python.h"
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
|
@@ -25,13 +25,13 @@ def get_p2p_matrix():
|
||||
def get_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "OK")
|
||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||
|
||||
|
||||
def get_non_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "NS")
|
||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||
|
||||
|
||||
p2p_devices = get_p2p_devices()
|
||||
|
@@ -358,9 +358,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Str(self, node):
|
||||
return ast.literal_eval(node)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
@@ -441,9 +438,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Index(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_NameConstant(self, node):
|
||||
return node.value
|
||||
|
||||
def visit_keyword(self, node):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
@@ -460,10 +454,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
for arg in args]
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Num(self, node):
|
||||
return triton.language.constexpr(node.n)
|
||||
def visit_Constant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
def visit_NameConstant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
||||
def visit_Num(self, node):
|
||||
return triton.language.constexpr(node.n)
|
||||
|
||||
def visit_Str(self, node):
|
||||
return triton.language.constexpr(ast.literal_eval(node))
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
|
@@ -130,6 +130,94 @@ float64 = dtype(ir.type.get_fp64)
|
||||
# pointer types
|
||||
pi32_t = pointer_dtype(int32)
|
||||
|
||||
# -----------------------
|
||||
# constexpr
|
||||
# -----------------------
|
||||
|
||||
|
||||
class constexpr:
|
||||
"""
|
||||
This class is used to store a value that is known at compile-time.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
if isinstance(value, constexpr):
|
||||
self.value = value.value
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"constexpr[{self.value}]"
|
||||
|
||||
#
|
||||
def __add__(self, other):
|
||||
return self.value + other.value
|
||||
|
||||
def __radd__(self, other):
|
||||
return other.value + self.value
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.value - other.value
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other.value - self.value
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.value * other.value
|
||||
|
||||
def __rmul__(self, other):
|
||||
return other.value * self.value
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.value / other.value
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return other.value / self.value
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self.value // other.value
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return other.value // self.value
|
||||
|
||||
#
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value > other.value
|
||||
|
||||
def __rgt__(self, other):
|
||||
return other.value > self.value
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.value >= other.value
|
||||
|
||||
def __rge__(self, other):
|
||||
return other.value >= self.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __rlt__(self, other):
|
||||
return other.value < self.value
|
||||
|
||||
def __le__(self, other):
|
||||
return self.value <= other.value
|
||||
|
||||
def __rle__(self, other):
|
||||
return other.value <= self.value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.value != other.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.value(*args, **kwds)
|
||||
|
||||
|
||||
class block:
|
||||
@staticmethod
|
||||
@@ -296,7 +384,7 @@ class block:
|
||||
dst_shape = []
|
||||
curr = 0
|
||||
for sl in slices:
|
||||
if sl is None:
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
dst_shape.append(1)
|
||||
elif sl == slice(None, None, None):
|
||||
dst_shape.append(src_shape[curr].value)
|
||||
@@ -312,93 +400,6 @@ class block:
|
||||
return frontend.cast(self, dtype, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# constexpr
|
||||
# -----------------------
|
||||
|
||||
class constexpr:
|
||||
"""
|
||||
This class is used to store a value that is known at compile-time.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
if isinstance(value, constexpr):
|
||||
self.value = value.value
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"constexpr[{self.value}]"
|
||||
|
||||
#
|
||||
def __add__(self, other):
|
||||
return self.value + other.value
|
||||
|
||||
def __radd__(self, other):
|
||||
return other.value + self.value
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.value - other.value
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other.value - self.value
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.value * other.value
|
||||
|
||||
def __rmul__(self, other):
|
||||
return other.value * self.value
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.value / other.value
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return other.value / self.value
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self.value // other.value
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return other.value // self.value
|
||||
|
||||
#
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value > other.value
|
||||
|
||||
def __rgt__(self, other):
|
||||
return other.value > self.value
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.value >= other.value
|
||||
|
||||
def __rge__(self, other):
|
||||
return other.value >= self.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __rlt__(self, other):
|
||||
return other.value < self.value
|
||||
|
||||
def __le__(self, other):
|
||||
return self.value <= other.value
|
||||
|
||||
def __rle__(self, other):
|
||||
return other.value <= self.value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.value != other.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.value(*args, **kwds)
|
||||
|
||||
# -----------------------
|
||||
# SPMD Programming Model
|
||||
# -----------------------
|
||||
|
Reference in New Issue
Block a user