diff --git a/CMakeLists.txt b/CMakeLists.txt index 5faf6c8bb..b57c0859d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,18 +7,6 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") option(BUILD_EXAMPLES "Build C++ Triton examples" ON) option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) -# FLEX/YACC -find_package(BISON) -find_package(FLEX) -BISON_TARGET(Parser ${CMAKE_CURRENT_SOURCE_DIR}/include/triton/lang/parser.y ${CMAKE_CURRENT_SOURCE_DIR}/lib/lang/parser.cpp) -FLEX_TARGET(Lexer ${CMAKE_CURRENT_SOURCE_DIR}/include/triton/lang/scanner.l ${CMAKE_CURRENT_SOURCE_DIR}/lib/lang/scanner.cpp) -get_filename_component(BISON_Parser_INCLUDE_DIRECTORIES ${BISON_Parser_OUTPUT_HEADER} DIRECTORY) -include_directories(${BISON_Parser_INCLUDE_DIRECTORIES}) - -#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)" -# OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE) -#add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) - # LLVM find_package(LLVM REQUIRED CONFIG) include_directories(${LLVM_INCLUDE_DIRS}) @@ -32,7 +20,7 @@ if(NOT CMAKE_BUILD_TYPE) endif() # Gather headers for cmake-based IDEs -file( GLOB_RECURSE ALL_SRC *.cpp *.hpp *.h *.py *.y *.l CMakeLists*) +file( GLOB_RECURSE ALL_SRC *.cpp *.hpp *.h *.py CMakeLists*) add_custom_target( ALL SOURCES ${ALL_SRC} ) # Compiler flags @@ -63,7 +51,7 @@ endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp lib/*.cc) -add_library(triton SHARED ${LIBTRITON_SRC} ${EIGHTCC_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS}) +add_library(triton SHARED ${LIBTRITON_SRC} ${EIGHTCC_SRC} ${PYTHON_SRC}) target_link_libraries(triton LLVM) # Warning level diff --git a/include/triton/driver/helpers/CL/infos.hpp b/include/triton/driver/helpers/CL/infos.hpp deleted file mode 100644 index dcd80928c..000000000 --- a/include/triton/driver/helpers/CL/infos.hpp +++ /dev/null @@ -1,413 +0,0 @@ -#ifndef ISAAC_DRIVER_HELPERS_OCL_INFOS_HPP_ -#define ISAAC_DRIVER_HELPERS_OCL_INFOS_HPP_ - -/* ========================================================================= - Copyright (c) 2010-2012, Institute for Microelectronics, - Institute for Analysis and Scientific Computing, - TU Wien. - - ----------------- - ViennaCL - The Vienna Computing Library - ----------------- - - Project Head: Karl Rupp rupp@iue.tuwien.ac.at - - (A list of authors and contributors can be found in the PDF manual) - - License: MIT (X11), see file LICENSE in the base directory -============================================================================= */ - - - -#include "triton/driver/error.h" -#include -#include - -namespace triton -{ -namespace driver -{ -namespace ocl -{ - - /** @brief Implementation details for the OpenCL managment layer in ViennaCL */ -namespace detail{ - -/** @brief Helper class for obtaining informations from the OpenCL backend. Deprecated! */ -template -struct info; - -/** \cond */ -template<> -struct info -{ - typedef cl_mem_info type; - - static void get(cl_mem handle, cl_mem_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret) - { - cl_int err = dispatch::clGetMemObjectInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -template<> -struct info -{ - typedef cl_device_info type; - - static void get(cl_device_id handle, cl_device_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret) - { - cl_int err = dispatch::clGetDeviceInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -template<> -struct info -{ - typedef cl_kernel_info type; - - static void get(cl_kernel handle, cl_kernel_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetKernelInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } - - static void get(cl_kernel handle, cl_device_id dev_id, cl_kernel_work_group_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetKernelWorkGroupInfo(handle, dev_id, param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -template<> -struct info -{ - typedef cl_context_info type; - - static void get(cl_context handle, cl_context_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetContextInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -template<> -struct info -{ - typedef cl_program_info type; - - static void get(cl_program handle, cl_program_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetProgramInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } - - static void get(cl_program handle, cl_device_id device, cl_program_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetProgramBuildInfo(handle,device,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - - -template<> -struct info -{ - typedef cl_profiling_info type; - static void get(cl_event handle, cl_profiling_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetEventProfilingInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -template<> -struct info -{ - typedef cl_command_queue_info type; - static void get(cl_command_queue handle, cl_profiling_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetCommandQueueInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -template<> -struct info -{ - typedef cl_command_queue_info type; - static void get(cl_platform_id handle, cl_profiling_info param_name,size_t param_value_size,void *param_value,size_t *param_value_size_ret){ - cl_int err = dispatch::clGetPlatformInfo(handle,param_name,param_value_size,param_value,param_value_size_ret); - check(err); - } -}; - -//Info getter -//Some intelligence is needed for some types -template -struct get_info_impl{ - - template - RES_T operator()(MEM_T const & mem, INFO_T const & info){ - RES_T res; - detail::info::get(mem,info,sizeof(RES_T),&res,NULL); - return res; - } - - template - RES_T operator()(MEM_T const & mem, ARG_MEM_T const & arg_mem, INFO_T const & info){ - RES_T res; - detail::info::get(mem,arg_mem, info,sizeof(RES_T),&res,NULL); - return res; - } -}; - -template<> -struct get_info_impl{ - - template - std::string operator()(const MEM_T &mem, const INFO_T &info){ - char buff[1024]; - detail::info::get(mem,info,1024,buff,NULL); - return std::string(buff); - } - - template - std::string operator()(MEM_T const & mem, ARG_MEM_T const & arg_mem, INFO_T const & info){ - char buff[1024]; - detail::info::get(mem,arg_mem,info,1024,buff,NULL); - return std::string(buff); - } -}; - -template -struct get_info_impl > -{ - template - std::vector operator()(const MEM_T &mem, const INFO_T &info) - { - size_t vec_size; - detail::info::get(mem,info,0,NULL,&vec_size); - std::vector res(vec_size/sizeof(T)); - detail::info::get(mem,info,vec_size,res.data(),NULL); - return res; - } - - template - std::vector operator()(MEM_T const & mem, ARG_MEM_T const & arg_mem, INFO_T const & info) - { - size_t vec_size; - detail::info::get(mem,arg_mem,info,0,NULL,&vec_size); - std::vector res(vec_size/sizeof(T)); - detail::info::get(mem,arg_mem,info,vec_size,res.data(),NULL); - return res; - } -}; - -template::type param> -struct return_type; -/** \endcond */ - -/** \cond */ - #define SET_INFO_RETURN_TYPE(DATA_TYPE,NAME,RETURN_TYPE) template<> struct return_type { typedef RETURN_TYPE Result; } - -SET_INFO_RETURN_TYPE(cl_command_queue, CL_QUEUE_CONTEXT, cl_context); -SET_INFO_RETURN_TYPE(cl_command_queue, CL_QUEUE_DEVICE, cl_device_id); -SET_INFO_RETURN_TYPE(cl_command_queue, CL_QUEUE_REFERENCE_COUNT, cl_uint); -SET_INFO_RETURN_TYPE(cl_command_queue, CL_QUEUE_PROPERTIES, cl_command_queue_properties); - -SET_INFO_RETURN_TYPE(cl_context, CL_CONTEXT_DEVICES, std::vector); -SET_INFO_RETURN_TYPE(cl_context, CL_CONTEXT_NUM_DEVICES, cl_uint); -SET_INFO_RETURN_TYPE(cl_context, CL_CONTEXT_REFERENCE_COUNT, cl_uint); -SET_INFO_RETURN_TYPE(cl_context, CL_CONTEXT_PROPERTIES, cl_context_properties); - -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_ADDRESS_BITS, cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_AVAILABLE, cl_bool); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_COMPILER_AVAILABLE, cl_bool); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV, cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_WAVEFRONT_WIDTH_AMD, cl_uint); - -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_DOUBLE_FP_CONFIG, cl_device_fp_config); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_ENDIAN_LITTLE, cl_bool); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_ERROR_CORRECTION_SUPPORT, cl_bool); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_EXECUTION_CAPABILITIES, cl_device_exec_capabilities); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_EXTENSIONS, std::string); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, cl_ulong); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE, cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_GLOBAL_MEM_SIZE, cl_ulong); -//SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_HALF_FP_CONFIG, cl_device_fp_config); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_IMAGE_SUPPORT, cl_bool); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_IMAGE2D_MAX_HEIGHT , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_IMAGE2D_MAX_WIDTH , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_IMAGE3D_MAX_DEPTH , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_IMAGE3D_MAX_HEIGHT , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_IMAGE3D_MAX_WIDTH , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_LOCAL_MEM_SIZE, cl_ulong); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_LOCAL_MEM_TYPE, cl_device_local_mem_type); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_CLOCK_FREQUENCY , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_COMPUTE_UNITS , cl_uint); //The minimum value is 1 -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_CONSTANT_ARGS , cl_uint); //The minimum value is 8 -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE , cl_ulong); //The minimum value is 64 KB -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_MEM_ALLOC_SIZE , cl_ulong); //The minimum value is max (1/4th of CL_DEVICE_GLOBAL_MEM_SIZE, 128*1024*1024) -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_PARAMETER_SIZE , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_READ_IMAGE_ARGS , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_SAMPLERS , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_WORK_ITEM_SIZES , std::vector); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MAX_WRITE_IMAGE_ARGS , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MEM_BASE_ADDR_ALIGN , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_MIN_DATA_TYPE_ALIGN_SIZE , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_NAME , std::string); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PLATFORM , cl_platform_id); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PREFERRED_VECTOR_WIDTH_CHAR , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PREFERRED_VECTOR_WIDTH_SHORT , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PREFERRED_VECTOR_WIDTH_INT , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PREFERRED_VECTOR_WIDTH_FLOAT , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PREFERRED_VECTOR_WIDTH_DOUBLE , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PROFILE , std::string); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_PROFILING_TIMER_RESOLUTION , size_t); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_QUEUE_PROPERTIES , cl_command_queue_properties); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_SINGLE_FP_CONFIG , cl_device_fp_config); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_TYPE , cl_device_type); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_VENDOR , std::string); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_VENDOR_ID , cl_uint); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DEVICE_VERSION , std::string); -SET_INFO_RETURN_TYPE(cl_device_id, CL_DRIVER_VERSION , std::string); - -SET_INFO_RETURN_TYPE(cl_event, CL_PROFILING_COMMAND_QUEUED, cl_ulong); -SET_INFO_RETURN_TYPE(cl_event, CL_PROFILING_COMMAND_SUBMIT, cl_ulong); -SET_INFO_RETURN_TYPE(cl_event, CL_PROFILING_COMMAND_START, cl_ulong); -SET_INFO_RETURN_TYPE(cl_event, CL_PROFILING_COMMAND_END, cl_ulong); - -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_FUNCTION_NAME, std::string); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_NUM_ARGS, cl_uint); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_REFERENCE_COUNT, cl_uint); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_CONTEXT, cl_context); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_PROGRAM, cl_program); - - -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_WORK_GROUP_SIZE, size_t); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_COMPILE_WORK_GROUP_SIZE, std::vector); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_LOCAL_MEM_SIZE, cl_ulong); -SET_INFO_RETURN_TYPE(cl_kernel,CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE, size_t); - -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_TYPE, cl_mem_object_type); -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_FLAGS, cl_mem_flags); -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_SIZE, size_t); -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_HOST_PTR, void*); -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_MAP_COUNT, cl_uint); -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_REFERENCE_COUNT, cl_uint); -SET_INFO_RETURN_TYPE(cl_mem,CL_MEM_CONTEXT, cl_context); - -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_CONTEXT,cl_context); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_DEVICES,std::vector); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_NUM_DEVICES,cl_uint); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_SOURCE,std::string); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_BINARY_SIZES,std::vector); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_BINARIES,std::vector); -//Build -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_BUILD_STATUS, cl_build_status); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_BUILD_OPTIONS, std::string); -SET_INFO_RETURN_TYPE(cl_program,CL_PROGRAM_BUILD_LOG, std::string); - -SET_INFO_RETURN_TYPE(cl_platform_id,CL_PLATFORM_PROFILE, std::string); -SET_INFO_RETURN_TYPE(cl_platform_id,CL_PLATFORM_VERSION, std::string); -SET_INFO_RETURN_TYPE(cl_platform_id,CL_PLATFORM_NAME, std::string); -SET_INFO_RETURN_TYPE(cl_platform_id,CL_PLATFORM_VENDOR, std::string); -SET_INFO_RETURN_TYPE(cl_platform_id,CL_PLATFORM_EXTENSIONS, std::string); - -#undef SET_INFO_RETURN_TYPE - - /** \endcond */ -} - -template -typename detail::return_type::Result info(cl_device_id const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -template -typename detail::return_type::Result info(cl_mem const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -//Program - -template -typename detail::return_type::Result info(cl_program const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -template<> -inline typename detail::return_type::Result info(cl_program const & handle) -{ - std::vector res; - std::vector sizes = info(handle); - for(size_t s: sizes) - res.push_back(new unsigned char[s]); - dispatch::clGetProgramInfo(handle, CL_PROGRAM_BINARIES, sizeof(unsigned char**), (void*)res.data(), NULL); - return res; -} - -template -typename detail::return_type::Result info(cl_program const & phandle, cl_device_id const & dhandle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(phandle,dhandle,param); -} - -//Kernel -template -typename detail::return_type::Result info(cl_kernel const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -template -typename detail::return_type::Result info(cl_kernel const & khandle, cl_device_id const & dhandle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(khandle,dhandle,param); -} - -//Context -template -typename detail::return_type::Result info(cl_context const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -//Event -template -typename detail::return_type::Result info(cl_event const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -//Command queue -template -typename detail::return_type::Result info(cl_command_queue const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -//Plaftform -template -typename detail::return_type::Result info(cl_platform_id const & handle){ - typedef typename detail::return_type::Result res_t; - return detail::get_info_impl()(handle,param); -} - -template::type param> -typename detail::return_type::Result info(OCL_TYPE const & handle){ - return info(handle.get()); -} - - - -template::type param> -typename detail::return_type::Result info(OCL_TYPE const & handle, OCL_TYPE_ARG const & arg_handle){ - return info(handle.get(), arg_handle.get()); -} - -} -} -} -#endif // INFOS_HPP diff --git a/include/triton/lang/wgtcc/ast.h b/include/triton/lang/ast.h similarity index 100% rename from include/triton/lang/wgtcc/ast.h rename to include/triton/lang/ast.h diff --git a/include/triton/lang/wgtcc/code_gen.h b/include/triton/lang/code_gen.h similarity index 100% rename from include/triton/lang/wgtcc/code_gen.h rename to include/triton/lang/code_gen.h diff --git a/include/triton/lang/wgtcc/cpp.h b/include/triton/lang/cpp.h similarity index 100% rename from include/triton/lang/wgtcc/cpp.h rename to include/triton/lang/cpp.h diff --git a/include/triton/lang/declaration.h b/include/triton/lang/declaration.h deleted file mode 100644 index e406f00d8..000000000 --- a/include/triton/lang/declaration.h +++ /dev/null @@ -1,265 +0,0 @@ -#ifndef TRITON_INCLUDE_LANG_DECLARATION_H -#define TRITON_INCLUDE_LANG_DECLARATION_H - -#include "node.h" -#include - -namespace triton{ - - -namespace ir{ - class function; - class value; - class type; - class builder; - class module; -} - -namespace lang{ - -class expression; -class pointer; -class identifier; -class constant; -class compound_statement; -class initializer; -class declaration_specifier; - - -class declaration: public block_item{ -public: - declaration(node *spec, node *init) - : spec_((declaration_specifier*)spec), init_((list*)init) { } - - ir::value* codegen(ir::module * mod) const; - -public: - const declaration_specifier *spec_; - const list *init_; -}; - -// Types -class modifier: public node { -public: - virtual bool is_cst_space() const { return false; } - virtual bool is_tunable() const { return false; } - virtual bool is_cst() const { return false; } - virtual bool is_multiple_of() const { return false; } - virtual void add_attr(ir::function* fn, size_t pos) = 0; - virtual void add_metadata(ir::module* mod, std::string name) = 0; -}; - -class storage_specifier: public modifier { -public: - storage_specifier(STORAGE_SPEC_T value): value_(value) {} - STORAGE_SPEC_T value() const { return value_; } - bool is_cst_space() const { return value_ == CONSTANT_SPACE_T; } - bool is_tunable() const { return value_ == TUNABLE_T; } - bool is_cst() const { return value_ == CONST_T; } - void add_attr(ir::function* fn, size_t pos); - void add_metadata(ir::module* mod, std::string name); - -private: - const STORAGE_SPEC_T value_; -}; - -class alignment_specifier: public modifier { -public: - alignment_specifier(node* value): cst_((constant*)value) { } - void add_attr(ir::function* fn, size_t pos); - void add_metadata(ir::module* mod, std::string name); - -private: - constant* cst_; -}; - -class multiple_of_specifier: public modifier { -public: - multiple_of_specifier(node* value): cst_((constant*)value) {} - void add_attr(ir::function* fn, size_t pos); - void add_metadata(ir::module* mod, std::string name); - bool is_multiple_of() const { return true; } - -private: - constant* cst_; -}; - -// declaration specifier -class declaration_specifier: public node{ -public: - virtual ir::type* type(ir::module *mod) const = 0; - virtual std::vector modifiers() const = 0; -}; - -class typed_declaration_specifier: public declaration_specifier { -public: - typed_declaration_specifier(TYPE_T ty): ty_(ty){ } - ir::type* type(ir::module *mod) const; - std::vector modifiers() const; - -private: - const TYPE_T ty_; -}; - -// declaration modifier -class declaration_modifier: public declaration_specifier { -public: - declaration_modifier(node* mod, node *decl_spec) - : mod_((modifier*)mod), decl_spec_((declaration_specifier*)decl_spec) {} - ir::type* type(ir::module *mod) const; - std::vector modifiers() const; - -private: - modifier* mod_; - const declaration_specifier* decl_spec_; -}; - - -class declarator; -class parameter: public node { -public: - parameter(node *spec, node *decl) - : spec_((declaration_specifier*)spec), - decl_((declarator*)decl) { } - - ir::type* type(ir::module *mod) const; - std::vector modifiers() const; - const identifier* id() const; - -public: - const declaration_specifier *spec_; - const declarator *decl_; -}; - -/* Declarators */ -class declarator: public node{ -protected: - typedef std::vector storage_spec_vec_t; - typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t; - -public: - virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0; - -public: - declarator(node *lhs) - : lhs_((declarator*)lhs), ptr_(nullptr){ } - - ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - - const identifier* id() const { - return (const identifier*)lhs_; - } - - declarator *set_ptr(node *ptr){ - ptr_ = (pointer*)ptr; - return this; - } - - void set_addr_space(unsigned addr_space){ - addr_space_ = addr_space; - } - -protected: - declarator *lhs_; - pointer *ptr_; - unsigned addr_space_; -}; - -class identifier: public declarator { - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - identifier(char *&name): declarator(this), name_(name) { } - const std::string &name() const; - -private: - std::string name_; -}; - -class pointer: public declarator{ -private: - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - pointer(node *id): declarator(id) { } -}; - -class tile: public declarator{ -private: - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - tile(node *id, node *shapes) - : declarator(id), shapes_((list*)(shapes)) { } - -public: - const list* shapes_; -}; - -class function: public declarator{ -private: - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - function(node *id, node *args) - : declarator(id), args_((list*)args) { } - - void bind_parameters(ir::module *mod, ir::function *fn) const; - unsigned get_num_args() const { return args_->values().size(); } - parameter* get_arg(unsigned i) const { return args_->values().at(i); } - -public: - const list* args_; -}; - - -class initializer : public declarator{ -private: - ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - initializer(node *decl, node *init) - : declarator((node*)((declarator*)decl)->id()), - decl_((declarator*)decl), expr_((expression*)init){ } - - void set_specifier(const declaration_specifier *spec); - ir::value* codegen(ir::module *) const; - -public: - const declaration_specifier *spec_; - declarator *decl_; - const expression *expr_; -}; - - -class type_name: public node{ -public: - type_name(node *spec, node * decl) - : spec_((declaration_specifier*)spec), decl_((declarator*)decl) { } - - ir::type *type(ir::module *mod) const; - -public: - const declaration_specifier *spec_; - const declarator *decl_; -}; - -/* Function definition */ -class function_definition: public node{ -public: - function_definition(node *spec, node *header, node *body) - : spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { } - - ir::value* codegen(ir::module * mod) const; - -public: - const declaration_specifier *spec_; - const function *header_; - const compound_statement *body_; -}; - -} - -} - -#endif diff --git a/include/triton/lang/wgtcc/encoding.h b/include/triton/lang/encoding.h similarity index 100% rename from include/triton/lang/wgtcc/encoding.h rename to include/triton/lang/encoding.h diff --git a/include/triton/lang/error.h b/include/triton/lang/error.h index 70e70a387..fdae7e060 100644 --- a/include/triton/lang/error.h +++ b/include/triton/lang/error.h @@ -1,20 +1,15 @@ -#ifndef TRITON_INCLUDE_LANG_ERROR_H -#define TRITON_INCLUDE_LANG_ERROR_H - -#include "parser.hpp" +#ifndef _WGTCC_ERROR_H_ +#define _WGTCC_ERROR_H_ -namespace triton{ -namespace lang{ +struct SourceLocation; +class Token; +class Expr; -void update_location(const char *t); -void print_error(const char *error); -char return_impl(char t, const char * yytext); -yytokentype return_impl(yytokentype t, const char * yytext); -void return_void(const char * yytext); - -} -} +[[noreturn]] void Error(const char* format, ...); +[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...); +[[noreturn]] void Error(const Token* tok, const char* format, ...); +[[noreturn]] void Error(const Expr* expr, const char* format, ...); #endif diff --git a/include/triton/lang/wgtcc/evaluator.h b/include/triton/lang/evaluator.h similarity index 100% rename from include/triton/lang/wgtcc/evaluator.h rename to include/triton/lang/evaluator.h diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h deleted file mode 100644 index a3574f15d..000000000 --- a/include/triton/lang/expression.h +++ /dev/null @@ -1,357 +0,0 @@ -#ifndef TDL_INCLUDE_LANG_EXPRESSION_H -#define TDL_INCLUDE_LANG_EXPRESSION_H - -#include "lang.h" -#include -#include - - -namespace triton{ - - -namespace ir{ - class function; - class value; - class type; - class builder; - class module; -} - -namespace lang{ - - -enum slice_enum_t{ - ALL, - NEWAXIS -}; - -class slice: public node{ -public: - slice(slice_enum_t type) - : type_(type){} - - slice_enum_t type() const{ - return type_; - } - -public: - const slice_enum_t type_; -}; - - -class named_expression; - -class expression: public node{ -public: - virtual ir::value* codegen(ir::module *) const = 0; - named_expression *lvalue() const { return lvalue_; } - -protected: - named_expression *lvalue_; -}; - -class postfix_expression: public expression{ - -}; - -class builtin_expression: public node{ - -}; - -class typed_declaration_specifier; -class alloc_const_expression: public builtin_expression{ -public: - alloc_const_expression(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { } - ir::value* codegen(ir::module *mod) const; - -private: - const typed_declaration_specifier* spec_; - const constant* size_; -}; - -class get_program_id_expression: public builtin_expression{ -public: - get_program_id_expression(node *axis): axis_((constant*)axis) { } - ir::value* codegen(ir::module *) const; - -private: - const constant* axis_; -}; - -class get_num_program_expression: public builtin_expression{ -public: - get_num_program_expression(node *axis): axis_((constant*)axis) { } - ir::value* codegen(ir::module *mod) const; - -private: - const constant* axis_; -}; - -class atomic_cas_expression: public builtin_expression{ -public: - atomic_cas_expression(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { } - ir::value* codegen(ir::module *) const; - -private: - const node *ptr_; - const node *cmp_; - const node *val_; -}; - -class atomic_exch_expression: public builtin_expression{ -public: - atomic_exch_expression(node *ptr, node *val): ptr_(ptr), val_(val) { } - ir::value* codegen(ir::module *) const; - -private: - const node *ptr_; - const node *val_; -}; - - -class atomic_add_expression: public builtin_expression{ -public: - atomic_add_expression(node *ptr, node *val): ptr_(ptr), val_(val) { } - ir::value* codegen(ir::module *) const; - -private: - const node *ptr_; - const node *val_; -}; - - -class matmul_expression: public builtin_expression{ -public: - matmul_expression(node* A, node *B, node *C): - A_((expression*)A), B_((expression*)B), C_((expression*)C) { } - ir::value* codegen(ir::module *) const; - -private: - const expression *A_; - const expression *B_; - const expression *C_; -}; - -class reshape_expression: public builtin_expression{ -public: - reshape_expression(node *arg, node *shapes): arg_(arg), shapes_((list*)shapes) { } - ir::value* codegen(ir::module *) const; - -private: - const node *arg_; - const list* shapes_; -}; - -class max_expression: public builtin_expression{ -public: - max_expression(node* x, node* y) - : x_((expression*)x), y_((expression*)y){ } - ir::value* codegen(ir::module *) const; - -private: - const expression *x_; - const expression *y_; -}; - -class min_expression: public builtin_expression{ -public: - min_expression(node* x, node* y) - : x_((expression*)x), y_((expression*)y){ } - ir::value* codegen(ir::module *mod) const; - -private: - const expression *x_; - const expression *y_; -}; - -class select_expression: public builtin_expression{ -public: - select_expression(node* pred, node* if_value, node* else_value) - : pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { } - ir::value* codegen(ir::module *mod) const; - -private: - const expression *pred_; - const expression *if_value_; - const expression *else_value_; -}; - -class trans_expression: public builtin_expression{ -public: - trans_expression(node *arg, node *perm): arg_(arg), perm_((list*)perm) {} - ir::value* codegen(ir::module *mod) const; - -private: - node* arg_; - const list* perm_; -}; - -class sqrt_expression: public builtin_expression{ -public: - sqrt_expression(node *arg): arg_(arg) {} - ir::value* codegen(ir::module *) const; - -private: - node* arg_; -}; - -class reduce_expression: public builtin_expression{ -public: - reduce_expression(node *arg, node *axis): arg_(arg), axis_((constant*)axis) {} - ir::value* codegen(ir::module *mod) const; - -private: - node* arg_; - constant* axis_; -}; - -class indexing_expression: public postfix_expression{ -public: - indexing_expression(node *lhs, node *slices) - : lhs_((const expression*)lhs), slices_((const list*)slices) {} - - ir::value* codegen(ir::module *) const; - -private: - const expression* lhs_; - const list* slices_; -}; - - - -class named_expression: public expression { -public: - named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; } - const identifier *id() const { return id_; } - ir::value* codegen(ir::module * mod) const; - -private: - const identifier *id_; -}; - -class binary_expression: public expression{ -private: - ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const; - -public: - binary_expression(BIN_OP_T op, node *lhs, node *rhs) - : op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { - } - ir::value* codegen(ir::module *) const; - -private: - const BIN_OP_T op_; - const expression *lhs_; - const expression *rhs_; -}; - - -class constant: public expression{ -public: - constant(int value): value_(value) { } - ir::value* codegen(ir::module *mod) const; - int value() const; - -private: - const int value_; -}; - -class constant_range: public expression { -public: - constant_range(node *first, node *last) - : first_((constant*)first), last_((constant*)last) { } - - ir::value* codegen(ir::module *mod) const; - -private: - constant *first_; - constant *last_; -}; - -class string_literal: public expression{ -public: - string_literal(char *&value): value_(value) { } - ir::value* codegen(ir::module *mod) const; - -public: - std::string value_; -}; - -class unary_expression: public expression{ -private: - ir::value *llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const; - -public: - unary_expression(UNARY_OP_T op, node *arg) - : op_(op), - arg_((expression*)arg) { - if(op == DEREF) - this->lvalue_ = arg_->lvalue(); - } - - UNARY_OP_T get_op() const { return op_; } - ir::value* codegen(ir::module *mod) const; - -private: - const UNARY_OP_T op_; - const expression *arg_; -}; - -class type_name; -class cast_expression: public expression{ -private: - ir::value *llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const; - -public: - cast_expression(node *T, node *arg): - T_((type_name*)T), - arg_((expression*)arg) { } - - ir::value* codegen(ir::module *mod) const; - -public: - const type_name *T_; - const expression *arg_; -}; - -class conditional_expression: public expression{ -private: - ir::value *llvm_op(ir::builder &builder, - ir::value *cond, ir::value *true_value, ir::value *false_value, - const std::string &name) const; - -public: - conditional_expression(node *cond, node *true_value, node *false_value) - : cond_((expression*)cond), - true_value_((expression*)true_value), - false_value_((expression*)false_value) { } - - ir::value* codegen(ir::module *mod) const; - -public: - const expression *cond_; - const expression *true_value_; - const expression *false_value_; -}; - -class assignment_expression: public expression{ -public: - assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) - : lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } - - ir::value* codegen(ir::module *mod) const; - const expression *lvalue() const { return lvalue_; } - const expression *rvalue() const { return rvalue_; } - -public: - const expression *lvalue_; - ASSIGN_OP_T op_; - const expression *rvalue_; -}; - - -} - -} - -#endif diff --git a/include/triton/lang/lang.h b/include/triton/lang/lang.h deleted file mode 100644 index ba1d1a2d8..000000000 --- a/include/triton/lang/lang.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef TRITON_INCLUDE_LANG_LANG_H -#define TRITON_INCLUDE_LANG_LANG_H - -#include "parser.hpp" -#include "declaration.h" -#include "error.h" -#include "expression.h" -#include "node.h" -#include "ops.h" -#include "module.h" -#include "statement.h" - -#endif diff --git a/include/triton/lang/wgtcc/mem_pool.h b/include/triton/lang/mem_pool.h similarity index 100% rename from include/triton/lang/wgtcc/mem_pool.h rename to include/triton/lang/mem_pool.h diff --git a/include/triton/lang/module.h b/include/triton/lang/module.h deleted file mode 100644 index 7ac6c2960..000000000 --- a/include/triton/lang/module.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef TRITON_INCLUDE_LANG_MODULE_H -#define TRITON_INCLUDE_LANG_MODULE_H - -#include "node.h" - -namespace triton{ -namespace lang{ - -/* Translation Unit */ -class translation_unit: public node{ -public: - translation_unit(node *item) - : decls_(item) { } - - translation_unit *add(node *item) { - decls_.append(item); - return this; - } - - ir::value* codegen(ir::module * mod) const; - -private: - list decls_; -}; - -} - -} - -#endif diff --git a/include/triton/lang/node.h b/include/triton/lang/node.h deleted file mode 100644 index c9bd0b011..000000000 --- a/include/triton/lang/node.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef TRITON_INCLUDE_LANG_NODE_H -#define TRITON_INCLUDE_LANG_NODE_H - -#include -#include "ops.h" - -namespace triton{ - - -namespace ir{ - class function; - class value; - class type; - class builder; - class module; -} - -namespace lang{ - -class expression; -class pointer; -class identifier; -class constant; -class compound_statement; -class initializer; -class modifier; -class function; - -// Node -class node { -protected: - static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty); - static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src); - static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs); - static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, - bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed); -public: - virtual ir::value* codegen(ir::module *) const { return nullptr; } -}; - -class block_item: public node{ -}; - -template -class list: public node { -public: - list(const T& x): values_(1, x) {} - - node* append(const T& x){ - values_.push_back(x); - return this; - } - - ir::value* codegen(ir::module * mod) const{ - for(T x: values_){ - x->codegen(mod); - } - return nullptr; - } - - const std::vector &values() const - { return values_; } - -private: - std::vector values_; -}; - -} - -} - -#endif diff --git a/include/triton/lang/ops.h b/include/triton/lang/ops.h deleted file mode 100644 index 38fc200bf..000000000 --- a/include/triton/lang/ops.h +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef TRITON_INCLUDE_LANG_OPS_H -#define TRITON_INCLUDE_LANG_OPS_H - -namespace triton{ -namespace lang{ - -enum ASSIGN_OP_T{ - ASSIGN, - INPLACE_MUL, INPLACE_DIV, INPLACE_MOD, - INPLACE_ADD, INPLACE_SUB, - INPLACE_LSHIFT, INPLACE_RSHIFT, - INPLACE_AND, INPLACE_XOR, - INPLACE_OR -}; - -enum BIN_OP_T{ - MUL, DIV, MOD, - ADD, SUB, - LEFT_SHIFT, RIGHT_SHIFT, - LT, GT, - LE, GE, - EQ, NE, - AND, XOR, OR, - LAND, LOR -}; - -enum UNARY_OP_T{ - INC, DEC, - PLUS, MINUS, - ADDR, DEREF, - COMPL, NOT -}; - -enum TYPE_T{ - VOID_T, - UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T, - INT1_T, INT8_T, INT16_T, INT32_T, INT64_T, - FLOAT16_T, FLOAT32_T, FLOAT64_T -}; - -enum STORAGE_SPEC_T{ - CONST_T, - TUNABLE_T, - KERNEL_T, - RESTRICT_T, - READONLY_T, - CONSTANT_SPACE_T, - WRITEONLY_T -}; - -} -} - -#endif diff --git a/include/triton/lang/wgtcc/parser.h b/include/triton/lang/parser.h similarity index 100% rename from include/triton/lang/wgtcc/parser.h rename to include/triton/lang/parser.h diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y deleted file mode 100644 index e3c22c132..000000000 --- a/include/triton/lang/parser.y +++ /dev/null @@ -1,424 +0,0 @@ -%define parse.error verbose - -%{ -namespace triton{ -namespace lang{ -class node; -} -} -using namespace triton::lang; -#define YYSTYPE node* -#include "../include/triton/lang/lang.h" - -extern char* yytext; -void yyerror(const char *s); -int yylex(void); - -translation_unit *ast_root; - -/* wrap token in AST node */ -struct token: public node{ - token(ASSIGN_OP_T value): assign_op(value){ } - token(BIN_OP_T value): bin_op(value){ } - token(UNARY_OP_T value): unary_op(value){ } - token(TYPE_T value): type(value){ } - token(STORAGE_SPEC_T value): storage_spec(value){ } - - union { - ASSIGN_OP_T assign_op; - BIN_OP_T bin_op; - UNARY_OP_T unary_op; - TYPE_T type; - STORAGE_SPEC_T storage_spec; - }; -}; - -/* shortcut to append in list */ -template -node* append_ptr_list(node *result, node *in){ - return static_cast*>(result)->append((T*)in); -} - -/* shortcut to access token value */ -ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; } -UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; } -TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } -STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} -%} - -%token IDENTIFIER CONSTANT STRING_LITERAL -%token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST CONSTANT_SPACE ALIGN MULTIPLE_OF -%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP -%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN -%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN -%token XOR_ASSIGN OR_ASSIGN TYPE_NAME -%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 -%token IF ELSE FOR CONTINUE WHILE -%token NEWAXIS ELLIPSIS AT -%token GET_NUM_PROGRAM GET_PROGRAM_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE - -%start translation_unit -%% - - -/* -------------------------- */ -/* Types */ -/* -------------------------- */ - -type_specifier - : VOID { $$ = new token(VOID_T); } - | UINT1 { $$ = new token(UINT1_T); } - | UINT8 { $$ = new token(UINT8_T); } - | UINT16 { $$ = new token(UINT16_T); } - | UINT32 { $$ = new token(UINT32_T); } - | UINT64 { $$ = new token(UINT64_T); } - | INT1 { $$ = new token(INT1_T);} - | INT8 { $$ = new token(INT8_T); } - | INT16 { $$ = new token(INT16_T); } - | INT32 { $$ = new token(INT32_T); } - | INT64 { $$ = new token(INT64_T); } - | FP16 { $$ = new token(FLOAT16_T); } - | FP32 { $$ = new token(FLOAT32_T); } - | FP64 { $$ = new token(FLOAT64_T); } - ; - -pointer - : '*' { $$ = new pointer(nullptr); } - | '*' pointer { $$ = new pointer($1); } - -abstract_declarator - : pointer { $$ = $1; } - | pointer direct_abstract_declarator { $$ = ((declarator*)$2)->set_ptr($1); } - | direct_abstract_declarator { $$ = $1; } - ; - -direct_abstract_declarator - : '[' constant_expression_list ']' { $$ = new tile(nullptr, $2); } - -type_name - : declaration_specifiers { $$ = new type_name($1, nullptr); } - | declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); } - ; - -/* -------------------------- */ -/* Expressions */ -/* -------------------------- */ - -/* Constants */ -constant - : CONSTANT { $$ = new constant(atoi(yytext)); } - ; - -constant_list - : constant { $$ = new list((constant*)$1); } - | constant_list ',' constant { $$ = append_ptr_list($1, $3); } - ; - -identifier - : IDENTIFIER { $$ = new identifier(yytext); } - ; - -/* Built-in */ -builtin_expression - : GET_PROGRAM_ID '(' constant ')' { $$ = new get_program_id_expression($3); } - | GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); } - | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } - | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } - | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } - | TRANS '(' expression ',' constant_expression_list ')' { $$ = new trans_expression($3, $5); } - | TRANS '(' expression ')' { $$ = new trans_expression($3, nullptr); } - | REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);} - | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } - | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } - | SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); } - | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); } - | ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); } - | ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); } - | RESHAPE '(' expression ',' constant_expression_list ')' { $$ = new reshape_expression($3, $5); } - ; - -/* Primary */ -primary_expression - : identifier { $$ = new named_expression($1); } - | constant { $$ = $1; } - | primary_expression ELLIPSIS primary_expression { $$ = new constant_range($1, $3); } - | builtin_expression { $$ = $1; } - | STRING_LITERAL { $$ = new string_literal(yytext); } - | '(' expression ')' { $$ = $2; } - ; - -/* Postfix */ -slice - : ':' { $$ = new slice(triton::lang::ALL); } - | NEWAXIS { $$ = new slice(triton::lang::NEWAXIS); } - -slice_list - : slice { $$ = new list((slice*)$1); } - | slice_list ',' slice { $$ = append_ptr_list($1, $3); } - -postfix_expression - : primary_expression { $$ = $1;} - | primary_expression '[' slice_list ']' { $$ = new indexing_expression($1, $3);} - ; - -/* Unary */ -unary_operator - : '&' { $$ = new token(ADDR); } - | '*' { $$ = new token(DEREF); } - | '+' { $$ = new token(PLUS); } - | '-' { $$ = new token(MINUS); } - | '~' { $$ = new token(COMPL); } - | '!' { $$ = new token(NOT); } - ; - -unary_expression - : postfix_expression { $$ = $1; } - | INC_OP unary_expression { $$ = new unary_expression(INC, $2); } - | DEC_OP unary_expression { $$ = new unary_expression(DEC, $2); } - | unary_operator cast_expression { $$ = new unary_expression(get_unary_op($1), $2); } - ; - -cast_expression - : unary_expression { $$ = $1; } - | '(' type_name ')' cast_expression { $$ = new cast_expression($2, $4); } - ; - -multiplicative_expression - : cast_expression { $$ = $1; } - | multiplicative_expression '*' cast_expression { $$ = new binary_expression(MUL, $1, $3); } - | multiplicative_expression '/' cast_expression { $$ = new binary_expression(DIV, $1, $3); } - | multiplicative_expression '%' cast_expression { $$ = new binary_expression(MOD, $1, $3); } - ; - -additive_expression - : multiplicative_expression { $$ = $1; } - | additive_expression '+' multiplicative_expression { $$ = new binary_expression(ADD, $1, $3); } - | additive_expression '-' multiplicative_expression { $$ = new binary_expression(SUB, $1, $3); } - ; - -shift_expression - : additive_expression { $$ = $1; } - | shift_expression LEFT_OP additive_expression { $$ = new binary_expression(LEFT_SHIFT, $1, $3); } - | shift_expression RIGHT_OP additive_expression { $$ = new binary_expression(RIGHT_SHIFT, $1, $3); } - ; - -/* Comparison */ -relational_expression - : shift_expression { $$ = $1; } - | relational_expression '<' shift_expression { $$ = new binary_expression(LT, $1, $3); } - | relational_expression '>' shift_expression { $$ = new binary_expression(GT, $1, $3); } - | relational_expression LE_OP shift_expression { $$ = new binary_expression(LE, $1, $3); } - | relational_expression GE_OP shift_expression { $$ = new binary_expression(GE, $1, $3); } - ; - -equality_expression - : relational_expression { $$ = $1; } - | equality_expression EQ_OP relational_expression { $$ = new binary_expression(EQ, $1, $3); } - | equality_expression NE_OP relational_expression { $$ = new binary_expression(NE, $1, $3); } - ; - -/* Binary */ -and_expression - : equality_expression { $$ = $1; } - | and_expression '&' equality_expression { $$ = new binary_expression(AND, $1, $3); } - ; - -exclusive_or_expression - : and_expression { $$ = $1; } - | exclusive_or_expression '^' and_expression { $$ = new binary_expression(XOR, $1, $3); } - ; - -inclusive_or_expression - : exclusive_or_expression { $$ = $1; } - | inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_expression(OR, $1, $3); } - ; - -/* Logical */ -logical_and_expression - : inclusive_or_expression { $$ = $1; } - | logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_expression(LAND, $1, $3); } - ; - -logical_or_expression - : logical_and_expression { $$ = $1; } - | logical_or_expression OR_OP logical_and_expression { $$ = new binary_expression(LOR, $1, $3); } - ; - -/* Conditional */ -conditional_expression - : logical_or_expression { $$ = $1; } - | logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $3, $5); } - ; - -/* Assignment */ -assignment_operator - : '=' { $$ = new token(ASSIGN); } - | MUL_ASSIGN { $$ = new token(INPLACE_MUL); } - | DIV_ASSIGN { $$ = new token(INPLACE_DIV); } - | MOD_ASSIGN { $$ = new token(INPLACE_MOD); } - | ADD_ASSIGN { $$ = new token(INPLACE_ADD); } - | SUB_ASSIGN { $$ = new token(INPLACE_SUB); } - | LEFT_ASSIGN { $$ = new token(INPLACE_LSHIFT); } - | RIGHT_ASSIGN { $$ = new token(INPLACE_RSHIFT); } - | AND_ASSIGN { $$ = new token(INPLACE_AND); } - | XOR_ASSIGN { $$ = new token(INPLACE_XOR); } - | OR_ASSIGN { $$ = new token(INPLACE_OR); } - ; - -assignment_expression - : conditional_expression { $$ = $1; } - | unary_expression assignment_operator assignment_expression { $$ = new assignment_expression($1, get_assign_op($2), $3); } - ; - -/* Expression */ -expression - : assignment_expression { $$ = $1; } - ; - -constant_expression_list - : expression { $$ = new list((expression*)$1); } - | constant_expression_list ',' expression { $$ = append_ptr_list($1, $3); } - -/* Initialization */ -initialization_expression - : assignment_expression { $$ = $1; } - | '{' constant_list '}' { $$ = $2; } - ; - - -/* -------------------------- */ -/* Statements */ -/* -------------------------- */ - -statement - : compound_statement { $$ = $1; } - | expression_statement { $$ = $1; } - | selection_statement { $$ = $1; } - | iteration_statement { $$ = $1; } - | jump_statement { $$ = $1; } - ; - -compound_statement - : '{' '}' { $$ = new compound_statement(nullptr); } - | '{' block_item_list '}' { $$ = new compound_statement($2); } - -block_item_list - : block_item { $$ = new list((block_item*)$1); } - | block_item_list block_item { $$ = append_ptr_list($1, $2); } - -block_item - : declaration { $$ = $1; } - | statement { $$ = $1; } - -expression_statement - : ';' { $$ = new no_op(); } - | expression ';' { $$ = new expression_statement($1); } - | AT primary_expression expression ';' { $$ = new expression_statement($3, $2); } - ; - -selection_statement - : IF '(' expression ')' statement { $$ = new selection_statement($3, $5); } - | IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($3, $5, $7); } - ; - -iteration_statement - : FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } - | FOR '(' declaration expression_statement ')' statement { $$ = new iteration_statement($3, $4, nullptr, $6); } - | FOR '(' declaration expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } - | WHILE '(' expression ')' statement { $$ = new while_statement($3, $5); }; - -jump_statement - : CONTINUE ';' { $$ = new continue_statement(); } -; - -/* -------------------------- */ -/* Declarator */ -/* -------------------------- */ - - -direct_declarator - : identifier { $$ = $1; } - | identifier '[' constant_expression_list ']' { $$ = new tile($1, $3); } - | identifier '(' parameter_list ')' { $$ = new function($1, $3); } - | identifier '(' ')' { $$ = new function($1, nullptr); } - ; - - -parameter_list - : parameter_declaration { $$ = new list((parameter*)$1); } - | parameter_list ',' parameter_declaration { $$ = append_ptr_list($1, $3); } - ; - -parameter_declaration - : declaration_specifiers declarator { $$ = new parameter($1, $2); } - | declaration_specifiers abstract_declarator { $$ = new parameter($1, $2); } - ; - - -declaration_specifiers - : type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); } - | storage_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } - | alignment_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } - | multiple_of_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } - ; - -init_declarator_list - : init_declarator { $$ = new list((initializer*)$1); } - | init_declarator_list ',' init_declarator { $$ = append_ptr_list($1, $3); } - ; - -declaration - : declaration_specifiers ';' { $$ = new declaration($1, nullptr); } - | declaration_specifiers init_declarator_list ';' { $$ = new declaration($1, $2); } - ; - -declarator - : pointer direct_declarator { $$ = ((declarator*)$2)->set_ptr($1); } - | direct_declarator { $$ = $1; } - ; - -init_declarator - : declarator { $$ = new initializer($1, nullptr); } - | declarator '=' initialization_expression { $$ = new initializer($1, $3); } - ; - -storage_class_specifier - : CONST { $$ = new storage_specifier(CONST_T); } - | TUNABLE { $$ = new storage_specifier(TUNABLE_T); } - | KERNEL { $$ = new storage_specifier(KERNEL_T); } - | RESTRICT { $$ = new storage_specifier(RESTRICT_T); } - | READONLY { $$ = new storage_specifier(READONLY_T); } - | WRITEONLY { $$ = new storage_specifier(WRITEONLY_T); } - | CONSTANT_SPACE { $$ = new storage_specifier(CONSTANT_SPACE_T); } -; - -alignment_class_specifier - : ALIGN '(' constant ')' { $$ = new alignment_specifier($3); } - -multiple_of_class_specifier - : MULTIPLE_OF '(' constant ')' { $$ = new multiple_of_specifier($3); } - - -external_declaration - : function_definition { $$ = $1; } - | declaration { $$ = $1; } - ; - -function_definition - : declaration_specifiers declarator compound_statement { $$ = new function_definition($1, $2, $3); } - ; - -/* -------------------------- */ -/* Translation Unit */ -/* -------------------------- */ - -translation_unit - : external_declaration { ast_root = new translation_unit($1); $$ = ast_root; } - | translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); } - ; - - -%% -void yyerror (const char *s){ - print_error(s); -} diff --git a/include/triton/lang/wgtcc/scanner.h b/include/triton/lang/scanner.h similarity index 100% rename from include/triton/lang/wgtcc/scanner.h rename to include/triton/lang/scanner.h diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l deleted file mode 100644 index 6062a51ad..000000000 --- a/include/triton/lang/scanner.l +++ /dev/null @@ -1,119 +0,0 @@ -D [0-9] -L [a-zA-Z_] -H [a-fA-F0-9] -E [Ee][+-]?{D}+ -FS (f|F|l|L) -IS (u|U|l|L)* - -%{ -#include -#include "parser.hpp" -#include "../include/triton/lang/lang.h" -using triton::lang::return_impl; -using triton::lang::return_void; -%} - -%% -"__constant__" { return return_impl(CONSTANT_SPACE, yytext); } -"const" { return return_impl(CONST, yytext); } -"tunable" { return return_impl(TUNABLE, yytext); } -"kernel" { return return_impl(KERNEL, yytext); } -"restrict" { return return_impl(RESTRICT, yytext); } -"read_only" { return return_impl(READONLY, yytext); } -"write_only" { return return_impl(WRITEONLY, yytext); } -"align" { return return_impl(ALIGN, yytext); } -"multiple_of" { return return_impl(MULTIPLE_OF, yytext); } -"@" { return return_impl(AT, yytext); } -"newaxis" { return return_impl(NEWAXIS, yytext); } -"if" { return return_impl(IF, yytext); } -"else" { return return_impl(ELSE, yytext); } -"for" { return return_impl(FOR, yytext); } -"while" { return return_impl(WHILE, yytext); } -"void" { return return_impl(VOID, yytext); } -"uchar" { return return_impl(UINT8, yytext); } -"ushort" { return return_impl(UINT16, yytext); } -"uint" { return return_impl(UINT32, yytext); } -"ulong" { return return_impl(UINT64, yytext); } -"bool" { return return_impl(INT1, yytext); } -"char" { return return_impl(INT8, yytext); } -"short" { return return_impl(INT16, yytext); } -"int" { return return_impl(INT32, yytext); } -"long" { return return_impl(INT64, yytext); } -"half" { return return_impl(FP16, yytext); } -"float" { return return_impl(FP32, yytext); } -"double" { return return_impl(FP64, yytext); } -"..." { return return_impl(ELLIPSIS, yytext); } -"get_program_id" { return return_impl(GET_PROGRAM_ID, yytext); } -"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); } -"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); } -"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); } -"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); } -"__sum" { return return_impl(REDUCE_SUM, yytext); } -"__reshape" { return return_impl(RESHAPE, yytext); } -"sqrt" { return return_impl(SQRT, yytext); } -"dot" { return return_impl(DOT, yytext); } -"max" { return return_impl(MAX, yytext); } -"min" { return return_impl(MIN, yytext); } -"select" { return return_impl(SELECT, yytext); } -"trans" { return return_impl(TRANS, yytext); } -"continue" { return return_impl(CONTINUE, yytext); } -"alloc_const" { return return_impl(ALLOC_CONST, yytext); } -{L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); } -0[xX]{H}+{IS}? { return return_impl(CONSTANT, yytext); } -0{D}+{IS}? { return return_impl(CONSTANT, yytext); } -{D}+{IS}? { return return_impl(CONSTANT, yytext); } -L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); } -{D}+{E}{FS}? { return return_impl(CONSTANT, yytext); } -L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); } -">>=" { return return_impl(RIGHT_ASSIGN, yytext); } -"<<=" { return return_impl(LEFT_ASSIGN, yytext); } -"+=" { return return_impl(ADD_ASSIGN, yytext); } -"-=" { return return_impl(SUB_ASSIGN, yytext); } -"*=" { return return_impl(MUL_ASSIGN, yytext); } -"/=" { return return_impl(DIV_ASSIGN, yytext); } -"%=" { return return_impl(MOD_ASSIGN, yytext); } -"&=" { return return_impl(AND_ASSIGN, yytext); } -"^=" { return return_impl(XOR_ASSIGN, yytext); } -"|=" { return return_impl(OR_ASSIGN, yytext); } -">>" { return return_impl(RIGHT_OP, yytext); } -"<<" { return return_impl(LEFT_OP, yytext); } -"++" { return return_impl(INC_OP, yytext); } -"--" { return return_impl(DEC_OP, yytext); } -"->" { return return_impl(PTR_OP, yytext); } -"&&" { return return_impl(AND_OP, yytext); } -"||" { return return_impl(OR_OP, yytext); } -"<=" { return return_impl(LE_OP, yytext); } -">=" { return return_impl(GE_OP, yytext); } -"==" { return return_impl(EQ_OP, yytext); } -"!=" { return return_impl(NE_OP, yytext); } -";" { return return_impl(';', yytext); } -("{"|"<%") { return return_impl('{', yytext); } -("}"|"%>") { return return_impl('}', yytext); } -"," { return return_impl(',', yytext); } -":" { return return_impl(':', yytext); } -"=" { return return_impl('=', yytext); } -"(" { return return_impl('(', yytext); } -")" { return return_impl(')', yytext); } -("["|"<:") { return return_impl('[', yytext); } -("]"|":>") { return return_impl(']', yytext); } -"." { return return_impl('.', yytext); } -"&" { return return_impl('&', yytext); } -"!" { return return_impl('!', yytext); } -"~" { return return_impl('~', yytext); } -"-" { return return_impl('-', yytext); } -"+" { return return_impl('+', yytext); } -"*" { return return_impl('*', yytext); } -"/" { return return_impl('/', yytext); } -"%" { return return_impl('%', yytext); } -"<" { return return_impl('<', yytext); } -">" { return return_impl('>', yytext); } -"^" { return return_impl('^', yytext); } -"|" { return return_impl('|', yytext); } -"?" { return return_impl('?', yytext); } -[ \t\v\n\f] { return_void(yytext);} -. { /* ignore bad characters */ } - -%% - -int yywrap() -{ return(1); } diff --git a/include/triton/lang/wgtcc/scope.h b/include/triton/lang/scope.h similarity index 100% rename from include/triton/lang/wgtcc/scope.h rename to include/triton/lang/scope.h diff --git a/include/triton/lang/statement.h b/include/triton/lang/statement.h deleted file mode 100644 index 42b4140dc..000000000 --- a/include/triton/lang/statement.h +++ /dev/null @@ -1,115 +0,0 @@ -#ifndef TRITON_INCLUDE_LANG_STATEMENT_H -#define TRITON_INCLUDE_LANG_STATEMENT_H - -#include "expression.h" - -namespace triton{ - - -namespace ir{ - class function; - class value; - class type; - class builder; - class module; -} - -namespace lang{ - -class declaration; - -class statement: public block_item{ -}; - -// Expression -class expression_statement: public statement{ -public: - expression_statement(node *expr, node *mask = nullptr) - : expr_((expression*)expr), pred_((expression*)mask){ } - - ir::value* codegen(ir::module * mod) const; - -private: - expression *expr_; - expression *pred_; -}; - -// Compound -class compound_statement: public statement{ - typedef list* declarations_t; - typedef list* statements_t; - -public: - compound_statement(node* items) - : items_((list*)items){} - - ir::value* codegen(ir::module * mod) const; - -private: - list* items_; -}; - -// Selection -class selection_statement: public statement{ -public: - selection_statement(node *cond, node *if_value, node *else_value = nullptr) - : cond_(cond), then_value_(if_value), else_value_(else_value) { } - - ir::value* codegen(ir::module *mod) const; - -public: - const node *cond_; - const node *then_value_; - const node *else_value_; -}; - -// Iteration -class iteration_statement: public statement{ -public: - iteration_statement(node *init, node *stop, node *exec, node *statements) - : init_(init), stop_(stop), exec_(exec), statements_(statements) - { } - - ir::value* codegen(ir::module *mod) const; - -private: - const node *init_; - const node *stop_; - const node *exec_; - const node *statements_; -}; - -// While -class while_statement: public statement{ -public: - while_statement(node *cond, node *statements) - : cond_(cond), statements_(statements) - { } - - ir::value* codegen(ir::module *) const; - -private: - const node *cond_; - const node *statements_; -}; - -// Jump -class jump_statement: public statement{ -public: - using statement::statement; -}; - -// Continue -class continue_statement: public jump_statement{ -public: - ir::value* codegen(ir::module *mod) const; -}; - -// No op -class no_op: public statement { }; - -} - -} - -#endif diff --git a/include/triton/lang/wgtcc/token.h b/include/triton/lang/token.h similarity index 100% rename from include/triton/lang/wgtcc/token.h rename to include/triton/lang/token.h diff --git a/include/triton/lang/wgtcc/type.h b/include/triton/lang/type.h similarity index 100% rename from include/triton/lang/wgtcc/type.h rename to include/triton/lang/type.h diff --git a/include/triton/lang/wgtcc/visitor.h b/include/triton/lang/visitor.h similarity index 100% rename from include/triton/lang/wgtcc/visitor.h rename to include/triton/lang/visitor.h diff --git a/include/triton/lang/wgtcc/error.h b/include/triton/lang/wgtcc/error.h deleted file mode 100644 index fdae7e060..000000000 --- a/include/triton/lang/wgtcc/error.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _WGTCC_ERROR_H_ -#define _WGTCC_ERROR_H_ - - -struct SourceLocation; -class Token; -class Expr; - - -[[noreturn]] void Error(const char* format, ...); -[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...); -[[noreturn]] void Error(const Token* tok, const char* format, ...); -[[noreturn]] void Error(const Expr* expr, const char* format, ...); - -#endif diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 63c91de9b..f30cdabfd 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -20,7 +20,7 @@ #include "triton/codegen/transform/shmem/barriers.h" #include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/vectorize.h" -#include "triton/lang/wgtcc/parser.h" +#include "triton/lang/parser.h" namespace llvm { class Module; diff --git a/lib/driver/device.cpp b/lib/driver/device.cpp index 41a9561eb..fceb2754e 100755 --- a/lib/driver/device.cpp +++ b/lib/driver/device.cpp @@ -25,7 +25,6 @@ #include #include #include -#include "triton/driver/helpers/CL/infos.hpp" #include "triton/driver/device.h" #include "triton/driver/context.h" #include "triton/codegen/selection/target.h" @@ -51,11 +50,13 @@ std::unique_ptr host_device::make_target() const { // maximum amount of shared memory per block size_t ocl_device::max_shared_memory() const { - return ocl::info(*cl_); + throw std::runtime_error("not implemented"); +// return ocl::info(*cl_); } size_t ocl_device::max_threads_per_block() const { - return ocl::info(*cl_).at(0); + throw std::runtime_error("not implemented"); +// return ocl::info(*cl_).at(0); } std::unique_ptr ocl_device::make_target() const { diff --git a/lib/lang/wgtcc/ast.cc b/lib/lang/ast.cc similarity index 99% rename from lib/lang/wgtcc/ast.cc rename to lib/lang/ast.cc index 47bc6d3a4..7d7e28471 100644 --- a/lib/lang/wgtcc/ast.cc +++ b/lib/lang/ast.cc @@ -1,9 +1,9 @@ -#include "triton/lang/wgtcc/ast.h" -#include "triton/lang/wgtcc/error.h" -#include "triton/lang/wgtcc/evaluator.h" -#include "triton/lang/wgtcc/mem_pool.h" -#include "triton/lang/wgtcc/parser.h" -#include "triton/lang/wgtcc/token.h" +#include "triton/lang/ast.h" +#include "triton/lang/error.h" +#include "triton/lang/evaluator.h" +#include "triton/lang/mem_pool.h" +#include "triton/lang/parser.h" +#include "triton/lang/token.h" static MemPoolImp binaryOpPool; diff --git a/lib/lang/wgtcc/code_gen.cc b/lib/lang/code_gen.cc similarity index 99% rename from lib/lang/wgtcc/code_gen.cc rename to lib/lang/code_gen.cc index d7188b2b1..cfdeee1f6 100644 --- a/lib/lang/wgtcc/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -1,7 +1,7 @@ -#include "triton/lang/wgtcc/code_gen.h" -#include "triton/lang/wgtcc/evaluator.h" -#include "triton/lang/wgtcc/parser.h" -#include "triton/lang/wgtcc/token.h" +#include "triton/lang/code_gen.h" +#include "triton/lang/evaluator.h" +#include "triton/lang/parser.h" +#include "triton/lang/token.h" #include "triton/ir/module.h" #include "triton/ir/function.h" diff --git a/lib/lang/wgtcc/cpp.cc b/lib/lang/cpp.cc similarity index 99% rename from lib/lang/wgtcc/cpp.cc rename to lib/lang/cpp.cc index 543bf3194..308eba1e6 100644 --- a/lib/lang/wgtcc/cpp.cc +++ b/lib/lang/cpp.cc @@ -1,7 +1,7 @@ -#include "triton/lang/wgtcc/cpp.h" +#include "triton/lang/cpp.h" -#include "triton/lang/wgtcc/evaluator.h" -#include "triton/lang/wgtcc/parser.h" +#include "triton/lang/evaluator.h" +#include "triton/lang/parser.h" #include #include @@ -823,7 +823,7 @@ void Preprocessor::Init() { AddSearchPath("/usr/include/x86_64-linux-gnu/"); AddSearchPath("/usr/include/linux/"); AddSearchPath("/usr/include/"); - AddSearchPath("/usr/local/wgtcc/include/"); + AddSearchPath("/usr/local/include/"); // The __FILE__ and __LINE__ macro is empty // They are handled seperately diff --git a/lib/lang/declaration.cpp b/lib/lang/declaration.cpp deleted file mode 100644 index 3f706bee1..000000000 --- a/lib/lang/declaration.cpp +++ /dev/null @@ -1,241 +0,0 @@ -#include -#include "triton/lang/statement.h" -#include "triton/lang/declaration.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/builder.h" -#include "triton/ir/type.h" -#include "triton/ir/metadata.h" - - -namespace triton{ - -namespace lang{ - -/* Declaration specifier */ -ir::type* typed_declaration_specifier::type(ir::module *mod) const { - ir::context &ctx = mod->get_context(); - switch (ty_) { - case VOID_T: return ir::type::get_void_ty(ctx); - case INT1_T: return ir::type::get_int1_ty(ctx); - case INT8_T: return ir::type::get_int8_ty(ctx); - case INT16_T: return ir::type::get_int16_ty(ctx); - case INT32_T: return ir::type::get_int32_ty(ctx); - case INT64_T: return ir::type::get_int64_ty(ctx); - case FLOAT16_T: return ir::type::get_half_ty(ctx); - case FLOAT32_T: return ir::type::get_float_ty(ctx); - case FLOAT64_T: return ir::type::get_double_ty(ctx); - default: throw std::runtime_error("unreachable"); - } -} - -std::vector typed_declaration_specifier::modifiers() const { - return {}; -} - - -ir::type* declaration_modifier::type(ir::module *mod) const { - return decl_spec_->type(mod); -} - -std::vector declaration_modifier::modifiers() const { - auto result = decl_spec_->modifiers(); - result.push_back(mod_); - return result; -} - - -/* Parameter */ -ir::type* parameter::type(ir::module *mod) const { - return decl_->type(mod, spec_->type(mod), {}); -} - -std::vector parameter::modifiers() const { - return spec_->modifiers(); -} - -const identifier *parameter::id() const { - return decl_->id(); -} - -/* Declarators */ -ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - if(ptr_) - return type_impl(mod, ptr_->type(mod, type, storage), storage); - return type_impl(mod, type, storage); -} - -// Identifier -ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{ - return type; -} - -const std::string &identifier::name() const{ - return name_; -} - -// Tile -ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{ - ir::type::tile_shapes_t shapes; - for(expression *expr: shapes_->values()){ - ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); - if(shape == nullptr) - throw std::runtime_error("tile shapes must be constant expressions"); - shapes.push_back(shape); - } - return ir::tile_type::get(type, shapes); -} - - -// Pointer -ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - auto is_cst = [](modifier* x){ return x->is_cst_space(); }; - bool is_ptr_to_const = std::find_if(storage.begin(), storage.end(), is_cst) != storage.end(); - return ir::pointer_type::get(type, is_ptr_to_const?4:1); -} - -// Function -void function::bind_parameters(ir::module *mod, ir::function *fn) const{ - std::vector args = fn->args(); - assert(args.size() == args_->values().size()); - for(size_t i = 0; i < args.size(); i++){ - parameter *param_i = args_->values().at(i); - const identifier *id_i = param_i->id(); - if(id_i){ - args[i]->set_name(id_i->name()); - mod->set_value(id_i->name(), nullptr, args[i]); - mod->get_scope().types[id_i->name()] = args[i]->get_type(); - } - } -} - -ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{ - std::vector types; - for(parameter* param: args_->values()) - types.push_back(param->type(mod)); - return ir::function_type::get(type, types); -} - - -/* Declaration */ -ir::value* declaration::codegen(ir::module* mod) const{ - for(initializer *init: init_->values()) - init->set_specifier(spec_); - init_->codegen(mod); - return nullptr; -} - -/* Initializer */ -ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - return decl_->type(mod, type, storage); -} - -void initializer::set_specifier(const declaration_specifier *spec) { - spec_ = spec; -} - -ir::value* initializer::codegen(ir::module * mod) const{ - std::vector modifiers = spec_->modifiers(); - ir::type *ty = decl_->type(mod, spec_->type(mod), modifiers); - std::string name = decl_->id()->name(); - ir::value *value = ir::undef_value::get(ty); - auto is_tunable = [](modifier* x){ return x->is_tunable(); }; - if(std::find_if(modifiers.begin(), modifiers.end(), is_tunable) != modifiers.end()){ - auto csts = dynamic_cast*>((node*)expr_); - if(csts == nullptr) - throw std::runtime_error("must specify constant list for metaparameters"); - std::vector values; - for(constant* cst: csts->values()) - values.push_back(cst->value()); - value = ir::metaparameter::create(mod->get_context(), ty, values); - mod->register_global(name, value); - } - else if(expr_){ - value = expr_->codegen(mod); - value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty()); - implicit_broadcast(mod, ty, value); - } - value->set_name(name); - // metadata - auto is_multiple_of = [](modifier* x){ return x->is_multiple_of(); }; - auto it = std::find_if(modifiers.begin(), modifiers.end(), is_multiple_of); - if(it != modifiers.end()) - (*it)->add_metadata(mod, name); - // register - mod->set_value(name, value); - mod->get_scope().types[name] = ty; - if(auto *x = dynamic_cast(value)) - mod->add_alloc(x); - // constants - auto is_cst = [](modifier* x){ return x->is_cst(); }; - if(std::find_if(modifiers.begin(), modifiers.end(), is_cst) != modifiers.end()) - mod->set_const(name); - return value; -} - -/* Type name */ -ir::type *type_name::type(ir::module *mod) const{ - return decl_->type(mod, spec_->type(mod), {}); -} - -/* Storage specifier */ -inline ir::attribute_kind_t get_ir_attr(STORAGE_SPEC_T spec){ - switch(spec){ - case RESTRICT_T: return ir::noalias; - case READONLY_T: return ir::readonly; - case WRITEONLY_T: return ir::writeonly; - default: throw std::runtime_error("cannot convert storage specifier to IR function attribute"); - } -} - -void storage_specifier::add_attr(ir::function* fn, size_t pos) { - fn->add_attr(pos, ir::attribute(get_ir_attr(value_))); -} - -void storage_specifier::add_metadata(ir::module*, std::string) { - throw std::runtime_error("storage specifier is not a metadata"); -} - -/* Alignment specifier */ -void alignment_specifier::add_attr(ir::function* fn, size_t pos) { - fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value())); -} - -void alignment_specifier::add_metadata(ir::module *mod, std::string name) { - throw std::runtime_error("alignment specifier is not a metadata"); -} - -/* Multiple-Of specifier */ -void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) { - fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value())); -} - -void multiple_of_specifier::add_metadata(ir::module *mod, std::string name) { - mod->add_metadata(name, {ir::metadata::multiple_of, cst_->value()}); -} - - -/* Function definition */ -ir::value* function_definition::codegen(ir::module *mod) const{ - ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->modifiers()); - const std::string &name = header_->id()->name(); - ir::function *fn = mod->get_or_insert_function(name, prototype); - for(unsigned i = 0; i < header_->get_num_args(); i++){ - parameter *param = header_->get_arg(i); - std::vector modifiers = param->modifiers(); - for(modifier* m: modifiers) - m->add_attr(fn, 1 + i); - } - header_->bind_parameters(mod, fn); - ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn); - mod->seal_block(entry); - mod->get_builder().set_insert_point(entry); - body_->codegen(mod); - mod->get_builder().create_ret_void(); - return nullptr; -} - -} - -} diff --git a/lib/lang/wgtcc/encoding.cc b/lib/lang/encoding.cc similarity index 96% rename from lib/lang/wgtcc/encoding.cc rename to lib/lang/encoding.cc index d5d1f99d1..931e4fc30 100644 --- a/lib/lang/wgtcc/encoding.cc +++ b/lib/lang/encoding.cc @@ -1,4 +1,4 @@ -#include "triton/lang/wgtcc/encoding.h" +#include "triton/lang/encoding.h" #include #include diff --git a/lib/lang/wgtcc/error.cc b/lib/lang/error.cc similarity index 94% rename from lib/lang/wgtcc/error.cc rename to lib/lang/error.cc index 618a83181..baf944468 100644 --- a/lib/lang/wgtcc/error.cc +++ b/lib/lang/error.cc @@ -1,7 +1,7 @@ -#include "triton/lang/wgtcc/error.h" +#include "triton/lang/error.h" -#include "triton/lang/wgtcc/ast.h" -#include "triton/lang/wgtcc/token.h" +#include "triton/lang/ast.h" +#include "triton/lang/token.h" #include #include diff --git a/lib/lang/error.cpp b/lib/lang/error.cpp deleted file mode 100644 index 77076fba0..000000000 --- a/lib/lang/error.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include -#include "triton/lang/error.h" - - -namespace triton{ - -namespace lang{ - -static int current_line = 0; -static int current_column = 0; - -// begin token -void update_location(const char *text) { - for (int i = 0; text[i] != '\0'; i++){ - if (text[i] == '\n'){ - current_column = 0; - current_line++; - } - else if (text[i] == '\t') - current_column += 8 - (current_column % 8); - else - current_column++; - } -} - -void print_error(const char *cerror) { - std::string error(cerror); - auto it = error.find("syntax error,"); - error.replace(it, 13, ""); - std::cerr << "error at line " << current_line << " (column " << current_column << "): " << error << std::endl; - throw std::runtime_error("compilation failed"); -} - -char return_impl(char t, const char * yytext) { - update_location(yytext); - return t; -} - -yytokentype return_impl(yytokentype t, const char * yytext){ - update_location(yytext); - return t; -} - -void return_void(const char * yytext){ - update_location(yytext); -} - -} - -} diff --git a/lib/lang/wgtcc/evaluator.cc b/lib/lang/evaluator.cc similarity index 97% rename from lib/lang/wgtcc/evaluator.cc rename to lib/lang/evaluator.cc index 02cb224f9..0123f4239 100644 --- a/lib/lang/wgtcc/evaluator.cc +++ b/lib/lang/evaluator.cc @@ -1,6 +1,6 @@ -#include "triton/lang/wgtcc/evaluator.h" -#include "triton/lang/wgtcc/ast.h" -#include "triton/lang/wgtcc/token.h" +#include "triton/lang/evaluator.h" +#include "triton/lang/ast.h" +#include "triton/lang/token.h" template diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp deleted file mode 100644 index 8d5288e8b..000000000 --- a/lib/lang/expression.cpp +++ /dev/null @@ -1,359 +0,0 @@ -#include "triton/lang/expression.h" -#include "triton/lang/declaration.h" -#include "triton/ir/constant.h" -#include "triton/ir/module.h" -#include "triton/ir/builder.h" -#include "triton/ir/type.h" - - -namespace triton{ - -namespace lang{ - - -/* Binary operator */ -ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const -{ - bool is_float = false, is_ptr = false, is_int = false, is_signed = false; - implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); - implicit_broadcast(mod, lhs, rhs); - if(op_==MUL && is_float) - return builder.create_fmul(lhs, rhs, name); - if(op_==MUL && is_int) - return builder.create_mul(lhs, rhs, name); - if(op_==DIV && is_float) - return builder.create_fdiv(lhs, rhs, name); - if(op_==DIV && is_int && is_signed) - return builder.create_sdiv(lhs, rhs, name); - if(op_==DIV && is_int && !is_signed) - return builder.create_udiv(lhs, rhs, name); - if(op_==MOD && is_float) - return builder.create_frem(lhs, rhs, name); - if(op_==MOD && is_int && is_signed) - return builder.create_srem(lhs, rhs, name); - if(op_==MOD && is_int && !is_signed) - return builder.create_urem(lhs, rhs, name); - if(op_==ADD && is_float) - return builder.create_fadd(lhs, rhs, name); - if(op_==ADD && is_int) - return builder.create_add(lhs, rhs); - if(op_==ADD && is_ptr) - return builder.create_gep(lhs, {rhs}); - if(op_==SUB && is_float) - return builder.create_fsub(lhs, rhs, name); - if(op_==SUB && is_int) - return builder.create_sub(lhs, rhs, name); - if(op_==SUB && is_ptr) - return builder.create_gep(lhs, {builder.create_neg(rhs)}); - if(op_==LEFT_SHIFT) - return builder.create_shl(lhs, rhs, name); - if(op_==RIGHT_SHIFT) - return builder.create_ashr(lhs, rhs, name); - if(op_ == LT && is_float) - return builder.create_fcmpOLT(lhs, rhs, name); - if(op_ == LT && is_int && is_signed) - return builder.create_icmpSLT(lhs, rhs, name); - if(op_ == LT && is_int && !is_signed) - return builder.create_icmpULT(lhs, rhs, name); - if(op_ == GT && is_float) - return builder.create_fcmpOGT(lhs, rhs, name); - if(op_ == GT && is_int && is_signed) - return builder.create_icmpSGT(lhs, rhs, name); - if(op_ == GT && is_int && !is_signed) - return builder.create_icmpUGT(lhs, rhs, name); - if(op_ == LE && is_float) - return builder.create_fcmpOLE(lhs, rhs, name); - if(op_ == LE && is_int && is_signed) - return builder.create_icmpSLE(lhs, rhs, name); - if(op_ == LE && is_int && !is_signed) - return builder.create_icmpULE(lhs, rhs, name); - if(op_ == GE && is_float) - return builder.create_fcmpOGE(lhs, rhs, name); - if(op_ == GE && is_int && is_signed) - return builder.create_icmpSGE(lhs, rhs, name); - if(op_ == GE && is_int && !is_signed) - return builder.create_icmpUGE(lhs, rhs, name); - if(op_ == EQ && is_ptr) - return builder.create_icmpEQ(lhs, rhs, name); - if(op_ == EQ && is_float) - return builder.create_fcmpOEQ(lhs, rhs, name); - if(op_ == EQ && is_int) - return builder.create_icmpEQ(lhs, rhs, name); - if(op_ == NE && is_ptr) - return builder.create_icmpNE(lhs, rhs, name); - if(op_ == NE && is_float) - return builder.create_fcmpONE(lhs, rhs, name); - if(op_ == NE && is_int) - return builder.create_icmpNE(lhs, rhs, name); - if(op_ == AND) - return builder.create_and(lhs, rhs, name); - if(op_ == XOR) - return builder.create_xor(lhs, rhs, name); - if(op_ == OR) - return builder.create_or(lhs, rhs, name); - if(op_ == LAND) - return builder.create_and(lhs, rhs, name); - if(op_ == LOR) - return builder.create_or(lhs, rhs, name); - throw std::runtime_error("unreachable"); -} - -ir::value* binary_expression::codegen(ir::module *mod) const{ - ir::value *lhs = lhs_->codegen(mod); - ir::value *rhs = rhs_->codegen(mod); - ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, ""); - return result; -} - -/* Builtin expression */ - -// alloc constant -ir::value* alloc_const_expression::codegen(ir::module *mod) const { - ir::type *ty = spec_->type(mod); - ir::constant_int *size = (ir::constant_int*)size_->codegen(mod); - ir::alloc_const *res = new ir::alloc_const(ty, size); - return res; -} - -// get_program_id -ir::value* get_program_id_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_get_program_id(axis_->value()); -} - -// get_num_program -ir::value* get_num_program_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_get_num_program(axis_->value()); -} - -// atomic cas -ir::value* atomic_cas_expression::codegen(ir::module *mod) const { - ir::value *ptr = ptr_->codegen(mod); - ir::value *cmp = cmp_->codegen(mod); - ir::value *val = val_->codegen(mod); - return mod->get_builder().create_atomic_cas(ptr, cmp, val); -} - -// atomic exch -ir::value* atomic_exch_expression::codegen(ir::module *mod) const { - ir::value *ptr = ptr_->codegen(mod); - ir::value *val = val_->codegen(mod); - return mod->get_builder().create_atomic_exch(ptr, val); -} - -// atomic add -ir::value* atomic_add_expression::codegen(ir::module *mod) const { - ir::value *ptr = ptr_->codegen(mod); - ir::value *val = val_->codegen(mod); - return mod->get_builder().create_atomic_add(ptr, val); -} - -// matmul -ir::value* matmul_expression::codegen(ir::module *mod) const { - ir::value *A = A_->codegen(mod); - ir::value *B = B_->codegen(mod); - ir::value *C = C_->codegen(mod); -// unsigned M = A->get_type()->get_tile_shapes()[0]; -// unsigned N = B->get_type()->get_tile_shapes()[1]; -// ir::type *scalar_ty = A->get_type()->get_scalar_ty(); -// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N}); -// ir::value *tmp = ir::undef_value::get(tile_ty); -// implicit_broadcast(mod, tmp, C); - return mod->get_builder().create_dot(A, B, C); -} - -// reshape -ir::value* reshape_expression::codegen(ir::module *mod) const { - // arg - ir::value *arg = arg_->codegen(mod); - // shapes - ir::type::tile_shapes_t shapes; - for(expression *expr: shapes_->values()){ - ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); - if(shape == nullptr) - throw std::runtime_error("tile shapes must be constant expressions"); - shapes.push_back(shape); - } - // return - return mod->get_builder().create_reshape(arg, shapes); -} - -// min -ir::value* min_expression::codegen(ir::module *mod) const { - ir::value* cmp = binary_expression(LT, (node*)x_, (node*)y_).codegen(mod); - ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); - ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); - return mod->get_builder().create_select(cmp, x, y); -} - -// max -ir::value* max_expression::codegen(ir::module *mod) const { - ir::value* cmp = binary_expression(GT, (node*)x_, (node*)y_).codegen(mod); - ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); - ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); - return mod->get_builder().create_select(cmp, x, y); -} - -// select -ir::value* select_expression::codegen(ir::module *mod) const { - ir::value* pred = pred_->codegen(mod); - ir::value* if_value = if_value_->codegen(mod); - ir::value* else_value = else_value_->codegen(mod); - return mod->get_builder().create_select(pred, if_value, else_value); -} - -// trans -ir::value* trans_expression::codegen(ir::module *mod) const { - // shapes - std::vector perm; - if(perm_) { - for(expression *expr: perm_->values()){ - ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); - if(shape == nullptr) - throw std::runtime_error("tile shapes must be constant expressions"); - perm.push_back(shape); - } - } - return mod->get_builder().create_trans(arg_->codegen(mod), perm); -} - -// sqrt -ir::value* sqrt_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_sqrt(arg_->codegen(mod)); -} - -// reduce -ir::value* reduce_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value()); -} - -/* Postfix expression */ -ir::value* indexing_expression::codegen(ir::module *mod) const{ - ir::value *in = lhs_->codegen(mod); - const std::vector &slices = slices_->values(); - auto in_shapes = in->get_type()->get_tile_shapes(); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - ir::type::tile_shapes_t out_shapes(slices.size()); - // create shapes - size_t current = 0; - for(size_t i = 0; i < out_shapes.size(); i++) - out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++]; - return mod->get_builder().create_reshape(in, out_shapes); -} - - -/* Unary operator */ -ir::value *unary_expression::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{ - ir::type *atype = arg->get_type(); - bool is_float = atype->is_floating_point_ty(); - bool is_int = atype->is_integer_ty(); - if(op_ == INC) - return builder.create_add(arg, builder.get_int32(1), name); - if(op_ == DEC) - return builder.create_sub(arg, builder.get_int32(1), name); - if(op_ == PLUS) - return arg; - if(op_ == MINUS && is_float) - return builder.create_fneg(arg, name); - if(op_ == MINUS && is_int) - return builder.create_neg(arg, name); - if(op_ == ADDR) - throw std::runtime_error("not supported"); - if(op_ == DEREF) - return builder.create_load(arg, name); - if(op_ == COMPL) - throw std::runtime_error("not supported"); - if(op_ == NOT) - return builder.create_not(arg, name); - throw std::runtime_error("unreachable"); -} - -ir::value* unary_expression::codegen(ir::module *mod) const{ - ir::value *arg = arg_->codegen(mod); - ir::value *result = llvm_op(mod->get_builder(), arg, ""); - return result; -} - -/* Cast operator */ -ir::value *cast_expression::llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const{ - return nullptr; -} - -ir::value* cast_expression::codegen(ir::module *mod) const{ - ir::value *arg = arg_->codegen(mod); - ir::type *T = T_->type(mod); - return llvm_op(mod->get_builder(), T, arg, ""); -} - -/* Conditional expression */ -ir::value *conditional_expression::codegen(ir::module *mod) const { - ir::builder &builder = mod->get_builder(); - ir::value *mask = cond_->codegen(mod); - ir::value *true_value = true_value_->codegen(mod); - ir::value *false_value = false_value_->codegen(mod); - bool is_float, is_ptr, is_int, is_signed; - implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); - implicit_broadcast(mod, mask, true_value); - implicit_broadcast(mod, mask, false_value); - if(ir::load_inst* load = dynamic_cast(true_value)){ - load->erase_from_parent(); - return builder.create_masked_load(load->get_pointer_operand(), mask, false_value); - } - if(ir::load_inst* load = dynamic_cast(false_value)){ - load->erase_from_parent(); - return builder.create_masked_load(load->get_pointer_operand(), mask, true_value); - } - throw std::runtime_error("not implemented"); -} - -/* Assignment expression */ -ir::value *assignment_expression::codegen(ir::module *mod) const{ - ir::value *rvalue = rvalue_->codegen(mod); - if(auto *x = dynamic_cast(lvalue_)){ - ir::type *ty = mod->get_scope().types.at(x->id()->name()); - rvalue = explicit_cast(mod->get_builder(), rvalue, ty); - implicit_broadcast(mod, ty, rvalue); - mod->set_value(x->id()->name(), rvalue); - } - else if(auto* x = dynamic_cast(lvalue_)){ - assert(x->get_op()==DEREF); - assert(x->lvalue()); - ir::value *ptr = x->lvalue()->codegen(mod); - rvalue = mod->get_builder().create_store(ptr, rvalue); - } - return rvalue; -} - - -/* String literal */ -ir::value* string_literal::codegen(ir::module *) const{ - throw std::runtime_error("not supported"); -// return ir::constant_data_array::get_string(mod->get_context(), value_); -} - -/* Constant */ -ir::value* constant::codegen(ir::module *mod) const{ - return mod->get_builder().get_int32(value_); -} - -int constant::value() const{ - return value_; -} - -/* Constant range */ -ir::value* constant_range::codegen(ir::module *mod) const{ - return ir::constant_range::get((ir::constant_int*)first_->codegen(mod), - (ir::constant_int*)last_->codegen(mod)); -} - -/* Named */ -ir::value* named_expression::codegen(ir::module *mod) const{ - const std::string &name = id()->name(); - const auto& declarations = mod->get_scope().types; - if(declarations.find(name) == declarations.end()) - throw std::runtime_error("variable " + name + " not declared"); - return mod->get_value(name); -} - -} - -} diff --git a/lib/lang/module.cpp b/lib/lang/module.cpp deleted file mode 100644 index 3455ca98f..000000000 --- a/lib/lang/module.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "triton/lang/module.h" -#include "triton/ir/module.h" - - -namespace triton{ - -namespace lang{ - -/* Translation unit */ -ir::value* translation_unit::codegen(ir::module *mod) const{ - mod->add_new_scope(); - decls_.codegen(mod); - return nullptr; -} - -} - -} diff --git a/lib/lang/node.cpp b/lib/lang/node.cpp deleted file mode 100644 index dda7126bd..000000000 --- a/lib/lang/node.cpp +++ /dev/null @@ -1,164 +0,0 @@ -#include "triton/lang/node.h" -#include "triton/ir/builder.h" -#include "triton/ir/module.h" -#include "triton/ir/constant.h" - -namespace triton{ - -namespace lang{ - -/* node */ -ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ - ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); - ir::type *dst_scalar_ty = dst_ty->get_scalar_ty(); - if(src->get_type()->is_tile_ty()) - dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type()); - bool src_signed = false; - bool dst_signed = false; - if(src_scalar_ty == dst_scalar_ty) - return src; - else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty()) - return builder.create_si_to_fp(src, dst_ty); - - else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty()) - return builder.create_ui_to_fp(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed) - return builder.create_fp_to_si(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed) - return builder.create_fp_to_ui(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && - src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width()) - return builder.create_fp_ext(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && - src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width()) - return builder.create_fp_trunc(src, dst_ty); - - else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() && - src_scalar_ty->get_integer_bitwidth()) - return builder.create_int_cast(src, dst_ty, dst_signed); - - else - throw std::runtime_error("unreachable"); -} - - -void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, - bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){ - // Input types - ir::type *left_ty = lhs->get_type()->get_scalar_ty(); - ir::type *right_ty = rhs->get_type()->get_scalar_ty(); - // One operand is pointer - if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){ - is_ptr = true; - } - // One operand is double - else if(left_ty->is_double_ty() || right_ty->is_double_ty()){ - ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs; - to_convert = explicit_cast(builder, to_convert, builder.get_double_ty()); - is_float = true; - } - // One operand is float - else if(left_ty->is_float_ty() || right_ty->is_float_ty()){ - ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs; - to_convert = explicit_cast(builder, to_convert, builder.get_float_ty()); - is_float = true; - } - // One operand is half - else if(left_ty->is_half_ty() || right_ty->is_half_ty()){ - ir::value *&to_convert = left_ty->is_half_ty()?rhs:lhs; - to_convert = explicit_cast(builder, to_convert, builder.get_half_ty()); - is_float = true; - } - // Both operands are integers - else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ - is_int = true; - is_signed = true; // always signed for now - if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){ - ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs; - ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty; - to_convert = explicit_cast(builder, to_convert, dst_ty); - } - } - // Not reachable - else - throw std::runtime_error("unreachable"); -} - -void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) { - ir::type *lhs_ty = lhs->get_type(); - ir::type *rhs_ty = rhs->get_type(); - ir::type *res_ty = nullptr; - if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) - return; - else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) - res_ty = lhs_ty; - else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty()) - res_ty = rhs_ty; - else{ - auto lhs_shapes = lhs_ty->get_tile_shapes(); - auto rhs_shapes = rhs_ty->get_tile_shapes(); - size_t lhs_size = lhs_shapes.size(); - size_t rhs_size = rhs_shapes.size(); - size_t res_size = std::max(lhs_size, rhs_size); - ir::type::tile_shapes_t res_shapes(res_size); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - for(size_t i = 0; i < res_size; i++){ - if(i >= res_size - lhs_size && i >= res_size - rhs_size) - res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i]; - else if(i >= res_size - lhs_size) - res_shapes[i] = lhs_shapes[i]; - else if(i >= res_size - rhs_size) - res_shapes[i] = rhs_shapes[i]; - } - res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes); - } - implicit_broadcast(mod, res_ty, rhs); - implicit_broadcast(mod, res_ty, lhs); -} - -void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){ - ir::builder &builder = mod->get_builder(); - ir::type *src_ty = src->get_type(); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - // Both are scalar - if(!ty->is_tile_ty() && !src_ty->is_tile_ty()) - return; - // Broadcast scalar - if(ty->is_tile_ty() && !src_ty->is_tile_ty()){ - src = builder.create_splat(src, ty->get_tile_shapes()); - return; - } - // Downcast tile - if(!ty->is_tile_ty() && src_ty->is_tile_ty()){ - for(ir::constant *shape: src_ty->get_tile_shapes()) - if(shape != one) - throw std::runtime_error("cannot downcast"); - src = builder.create_downcast(src); - return; - } - // Both are arrays - auto dst_shapes = ty->get_tile_shapes(); - auto src_shapes = src_ty->get_tile_shapes(); - int dst_dim = dst_shapes.size(); - int src_dim = src_shapes.size(); - // Pad - int off = dst_dim - src_dim; - for(int i = 0; i < off; i++) - src_shapes.insert(src_shapes.begin(), one); - if(off > 0) - src = builder.create_reshape(src, src_shapes); - // Broadcast - for(int i = dst_dim - 1; i>= 0; i--) - if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one) - throw std::runtime_error("cannot broadcast"); - if(dst_shapes != src_shapes) - src = builder.create_broadcast(src, dst_shapes); -} - -} - -} diff --git a/lib/lang/wgtcc/parser.cc b/lib/lang/parser.cc similarity index 99% rename from lib/lang/wgtcc/parser.cc rename to lib/lang/parser.cc index ee8a8a319..35ed63e15 100644 --- a/lib/lang/wgtcc/parser.cc +++ b/lib/lang/parser.cc @@ -1,11 +1,11 @@ -#include "triton/lang/wgtcc/parser.h" +#include "triton/lang/parser.h" -#include "triton/lang/wgtcc/cpp.h" -#include "triton/lang/wgtcc/encoding.h" -#include "triton/lang/wgtcc/error.h" -#include "triton/lang/wgtcc/evaluator.h" -#include "triton/lang/wgtcc/scope.h" -#include "triton/lang/wgtcc/type.h" +#include "triton/lang/cpp.h" +#include "triton/lang/encoding.h" +#include "triton/lang/error.h" +#include "triton/lang/evaluator.h" +#include "triton/lang/scope.h" +#include "triton/lang/type.h" #include #include diff --git a/lib/lang/wgtcc/scanner.cc b/lib/lang/scanner.cc similarity index 99% rename from lib/lang/wgtcc/scanner.cc rename to lib/lang/scanner.cc index 0f0dbdfa0..9c394ecfd 100644 --- a/lib/lang/wgtcc/scanner.cc +++ b/lib/lang/scanner.cc @@ -1,4 +1,4 @@ -#include "triton/lang/wgtcc/scanner.h" +#include "triton/lang/scanner.h" #include #include diff --git a/lib/lang/wgtcc/scope.cc b/lib/lang/scope.cc similarity index 96% rename from lib/lang/wgtcc/scope.cc rename to lib/lang/scope.cc index bc1c6827c..9e487deba 100644 --- a/lib/lang/wgtcc/scope.cc +++ b/lib/lang/scope.cc @@ -1,6 +1,6 @@ -#include "triton/lang/wgtcc/scope.h" +#include "triton/lang/scope.h" -#include "triton/lang/wgtcc/ast.h" +#include "triton/lang/ast.h" #include #include diff --git a/lib/lang/statement.cpp b/lib/lang/statement.cpp deleted file mode 100644 index a768bf7b4..000000000 --- a/lib/lang/statement.cpp +++ /dev/null @@ -1,161 +0,0 @@ -#include "triton/lang/expression.h" -#include "triton/lang/statement.h" -#include "triton/lang/declaration.h" -#include "triton/ir/constant.h" -#include "triton/ir/module.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/builder.h" -#include "triton/ir/type.h" - -namespace triton{ - -namespace lang{ - -/* Helpers */ -inline bool is_terminator(ir::value* x) { - return x && dynamic_cast(x); -} - - -/* Statements */ -ir::value* compound_statement::codegen(ir::module* mod) const{ - mod->add_new_scope(); - if(items_) - items_->codegen(mod); - mod->pop_scope(); - return nullptr; -} - -/* Expression statement */ -ir::value* expression_statement::codegen(ir::module *mod) const{ - ir::builder &builder = mod->get_builder(); - // get name if applicable - std::string name = ""; - ir::value *current = nullptr; - if(assignment_expression *assignment = dynamic_cast(expr_)) - if(const named_expression* named = dynamic_cast(assignment->lvalue())){ - name = named->id()->name(); - current = mod->get_value(name); - } - // lower expression - ir::value *expr = expr_->codegen(mod); - // modify expression if predicated - if(pred_) { - ir::value *pred = pred_->codegen(mod); - if(!current) - current = ir::undef_value::get(expr->get_type()); - if(auto *x = dynamic_cast(expr)){ - x->erase_from_parent(); - expr = builder.create_masked_load(x->get_pointer_operand(), pred, current); - } - else if(auto *x = dynamic_cast(expr)){ - x->erase_from_parent(); - expr =builder.create_masked_store(x->get_pointer_operand(), x->get_value_operand(), pred); - } - else - expr = builder.create_select(pred, expr, current); - } - // update symbols table - if(!name.empty()) - mod->set_value(name, expr); - return expr; -} - -/* For statement */ -ir::value* iteration_statement::codegen(ir::module *mod) const{ - ir::builder &builder = mod->get_builder(); - ir::context &ctx = mod->get_context(); - ir::basic_block *current_bb = builder.get_insert_block(); - ir::function *fn = current_bb->get_parent(); - ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); - ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); - mod->set_continue_fn([&](){ - if(exec_) - exec_->codegen(mod); - ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx)); - return builder.create_cond_br(cond, loop_bb, next_bb); - }); - init_->codegen(mod); - ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx)); - builder.create_cond_br(cond, loop_bb, next_bb); -// builder.create_br(loop_bb); - builder.set_insert_point(loop_bb); - if(!is_terminator(statements_->codegen(mod))) - mod->get_continue_fn()(); - ir::basic_block *stop_bb = builder.get_insert_block(); - mod->seal_block(stop_bb); - mod->seal_block(loop_bb); - mod->seal_block(builder.get_insert_block()); - mod->seal_block(next_bb); - builder.set_insert_point(next_bb); - return nullptr; -} - -/* While statement */ -ir::value* while_statement::codegen(ir::module* mod) const{ - ir::builder &builder = mod->get_builder(); - ir::context &ctx = mod->get_context(); - ir::basic_block *current_bb = builder.get_insert_block(); - ir::function *fn = current_bb->get_parent(); - ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); - ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); - mod->set_continue_fn([&](){ - ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx)); - return builder.create_cond_br(cond, loop_bb, next_bb); - }); - ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx)); - builder.create_cond_br(cond, loop_bb, next_bb); - builder.set_insert_point(loop_bb); - if(!is_terminator(statements_->codegen(mod))) - mod->get_continue_fn()(); - ir::basic_block *stop_bb = builder.get_insert_block(); - mod->seal_block(stop_bb); - mod->seal_block(loop_bb); - mod->seal_block(builder.get_insert_block()); - mod->seal_block(next_bb); - builder.set_insert_point(next_bb); - return nullptr; -} - -/* Selection statement */ -ir::value* selection_statement::codegen(ir::module* mod) const{ - ir::builder &builder = mod->get_builder(); - ir::context &ctx = mod->get_context(); - ir::function *fn = builder.get_insert_block()->get_parent(); - ir::value *cond = cond_->codegen(mod); - ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn); - ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr; - ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn); - mod->seal_block(then_bb); - if(else_value_) - mod->seal_block(else_bb); - - // Branch - if(else_value_) - builder.create_cond_br(cond, then_bb, else_bb); - else - builder.create_cond_br(cond, then_bb, endif_bb); - // Then - builder.set_insert_point(then_bb); - if(!is_terminator(then_value_->codegen(mod))) - builder.create_br(endif_bb); - // Else - if(else_value_){ - builder.set_insert_point(else_bb); - if(!is_terminator(else_value_->codegen(mod))) - builder.create_br(endif_bb); - } - // Endif - mod->seal_block(endif_bb); - builder.set_insert_point(endif_bb); - return nullptr; -} - -/* Continue statement */ -ir::value* continue_statement::codegen(ir::module *mod) const{ - return mod->get_continue_fn()(); -} - -} - -} diff --git a/lib/lang/wgtcc/token.cc b/lib/lang/token.cc similarity index 98% rename from lib/lang/wgtcc/token.cc rename to lib/lang/token.cc index ba588588e..5445b2044 100644 --- a/lib/lang/wgtcc/token.cc +++ b/lib/lang/token.cc @@ -1,7 +1,7 @@ -#include "triton/lang/wgtcc/token.h" +#include "triton/lang/token.h" -#include "triton/lang/wgtcc/mem_pool.h" -#include "triton/lang/wgtcc/parser.h" +#include "triton/lang/mem_pool.h" +#include "triton/lang/parser.h" static MemPoolImp tokenPool; diff --git a/lib/lang/wgtcc/type.cc b/lib/lang/type.cc similarity index 98% rename from lib/lang/wgtcc/type.cc rename to lib/lang/type.cc index 25d5c56ce..a1564ad97 100644 --- a/lib/lang/wgtcc/type.cc +++ b/lib/lang/type.cc @@ -1,8 +1,8 @@ -#include "triton/lang/wgtcc/type.h" +#include "triton/lang/type.h" -#include "triton/lang/wgtcc/ast.h" -#include "triton/lang/wgtcc/scope.h" -#include "triton/lang/wgtcc/token.h" +#include "triton/lang/ast.h" +#include "triton/lang/scope.h" +#include "triton/lang/token.h" #include #include diff --git a/lib/lang/wgtcc/main.cc b/lib/lang/wgtcc/main.cc deleted file mode 100644 index cc02588f6..000000000 --- a/lib/lang/wgtcc/main.cc +++ /dev/null @@ -1,30 +0,0 @@ -#include "triton/lang/wgtcc/code_gen.h" -#include "triton/lang/wgtcc/cpp.h" -#include "triton/lang/wgtcc/error.h" -#include "triton/lang/wgtcc/parser.h" -#include "triton/lang/wgtcc/scanner.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - - -std::string program; -std::string filename_in; -std::string filename_out; -bool debug = false; -static bool only_preprocess = false; -static bool only_compile = false; -static bool specified_out_name = false; -static std::list filenames_in; -static std::list gcc_filenames_in; -static std::list gcc_args; -static std::list defines; -static std::list include_paths; diff --git a/lib/runtime/function.cpp b/lib/runtime/function.cpp index 36bbc2100..11279b571 100644 --- a/lib/runtime/function.cpp +++ b/lib/runtime/function.cpp @@ -5,10 +5,9 @@ #include #include "triton/codegen/selection/selection.h" #include "triton/runtime/function.h" -#include "triton/lang/lang.h" -#include "triton/lang/wgtcc/cpp.h" -#include "triton/lang/wgtcc/parser.h" -#include "triton/lang/wgtcc/code_gen.h" +#include "triton/lang/cpp.h" +#include "triton/lang/parser.h" +#include "triton/lang/code_gen.h" #include "triton/driver/device.h" #include "triton/driver/stream.h" #include "triton/driver/kernel.h" @@ -19,16 +18,11 @@ #include "llvm/IR/Module.h" -typedef struct yy_buffer_state * YY_BUFFER_STATE; -extern int yyparse(); -extern YY_BUFFER_STATE yy_scan_string(const char * str); -extern void yy_delete_buffer(YY_BUFFER_STATE buffer); -extern triton::lang::translation_unit *ast_root; + namespace triton{ namespace runtime { - // helpers void _parallel_loop_nest(std::vector const & ranges, std::function const &)> const & f, @@ -148,7 +142,6 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr opt.num_warps = std::stoi(params[i++]); for(auto it: opt_space_.defines) opt.defines[it.first] = params[i++]; - // pre-process TokenSequence tokens; Preprocessor cpp(&src_, true); @@ -241,9 +234,7 @@ void function::operator()(const std::vector& args, const grid_fn_ty& grid_f } /* re-tune and re-compile */ - caller call = autotune(stream, grid_fn, args); - cache_.insert({key, call}); - + cache_.insert({key, autotune(stream, grid_fn, args)}); } void function::operator()(const std::vector& args, const grid_t& grid, driver::stream *stream) {