Allow multiple_of and max_contiguous to accept n-d values (#617)
This commit is contained in:
@@ -59,8 +59,8 @@ public:
|
|||||||
std::string repr() const { return repr_impl(); }
|
std::string repr() const { return repr_impl(); }
|
||||||
// metadata
|
// metadata
|
||||||
void set_metadata(ir::metadata::kind_t kind,
|
void set_metadata(ir::metadata::kind_t kind,
|
||||||
unsigned value) { metadatas_[kind] = value;}
|
std::vector<unsigned> value) { metadatas_[kind] = value;}
|
||||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||||
// cloning
|
// cloning
|
||||||
ir::instruction* clone() {
|
ir::instruction* clone() {
|
||||||
ir::instruction* res = clone_impl();
|
ir::instruction* res = clone_impl();
|
||||||
@@ -77,7 +77,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
basic_block *parent_;
|
basic_block *parent_;
|
||||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_;
|
||||||
value_id_t id_;
|
value_id_t id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -3,6 +3,8 @@
|
|||||||
#ifndef _TRITON_IR_METADATA_H_
|
#ifndef _TRITON_IR_METADATA_H_
|
||||||
#define _TRITON_IR_METADATA_H_
|
#define _TRITON_IR_METADATA_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
@@ -16,14 +18,14 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
metadata(kind_t kind, unsigned value);
|
metadata(kind_t kind, std::vector<unsigned> value);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static metadata* get(kind_t kind, unsigned value);
|
static metadata* get(kind_t kind, std::vector<unsigned> value);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
kind_t kind_;
|
kind_t kind_;
|
||||||
unsigned value_;
|
std::vector<unsigned> value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -70,7 +70,7 @@ private:
|
|||||||
|
|
||||||
class module {
|
class module {
|
||||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
typedef std::pair<ir::metadata::kind_t, std::vector<unsigned>> md_pair_t;
|
||||||
friend class function;
|
friend class function;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@@ -366,9 +366,9 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
|||||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||||
return max_contiguous_.at(v);
|
return max_contiguous_.at(v);
|
||||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||||
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||||
if(max_contiguous > 0)
|
if(!max_contiguous.empty())
|
||||||
return add_to_cache(x, {max_contiguous}, max_contiguous_);
|
return add_to_cache(x, max_contiguous, max_contiguous_);
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||||
return populate_max_contiguous_cast(x);
|
return populate_max_contiguous_cast(x);
|
||||||
@@ -521,9 +521,9 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
|||||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||||
return starting_multiple_.at(v);
|
return starting_multiple_.at(v);
|
||||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||||
if(multiple_of > 0)
|
if(!multiple_of.empty())
|
||||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
return add_to_cache(x, multiple_of, starting_multiple_);
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||||
return populate_starting_multiple_cast(x);
|
return populate_starting_multiple_cast(x);
|
||||||
|
@@ -3,10 +3,10 @@
|
|||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
metadata::metadata(kind_t kind, unsigned value)
|
metadata::metadata(kind_t kind, std::vector<unsigned> value)
|
||||||
: kind_(kind), value_(value) { }
|
: kind_(kind), value_(value) { }
|
||||||
|
|
||||||
metadata* metadata::get(kind_t kind, unsigned value) {
|
metadata* metadata::get(kind_t kind, std::vector<unsigned> value) {
|
||||||
return new metadata(kind, value);
|
return new metadata(kind, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -625,13 +625,13 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def(py::init<>());
|
.def(py::init<>());
|
||||||
|
|
||||||
py::class_<ir::value>(m, "value")
|
py::class_<ir::value>(m, "value")
|
||||||
.def("multiple_of", [](ir::value *self, int val) {
|
.def("multiple_of", [](ir::value *self, std::vector<unsigned> val) {
|
||||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||||
instr->set_metadata(ir::metadata::multiple_of, val);
|
instr->set_metadata(ir::metadata::multiple_of, val);
|
||||||
} else
|
} else
|
||||||
throw std::runtime_error("multiple_of");
|
throw std::runtime_error("multiple_of");
|
||||||
})
|
})
|
||||||
.def("max_contiguous", [](ir::value *self, int val) {
|
.def("max_contiguous", [](ir::value *self, std::vector<unsigned> val) {
|
||||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||||
instr->set_metadata(ir::metadata::max_contiguous, val);
|
instr->set_metadata(ir::metadata::max_contiguous, val);
|
||||||
} else
|
} else
|
||||||
|
@@ -1088,21 +1088,35 @@ def debug_barrier(_builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def multiple_of(input, value, _builder=None):
|
def multiple_of(input, values, _builder=None):
|
||||||
"""
|
"""
|
||||||
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||||
"""
|
"""
|
||||||
value = _constexpr_to_value(value)
|
if isinstance(values, constexpr):
|
||||||
return semantic.multiple_of(input, value)
|
values = [values]
|
||||||
|
for i, d in enumerate(values):
|
||||||
|
if not isinstance(d, constexpr):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||||
|
if not isinstance(d.value, int):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||||
|
values = [x.value for x in values]
|
||||||
|
return semantic.multiple_of(input, values)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def max_contiguous(input, value, _builder=None):
|
def max_contiguous(input, values, _builder=None):
|
||||||
"""
|
"""
|
||||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||||
"""
|
"""
|
||||||
value = _constexpr_to_value(value)
|
if isinstance(values, constexpr):
|
||||||
return semantic.max_contiguous(input, value)
|
values = [values]
|
||||||
|
for i, d in enumerate(values):
|
||||||
|
if not isinstance(d, constexpr):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||||
|
if not isinstance(d.value, int):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||||
|
values = [x.value for x in values]
|
||||||
|
return semantic.max_contiguous(input, values)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
@@ -1090,13 +1090,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|||||||
|
|
||||||
##
|
##
|
||||||
|
|
||||||
def multiple_of(x: tl.tensor, value: int) -> tl.tensor:
|
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||||
x.handle.multiple_of(value)
|
if len(x.shape) != len(values):
|
||||||
|
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||||
|
x.handle.multiple_of(values)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def max_contiguous(x: tl.tensor, value: int) -> tl.tensor:
|
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||||
x.handle.max_contiguous(value)
|
if len(x.shape) != len(values):
|
||||||
|
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
||||||
|
x.handle.max_contiguous(values)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user