From 411bacb2a84c2047a1bfd77b7de334d0a666cd92 Mon Sep 17 00:00:00 2001 From: Sophia Wisdom Date: Wed, 4 Jan 2023 18:06:32 -0800 Subject: [PATCH] [FRONTEND] Add logical operations on constexprs (#1033) --- python/triton/compiler.py | 4 ---- python/triton/language/core.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 96ebfc3cf..aa2760e64 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -734,10 +734,6 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.values) == 2 lhs = self.visit(node.values[0]) rhs = self.visit(node.values[1]) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value fn = { ast.And: 'logical_and', diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 15dd8462a..9750c2237 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -403,6 +403,18 @@ class constexpr: def __neg__(self): return constexpr(-self.value) + def __and__(self, other): + return constexpr(self.value & other.value) + + def logical_and(self, other): + return constexpr(self.value and other.value) + + def __or__(self, other): + return constexpr(self.value | other.value) + + def logical_or(self, other): + return constexpr(self.value or other.value) + def __pos__(self): return constexpr(+self.value)