From 3e2953f3570acc807f975eee58f1565483841b72 Mon Sep 17 00:00:00 2001 From: Da Yan Date: Wed, 10 Aug 2022 09:59:32 -0700 Subject: [PATCH] Allow multiple_of and max_contiguous to accept n-d values (#617) --- include/triton/ir/instructions.h | 6 +++--- include/triton/ir/metadata.h | 8 +++++--- include/triton/ir/module.h | 2 +- lib/codegen/analysis/align.cc | 12 ++++++------ lib/ir/metadata.cc | 4 ++-- python/src/triton.cc | 4 ++-- python/triton/language/core.py | 26 ++++++++++++++++++++------ python/triton/language/semantic.py | 12 ++++++++---- 8 files changed, 47 insertions(+), 27 deletions(-) diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 1bad86c33..8a1c3f7cf 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -59,8 +59,8 @@ public: std::string repr() const { return repr_impl(); } // metadata void set_metadata(ir::metadata::kind_t kind, - unsigned value) { metadatas_[kind] = value;} - unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} + std::vector value) { metadatas_[kind] = value;} + std::vector get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} // cloning ir::instruction* clone() { ir::instruction* res = clone_impl(); @@ -77,7 +77,7 @@ public: private: basic_block *parent_; - std::map metadatas_; + std::map> metadatas_; value_id_t id_; }; diff --git a/include/triton/ir/metadata.h b/include/triton/ir/metadata.h index 9d4fb1137..69512c6b0 100644 --- a/include/triton/ir/metadata.h +++ b/include/triton/ir/metadata.h @@ -3,6 +3,8 @@ #ifndef _TRITON_IR_METADATA_H_ #define _TRITON_IR_METADATA_H_ +#include + namespace triton{ namespace ir{ @@ -16,14 +18,14 @@ public: }; private: - metadata(kind_t kind, unsigned value); + metadata(kind_t kind, std::vector value); public: - static metadata* get(kind_t kind, unsigned value); + static metadata* get(kind_t kind, std::vector value); private: kind_t kind_; - unsigned value_; + std::vector value_; }; } diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index aa279af98..1ed0b6646 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -70,7 +70,7 @@ private: class module { typedef std::pair val_key_t; - typedef std::pair md_pair_t; + typedef std::pair> md_pair_t; friend class function; public: diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 1c48a4c05..6bd6e4ef9 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -366,9 +366,9 @@ std::vector align::populate_max_contiguous(ir::value *v){ if(max_contiguous_.find(v) != max_contiguous_.end()) return max_contiguous_.at(v); if(auto *x = dynamic_cast(v)){ - unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous); - if(max_contiguous > 0) - return add_to_cache(x, {max_contiguous}, max_contiguous_); + std::vector max_contiguous = x->get_metadata(ir::metadata::max_contiguous); + if(!max_contiguous.empty()) + return add_to_cache(x, max_contiguous, max_contiguous_); } if(auto *x = dynamic_cast(v)) return populate_max_contiguous_cast(x); @@ -521,9 +521,9 @@ std::vector align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); if(auto *x = dynamic_cast(v)){ - unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); - if(multiple_of > 0) - return add_to_cache(x, {multiple_of}, starting_multiple_); + std::vector multiple_of = x->get_metadata(ir::metadata::multiple_of); + if(!multiple_of.empty()) + return add_to_cache(x, multiple_of, starting_multiple_); } if(auto *x = dynamic_cast(v)) return populate_starting_multiple_cast(x); diff --git a/lib/ir/metadata.cc b/lib/ir/metadata.cc index 16bc059c5..9d31963c2 100644 --- a/lib/ir/metadata.cc +++ b/lib/ir/metadata.cc @@ -3,10 +3,10 @@ namespace triton{ namespace ir{ -metadata::metadata(kind_t kind, unsigned value) +metadata::metadata(kind_t kind, std::vector value) : kind_(kind), value_(value) { } -metadata* metadata::get(kind_t kind, unsigned value) { +metadata* metadata::get(kind_t kind, std::vector value) { return new metadata(kind, value); } diff --git a/python/src/triton.cc b/python/src/triton.cc index f72513395..a606c051a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -625,13 +625,13 @@ void init_triton_ir(py::module &&m) { .def(py::init<>()); py::class_(m, "value") - .def("multiple_of", [](ir::value *self, int val) { + .def("multiple_of", [](ir::value *self, std::vector val) { if (auto *instr = dynamic_cast(self)) { instr->set_metadata(ir::metadata::multiple_of, val); } else throw std::runtime_error("multiple_of"); }) - .def("max_contiguous", [](ir::value *self, int val) { + .def("max_contiguous", [](ir::value *self, std::vector val) { if (auto *instr = dynamic_cast(self)) { instr->set_metadata(ir::metadata::max_contiguous, val); } else diff --git a/python/triton/language/core.py b/python/triton/language/core.py index fdf9063a7..29a128321 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1088,21 +1088,35 @@ def debug_barrier(_builder=None): @builtin -def multiple_of(input, value, _builder=None): +def multiple_of(input, values, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - value = _constexpr_to_value(value) - return semantic.multiple_of(input, value) + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) @builtin -def max_contiguous(input, value, _builder=None): +def max_contiguous(input, values, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - value = _constexpr_to_value(value) - return semantic.max_contiguous(input, value) + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ecd740114..a860dc763 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1090,13 +1090,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: ## -def multiple_of(x: tl.tensor, value: int) -> tl.tensor: - x.handle.multiple_of(value) +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.multiple_of(values) return x -def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: - x.handle.max_contiguous(value) +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.max_contiguous(values) return x