Allow multiple_of and max_contiguous to accept n-d values (#617)
This commit is contained in:
@@ -625,13 +625,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(py::init<>());
|
||||
|
||||
py::class_<ir::value>(m, "value")
|
||||
.def("multiple_of", [](ir::value *self, int val) {
|
||||
.def("multiple_of", [](ir::value *self, std::vector<unsigned> val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::multiple_of, val);
|
||||
} else
|
||||
throw std::runtime_error("multiple_of");
|
||||
})
|
||||
.def("max_contiguous", [](ir::value *self, int val) {
|
||||
.def("max_contiguous", [](ir::value *self, std::vector<unsigned> val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::max_contiguous, val);
|
||||
} else
|
||||
|
@@ -1088,21 +1088,35 @@ def debug_barrier(_builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def multiple_of(input, value, _builder=None):
|
||||
def multiple_of(input, values, _builder=None):
|
||||
"""
|
||||
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||
"""
|
||||
value = _constexpr_to_value(value)
|
||||
return semantic.multiple_of(input, value)
|
||||
if isinstance(values, constexpr):
|
||||
values = [values]
|
||||
for i, d in enumerate(values):
|
||||
if not isinstance(d, constexpr):
|
||||
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.multiple_of(input, values)
|
||||
|
||||
|
||||
@builtin
|
||||
def max_contiguous(input, value, _builder=None):
|
||||
def max_contiguous(input, values, _builder=None):
|
||||
"""
|
||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||
"""
|
||||
value = _constexpr_to_value(value)
|
||||
return semantic.max_contiguous(input, value)
|
||||
if isinstance(values, constexpr):
|
||||
values = [values]
|
||||
for i, d in enumerate(values):
|
||||
if not isinstance(d, constexpr):
|
||||
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_contiguous(input, values)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
@@ -1090,13 +1090,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
##
|
||||
|
||||
def multiple_of(x: tl.tensor, value: int) -> tl.tensor:
|
||||
x.handle.multiple_of(value)
|
||||
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
x.handle.multiple_of(values)
|
||||
return x
|
||||
|
||||
|
||||
def max_contiguous(x: tl.tensor, value: int) -> tl.tensor:
|
||||
x.handle.max_contiguous(value)
|
||||
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
||||
x.handle.max_contiguous(values)
|
||||
return x
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user