#include #include "triton/lang/statement.h" #include "triton/lang/declaration.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/basic_block.h" #include "triton/ir/builder.h" #include "triton/ir/type.h" #include "triton/ir/metadata.h" namespace triton{ namespace lang{ /* Declaration specifier */ ir::type* typed_declaration_specifier::type(ir::module *mod) const { ir::context &ctx = mod->get_context(); switch (ty_) { case VOID_T: return ir::type::get_void_ty(ctx); case INT1_T: return ir::type::get_int1_ty(ctx); case INT8_T: return ir::type::get_int8_ty(ctx); case INT16_T: return ir::type::get_int16_ty(ctx); case INT32_T: return ir::type::get_int32_ty(ctx); case INT64_T: return ir::type::get_int64_ty(ctx); case FLOAT16_T: return ir::type::get_half_ty(ctx); case FLOAT32_T: return ir::type::get_float_ty(ctx); case FLOAT64_T: return ir::type::get_double_ty(ctx); default: throw std::runtime_error("unreachable"); } } std::vector typed_declaration_specifier::modifiers() const { return {}; } ir::type* declaration_modifier::type(ir::module *mod) const { return decl_spec_->type(mod); } std::vector declaration_modifier::modifiers() const { auto result = decl_spec_->modifiers(); result.push_back(mod_); return result; } /* Parameter */ ir::type* parameter::type(ir::module *mod) const { return decl_->type(mod, spec_->type(mod), {}); } std::vector parameter::modifiers() const { return spec_->modifiers(); } const identifier *parameter::id() const { return decl_->id(); } /* Declarators */ ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ if(ptr_) return type_impl(mod, ptr_->type(mod, type, storage), storage); return type_impl(mod, type, storage); } // Identifier ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{ return type; } const std::string &identifier::name() const{ return name_; } // Tile ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{ ir::type::tile_shapes_t shapes; for(expression *expr: shapes_->values()){ ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); if(shape == nullptr) throw std::runtime_error("tile shapes must be constant expressions"); shapes.push_back(shape); } return ir::tile_type::get(type, shapes); } // Pointer ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ auto is_cst = [](modifier* x){ return x->is_cst_space(); }; bool is_ptr_to_const = std::find_if(storage.begin(), storage.end(), is_cst) != storage.end(); return ir::pointer_type::get(type, is_ptr_to_const?4:1); } // Function void function::bind_parameters(ir::module *mod, ir::function *fn) const{ std::vector args = fn->args(); assert(args.size() == args_->values().size()); for(size_t i = 0; i < args.size(); i++){ parameter *param_i = args_->values().at(i); const identifier *id_i = param_i->id(); if(id_i){ args[i]->set_name(id_i->name()); mod->set_value(id_i->name(), nullptr, args[i]); mod->get_scope().types[id_i->name()] = args[i]->get_type(); } } } ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{ std::vector types; for(parameter* param: args_->values()) types.push_back(param->type(mod)); return ir::function_type::get(type, types); } /* Declaration */ ir::value* declaration::codegen(ir::module* mod) const{ for(initializer *init: init_->values()) init->set_specifier(spec_); init_->codegen(mod); return nullptr; } /* Initializer */ ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ return decl_->type(mod, type, storage); } void initializer::set_specifier(const declaration_specifier *spec) { spec_ = spec; } ir::value* initializer::codegen(ir::module * mod) const{ std::vector modifiers = spec_->modifiers(); ir::type *ty = decl_->type(mod, spec_->type(mod), modifiers); std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); auto is_tunable = [](modifier* x){ return x->is_tunable(); }; if(std::find_if(modifiers.begin(), modifiers.end(), is_tunable) != modifiers.end()){ auto csts = dynamic_cast*>((node*)expr_); if(csts == nullptr) throw std::runtime_error("must specify constant list for metaparameters"); std::vector values; for(constant* cst: csts->values()) values.push_back(cst->value()); value = ir::metaparameter::create(mod->get_context(), ty, values); mod->register_global(name, value); } else if(expr_){ value = expr_->codegen(mod); value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty()); implicit_broadcast(mod, ty, value); } value->set_name(name); // metadata auto is_multiple_of = [](modifier* x){ return x->is_multiple_of(); }; auto it = std::find_if(modifiers.begin(), modifiers.end(), is_multiple_of); if(it != modifiers.end()) (*it)->add_metadata(mod, name); // register mod->set_value(name, value); mod->get_scope().types[name] = ty; if(auto *x = dynamic_cast(value)) mod->add_alloc(x); // constants auto is_cst = [](modifier* x){ return x->is_cst(); }; if(std::find_if(modifiers.begin(), modifiers.end(), is_cst) != modifiers.end()) mod->set_const(name); return value; } /* Type name */ ir::type *type_name::type(ir::module *mod) const{ return decl_->type(mod, spec_->type(mod), {}); } /* Storage specifier */ inline ir::attribute_kind_t get_ir_attr(STORAGE_SPEC_T spec){ switch(spec){ case RESTRICT_T: return ir::noalias; case READONLY_T: return ir::readonly; case WRITEONLY_T: return ir::writeonly; default: throw std::runtime_error("cannot convert storage specifier to IR function attribute"); } } void storage_specifier::add_attr(ir::function* fn, size_t pos) { fn->add_attr(pos, ir::attribute(get_ir_attr(value_))); } void storage_specifier::add_metadata(ir::module*, std::string) { throw std::runtime_error("storage specifier is not a metadata"); } /* Alignment specifier */ void alignment_specifier::add_attr(ir::function* fn, size_t pos) { fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value())); } void alignment_specifier::add_metadata(ir::module *mod, std::string name) { throw std::runtime_error("alignment specifier is not a metadata"); } /* Multiple-Of specifier */ void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) { fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value())); } void multiple_of_specifier::add_metadata(ir::module *mod, std::string name) { mod->add_metadata(name, {ir::metadata::multiple_of, cst_->value()}); } /* Function definition */ ir::value* function_definition::codegen(ir::module *mod) const{ ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->modifiers()); const std::string &name = header_->id()->name(); ir::function *fn = mod->get_or_insert_function(name, prototype); for(unsigned i = 0; i < header_->get_num_args(); i++){ parameter *param = header_->get_arg(i); std::vector modifiers = param->modifiers(); for(modifier* m: modifiers) m->add_attr(fn, 1 + i); } header_->bind_parameters(mod, fn); ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn); mod->seal_block(entry); mod->get_builder().set_insert_point(entry); body_->codegen(mod); mod->get_builder().create_ret_void(); return nullptr; } } }