diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index a80bc471f..357fffc6a 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -130,9 +130,9 @@ public: value *create_xor(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs); // Input/Output - value *create_load(value *arg, load_inst::CACHE_MODIFIER cache); + value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile); value *create_store(value *ptr, value *val); - value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); + value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile); value *create_masked_store(value *ptr, value *val, value *mask); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index b6cc7dcc1..c90480f1e 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -68,7 +68,7 @@ struct dispatch{ static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder); + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 699d22257..7c147f634 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -402,8 +402,11 @@ public: }; CACHE_MODIFIER get_cache_modifier() const { return cache_; } + bool get_is_volatile() const { return is_volatile_; } + protected: load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, + bool is_volatile, const std::string &name = "", instruction *next = nullptr); std::string get_cache_modifier_repr() const { if (cache_ == CA) return ".ca"; @@ -412,20 +415,24 @@ protected: } CACHE_MODIFIER cache_; + std::string get_volatile_repr() { + return is_volatile_ ? ".volatile" : ""; + } + bool is_volatile_; + private: static type *get_pointee_type(type *ty); - }; // unmasked load class unmasked_load_inst: public load_inst { private: std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } - unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); + unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next); public: static unmasked_load_inst* create(value *ptr, - CACHE_MODIFIER cache, + CACHE_MODIFIER cache, bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_load_inst) @@ -436,7 +443,7 @@ public: class masked_load_inst: public load_inst { private: std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } - masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, + masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next); public: @@ -445,7 +452,7 @@ public: value *get_false_value_operand() { return get_operand(2); } // factory method static masked_load_inst* create(value *ptr, value *mask, value *false_value, - CACHE_MODIFIER cache, + CACHE_MODIFIER cache, bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_inst) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a838bfa5a..a6148b2d1 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -629,7 +629,11 @@ void generator::visit_load_inst(ir::load_inst* x){ // ----- std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate - asm_oss << " ld.global"; + asm_oss << " ld"; +// std::cout << x->get_is_volatile() << std::endl; + if(x->get_is_volatile()) + asm_oss << ".volatile"; + asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; if(n_words > 1) diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index f7ebdad80..864fea85c 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -207,7 +207,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), if_value->get_mask_operand(), select->get_else_value_op(), - if_value->get_cache_modifier()); + if_value->get_cache_modifier(), + if_value->get_is_volatile()); select->replace_all_uses_with(new_load); return true; } diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index bc249841b..eb3fe6164 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) { false_value = remat_false_value; } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier()); + first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile()); for (int stage = 1; stage < num_stages-1; ++stage) { // mask is the loop condition of the previous iteration @@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) { first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); false_value = remat_false_value; } - first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier()); + first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile()); } // create new phis for induction variables @@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); // phi node @@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) { } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier()); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); ir::value* next_ptr = ptr->get_value_for_block(block); @@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); // phi node builder.set_insert_point(block->get_first_non_phi()); ir::phi_node* new_load = builder.create_phi(ty, 2); @@ -327,4 +327,4 @@ void pipeline::run(ir::module &mod) { } } -} \ No newline at end of file +} diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index cc1d354ee..feac3c6b6 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){ - return insert(unmasked_load_inst::create(ptr, cache)); +value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){ + return insert(unmasked_load_inst::create(ptr, cache, is_volatile)); } value *builder::create_store(value *ptr, value *val){ return insert(unmasked_store_inst::create(ptr, val)); } -value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){ - return insert(masked_load_inst::create(ptr, mask, false_value, cache)); +value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){ + return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile)); } value *builder::create_masked_store(value *ptr, value *val, value *mask){ diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index c4e8ccafb..2b8ae0259 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -500,7 +500,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build // Memory Operators //===----------------------------------------------------------------------===// -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) { +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) { if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); if(ptr->get_type()->is_block_ty()){ @@ -520,6 +520,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); ptr = dispatch::cast(ptr, ptr_ty, builder); } + // cache modifier load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default if (!cache_modifier.empty()) { if (cache_modifier == ".ca") @@ -530,7 +531,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); } if (!mask && !other) - return builder->create_load(ptr, cache); + return builder->create_load(ptr, cache, is_volatile); if (!mask) throw std::runtime_error("`other` cannot be provided without `mask`"); auto shape = ptr->get_type()->get_block_shapes(); @@ -539,7 +540,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con if(ptr->get_type()->is_block_ty()) other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); } - return builder->create_masked_load(ptr, mask, other, cache); + return builder->create_masked_load(ptr, mask, other, cache, is_volatile); } ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 00d801616..0206b7e77 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n { } // load_inst -load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache) +load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile) { } // load @@ -448,35 +448,35 @@ type *load_inst::get_pointee_type(type *ty) { } // unmasked_load -unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) { +unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) + : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) { set_operand(0, ptr); } -unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) { - return new unmasked_load_inst(ptr, cache, name, next); +unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) { + return new unmasked_load_inst(ptr, cache, is_volatile, name, next); } // masked load -masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, +masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) { + : load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, + load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) { - return new masked_load_inst(ptr, mask, false_value, cache, name, next); + return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next); } // masked load async masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) { + : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); diff --git a/python/src/triton.cc b/python/src/triton.cc index cec6fba94..783b0406a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -600,7 +600,8 @@ void init_triton_ir(py::module &&m) { py::class_(m, "constant_int") .def_property_readonly("value", &ir::constant_int::get_value) - .def("__int__", [](ir::constant_int *self) { return self->get_value(); }); + .def("__int__", [](ir::constant_int *self) { return self->get_value(); }) + .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") .def_property_readonly("value", &ir::constant_fp::get_value); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e939319aa..d5d3313e5 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -519,7 +519,7 @@ def dot(input, other, _builder=None): @builtin -def load(pointer, mask=None, other=None, cache_modifier="", _builder=None): +def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None): """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. @@ -536,7 +536,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", _builder=None): :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, _builder) + return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder) @builtin